├── CRF.py ├── Image ├── figure.png └── test ├── README.md ├── __init__.py ├── __pycache__ └── __init__.cpython-38.pyc ├── base ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-38.pyc │ └── initialization.cpython-38.pyc ├── _utils.py ├── initialization.py └── modules.py └── new_model ├── __init__.py ├── __pycache__ ├── __init__.cpython-38.pyc ├── builder.cpython-38.pyc ├── init_func.cpython-38.pyc ├── modules.cpython-38.pyc └── net_utils.cpython-38.pyc ├── decoders ├── MLPAlignDecoder.py ├── MLPDecoder.py ├── UPernet.py ├── __pycache__ │ ├── MLPAlignDecoder.cpython-38.pyc │ ├── MLPDecoder.cpython-38.pyc │ ├── UPernet.cpython-38.pyc │ ├── condnet.cpython-38.pyc │ ├── deeplabv3plus.cpython-38.pyc │ ├── fapn.cpython-38.pyc │ ├── fcnhead.cpython-38.pyc │ ├── fpn.cpython-38.pyc │ ├── hem.cpython-38.pyc │ └── lawin.cpython-38.pyc ├── condnet.py ├── deeplabv3plus.py ├── fapn.py ├── fcnhead.py ├── fpn.py ├── fpn_head.py ├── hem.py ├── lawin.py └── sfnet.py ├── encoders ├── Transformer │ ├── BiFormer │ │ ├── __init__.py │ │ ├── __pycache__ │ │ │ ├── __init__.cpython-38.pyc │ │ │ ├── biformer.cpython-38.pyc │ │ │ ├── bra_legacy.cpython-38.pyc │ │ │ ├── dual_biformer.cpython-38.pyc │ │ │ └── modules.cpython-38.pyc │ │ ├── bra_legacy.py │ │ ├── dual_biformer.py │ │ └── modules.py │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-38.pyc │ │ ├── dual_cswin.cpython-38.pyc │ │ ├── dual_segformer.cpython-38.pyc │ │ ├── dual_swin.cpython-38.pyc │ │ └── dual_uniformer.cpython-38.pyc │ ├── dual_cswin.py │ ├── dual_dilateformer.py │ ├── dual_segformer.py │ ├── dual_swin.py │ └── dual_uniformer.py ├── __init__.py └── __pycache__ │ └── __init__.cpython-38.pyc ├── init_func.py └── modules.py /CRF.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pydensecrf.densecrf as dcrf 3 | import pydensecrf.utils as utils 4 | 5 | 6 | class DenseCRF(object): 7 | def __init__(self, iter_max, pos_w, pos_xy_std, bi_w, bi_xy_std, bi_rgb_std): 8 | self.iter_max = iter_max 9 | self.pos_w = pos_w 10 | self.pos_xy_std = pos_xy_std 11 | self.bi_w = bi_w 12 | self.bi_xy_std = bi_xy_std 13 | self.bi_rgb_std = bi_rgb_std 14 | 15 | def __call__(self, image, probmap): 16 | C, H, W = probmap.shape 17 | 18 | U = utils.unary_from_softmax(probmap) 19 | U = np.ascontiguousarray(U) 20 | 21 | image = np.ascontiguousarray(image) # 内存不连续的图像转换为内存连续的图像 22 | 23 | d = dcrf.DenseCRF2D(W, H, C) 24 | d.setUnaryEnergy(U) 25 | d.addPairwiseGaussian(sxy=self.pos_xy_std, compat=self.pos_w) 26 | d.addPairwiseBilateral( 27 | sxy=self.bi_xy_std, srgb=self.bi_rgb_std, rgbim=image, compat=self.bi_w 28 | ) 29 | 30 | Q = d.inference(self.iter_max) 31 | Q = np.array(Q).reshape((C, H, W)) 32 | 33 | return Q 34 | -------------------------------------------------------------------------------- /Image/figure.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/masurq/CFFormer/6be4de96f134c162253933dbf4271485d2286a87/Image/figure.png -------------------------------------------------------------------------------- /Image/test: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # CFFormer 2 | This repo holds code for [CFFormer: A Cross-Fusion Transformer Framework for the Semantic Segmentation of Multi-Source Remote Sensing Images](https://ieeexplore.ieee.org/document/10786275) 3 | # The overall architecture 4 | We propose a novel network framework based on a transformer model, which uses the FCM and FFM to facilitate the fusion of heterogeneous data sources and achieve more accurate semantic segmentation. The algorithmic framework of this paper is shown in Figure. In detail, the proposed approach relies on the classical encoder-decoder architecture, where the encoder incorporates feature extraction networks without weight sharing: the FCM for filtering diverse modal noise and differences, and the FFM for enhancing the information interaction and fusion. The decoder part aggregates the multi-scale features to generate the final result. Other common methods such as ResNet can be employed as an alternative for the feature extraction network. 5 | ![overall architecture](Image/figure.png) 6 | # Credits 7 | If you find this work useful, please consider citing: 8 | 9 | ```bibtex 10 | @ARTICLE{10786275, 11 | author={Zhao, Jinqi and Zhang, Ming and Zhou, Zhonghuai and Wang, Zixuan and Lang, Fengkai and Shi, Hongtao and Zheng, Nanshan}, 12 | journal={IEEE Transactions on Geoscience and Remote Sensing}, 13 | title={CFFormer: A Cross-Fusion Transformer Framework for the Semantic Segmentation of Multisource Remote Sensing Images}, 14 | year={2025}, 15 | volume={63}, 16 | number={}, 17 | pages={1-17}, 18 | keywords={Feature extraction;Optical imaging;Adaptive optics;Optical sensors;Semantic segmentation;Transformers;Remote sensing;Correlation;Noise;Fuses;Feature correction module (FCM);feature fusion module (FFM);multisource remote sensing images (RSIs);semantic segmentation;vision transformer}, 19 | doi={10.1109/TGRS.2024.3507274}} 20 | -------------------------------------------------------------------------------- /__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/masurq/CFFormer/6be4de96f134c162253933dbf4271485d2286a87/__init__.py -------------------------------------------------------------------------------- /__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/masurq/CFFormer/6be4de96f134c162253933dbf4271485d2286a87/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /base/__init__.py: -------------------------------------------------------------------------------- 1 | from .modules import conv_bn, conv_bn_relu, CPCAAttention 2 | -------------------------------------------------------------------------------- /base/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/masurq/CFFormer/6be4de96f134c162253933dbf4271485d2286a87/base/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /base/__pycache__/initialization.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/masurq/CFFormer/6be4de96f134c162253933dbf4271485d2286a87/base/__pycache__/initialization.cpython-38.pyc -------------------------------------------------------------------------------- /base/_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | def patch_first_conv(model, in_channel1, in_channel2): 6 | """Change first convolution layer input channels. 7 | In case: 8 | in_channels == 1 or in_channels == 2 -> reuse original weights 9 | in_channels > 3 -> make random kaiming normal initialization 10 | """ 11 | 12 | conv1_found = False 13 | conv2_found = False 14 | 15 | for name, module in model.named_modules(): 16 | if not conv1_found and isinstance(module, nn.Conv2d) and "conv1" in name: 17 | conv1_found = True 18 | module.in_channels = in_channel1 19 | weight = module.weight.detach() 20 | reset = False 21 | 22 | if in_channel1 == 1: 23 | weight = weight.sum(1, keepdim=True) 24 | elif in_channel1 == 2: 25 | weight[:, 0] = weight[:, 0] + 0.5 * weight[:, 1] 26 | weight[:, 1] = weight[:, 2] + 0.5 * weight[:, 1] 27 | weight = weight[:, :2] 28 | else: 29 | for i in range(3, in_channel1): 30 | weight = torch.cat((weight, weight[:, (i % 3):(i % 3 + 1)]), dim=1) 31 | weight *= (3 / in_channel1) 32 | 33 | module.weight = nn.parameter.Parameter(weight) 34 | 35 | if not conv2_found and isinstance(module, nn.Conv2d) and "hha_conv1" in name: 36 | conv2_found = True 37 | module.in_channels = in_channel2 38 | weight = module.weight.detach() 39 | reset = False 40 | 41 | if in_channel2 == 1: 42 | weight = weight.sum(1, keepdim=True) 43 | elif in_channel2 == 2: 44 | weight[:, 0] = weight[:, 0] + 0.5 * weight[:, 1] 45 | weight[:, 1] = weight[:, 2] + 0.5 * weight[:, 1] 46 | weight = weight[:, :2] 47 | else: 48 | for i in range(3, in_channel2): 49 | weight = torch.cat((weight, weight[:, (i % 3):(i % 3 + 1)]), dim=1) 50 | weight *= (3 / in_channel2) 51 | 52 | module.weight = nn.parameter.Parameter(weight) 53 | 54 | if conv1_found and conv2_found: 55 | break 56 | -------------------------------------------------------------------------------- /base/initialization.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | 4 | def weights_init(net, init_type='normal', init_gain=0.02): 5 | def init_func(m): 6 | classname = m.__class__.__name__ 7 | if hasattr(m, 'weight') and classname.find('Conv') != -1: 8 | if init_type == 'normal': 9 | nn.init.normal_(m.weight.data, 0.0, init_gain) 10 | elif init_type == 'xavier': 11 | nn.init.xavier_normal_(m.weight.data, gain=init_gain) 12 | elif init_type == 'kaiming': 13 | nn.init.kaiming_normal_(m.weight.data, a=0, mode='fan_in') 14 | elif init_type == 'orthogonal': 15 | nn.init.orthogonal_(m.weight.data, gain=init_gain) 16 | else: 17 | raise NotImplementedError('initialization method [%s] is not implemented' % init_type) 18 | elif classname.find('BatchNorm2d') != -1: 19 | nn.init.normal_(m.weight.data, 1.0, 0.02) 20 | nn.init.constant_(m.bias.data, 0.0) 21 | 22 | print('initialize network with %s type' % init_type) 23 | net.apply(init_func) 24 | 25 | 26 | def initialize_weights(*models): 27 | for model in models: 28 | for module in model.modules(): 29 | if isinstance(module, nn.Conv2d) or isinstance(module, nn.Linear): 30 | nn.init.kaiming_normal_(module.weight) 31 | if module.bias is not None: 32 | module.bias.data.zero_() 33 | elif isinstance(module, nn.BatchNorm2d): 34 | module.weight.data.fill_(1) 35 | module.bias.data.zero_() 36 | 37 | 38 | def initialize_decoder(module): 39 | for m in module.modules(): 40 | 41 | if isinstance(m, nn.Conv2d): 42 | nn.init.kaiming_uniform_(m.weight, mode="fan_in", nonlinearity="relu") 43 | if m.bias is not None: 44 | nn.init.constant_(m.bias, 0) 45 | 46 | elif isinstance(m, nn.BatchNorm2d): 47 | nn.init.constant_(m.weight, 1) 48 | nn.init.constant_(m.bias, 0) 49 | 50 | elif isinstance(m, nn.Linear): 51 | nn.init.xavier_uniform_(m.weight) 52 | if m.bias is not None: 53 | nn.init.constant_(m.bias, 0) 54 | 55 | 56 | def initialize_head(module): 57 | for m in module.modules(): 58 | if isinstance(m, (nn.Linear, nn.Conv2d)): 59 | nn.init.xavier_uniform_(m.weight) 60 | if m.bias is not None: 61 | nn.init.constant_(m.bias, 0) 62 | -------------------------------------------------------------------------------- /base/modules.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | import torch.nn.functional as F 4 | 5 | 6 | def conv_bn(in_channels, out_channels, kernel_size, stride, padding, groups=1): 7 | result = nn.Sequential() 8 | result.add_module('conv', nn.Conv2d(in_channels=in_channels, out_channels=out_channels, 9 | kernel_size=kernel_size, stride=stride, padding=padding, groups=groups, 10 | bias=False)) 11 | result.add_module('bn', nn.BatchNorm2d(num_features=out_channels)) 12 | return result 13 | 14 | 15 | def conv_bn_relu(in_channels, out_channels, kernel_size, stride, padding, groups=1): 16 | result = conv_bn(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride, 17 | padding=padding, groups=groups) 18 | result.add_module('relu', nn.ReLU()) 19 | return result 20 | 21 | 22 | ########################################################################## 23 | # Channel prior convolutional attention for medical image segmentation 24 | class ChannelAttention(nn.Module): 25 | 26 | def __init__(self, input_channels, internal_neurons): 27 | super(ChannelAttention, self).__init__() 28 | # 这里的两个1*1卷积就是所谓的shared mlp层 29 | self.fc1 = nn.Conv2d(in_channels=input_channels, out_channels=internal_neurons, kernel_size=1, stride=1, 30 | bias=True) 31 | self.fc2 = nn.Conv2d(in_channels=internal_neurons, out_channels=input_channels, kernel_size=1, stride=1, 32 | bias=True) 33 | self.input_channels = input_channels 34 | 35 | def forward(self, inputs): 36 | x1 = F.adaptive_avg_pool2d(inputs, output_size=(1, 1)) 37 | # print('x:', x.shape) 38 | x1 = self.fc1(x1) 39 | x1 = F.relu(x1, inplace=True) 40 | x1 = self.fc2(x1) 41 | x1 = torch.sigmoid(x1) 42 | x2 = F.adaptive_max_pool2d(inputs, output_size=(1, 1)) 43 | # print('x:', x.shape) 44 | x2 = self.fc1(x2) 45 | x2 = F.relu(x2, inplace=True) 46 | x2 = self.fc2(x2) 47 | x2 = torch.sigmoid(x2) 48 | x = x1 + x2 49 | x = x.view(-1, self.input_channels, 1, 1) 50 | return x 51 | 52 | 53 | class CPCAAttention(nn.Module): 54 | def __init__(self, in_channels, out_channels, channel_attention_reduce=4): 55 | super().__init__() 56 | 57 | self.C = in_channels 58 | self.O = out_channels 59 | 60 | assert in_channels == out_channels 61 | self.ca = ChannelAttention(input_channels=in_channels, internal_neurons=in_channels // channel_attention_reduce) 62 | self.dconv5_5 = nn.Conv2d(in_channels, in_channels, kernel_size=5, padding=2, groups=in_channels) 63 | # 连续使用kernel_size=(1,7)和kernel_size=(7,1)的卷积操作可以在水平和垂直方向上分别提取特征,具有更多的灵活性和精细度, 64 | # 适用于需要更细粒度的特征信息的场景。而直接使用kernel_size=(7,7)的卷积操作可以更全面地捕捉局部特征, 65 | # 适用于需要更广阔感受野和全局特征的场景,配合padding和groups似乎效果更好 66 | self.dconv1_7 = nn.Conv2d(in_channels, in_channels, kernel_size=(1, 7), padding=(0, 3), groups=in_channels) 67 | self.dconv7_1 = nn.Conv2d(in_channels, in_channels, kernel_size=(7, 1), padding=(3, 0), groups=in_channels) 68 | self.dconv1_11 = nn.Conv2d(in_channels, in_channels, kernel_size=(1, 11), padding=(0, 5), groups=in_channels) 69 | self.dconv11_1 = nn.Conv2d(in_channels, in_channels, kernel_size=(11, 1), padding=(5, 0), groups=in_channels) 70 | self.dconv1_21 = nn.Conv2d(in_channels, in_channels, kernel_size=(1, 21), padding=(0, 10), groups=in_channels) 71 | self.dconv21_1 = nn.Conv2d(in_channels, in_channels, kernel_size=(21, 1), padding=(10, 0), groups=in_channels) 72 | self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=(1, 1), padding=0) 73 | self.act = nn.GELU() 74 | 75 | def forward(self, inputs): 76 | # Global Perceptron 77 | # 一个1*1conv用了三次,难道全是channel-mixing吗??? 78 | # channel-mixing? 79 | inputs = self.conv(inputs) 80 | inputs = self.act(inputs) 81 | 82 | channel_att_vec = self.ca(inputs) 83 | inputs = channel_att_vec * inputs 84 | 85 | x_init = self.dconv5_5(inputs) 86 | x_1 = self.dconv1_7(x_init) 87 | x_1 = self.dconv7_1(x_1) 88 | x_2 = self.dconv1_11(x_init) 89 | x_2 = self.dconv11_1(x_2) 90 | x_3 = self.dconv1_21(x_init) 91 | x_3 = self.dconv21_1(x_3) 92 | x = x_1 + x_2 + x_3 + x_init 93 | # channel-mixing? 94 | spatial_att = self.conv(x) 95 | out = spatial_att * inputs 96 | # channel-mixing? 97 | out = self.conv(out) 98 | return out 99 | 100 | 101 | ########################################################################## 102 | # FlowNet: Learning Optical Flow with Convolutional Networks 103 | def predict_flow(in_planes): 104 | # 光流预测,通道数为2,分别代表光流的方向和大小 105 | return nn.Conv2d(in_planes, 2, kernel_size=3, stride=1, padding=1, bias=False) 106 | 107 | 108 | def deconv(in_planes, out_planes): 109 | # 转置卷积 110 | return nn.Sequential( 111 | nn.ConvTranspose2d(in_planes, out_planes, kernel_size=4, stride=2, padding=1, bias=False), 112 | nn.LeakyReLU(0.1, inplace=True) 113 | ) 114 | 115 | 116 | 117 | -------------------------------------------------------------------------------- /new_model/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/masurq/CFFormer/6be4de96f134c162253933dbf4271485d2286a87/new_model/__init__.py -------------------------------------------------------------------------------- /new_model/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/masurq/CFFormer/6be4de96f134c162253933dbf4271485d2286a87/new_model/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /new_model/__pycache__/builder.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/masurq/CFFormer/6be4de96f134c162253933dbf4271485d2286a87/new_model/__pycache__/builder.cpython-38.pyc -------------------------------------------------------------------------------- /new_model/__pycache__/init_func.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/masurq/CFFormer/6be4de96f134c162253933dbf4271485d2286a87/new_model/__pycache__/init_func.cpython-38.pyc -------------------------------------------------------------------------------- /new_model/__pycache__/modules.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/masurq/CFFormer/6be4de96f134c162253933dbf4271485d2286a87/new_model/__pycache__/modules.cpython-38.pyc -------------------------------------------------------------------------------- /new_model/__pycache__/net_utils.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/masurq/CFFormer/6be4de96f134c162253933dbf4271485d2286a87/new_model/__pycache__/net_utils.cpython-38.pyc -------------------------------------------------------------------------------- /new_model/decoders/MLPAlignDecoder.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch.nn as nn 3 | import torch 4 | 5 | from torch.nn.modules import module 6 | import torch.nn.functional as F 7 | 8 | 9 | class MLP(nn.Module): 10 | """ 11 | Linear Embedding: 12 | """ 13 | 14 | def __init__(self, input_dim=2048, embed_dim=768): 15 | super().__init__() 16 | self.proj = nn.Linear(input_dim, embed_dim) 17 | 18 | def forward(self, x): 19 | # B C H W -> B HW C 20 | x = x.flatten(2).transpose(1, 2) 21 | x = self.proj(x) 22 | return x 23 | 24 | 25 | class AlignedModule(nn.Module): 26 | 27 | def __init__(self, inplane, outplane, kernel_size=3): 28 | super(AlignedModule, self).__init__() 29 | self.down_h = nn.Conv2d(inplane, outplane, 1, bias=False) 30 | self.down_l = nn.Conv2d(inplane, outplane, 1, bias=False) 31 | self.flow_make = nn.Conv2d(outplane * 2, 2, kernel_size=kernel_size, padding=1, bias=False) 32 | 33 | def forward(self, x1, x2): 34 | low_feature, h_feature = x1, x2 35 | h_feature_orign = h_feature 36 | h, w = low_feature.size()[2:] 37 | size = (h, w) 38 | low_feature = self.down_l(low_feature) 39 | h_feature = self.down_h(h_feature) 40 | h_feature = F.interpolate(h_feature, size=size, mode="bilinear", align_corners=True) 41 | flow = self.flow_make(torch.cat([h_feature, low_feature], 1)) 42 | h_feature = self.flow_warp(h_feature_orign, flow, size=size) 43 | 44 | return h_feature 45 | 46 | def flow_warp(self, input, flow, size): 47 | out_h, out_w = size 48 | n, c, h, w = input.size() 49 | # n, c, h, w 50 | # n, 2, h, w 51 | 52 | norm = torch.tensor([[[[out_w, out_h]]]]).type_as(input).to(input.device) 53 | h = torch.linspace(-1.0, 1.0, out_h).view(-1, 1).repeat(1, out_w) 54 | w = torch.linspace(-1.0, 1.0, out_w).repeat(out_h, 1) 55 | grid = torch.cat((w.unsqueeze(2), h.unsqueeze(2)), 2) 56 | grid = grid.repeat(n, 1, 1, 1).type_as(input).to(input.device) 57 | grid = grid + flow.permute(0, 2, 3, 1) / norm 58 | 59 | output = F.grid_sample(input, grid, align_corners=True) 60 | return output 61 | 62 | 63 | class DecoderHead(nn.Module): 64 | def __init__(self, 65 | in_channels=[64, 128, 320, 512], 66 | num_classes=40, 67 | dropout_ratio=0.1, 68 | norm_layer=nn.BatchNorm2d, 69 | embed_dim=768, 70 | align_corners=False): 71 | 72 | super(DecoderHead, self).__init__() 73 | self.num_classes = num_classes 74 | self.dropout_ratio = dropout_ratio 75 | self.align_corners = align_corners 76 | 77 | self.in_channels = in_channels 78 | 79 | if dropout_ratio > 0: 80 | self.dropout = nn.Dropout2d(dropout_ratio) 81 | else: 82 | self.dropout = None 83 | 84 | c1_in_channels, c2_in_channels, c3_in_channels, c4_in_channels = self.in_channels 85 | 86 | embedding_dim = embed_dim 87 | self.linear_c4 = MLP(input_dim=c4_in_channels, embed_dim=embedding_dim) 88 | self.linear_c3 = MLP(input_dim=c3_in_channels, embed_dim=embedding_dim) 89 | self.linear_c2 = MLP(input_dim=c2_in_channels, embed_dim=embedding_dim) 90 | self.linear_c1 = MLP(input_dim=c1_in_channels, embed_dim=embedding_dim) 91 | 92 | self.Aligned_c4 = AlignedModule(embedding_dim, embedding_dim // 2) 93 | self.Aligned_c3 = AlignedModule(embedding_dim, embedding_dim // 2) 94 | self.Aligned_c2 = AlignedModule(embedding_dim, embedding_dim // 2) 95 | 96 | self.linear_fuse = nn.Sequential( 97 | nn.Conv2d(in_channels=embedding_dim * 4, out_channels=embedding_dim, kernel_size=1), 98 | norm_layer(embedding_dim), 99 | nn.ReLU(inplace=True) 100 | ) 101 | 102 | self.linear_pred = nn.Conv2d(embedding_dim, self.num_classes, kernel_size=1) 103 | 104 | def forward(self, inputs): 105 | 106 | c1, c2, c3, c4 = inputs 107 | 108 | n, _, h, w = c4.shape 109 | 110 | # B HW C -> B C H W 111 | _c1 = self.linear_c1(c1).permute(0, 2, 1).reshape(n, -1, c1.shape[2], c1.shape[3]) 112 | 113 | # B HW C -> B C H W 114 | _c2 = self.linear_c2(c2).permute(0, 2, 1).reshape(n, -1, c2.shape[2], c2.shape[3]) 115 | _c2 = self.Aligned_c2(_c1, _c2) 116 | 117 | # B HW C -> B C H W 118 | _c3 = self.linear_c3(c3).permute(0, 2, 1).reshape(n, -1, c3.shape[2], c3.shape[3]) 119 | _c3 = self.Aligned_c3(_c1, _c3) 120 | 121 | # B HW C -> B C H W 122 | _c4 = self.linear_c4(c4).permute(0, 2, 1).reshape(n, -1, c4.shape[2], c4.shape[3]) 123 | _c4 = self.Aligned_c4(_c1, _c4) 124 | 125 | _c = self.linear_fuse(torch.cat([_c4, _c3, _c2, _c1], dim=1)) 126 | x = self.dropout(_c) 127 | x = self.linear_pred(x) 128 | 129 | return x 130 | -------------------------------------------------------------------------------- /new_model/decoders/MLPDecoder.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch.nn as nn 3 | import torch 4 | 5 | from torch.nn.modules import module 6 | import torch.nn.functional as F 7 | 8 | 9 | class MLP(nn.Module): 10 | """ 11 | Linear Embedding: 12 | """ 13 | 14 | def __init__(self, input_dim=2048, embed_dim=768): 15 | super().__init__() 16 | self.proj = nn.Linear(input_dim, embed_dim) 17 | 18 | def forward(self, x): 19 | # B C H W -> B HW C 20 | x = x.flatten(2).transpose(1, 2) 21 | x = self.proj(x) 22 | return x 23 | 24 | 25 | class DecoderHead(nn.Module): 26 | def __init__(self, 27 | in_channels=[64, 128, 320, 512], 28 | num_classes=40, 29 | dropout_ratio=0.1, 30 | norm_layer=nn.BatchNorm2d, 31 | embed_dim=768, 32 | align_corners=False): 33 | 34 | super(DecoderHead, self).__init__() 35 | self.num_classes = num_classes 36 | self.dropout_ratio = dropout_ratio 37 | self.align_corners = align_corners 38 | 39 | self.in_channels = in_channels 40 | 41 | if dropout_ratio > 0: 42 | self.dropout = nn.Dropout2d(dropout_ratio) 43 | else: 44 | self.dropout = None 45 | 46 | c1_in_channels, c2_in_channels, c3_in_channels, c4_in_channels = self.in_channels 47 | 48 | embedding_dim = embed_dim 49 | self.linear_c4 = MLP(input_dim=c4_in_channels, embed_dim=embedding_dim) 50 | self.linear_c3 = MLP(input_dim=c3_in_channels, embed_dim=embedding_dim) 51 | self.linear_c2 = MLP(input_dim=c2_in_channels, embed_dim=embedding_dim) 52 | self.linear_c1 = MLP(input_dim=c1_in_channels, embed_dim=embedding_dim) 53 | 54 | self.linear_fuse = nn.Sequential( 55 | nn.Conv2d(in_channels=embedding_dim * 4, out_channels=embedding_dim, kernel_size=1), 56 | norm_layer(embedding_dim), 57 | nn.ReLU(inplace=True) 58 | ) 59 | 60 | self.linear_pred = nn.Conv2d(embedding_dim, self.num_classes, kernel_size=1) 61 | 62 | def forward(self, inputs): 63 | # len=4, 1/4,1/8,1/16,1/32 64 | c1, c2, c3, c4 = inputs 65 | 66 | ############## MLP decoder on C1-C4 ########### 67 | n, _, h, w = c4.shape 68 | 69 | # B HW C -> B C H W 70 | _c4 = self.linear_c4(c4).permute(0, 2, 1).reshape(n, -1, c4.shape[2], c4.shape[3]) 71 | _c4 = F.interpolate(_c4, size=c1.size()[2:], mode='bilinear', align_corners=self.align_corners) 72 | 73 | # B HW C -> B C H W 74 | _c3 = self.linear_c3(c3).permute(0, 2, 1).reshape(n, -1, c3.shape[2], c3.shape[3]) 75 | _c3 = F.interpolate(_c3, size=c1.size()[2:], mode='bilinear', align_corners=self.align_corners) 76 | 77 | # B HW C -> B C H W 78 | _c2 = self.linear_c2(c2).permute(0, 2, 1).reshape(n, -1, c2.shape[2], c2.shape[3]) 79 | _c2 = F.interpolate(_c2, size=c1.size()[2:], mode='bilinear', align_corners=self.align_corners) 80 | 81 | # B HW C -> B C H W 82 | _c1 = self.linear_c1(c1).permute(0, 2, 1).reshape(n, -1, c1.shape[2], c1.shape[3]) 83 | 84 | _c = self.linear_fuse(torch.cat([_c4, _c3, _c2, _c1], dim=1)) 85 | x = self.dropout(_c) 86 | x = self.linear_pred(x) 87 | 88 | return x 89 | -------------------------------------------------------------------------------- /new_model/decoders/UPernet.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch.nn as nn 3 | import torch 4 | 5 | from torch.nn.modules import module 6 | import torch.nn.functional as F 7 | 8 | 9 | class UPerHead(nn.Module): 10 | """Unified Perceptual Parsing for Scene Understanding. 11 | This head is the implementation of `UPerNet 12 | `_. 13 | Args: 14 | pool_scales (tuple[int]): Pooling scales used in Pooling Pyramid 15 | Module applied on the last feature. Default: (1, 2, 3, 6). 16 | """ 17 | 18 | def __init__(self, in_channels=[96, 192, 384, 768], num_classes=40, channels=512, pool_scales=(1, 2, 3, 6), 19 | norm_layer=nn.BatchNorm2d, dropout_ratio=0.1, align_corners=False): 20 | super(UPerHead, self).__init__() 21 | self.in_channels = in_channels 22 | self.channels = channels 23 | self.align_corners = align_corners 24 | # PSP Module 25 | self.psp_modules = PPM( 26 | pool_scales, 27 | self.in_channels[-1], 28 | self.channels, 29 | norm_layer=norm_layer, 30 | align_corners=align_corners) 31 | self.bottleneck = nn.Sequential( 32 | nn.Conv2d(self.in_channels[-1] + len(pool_scales) * self.channels, self.channels, 3, padding=1), 33 | norm_layer(self.channels), 34 | nn.ReLU(inplace=True) 35 | ) 36 | # FPN Module 37 | self.lateral_convs = nn.ModuleList() 38 | self.fpn_convs = nn.ModuleList() 39 | for in_channels in self.in_channels[:-1]: # skip the top layer 40 | l_conv = nn.Sequential( 41 | nn.Conv2d(in_channels, self.channels, 1), 42 | norm_layer(self.channels), 43 | nn.ReLU(inplace=False) 44 | ) 45 | fpn_conv = nn.Sequential( 46 | nn.Conv2d(self.channels, self.channels, 3, padding=1), 47 | norm_layer(self.channels), 48 | nn.ReLU(inplace=False) 49 | ) 50 | self.lateral_convs.append(l_conv) 51 | self.fpn_convs.append(fpn_conv) 52 | 53 | self.fpn_bottleneck = nn.Sequential( 54 | nn.Conv2d(len(self.in_channels) * self.channels, self.channels, 3, padding=1), 55 | norm_layer(self.channels), 56 | nn.ReLU(inplace=True) 57 | ) 58 | self.conv_seg = nn.Conv2d(channels, num_classes, kernel_size=1) 59 | 60 | def psp_forward(self, inputs): 61 | """Forward function of PSP module.""" 62 | x = inputs[-1] 63 | psp_outs = [x] 64 | psp_outs.extend(self.psp_modules(x)) 65 | psp_outs = torch.cat(psp_outs, dim=1) 66 | output = self.bottleneck(psp_outs) 67 | 68 | return output 69 | 70 | def forward(self, inputs): 71 | # build laterals 72 | laterals = [ 73 | lateral_conv(inputs[i]) 74 | for i, lateral_conv in enumerate(self.lateral_convs) 75 | ] 76 | laterals.append(self.psp_forward(inputs)) 77 | 78 | # build top-down path 79 | used_backbone_levels = len(laterals) 80 | for i in range(used_backbone_levels - 1, 0, -1): 81 | prev_shape = laterals[i - 1].shape[2:] 82 | laterals[i - 1] = laterals[i - 1] + F.interpolate( 83 | laterals[i], 84 | size=prev_shape, 85 | mode='bilinear', 86 | align_corners=self.align_corners) 87 | 88 | # build outputs 89 | fpn_outs = [ 90 | self.fpn_convs[i](laterals[i]) 91 | for i in range(used_backbone_levels - 1) 92 | ] 93 | # append psp feature 94 | fpn_outs.append(laterals[-1]) 95 | 96 | for i in range(used_backbone_levels - 1, 0, -1): 97 | fpn_outs[i] = F.interpolate( 98 | fpn_outs[i], 99 | size=fpn_outs[0].shape[2:], 100 | mode='bilinear', 101 | align_corners=self.align_corners) 102 | fpn_outs = torch.cat(fpn_outs, dim=1) 103 | output = self.fpn_bottleneck(fpn_outs) 104 | output = self.conv_seg(output) 105 | 106 | return output 107 | 108 | 109 | class PPM(nn.ModuleList): 110 | """Pooling Pyramid Module used in PSPNet. 111 | Args: 112 | pool_scales (tuple[int]): Pooling scales used in Pooling Pyramid 113 | Module. 114 | in_channels (int): Input channels. 115 | channels (int): Channels after modules, before conv_seg. 116 | conv_cfg (dict|None): Config of conv layers. 117 | norm_cfg (dict|None): Config of norm layers. 118 | act_cfg (dict): Config of activation layers. 119 | align_corners (bool): align_corners argument of F.interpolate. 120 | """ 121 | 122 | def __init__(self, pool_scales, in_channel, channels, norm_layer, align_corners=False): 123 | super(PPM, self).__init__() 124 | self.pool_scales = pool_scales 125 | self.align_corners = align_corners 126 | self.in_channel = in_channel 127 | self.channels = channels 128 | for pool_scale in pool_scales: 129 | self.append( 130 | nn.Sequential( 131 | nn.AdaptiveAvgPool2d(pool_scale), 132 | nn.Conv2d(self.in_channel, self.channels, 1), 133 | norm_layer(self.channels), 134 | nn.ReLU(inplace=True) 135 | )) 136 | 137 | def forward(self, x): 138 | """Forward function.""" 139 | ppm_outs = [] 140 | for ppm in self: 141 | ppm_out = ppm(x) 142 | upsampled_ppm_out = F.interpolate( 143 | ppm_out, 144 | size=x.size()[2:], 145 | mode='bilinear', 146 | align_corners=self.align_corners) 147 | ppm_outs.append(upsampled_ppm_out) 148 | return ppm_outs 149 | -------------------------------------------------------------------------------- /new_model/decoders/__pycache__/MLPAlignDecoder.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/masurq/CFFormer/6be4de96f134c162253933dbf4271485d2286a87/new_model/decoders/__pycache__/MLPAlignDecoder.cpython-38.pyc -------------------------------------------------------------------------------- /new_model/decoders/__pycache__/MLPDecoder.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/masurq/CFFormer/6be4de96f134c162253933dbf4271485d2286a87/new_model/decoders/__pycache__/MLPDecoder.cpython-38.pyc -------------------------------------------------------------------------------- /new_model/decoders/__pycache__/UPernet.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/masurq/CFFormer/6be4de96f134c162253933dbf4271485d2286a87/new_model/decoders/__pycache__/UPernet.cpython-38.pyc -------------------------------------------------------------------------------- /new_model/decoders/__pycache__/condnet.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/masurq/CFFormer/6be4de96f134c162253933dbf4271485d2286a87/new_model/decoders/__pycache__/condnet.cpython-38.pyc -------------------------------------------------------------------------------- /new_model/decoders/__pycache__/deeplabv3plus.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/masurq/CFFormer/6be4de96f134c162253933dbf4271485d2286a87/new_model/decoders/__pycache__/deeplabv3plus.cpython-38.pyc -------------------------------------------------------------------------------- /new_model/decoders/__pycache__/fapn.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/masurq/CFFormer/6be4de96f134c162253933dbf4271485d2286a87/new_model/decoders/__pycache__/fapn.cpython-38.pyc -------------------------------------------------------------------------------- /new_model/decoders/__pycache__/fcnhead.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/masurq/CFFormer/6be4de96f134c162253933dbf4271485d2286a87/new_model/decoders/__pycache__/fcnhead.cpython-38.pyc -------------------------------------------------------------------------------- /new_model/decoders/__pycache__/fpn.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/masurq/CFFormer/6be4de96f134c162253933dbf4271485d2286a87/new_model/decoders/__pycache__/fpn.cpython-38.pyc -------------------------------------------------------------------------------- /new_model/decoders/__pycache__/hem.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/masurq/CFFormer/6be4de96f134c162253933dbf4271485d2286a87/new_model/decoders/__pycache__/hem.cpython-38.pyc -------------------------------------------------------------------------------- /new_model/decoders/__pycache__/lawin.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/masurq/CFFormer/6be4de96f134c162253933dbf4271485d2286a87/new_model/decoders/__pycache__/lawin.cpython-38.pyc -------------------------------------------------------------------------------- /new_model/decoders/condnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn, Tensor 3 | from torch.nn import functional as F 4 | 5 | 6 | class ConvModule(nn.Sequential): 7 | def __init__(self, c1, c2, k, s=1, p=0, d=1, g=1): 8 | super().__init__( 9 | nn.Conv2d(c1, c2, k, s, p, d, g, bias=False), 10 | nn.BatchNorm2d(c2), 11 | nn.ReLU(True) 12 | ) 13 | 14 | 15 | class CondHead(nn.Module): 16 | def __init__(self, in_channel: int = 2048, channel: int = 512, num_classes: int = 19): 17 | super().__init__() 18 | self.num_classes = num_classes 19 | self.weight_num = channel * num_classes 20 | self.bias_num = num_classes 21 | 22 | self.conv = ConvModule(in_channel, channel, 1) 23 | self.dropout = nn.Dropout2d(0.1) 24 | 25 | self.guidance_project = nn.Conv2d(channel, num_classes, 1) 26 | self.filter_project = nn.Conv2d(channel * num_classes, self.weight_num + self.bias_num, 1, groups=num_classes) 27 | 28 | def forward(self, features) -> Tensor: 29 | x = self.dropout(self.conv(features[-1])) 30 | B, C, H, W = x.shape 31 | guidance_mask = self.guidance_project(x) 32 | cond_logit = guidance_mask 33 | 34 | key = x 35 | value = x 36 | guidance_mask = guidance_mask.softmax(dim=1).view(*guidance_mask.shape[:2], -1) 37 | key = key.view(B, C, -1).permute(0, 2, 1) 38 | 39 | cond_filters = torch.matmul(guidance_mask, key) 40 | cond_filters /= H * W 41 | cond_filters = cond_filters.view(B, -1, 1, 1) 42 | cond_filters = self.filter_project(cond_filters) 43 | cond_filters = cond_filters.view(B, -1) 44 | 45 | weight, bias = torch.split(cond_filters, [self.weight_num, self.bias_num], dim=1) 46 | weight = weight.reshape(B * self.num_classes, -1, 1, 1) 47 | bias = bias.reshape(B * self.num_classes) 48 | 49 | value = value.view(-1, H, W).unsqueeze(0) 50 | seg_logit = F.conv2d(value, weight, bias, 1, 0, groups=B).view(B, self.num_classes, H, W) 51 | 52 | # if self.training: 53 | # return cond_logit, seg_logit 54 | return seg_logit 55 | 56 | 57 | # if __name__ == '__main__': 58 | # from semseg.models.backbones import ResNetD 59 | # 60 | # backbone = ResNetD('50') 61 | # head = CondHead() 62 | # x = torch.randn(2, 3, 224, 224) 63 | # features = backbone(x) 64 | # outs = head(features) 65 | # for out in outs: 66 | # out = F.interpolate(out, size=x.shape[-2:], mode='bilinear', align_corners=False) 67 | # print(out.shape) 68 | -------------------------------------------------------------------------------- /new_model/decoders/deeplabv3plus.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class DeepLabV3Plus(nn.Module): 7 | def __init__(self, in_channels=[256, 512, 1024, 2048], num_classes=40, norm_layer=nn.BatchNorm2d): 8 | super(DeepLabV3Plus, self).__init__() 9 | self.num_classes = num_classes 10 | 11 | self.aspp = ASPP(in_channels=in_channels[3], atrous_rates=[12, 24, 36], norm_layer=norm_layer, separable=True) 12 | self.low_level = nn.Sequential( 13 | SeparableConv2d(in_channels[0], 48, kernel_size=3, stride=1, padding=1, bias=False), 14 | # nn.Conv2d(in_channels[0], 48, kernel_size=3, stride=1, padding=1, bias=False), 15 | norm_layer(48), 16 | nn.ReLU(inplace=True) 17 | ) 18 | self.block = nn.Sequential( 19 | SeparableConv2d(304, 256, kernel_size=3, stride=1, padding=1, bias=False), 20 | # nn.Conv2d(304, 256, kernel_size=3, stride=1, padding=1, bias=False), 21 | norm_layer(256), 22 | nn.ReLU(inplace=True), 23 | # nn.Dropout(0.5), 24 | # nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1), 25 | # norm_layer(256), 26 | # nn.ReLU(inplace=True), 27 | nn.Dropout(0.1), 28 | nn.Conv2d(256, num_classes, 1)) 29 | 30 | def forward(self, inputs): 31 | c1, _, _, c4 = inputs 32 | c1 = self.low_level(c1) 33 | c4 = self.aspp(c4) 34 | c4 = F.interpolate(c4, c1.size()[2:], mode='bilinear', align_corners=True) 35 | output = self.block(torch.cat([c4, c1], dim=1)) 36 | return output 37 | 38 | 39 | class ASPPConv(nn.Module): 40 | def __init__(self, in_channels, out_channels, atrous_rate, norm_layer): 41 | super(ASPPConv, self).__init__() 42 | self.block = nn.Sequential( 43 | nn.Conv2d(in_channels, out_channels, 3, padding=atrous_rate, dilation=atrous_rate, bias=False), 44 | norm_layer(out_channels), 45 | nn.ReLU(True) 46 | ) 47 | 48 | def forward(self, x): 49 | return self.block(x) 50 | 51 | 52 | class ASPPSeparableConv(nn.Sequential): 53 | def __init__(self, in_channels, out_channels, atrous_rate, norm_layer): 54 | super().__init__( 55 | SeparableConv2d( 56 | in_channels, 57 | out_channels, 58 | kernel_size=3, 59 | padding=atrous_rate, 60 | dilation=atrous_rate, 61 | bias=False, 62 | ), 63 | norm_layer(out_channels), 64 | nn.ReLU(), 65 | ) 66 | 67 | 68 | class SeparableConv2d(nn.Sequential): 69 | 70 | def __init__( 71 | self, 72 | in_channels, 73 | out_channels, 74 | kernel_size, 75 | stride=1, 76 | padding=0, 77 | dilation=1, 78 | bias=True, 79 | ): 80 | dephtwise_conv = nn.Conv2d( 81 | in_channels, 82 | in_channels, 83 | kernel_size, 84 | stride=stride, 85 | padding=padding, 86 | dilation=dilation, 87 | groups=in_channels, 88 | bias=False, 89 | ) 90 | pointwise_conv = nn.Conv2d( 91 | in_channels, 92 | out_channels, 93 | kernel_size=1, 94 | bias=bias, 95 | ) 96 | super().__init__(dephtwise_conv, pointwise_conv) 97 | 98 | 99 | class AsppPooling(nn.Module): 100 | def __init__(self, in_channels, out_channels, norm_layer): 101 | super(AsppPooling, self).__init__() 102 | self.gap = nn.Sequential( 103 | nn.AdaptiveAvgPool2d(1), 104 | nn.Conv2d(in_channels, out_channels, 1, bias=False), 105 | norm_layer(out_channels), 106 | nn.ReLU(True) 107 | ) 108 | 109 | def forward(self, x): 110 | size = x.size()[2:] 111 | pool = self.gap(x) 112 | out = F.interpolate(pool, size, mode='bilinear', align_corners=True) 113 | return out 114 | 115 | 116 | class ASPP(nn.Module): 117 | def __init__(self, in_channels, atrous_rates, norm_layer, separable=False): 118 | super(ASPP, self).__init__() 119 | out_channels = 256 120 | self.b0 = nn.Sequential( 121 | nn.Conv2d(in_channels, out_channels, 1, bias=False), 122 | norm_layer(out_channels), 123 | nn.ReLU(True) 124 | ) 125 | 126 | rate1, rate2, rate3 = tuple(atrous_rates) 127 | ASPPConvModule = ASPPConv if not separable else ASPPSeparableConv 128 | 129 | self.b1 = ASPPConvModule(in_channels, out_channels, rate1, norm_layer) 130 | self.b2 = ASPPConvModule(in_channels, out_channels, rate2, norm_layer) 131 | self.b3 = ASPPConvModule(in_channels, out_channels, rate3, norm_layer) 132 | self.b4 = AsppPooling(in_channels, out_channels, norm_layer=norm_layer) 133 | 134 | self.project = nn.Sequential( 135 | nn.Conv2d(5 * out_channels, out_channels, 1, bias=False), 136 | norm_layer(out_channels), 137 | nn.ReLU(True), 138 | nn.Dropout(0.5) 139 | ) 140 | 141 | def forward(self, x): 142 | feat1 = self.b0(x) 143 | feat2 = self.b1(x) 144 | feat3 = self.b2(x) 145 | feat4 = self.b3(x) 146 | feat5 = self.b4(x) 147 | x = torch.cat((feat1, feat2, feat3, feat4, feat5), dim=1) 148 | x = self.project(x) 149 | return x 150 | -------------------------------------------------------------------------------- /new_model/decoders/fapn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn, Tensor 3 | from torch.nn import functional as F 4 | from torchvision.ops import DeformConv2d 5 | 6 | 7 | class ConvModule(nn.Sequential): 8 | def __init__(self, c1, c2, k, s=1, p=0, d=1, g=1): 9 | super().__init__( 10 | nn.Conv2d(c1, c2, k, s, p, d, g, bias=False), 11 | nn.BatchNorm2d(c2), 12 | nn.ReLU(True) 13 | ) 14 | 15 | 16 | class DCNv2(nn.Module): 17 | def __init__(self, c1, c2, k, s, p, g=1): 18 | super().__init__() 19 | self.dcn = DeformConv2d(c1, c2, k, s, p, groups=g) 20 | self.offset_mask = nn.Conv2d(c2, g * 3 * k * k, k, s, p) 21 | self._init_offset() 22 | 23 | def _init_offset(self): 24 | self.offset_mask.weight.data.zero_() 25 | self.offset_mask.bias.data.zero_() 26 | 27 | def forward(self, x, offset): 28 | out = self.offset_mask(offset) 29 | o1, o2, mask = torch.chunk(out, 3, dim=1) 30 | offset = torch.cat([o1, o2], dim=1) 31 | mask = mask.sigmoid() 32 | return self.dcn(x, offset, mask) 33 | 34 | 35 | class FSM(nn.Module): 36 | def __init__(self, c1, c2): 37 | super().__init__() 38 | self.conv_atten = nn.Conv2d(c1, c1, 1, bias=False) 39 | self.conv = nn.Conv2d(c1, c2, 1, bias=False) 40 | 41 | def forward(self, x: Tensor) -> Tensor: 42 | atten = self.conv_atten(F.avg_pool2d(x, x.shape[2:])).sigmoid() 43 | feat = torch.mul(x, atten) 44 | x = x + feat 45 | return self.conv(x) 46 | 47 | 48 | class FAM(nn.Module): 49 | def __init__(self, c1, c2): 50 | super().__init__() 51 | self.lateral_conv = FSM(c1, c2) 52 | self.offset = nn.Conv2d(c2 * 2, c2, 1, bias=False) 53 | self.dcpack_l2 = DCNv2(c2, c2, 3, 1, 1, 8) 54 | 55 | def forward(self, feat_l, feat_s): 56 | feat_up = feat_s 57 | if feat_l.shape[2:] != feat_s.shape[2:]: 58 | feat_up = F.interpolate(feat_s, size=feat_l.shape[2:], mode='bilinear', align_corners=False) 59 | 60 | feat_arm = self.lateral_conv(feat_l) 61 | offset = self.offset(torch.cat([feat_arm, feat_up * 2], dim=1)) 62 | 63 | feat_align = F.relu(self.dcpack_l2(feat_up, offset)) 64 | return feat_align + feat_arm 65 | 66 | 67 | class FaPNHead(nn.Module): 68 | def __init__(self, in_channels, channel=128, num_classes=19): 69 | super().__init__() 70 | in_channels = in_channels[::-1] 71 | self.align_modules = nn.ModuleList([ConvModule(in_channels[0], channel, 1)]) 72 | self.output_convs = nn.ModuleList([]) 73 | 74 | for ch in in_channels[1:]: 75 | self.align_modules.append(FAM(ch, channel)) 76 | self.output_convs.append(ConvModule(channel, channel, 3, 1, 1)) 77 | 78 | self.conv_seg = nn.Conv2d(channel, num_classes, 1) 79 | self.dropout = nn.Dropout2d(0.1) 80 | 81 | def forward(self, features) -> Tensor: 82 | features = features[::-1] 83 | out = self.align_modules[0](features[0]) 84 | 85 | for feat, align_module, output_conv in zip(features[1:], self.align_modules[1:], self.output_convs): 86 | out = align_module(feat, out) 87 | out = output_conv(out) 88 | out = self.conv_seg(self.dropout(out)) 89 | return out 90 | 91 | 92 | # if __name__ == '__main__': 93 | # from semseg.models.backbones import ResNet 94 | # 95 | # backbone = ResNet('50') 96 | # head = FaPNHead([256, 512, 1024, 2048], 128, 19) 97 | # x = torch.randn(2, 3, 224, 224) 98 | # features = backbone(x) 99 | # out = head(features) 100 | # out = F.interpolate(out, size=x.shape[-2:], mode='bilinear', align_corners=False) 101 | # print(out.shape) 102 | -------------------------------------------------------------------------------- /new_model/decoders/fcnhead.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | 4 | class MainFCNHead(nn.Module): 5 | def __init__(self, in_channels=384, channels=None, kernel_size=3, dilation=1, 6 | num_classes=40, norm_layer=nn.BatchNorm2d): 7 | super(MainFCNHead, self).__init__() 8 | self.kernel_size = kernel_size 9 | self.in_channels = in_channels 10 | self.channels = channels or in_channels // 4 11 | 12 | conv_padding = (kernel_size // 2) * dilation 13 | self.conv = nn.Sequential( 14 | nn.Conv2d(self.in_channels, self.channels, kernel_size, padding=conv_padding, bias=False), 15 | norm_layer(self.channels), 16 | nn.ReLU(inplace=True) 17 | ) 18 | 19 | self.classifier = nn.Conv2d(self.channels, num_classes, kernel_size=1) 20 | 21 | def forward(self, x): 22 | output = self.conv(x[-1]) 23 | output = self.classifier(output) 24 | return output 25 | 26 | 27 | class AuxFCNHead(nn.Module): 28 | def __init__(self, in_channels=384, channels=None, kernel_size=3, dilation=1, 29 | num_classes=40, norm_layer=nn.BatchNorm2d): 30 | super(AuxFCNHead, self).__init__() 31 | self.kernel_size = kernel_size 32 | self.in_channels = in_channels 33 | self.channels = channels or in_channels // 4 34 | 35 | conv_padding = (kernel_size // 2) * dilation 36 | self.conv = nn.Sequential( 37 | nn.Conv2d(self.in_channels, self.channels, kernel_size, padding=conv_padding, bias=False), 38 | norm_layer(self.channels), 39 | nn.ReLU(inplace=True) 40 | ) 41 | 42 | self.classifier = nn.Conv2d(self.channels, num_classes, kernel_size=1) 43 | 44 | def forward(self, x): 45 | output = self.conv(x) 46 | output = self.classifier(output) 47 | return output 48 | -------------------------------------------------------------------------------- /new_model/decoders/fpn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn, Tensor 3 | from torch.nn import functional as F 4 | 5 | 6 | class ConvModule(nn.Sequential): 7 | def __init__(self, c1, c2, k, s=1, p=0, d=1, g=1): 8 | super().__init__( 9 | nn.Conv2d(c1, c2, k, s, p, d, g, bias=False), 10 | nn.BatchNorm2d(c2), 11 | nn.ReLU(True) 12 | ) 13 | 14 | 15 | class FPNHead(nn.Module): 16 | """Panoptic Feature Pyramid Networks 17 | https://arxiv.org/abs/1901.02446 18 | """ 19 | 20 | def __init__(self, in_channels, channel=128, num_classes=19): 21 | super().__init__() 22 | self.lateral_convs = nn.ModuleList([]) 23 | self.output_convs = nn.ModuleList([]) 24 | 25 | for ch in in_channels[::-1]: 26 | self.lateral_convs.append(ConvModule(ch, channel, 1)) 27 | self.output_convs.append(ConvModule(channel, channel, 3, 1, 1)) 28 | 29 | self.conv_seg = nn.Conv2d(channel, num_classes, 1) 30 | self.dropout = nn.Dropout2d(0.1) 31 | 32 | def forward(self, features) -> Tensor: 33 | features = features[::-1] 34 | out = self.lateral_convs[0](features[0]) 35 | 36 | for i in range(1, len(features)): 37 | out = F.interpolate(out, scale_factor=2.0, mode='nearest') 38 | out = out + self.lateral_convs[i](features[i]) 39 | out = self.output_convs[i](out) 40 | out = self.conv_seg(self.dropout(out)) 41 | return out 42 | 43 | 44 | # if __name__ == '__main__': 45 | # from semseg.models.backbones import ResNet 46 | # 47 | # backbone = ResNet('50') 48 | # head = FPNHead([256, 512, 1024, 2048], 128, 19) 49 | # x = torch.randn(2, 3, 224, 224) 50 | # features = backbone(x) 51 | # out = head(features) 52 | # out = F.interpolate(out, size=x.shape[-2:], mode='bilinear', align_corners=False) 53 | # print(out.shape) 54 | -------------------------------------------------------------------------------- /new_model/decoders/fpn_head.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class AlignedModule_v1(nn.Module): 7 | 8 | def __init__(self, inplane, outplane=256, kernel_size=3, eps=1e-8): 9 | super(AlignedModule_v1, self).__init__() 10 | self.down_l = nn.Conv2d(inplane, outplane, 1, bias=False) 11 | self.flow_make = nn.Conv2d(outplane * 2, 2, kernel_size=kernel_size, padding=1, bias=False) 12 | # 自定义可训练权重参数 13 | self.weights = nn.Parameter(torch.ones(2, dtype=torch.float32), requires_grad=True) 14 | self.eps = eps 15 | 16 | def forward(self, low_feature, h_feature): 17 | h_feature_orign = h_feature 18 | h, w = low_feature.size()[2:] 19 | size = (h, w) 20 | low_feature = self.down_l(low_feature) 21 | h_feature = F.interpolate(h_feature, size=size, mode="bilinear", align_corners=True) 22 | flow = self.flow_make(torch.cat([h_feature, low_feature], 1)) 23 | h_feature = self.flow_warp(h_feature_orign, flow, size=size) 24 | 25 | weights = nn.ReLU()(self.weights) 26 | fuse_weights = weights / (torch.sum(weights, dim=0) + self.eps) 27 | fuse_feature = fuse_weights[0] * h_feature + fuse_weights[1] * low_feature 28 | 29 | return fuse_feature 30 | 31 | def flow_warp(self, input, flow, size): 32 | out_h, out_w = size 33 | n, c, h, w = input.size() 34 | 35 | norm = torch.tensor([[[[out_w, out_h]]]]).type_as(input).to(input.device) 36 | h = torch.linspace(-1.0, 1.0, out_h).view(-1, 1).repeat(1, out_w) 37 | w = torch.linspace(-1.0, 1.0, out_w).repeat(out_h, 1) 38 | grid = torch.cat((w.unsqueeze(2), h.unsqueeze(2)), 2) 39 | grid = grid.repeat(n, 1, 1, 1).type_as(input).to(input.device) 40 | grid = grid + flow.permute(0, 2, 3, 1) / norm 41 | 42 | output = F.grid_sample(input, grid, align_corners=True) 43 | return output 44 | 45 | 46 | class AlignedModule_v2(nn.Module): 47 | 48 | def __init__(self, inplane, outplane=256, kernel_size=3, eps=1e-8): 49 | super(AlignedModule_v2, self).__init__() 50 | self.down_l = nn.Conv2d(inplane, outplane, 1, bias=False) 51 | self.l_flow_make = nn.Conv2d(outplane * 2, 2, kernel_size=kernel_size, padding=1, bias=False) 52 | self.h_flow_make = nn.Conv2d(outplane * 2, 2, kernel_size=kernel_size, padding=1, bias=False) 53 | 54 | # 自定义可训练权重参数 55 | self.weights = nn.Parameter(torch.ones(2, dtype=torch.float32), requires_grad=True) 56 | self.eps = eps 57 | 58 | def forward(self, low_feature, high_feature): 59 | h, w = low_feature.size()[2:] 60 | size = (h, w) 61 | l_feature = self.down_l(low_feature) 62 | 63 | h_feature = F.interpolate(high_feature, size=size, mode="bilinear", align_corners=True) 64 | concat = torch.cat([h_feature, l_feature], 1) 65 | 66 | l_flow = self.l_flow_make(concat) 67 | h_flow = self.h_flow_make(concat) 68 | 69 | l_feature_warp = self.flow_warp(l_feature, l_flow, size=size) 70 | h_feature_warp = self.flow_warp(high_feature, h_flow, size=size) 71 | 72 | weights = nn.ReLU()(self.weights) 73 | fuse_weights = weights / (torch.sum(weights, dim=0) + self.eps) 74 | fuse_feature = fuse_weights[0] * h_feature_warp + fuse_weights[1] * l_feature_warp 75 | 76 | return fuse_feature 77 | 78 | def flow_warp(self, input, flow, size): 79 | out_h, out_w = size 80 | n, c, h, w = input.size() 81 | 82 | norm = torch.tensor([[[[out_w, out_h]]]]).type_as(input).to(input.device) 83 | h = torch.linspace(-1.0, 1.0, out_h).view(-1, 1).repeat(1, out_w) 84 | w = torch.linspace(-1.0, 1.0, out_w).repeat(out_h, 1) 85 | grid = torch.cat((w.unsqueeze(2), h.unsqueeze(2)), 2) 86 | grid = grid.repeat(n, 1, 1, 1).type_as(input).to(input.device) 87 | grid = grid + flow.permute(0, 2, 3, 1) / norm 88 | 89 | output = F.grid_sample(input, grid, align_corners=True) 90 | return output 91 | 92 | 93 | class AlignedModule_v3(nn.Module): 94 | 95 | def __init__(self, inplane, outplane=256, kernel_size=3, eps=1e-8): 96 | super(AlignedModule_v3, self).__init__() 97 | self.down_l = nn.Conv2d(inplane, outplane, 1, bias=False) 98 | self.flow_make = nn.Conv2d(outplane * 2, 4, kernel_size=kernel_size, padding=1, bias=False) 99 | 100 | # 自定义可训练权重参数 101 | self.weights = nn.Parameter(torch.ones(2, dtype=torch.float32), requires_grad=True) 102 | self.eps = eps 103 | 104 | def forward(self, low_feature, high_feature): 105 | h, w = low_feature.size()[2:] 106 | size = (h, w) 107 | l_feature = self.down_l(low_feature) 108 | 109 | h_feature = F.interpolate(high_feature, size=size, mode="bilinear", align_corners=True) 110 | concat = torch.cat([h_feature, l_feature], 1) 111 | 112 | flow = self.flow_make(concat) 113 | flow_up, flow_down = flow[:, :2, :, :], flow[:, 2:, :, :] 114 | 115 | l_feature_warp = self.flow_warp(l_feature, flow_down, size=size) 116 | h_feature_warp = self.flow_warp(high_feature, flow_up, size=size) 117 | 118 | weights = nn.ReLU()(self.weights) 119 | fuse_weights = weights / (torch.sum(weights, dim=0) + self.eps) 120 | fuse_feature = fuse_weights[0] * h_feature_warp + fuse_weights[1] * l_feature_warp 121 | 122 | return fuse_feature 123 | 124 | def flow_warp(self, input, flow, size): 125 | out_h, out_w = size 126 | n, c, h, w = input.size() 127 | 128 | norm = torch.tensor([[[[out_w, out_h]]]]).type_as(input).to(input.device) 129 | h = torch.linspace(-1.0, 1.0, out_h).view(-1, 1).repeat(1, out_w) 130 | w = torch.linspace(-1.0, 1.0, out_w).repeat(out_h, 1) 131 | grid = torch.cat((w.unsqueeze(2), h.unsqueeze(2)), 2) 132 | grid = grid.repeat(n, 1, 1, 1).type_as(input).to(input.device) 133 | grid = grid + flow.permute(0, 2, 3, 1) / norm 134 | 135 | output = F.grid_sample(input, grid, align_corners=True) 136 | return output 137 | 138 | 139 | class Conv3x3GNReLU(nn.Module): 140 | def __init__(self, in_channels, out_channels, upsample=False): 141 | super().__init__() 142 | self.upsample = upsample 143 | self.block = nn.Sequential( 144 | nn.Conv2d( 145 | in_channels, out_channels, (3, 3), stride=1, padding=1, bias=False 146 | ), 147 | nn.GroupNorm(32, out_channels), 148 | nn.ReLU(inplace=True), 149 | ) 150 | 151 | def forward(self, x): 152 | x = self.block(x) 153 | if self.upsample: 154 | x = F.interpolate(x, scale_factor=2, mode="bilinear", align_corners=True) 155 | return x 156 | 157 | 158 | # 先上采样再调整channel 159 | class FPNBlock(nn.Module): 160 | def __init__(self, skip_channels, pyramid_channels): 161 | super().__init__() 162 | self.skip_conv = nn.Conv2d(skip_channels, pyramid_channels, kernel_size=1) 163 | 164 | def forward(self, skip, x): 165 | x = F.interpolate(x, scale_factor=2, mode="nearest") 166 | skip = self.skip_conv(skip) 167 | x = x + skip 168 | return x 169 | 170 | 171 | class SegmentationBlock(nn.Module): 172 | def __init__(self, in_channels, out_channels, n_upsamples=0): 173 | super().__init__() 174 | 175 | blocks = [Conv3x3GNReLU(in_channels, out_channels, upsample=bool(n_upsamples))] 176 | 177 | if n_upsamples > 1: 178 | for _ in range(1, n_upsamples): 179 | blocks.append(Conv3x3GNReLU(out_channels, out_channels, upsample=True)) 180 | 181 | self.block = nn.Sequential(*blocks) 182 | 183 | def forward(self, x): 184 | return self.block(x) 185 | 186 | 187 | class MergeBlock(nn.Module): 188 | def __init__(self, policy): 189 | super().__init__() 190 | if policy not in ["add", "cat"]: 191 | raise ValueError( 192 | "`merge_policy` must be one of: ['add', 'cat'], got {}".format( 193 | policy 194 | ) 195 | ) 196 | self.policy = policy 197 | 198 | def forward(self, x): 199 | if self.policy == 'add': 200 | return sum(x) 201 | elif self.policy == 'cat': 202 | return torch.cat(x, dim=1) 203 | else: 204 | raise ValueError( 205 | "`merge_policy` must be one of: ['add', 'cat'], got {}".format(self.policy) 206 | ) 207 | 208 | 209 | class AlignedFPNDecoder(nn.Module): 210 | def __init__( 211 | self, 212 | encoder_channels=[64, 128, 320, 512], 213 | num_classes=40, 214 | pyramid_channels=512, 215 | segmentation_channels=256, 216 | dropout=0.1, 217 | merge_policy="cat", 218 | ): 219 | super().__init__() 220 | self.num_classes = num_classes 221 | self.out_channels = segmentation_channels if merge_policy == "add" else segmentation_channels * 4 222 | 223 | self.p5 = nn.Conv2d(encoder_channels[3], pyramid_channels, kernel_size=1) 224 | self.p4 = AlignedModule_v1(encoder_channels[2], pyramid_channels) 225 | self.p3 = AlignedModule_v1(encoder_channels[1], pyramid_channels) 226 | self.p2 = AlignedModule_v1(encoder_channels[0], pyramid_channels) 227 | 228 | self.seg_blocks = nn.ModuleList([ 229 | SegmentationBlock(pyramid_channels, segmentation_channels, n_upsamples=n_upsamples) 230 | for n_upsamples in [3, 2, 1, 0] 231 | ]) 232 | 233 | self.merge = MergeBlock(merge_policy) 234 | self.dropout = nn.Dropout2d(p=dropout) 235 | self.pred = nn.Conv2d(self.out_channels, self.num_classes, kernel_size=1) 236 | 237 | def forward(self, features): 238 | # len=4, 1/4,1/8,1/16,1/32 239 | c2, c3, c4, c5 = features 240 | 241 | p5 = self.p5(c5) 242 | p4 = self.p4(c4, p5) 243 | p3 = self.p3(c3, p4) 244 | p2 = self.p2(c2, p3) 245 | 246 | feature_pyramid = [seg_block(p) for seg_block, p in zip(self.seg_blocks, [p5, p4, p3, p2])] 247 | x = self.merge(feature_pyramid) 248 | x = self.dropout(x) 249 | x = self.pred(x) 250 | 251 | return x 252 | -------------------------------------------------------------------------------- /new_model/decoders/hem.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from fvcore.nn import flop_count_table, FlopCountAnalysis 5 | 6 | 7 | class ConvModule(nn.Sequential): 8 | def __init__(self, c1, c2, k, s=1, p=0, d=1, g=1): 9 | super().__init__( 10 | nn.Conv2d(c1, c2, k, s, p, d, g, bias=False), 11 | nn.BatchNorm2d(c2), 12 | nn.ReLU(True) 13 | ) 14 | 15 | 16 | class _MatrixDecomposition2DBase(nn.Module): 17 | def __init__(self, args=dict()): 18 | super().__init__() 19 | 20 | self.spatial = args.setdefault('SPATIAL', True) 21 | 22 | self.S = args.setdefault('MD_S', 1) 23 | self.D = args.setdefault('MD_D', 512) 24 | self.R = args.setdefault('MD_R', 64) 25 | 26 | self.train_steps = args.setdefault('TRAIN_STEPS', 6) 27 | self.eval_steps = args.setdefault('EVAL_STEPS', 7) 28 | 29 | self.inv_t = args.setdefault('INV_T', 100) 30 | self.eta = args.setdefault('ETA', 0.9) 31 | 32 | self.rand_init = args.setdefault('RAND_INIT', True) 33 | 34 | print('spatial', self.spatial) 35 | print('S', self.S) 36 | print('D', self.D) 37 | print('R', self.R) 38 | print('train_steps', self.train_steps) 39 | print('eval_steps', self.eval_steps) 40 | print('inv_t', self.inv_t) 41 | print('eta', self.eta) 42 | print('rand_init', self.rand_init) 43 | 44 | def _build_bases(self, B, S, D, R, cuda=False): 45 | raise NotImplementedError 46 | 47 | def local_step(self, x, bases, coef): 48 | raise NotImplementedError 49 | 50 | # @torch.no_grad() 51 | def local_inference(self, x, bases): 52 | # (B * S, D, N)^T @ (B * S, D, R) -> (B * S, N, R) 53 | coef = torch.bmm(x.transpose(1, 2), bases) 54 | coef = F.softmax(self.inv_t * coef, dim=-1) 55 | 56 | steps = self.train_steps if self.training else self.eval_steps 57 | for _ in range(steps): 58 | bases, coef = self.local_step(x, bases, coef) 59 | 60 | return bases, coef 61 | 62 | def compute_coef(self, x, bases, coef): 63 | raise NotImplementedError 64 | 65 | def forward(self, x, return_bases=False): 66 | B, C, H, W = x.shape 67 | 68 | # (B, C, H, W) -> (B * S, D, N) 69 | if self.spatial: 70 | D = C // self.S 71 | N = H * W 72 | x = x.view(B * self.S, D, N) 73 | else: 74 | D = H * W 75 | N = C // self.S 76 | x = x.view(B * self.S, N, D).transpose(1, 2) 77 | 78 | if not self.rand_init and not hasattr(self, 'bases'): 79 | bases = self._build_bases(1, self.S, D, self.R, cuda=True) 80 | self.register_buffer('bases', bases) 81 | 82 | # (S, D, R) -> (B * S, D, R) 83 | if self.rand_init: 84 | bases = self._build_bases(B, self.S, D, self.R, cuda=True) 85 | else: 86 | bases = self.bases.repeat(B, 1, 1) 87 | 88 | bases, coef = self.local_inference(x, bases) 89 | 90 | # (B * S, N, R) 91 | coef = self.compute_coef(x, bases, coef) 92 | 93 | # (B * S, D, R) @ (B * S, N, R)^T -> (B * S, D, N) 94 | x = torch.bmm(bases, coef.transpose(1, 2)) 95 | 96 | # (B * S, D, N) -> (B, C, H, W) 97 | if self.spatial: 98 | x = x.view(B, C, H, W) 99 | else: 100 | x = x.transpose(1, 2).view(B, C, H, W) 101 | 102 | # (B * H, D, R) -> (B, H, N, D) 103 | bases = bases.view(B, self.S, D, self.R) 104 | 105 | return x 106 | 107 | 108 | class NMF2D(_MatrixDecomposition2DBase): 109 | def __init__(self, args=dict()): 110 | super().__init__(args) 111 | 112 | self.inv_t = 1 113 | 114 | def _build_bases(self, B, S, D, R, cuda=False): 115 | if cuda: 116 | bases = torch.rand((B * S, D, R)).cuda() 117 | else: 118 | bases = torch.rand((B * S, D, R)) 119 | 120 | bases = F.normalize(bases, dim=1) 121 | 122 | return bases 123 | 124 | # @torch.no_grad() 125 | def local_step(self, x, bases, coef): 126 | # (B * S, D, N)^T @ (B * S, D, R) -> (B * S, N, R) 127 | numerator = torch.bmm(x.transpose(1, 2), bases) 128 | # (B * S, N, R) @ [(B * S, D, R)^T @ (B * S, D, R)] -> (B * S, N, R) 129 | denominator = coef.bmm(bases.transpose(1, 2).bmm(bases)) 130 | # Multiplicative Update 131 | coef = coef * numerator / (denominator + 1e-6) 132 | 133 | # (B * S, D, N) @ (B * S, N, R) -> (B * S, D, R) 134 | numerator = torch.bmm(x, coef) 135 | # (B * S, D, R) @ [(B * S, N, R)^T @ (B * S, N, R)] -> (B * S, D, R) 136 | denominator = bases.bmm(coef.transpose(1, 2).bmm(coef)) 137 | # Multiplicative Update 138 | bases = bases * numerator / (denominator + 1e-6) 139 | 140 | return bases, coef 141 | 142 | def compute_coef(self, x, bases, coef): 143 | # (B * S, D, N)^T @ (B * S, D, R) -> (B * S, N, R) 144 | numerator = torch.bmm(x.transpose(1, 2), bases) 145 | # (B * S, N, R) @ (B * S, D, R)^T @ (B * S, D, R) -> (B * S, N, R) 146 | denominator = coef.bmm(bases.transpose(1, 2).bmm(bases)) 147 | # multiplication update 148 | coef = coef * numerator / (denominator + 1e-6) 149 | 150 | return coef 151 | 152 | 153 | class Hamburger(nn.Module): 154 | def __init__(self, ham_channels=512, ham_kwargs=dict(), norm_cfg=None): 155 | super().__init__() 156 | self.ham_in = ConvModule(ham_channels, ham_channels, 1, norm_cfg=None, act_cfg=None) 157 | self.ham = NMF2D(ham_kwargs) 158 | self.ham_out = ConvModule(ham_channels, ham_channels, 1, norm_cfg=norm_cfg, act_cfg=None) 159 | 160 | def forward(self, x): 161 | enjoy = self.ham_in(x) 162 | enjoy = F.relu(enjoy, inplace=True) 163 | enjoy = self.ham(enjoy) 164 | enjoy = self.ham_out(enjoy) 165 | ham = F.relu(x + enjoy, inplace=True) 166 | return ham 167 | 168 | 169 | class LightHamHead(nn.Module): 170 | def __init__(self, in_channels=[64, 128, 320, 512], ham_channels=512, ham_kwargs=dict(), num_classes=25): 171 | super().__init__() 172 | self.in_channels = in_channels[1:] 173 | self.in_index = [1, 2, 3] 174 | self.ham_channels = self.channels = ham_channels 175 | self.conv_cfg = None 176 | self.norm_cfg = dict(type='GN', num_groups=32, requires_grad=True) 177 | self.act_cfg = dict(type='ReLU') 178 | 179 | self.ham_channels = ham_channels 180 | self.squeeze = ConvModule(sum(self.in_channels), self.ham_channels, 1, conv_cfg=self.conv_cfg, 181 | norm_cfg=self.norm_cfg, act_cfg=self.act_cfg) 182 | self.hamburger = Hamburger(ham_channels, ham_kwargs, self.norm_cfg) 183 | self.align = ConvModule(self.ham_channels, self.channels, 1, conv_cfg=self.conv_cfg, 184 | norm_cfg=self.norm_cfg, act_cfg=self.act_cfg) 185 | self.conv_seg = nn.Conv2d(self.channels, num_classes, kernel_size=1) 186 | 187 | def forward(self, inputs): 188 | """Forward function.""" 189 | inputs = [inputs[i] for i in self.in_index] 190 | 191 | inputs = [F.interpolate(level, size=inputs[0].shape[2:], mode='bilinear', align_corners=False) for level in 192 | inputs] 193 | 194 | inputs = torch.cat(inputs, dim=1) 195 | x = self.squeeze(inputs) 196 | 197 | x = self.hamburger(x) 198 | 199 | output = self.align(x) 200 | output = self.conv_seg(output) 201 | return output 202 | 203 | 204 | if __name__ == '__main__': 205 | model = LightHamHead(num_classes=25) 206 | model = model.cuda() 207 | x = [torch.zeros(1, 64, 256, 256), torch.ones(1, 128, 128, 128), torch.ones(1, 320, 64, 64) * 2, 208 | torch.ones(1, 512, 32, 32) * 3] 209 | x = [xi.cuda() for xi in x] 210 | outs = model(x) 211 | print(model) 212 | for y in outs: 213 | print(y.shape) 214 | print(flop_count_table(FlopCountAnalysis(model, x))) 215 | -------------------------------------------------------------------------------- /new_model/decoders/lawin.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn, Tensor 3 | from torch.nn import functional as F 4 | from einops import rearrange 5 | 6 | 7 | class MLP(nn.Module): 8 | def __init__(self, dim=2048, embed_dim=768): 9 | super().__init__() 10 | self.proj = nn.Linear(dim, embed_dim) 11 | 12 | def forward(self, x: Tensor) -> Tensor: 13 | x = x.flatten(2).transpose(1, 2) 14 | x = self.proj(x) 15 | return x 16 | 17 | 18 | class PatchEmbed(nn.Module): 19 | def __init__(self, patch_size=4, in_ch=3, dim=96, type='pool') -> None: 20 | super().__init__() 21 | self.patch_size = patch_size 22 | self.type = type 23 | self.dim = dim 24 | 25 | if type == 'conv': 26 | self.proj = nn.Conv2d(in_ch, dim, patch_size, patch_size, groups=patch_size * patch_size) 27 | else: 28 | self.proj = nn.ModuleList([ 29 | nn.MaxPool2d(patch_size, patch_size), 30 | nn.AvgPool2d(patch_size, patch_size) 31 | ]) 32 | 33 | self.norm = nn.LayerNorm(dim) 34 | 35 | def forward(self, x: Tensor) -> Tensor: 36 | _, _, H, W = x.shape 37 | if W % self.patch_size != 0: 38 | x = F.pad(x, (0, self.patch_size - W % self.patch_size)) 39 | if H % self.patch_size != 0: 40 | x = F.pad(x, (0, 0, 0, self.patch_size - H % self.patch_size)) 41 | 42 | if self.type == 'conv': 43 | x = self.proj(x) 44 | else: 45 | x = 0.5 * (self.proj[0](x) + self.proj[1](x)) 46 | Wh, Ww = x.size(2), x.size(3) 47 | x = x.flatten(2).transpose(1, 2) 48 | x = self.norm(x) 49 | x = x.transpose(1, 2).view(-1, self.dim, Wh, Ww) 50 | return x 51 | 52 | 53 | class LawinAttn(nn.Module): 54 | def __init__(self, in_ch=512, head=4, patch_size=8, reduction=2) -> None: 55 | super().__init__() 56 | self.head = head 57 | 58 | self.position_mixing = nn.ModuleList([ 59 | nn.Linear(patch_size * patch_size, patch_size * patch_size) 60 | for _ in range(self.head)]) 61 | 62 | self.inter_channels = max(in_ch // reduction, 1) 63 | self.g = nn.Conv2d(in_ch, self.inter_channels, 1) 64 | self.theta = nn.Conv2d(in_ch, self.inter_channels, 1) 65 | self.phi = nn.Conv2d(in_ch, self.inter_channels, 1) 66 | self.conv_out = nn.Sequential( 67 | nn.Conv2d(self.inter_channels, in_ch, 1, bias=False), 68 | nn.BatchNorm2d(in_ch) 69 | ) 70 | 71 | def forward(self, query: Tensor, context: Tensor) -> Tensor: 72 | B, C, H, W = context.shape 73 | context = context.reshape(B, C, -1) 74 | context_mlp = [] 75 | 76 | for i, pm in enumerate(self.position_mixing): 77 | context_crt = context[:, (C // self.head) * i:(C // self.head) * (i + 1), :] 78 | context_mlp.append(pm(context_crt)) 79 | 80 | context_mlp = torch.cat(context_mlp, dim=1) 81 | context = context + context_mlp 82 | context = context.reshape(B, C, H, W) 83 | 84 | g_x = self.g(context).view(B, self.inter_channels, -1) 85 | g_x = rearrange(g_x, "b (h dim) n -> (b h) dim n", h=self.head) 86 | g_x = g_x.permute(0, 2, 1) 87 | 88 | theta_x = self.theta(query).view(B, self.inter_channels, -1) 89 | theta_x = rearrange(theta_x, "b (h dim) n -> (b h) dim n", h=self.head) 90 | theta_x = theta_x.permute(0, 2, 1) 91 | 92 | phi_x = self.phi(context).view(B, self.inter_channels, -1) 93 | phi_x = rearrange(phi_x, "b (h dim) n -> (b h) dim n", h=self.head) 94 | 95 | pairwise_weight = torch.matmul(theta_x, phi_x) 96 | pairwise_weight /= theta_x.shape[-1] ** 0.5 97 | pairwise_weight = pairwise_weight.softmax(dim=-1) 98 | 99 | y = torch.matmul(pairwise_weight, g_x) 100 | y = rearrange(y, "(b h) n dim -> b n (h dim)", h=self.head) 101 | y = y.permute(0, 2, 1).contiguous().reshape(B, self.inter_channels, *query.shape[-2:]) 102 | 103 | output = query + self.conv_out(y) 104 | return output 105 | 106 | 107 | class ConvModule(nn.Module): 108 | def __init__(self, c1, c2): 109 | super().__init__() 110 | self.conv = nn.Conv2d(c1, c2, 1, bias=False) 111 | self.bn = nn.BatchNorm2d(c2) # use SyncBN in original 112 | self.activate = nn.ReLU(True) 113 | 114 | def forward(self, x: Tensor) -> Tensor: 115 | return self.activate(self.bn(self.conv(x))) 116 | 117 | 118 | class LawinHead(nn.Module): 119 | def __init__(self, in_channels: list, embed_dim=512, num_classes=19) -> None: 120 | super().__init__() 121 | for i, dim in enumerate(in_channels): 122 | self.add_module(f"linear_c{i + 1}", MLP(dim, 48 if i == 0 else embed_dim)) 123 | 124 | self.lawin_8 = LawinAttn(embed_dim, 64) 125 | self.lawin_4 = LawinAttn(embed_dim, 16) 126 | self.lawin_2 = LawinAttn(embed_dim, 4) 127 | self.ds_8 = PatchEmbed(8, embed_dim, embed_dim) 128 | self.ds_4 = PatchEmbed(4, embed_dim, embed_dim) 129 | self.ds_2 = PatchEmbed(2, embed_dim, embed_dim) 130 | 131 | self.image_pool = nn.Sequential( 132 | nn.AdaptiveAvgPool2d(1), 133 | ConvModule(embed_dim, embed_dim) 134 | ) 135 | self.linear_fuse = ConvModule(embed_dim * 3, embed_dim) 136 | self.short_path = ConvModule(embed_dim, embed_dim) 137 | self.cat = ConvModule(embed_dim * 5, embed_dim) 138 | 139 | self.low_level_fuse = ConvModule(embed_dim + 48, embed_dim) 140 | self.linear_pred = nn.Conv2d(embed_dim, num_classes, 1) 141 | self.dropout = nn.Dropout2d(0.1) 142 | 143 | def get_lawin_att_feats(self, x: Tensor, patch_size: int): 144 | _, _, H, W = x.shape 145 | query = F.unfold(x, patch_size, stride=patch_size) 146 | query = rearrange(query, 'b (c ph pw) (nh nw) -> (b nh nw) c ph pw', ph=patch_size, pw=patch_size, 147 | nh=H // patch_size, nw=W // patch_size) 148 | outs = [] 149 | 150 | for r in [8, 4, 2]: 151 | context = F.unfold(x, patch_size * r, stride=patch_size, padding=int((r - 1) / 2 * patch_size)) 152 | context = rearrange(context, "b (c ph pw) (nh nw) -> (b nh nw) c ph pw", ph=patch_size * r, 153 | pw=patch_size * r, nh=H // patch_size, nw=W // patch_size) 154 | context = getattr(self, f"ds_{r}")(context) 155 | output = getattr(self, f"lawin_{r}")(query, context) 156 | output = rearrange(output, "(b nh nw) c ph pw -> b c (nh ph) (nw pw)", ph=patch_size, pw=patch_size, 157 | nh=H // patch_size, nw=W // patch_size) 158 | outs.append(output) 159 | return outs 160 | 161 | def forward(self, features): 162 | B, _, H, W = features[1].shape 163 | outs = [self.linear_c2(features[1]).permute(0, 2, 1).reshape(B, -1, *features[1].shape[-2:])] 164 | 165 | for i, feature in enumerate(features[2:]): 166 | cf = eval(f"self.linear_c{i + 3}")(feature).permute(0, 2, 1).reshape(B, -1, *feature.shape[-2:]) 167 | outs.append(F.interpolate(cf, size=(H, W), mode='bilinear', align_corners=False)) 168 | 169 | feat = self.linear_fuse(torch.cat(outs[::-1], dim=1)) 170 | B, _, H, W = feat.shape 171 | 172 | ## Lawin attention spatial pyramid pooling 173 | feat_short = self.short_path(feat) 174 | feat_pool = F.interpolate(self.image_pool(feat), size=(H, W), mode='bilinear', align_corners=False) 175 | feat_lawin = self.get_lawin_att_feats(feat, 8) 176 | output = self.cat(torch.cat([feat_short, feat_pool, *feat_lawin], dim=1)) 177 | 178 | ## Low-level feature enhancement 179 | c1 = self.linear_c1(features[0]).permute(0, 2, 1).reshape(B, -1, *features[0].shape[-2:]) 180 | output = F.interpolate(output, size=features[0].shape[-2:], mode='bilinear', align_corners=False) 181 | fused = self.low_level_fuse(torch.cat([output, c1], dim=1)) 182 | 183 | seg = self.linear_pred(self.dropout(fused)) 184 | return seg 185 | -------------------------------------------------------------------------------- /new_model/decoders/sfnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn, Tensor 3 | from torch.nn import functional as F 4 | 5 | 6 | class ConvModule(nn.Sequential): 7 | def __init__(self, c1, c2, k, s=1, p=0, d=1, g=1): 8 | super().__init__( 9 | nn.Conv2d(c1, c2, k, s, p, d, g, bias=False), 10 | nn.BatchNorm2d(c2), 11 | nn.ReLU(True) 12 | ) 13 | 14 | 15 | class ConvModule(nn.Sequential): 16 | def __init__(self, c1, c2, k, s=1, p=0, d=1, g=1): 17 | super().__init__( 18 | nn.Conv2d(c1, c2, k, s, p, d, g, bias=False), 19 | nn.BatchNorm2d(c2), 20 | nn.ReLU(True) 21 | ) 22 | 23 | 24 | class PPM(nn.Module): 25 | """Pyramid Pooling Module in PSPNet 26 | """ 27 | 28 | def __init__(self, c1, c2=128, scales=(1, 2, 3, 6)): 29 | super().__init__() 30 | self.stages = nn.ModuleList([ 31 | nn.Sequential( 32 | nn.AdaptiveAvgPool2d(scale), 33 | ConvModule(c1, c2, 1) 34 | # ConvModule(c1, c2, 1, p=1) 35 | ) 36 | for scale in scales]) 37 | 38 | self.bottleneck = ConvModule(c1 + c2 * len(scales), c2, 3, 1, 1) 39 | 40 | def forward(self, x: Tensor) -> Tensor: 41 | outs = [] 42 | for stage in self.stages: 43 | outs.append(F.interpolate(stage(x), size=x.shape[-2:], mode='bilinear', align_corners=True)) 44 | 45 | outs = [x] + outs[::-1] 46 | out = self.bottleneck(torch.cat(outs, dim=1)) 47 | return out 48 | 49 | 50 | class AlignedModule(nn.Module): 51 | def __init__(self, c1, c2, k=3): 52 | super().__init__() 53 | self.down_h = nn.Conv2d(c1, c2, 1, bias=False) 54 | self.down_l = nn.Conv2d(c1, c2, 1, bias=False) 55 | self.flow_make = nn.Conv2d(c2 * 2, 2, k, 1, 1, bias=False) 56 | 57 | def forward(self, low_feature: Tensor, high_feature: Tensor) -> Tensor: 58 | high_feature_origin = high_feature 59 | H, W = low_feature.shape[-2:] 60 | low_feature = self.down_l(low_feature) 61 | high_feature = self.down_h(high_feature) 62 | high_feature = F.interpolate(high_feature, size=(H, W), mode='bilinear', align_corners=True) 63 | flow = self.flow_make(torch.cat([high_feature, low_feature], dim=1)) 64 | high_feature = self.flow_warp(high_feature_origin, flow, (H, W)) 65 | return high_feature 66 | 67 | def flow_warp(self, x: Tensor, flow: Tensor, size: tuple) -> Tensor: 68 | norm = torch.tensor([[[[*size]]]]).type_as(x).to(x.device) 69 | H = torch.linspace(-1.0, 1.0, size[0]).view(-1, 1).repeat(1, size[1]) 70 | W = torch.linspace(-1.0, 1.0, size[1]).repeat(size[0], 1) 71 | grid = torch.cat((W.unsqueeze(2), H.unsqueeze(2)), dim=2) 72 | grid = grid.repeat(x.shape[0], 1, 1, 1).type_as(x).to(x.device) 73 | grid = grid + flow.permute(0, 2, 3, 1) / norm 74 | output = F.grid_sample(x, grid, align_corners=False) 75 | return output 76 | 77 | 78 | class SFHead(nn.Module): 79 | def __init__(self, in_channels, channel=256, num_classes=19, scales=(1, 2, 3, 6)): 80 | super().__init__() 81 | self.ppm = PPM(in_channels[-1], channel, scales) 82 | 83 | self.fpn_in = nn.ModuleList([]) 84 | self.fpn_out = nn.ModuleList([]) 85 | self.fpn_out_align = nn.ModuleList([]) 86 | 87 | for in_ch in in_channels[:-1]: 88 | self.fpn_in.append(ConvModule(in_ch, channel, 1)) 89 | self.fpn_out.append(ConvModule(channel, channel, 3, 1, 1)) 90 | self.fpn_out_align.append(AlignedModule(channel, channel // 2)) 91 | 92 | self.bottleneck = ConvModule(len(in_channels) * channel, channel, 3, 1, 1) 93 | self.dropout = nn.Dropout2d(0.1) 94 | self.conv_seg = nn.Conv2d(channel, num_classes, 1) 95 | 96 | def forward(self, features: list) -> Tensor: 97 | f = self.ppm(features[-1]) 98 | fpn_features = [f] 99 | 100 | for i in reversed(range(len(features) - 1)): 101 | feature = self.fpn_in[i](features[i]) 102 | f = feature + self.fpn_out_align[i](feature, f) 103 | fpn_features.append(self.fpn_out[i](f)) 104 | 105 | fpn_features.reverse() 106 | 107 | for i in range(1, len(fpn_features)): 108 | fpn_features[i] = F.interpolate(fpn_features[i], size=fpn_features[0].shape[-2:], mode='bilinear', 109 | align_corners=True) 110 | 111 | output = self.bottleneck(torch.cat(fpn_features, dim=1)) 112 | output = self.conv_seg(self.dropout(output)) 113 | return output 114 | -------------------------------------------------------------------------------- /new_model/encoders/Transformer/BiFormer/__init__.py: -------------------------------------------------------------------------------- 1 | from .bra_legacy import * 2 | from .modules import * 3 | from .dual_biformer import * -------------------------------------------------------------------------------- /new_model/encoders/Transformer/BiFormer/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/masurq/CFFormer/6be4de96f134c162253933dbf4271485d2286a87/new_model/encoders/Transformer/BiFormer/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /new_model/encoders/Transformer/BiFormer/__pycache__/biformer.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/masurq/CFFormer/6be4de96f134c162253933dbf4271485d2286a87/new_model/encoders/Transformer/BiFormer/__pycache__/biformer.cpython-38.pyc -------------------------------------------------------------------------------- /new_model/encoders/Transformer/BiFormer/__pycache__/bra_legacy.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/masurq/CFFormer/6be4de96f134c162253933dbf4271485d2286a87/new_model/encoders/Transformer/BiFormer/__pycache__/bra_legacy.cpython-38.pyc -------------------------------------------------------------------------------- /new_model/encoders/Transformer/BiFormer/__pycache__/dual_biformer.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/masurq/CFFormer/6be4de96f134c162253933dbf4271485d2286a87/new_model/encoders/Transformer/BiFormer/__pycache__/dual_biformer.cpython-38.pyc -------------------------------------------------------------------------------- /new_model/encoders/Transformer/BiFormer/__pycache__/modules.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/masurq/CFFormer/6be4de96f134c162253933dbf4271485d2286a87/new_model/encoders/Transformer/BiFormer/__pycache__/modules.cpython-38.pyc -------------------------------------------------------------------------------- /new_model/encoders/Transformer/BiFormer/bra_legacy.py: -------------------------------------------------------------------------------- 1 | """ 2 | Core of BiFormer, Bi-Level Routing Attention. 3 | 4 | To be refactored. 5 | 6 | author: ZHU Lei 7 | github: https://github.com/rayleizhu 8 | email: ray.leizhu@outlook.com 9 | 10 | This source code is licensed under the license found in the 11 | LICENSE file in the root directory of this source tree. 12 | """ 13 | from typing import Tuple 14 | 15 | import torch 16 | import torch.nn as nn 17 | import torch.nn.functional as F 18 | from einops import rearrange 19 | from torch import Tensor 20 | 21 | 22 | class TopkRouting(nn.Module): 23 | """ 24 | differentiable topk routing with scaling 25 | Args: 26 | qk_dim: int, feature dimension of query and key 27 | topk: int, the 'topk' 28 | qk_scale: int or None, temperature (multiply) of softmax activation 29 | with_param: bool, wether inorporate learnable params in routing unit 30 | diff_routing: bool, wether make routing differentiable 31 | soft_routing: bool, wether make output value multiplied by routing weights 32 | """ 33 | 34 | def __init__(self, qk_dim, topk=4, qk_scale=None, param_routing=False, diff_routing=False): 35 | super().__init__() 36 | self.topk = topk 37 | self.qk_dim = qk_dim 38 | self.scale = qk_scale or qk_dim ** -0.5 39 | self.diff_routing = diff_routing 40 | # TODO: norm layer before/after linear? 41 | self.emb = nn.Linear(qk_dim, qk_dim) if param_routing else nn.Identity() 42 | # routing activation 43 | self.routing_act = nn.Softmax(dim=-1) 44 | 45 | def forward(self, query: Tensor, key: Tensor) -> Tuple[Tensor]: 46 | """ 47 | Args: 48 | q, k: (n, p^2, c) tensor 49 | Return: 50 | r_weight, topk_index: (n, p^2, topk) tensor 51 | """ 52 | if not self.diff_routing: 53 | query, key = query.detach(), key.detach() 54 | query_hat, key_hat = self.emb(query), self.emb(key) # per-window pooling -> (n, p^2, c) 55 | attn_logit = (query_hat * self.scale) @ key_hat.transpose(-2, -1) # (n, p^2, p^2) 56 | topk_attn_logit, topk_index = torch.topk(attn_logit, k=self.topk, dim=-1) # (n, p^2, k), (n, p^2, k) 57 | r_weight = self.routing_act(topk_attn_logit) # (n, p^2, k) 58 | 59 | return r_weight, topk_index 60 | 61 | 62 | class KVGather(nn.Module): 63 | def __init__(self, mul_weight='none'): 64 | super().__init__() 65 | assert mul_weight in ['none', 'soft', 'hard'] 66 | self.mul_weight = mul_weight 67 | 68 | def forward(self, r_idx: Tensor, r_weight: Tensor, kv: Tensor): 69 | """ 70 | r_idx: (n, p^2, topk) tensor 71 | r_weight: (n, p^2, topk) tensor 72 | kv: (n, p^2, w^2, c_kq+c_v) 73 | 74 | Return: 75 | (n, p^2, topk, w^2, c_kq+c_v) tensor 76 | """ 77 | # select kv according to routing index 78 | n, p2, w2, c_kv = kv.size() 79 | topk = r_idx.size(-1) 80 | # print(r_idx.size(), r_weight.size()) 81 | # FIXME: gather consumes much memory (topk times redundancy), write cuda kernel? 82 | topk_kv = torch.gather(kv.view(n, 1, p2, w2, c_kv).expand(-1, p2, -1, -1, -1), 83 | # (n, p^2, p^2, w^2, c_kv) without mem cpy 84 | dim=2, 85 | index=r_idx.view(n, p2, topk, 1, 1).expand(-1, -1, -1, w2, c_kv) 86 | # (n, p^2, k, w^2, c_kv) 87 | ) 88 | 89 | if self.mul_weight == 'soft': 90 | topk_kv = r_weight.view(n, p2, topk, 1, 1) * topk_kv # (n, p^2, k, w^2, c_kv) 91 | elif self.mul_weight == 'hard': 92 | raise NotImplementedError('differentiable hard routing TBA') 93 | # else: #'none' 94 | # topk_kv = topk_kv # do nothing 95 | 96 | return topk_kv 97 | 98 | 99 | class QKVLinear(nn.Module): 100 | def __init__(self, dim, qk_dim, bias=True): 101 | super().__init__() 102 | self.dim = dim 103 | self.qk_dim = qk_dim 104 | self.qkv = nn.Linear(dim, qk_dim + qk_dim + dim, bias=bias) 105 | 106 | def forward(self, x): 107 | q, kv = self.qkv(x).split([self.qk_dim, self.qk_dim + self.dim], dim=-1) 108 | return q, kv 109 | # q, k, v = self.qkv(x).split([self.qk_dim, self.qk_dim, self.dim], dim=-1) 110 | # return q, k, v 111 | 112 | 113 | class BiLevelRoutingAttention(nn.Module): 114 | """ 115 | n_win: number of windows in one side (so the actual number of windows is n_win*n_win) 116 | kv_per_win: for kv_downsample_mode='ada_xxxpool' only, number of key/values per window. Similar to n_win, the actual number is kv_per_win*kv_per_win. 117 | topk: topk for window filtering 118 | param_attention: 'qkvo'-linear for q,k,v and o, 'none': param free attention 119 | param_routing: extra linear for routing 120 | diff_routing: wether to set routing differentiable 121 | soft_routing: wether to multiply soft routing weights 122 | """ 123 | 124 | def __init__(self, dim, num_heads=8, n_win=7, qk_dim=None, qk_scale=None, 125 | kv_per_win=4, kv_downsample_ratio=4, kv_downsample_kernel=None, kv_downsample_mode='identity', 126 | topk=4, param_attention="qkvo", param_routing=False, diff_routing=False, soft_routing=False, 127 | side_dwconv=3, 128 | auto_pad=False): 129 | super().__init__() 130 | # local attention setting 131 | self.dim = dim 132 | self.n_win = n_win # Wh, Ww 133 | self.num_heads = num_heads 134 | self.qk_dim = qk_dim or dim 135 | assert self.qk_dim % num_heads == 0 and self.dim % num_heads == 0, 'qk_dim and dim must be divisible by num_heads!' 136 | self.scale = qk_scale or self.qk_dim ** -0.5 137 | 138 | ################side_dwconv (i.e. LCE in ShuntedTransformer)########### 139 | self.lepe = nn.Conv2d(dim, dim, kernel_size=side_dwconv, stride=1, padding=side_dwconv // 2, 140 | groups=dim) if side_dwconv > 0 else \ 141 | lambda x: torch.zeros_like(x) 142 | 143 | ################ global routing setting ################# 144 | self.topk = topk 145 | self.param_routing = param_routing 146 | self.diff_routing = diff_routing 147 | self.soft_routing = soft_routing 148 | # router 149 | assert not (self.param_routing and not self.diff_routing) # cannot be with_param=True and diff_routing=False 150 | self.router = TopkRouting(qk_dim=self.qk_dim, 151 | qk_scale=self.scale, 152 | topk=self.topk, 153 | diff_routing=self.diff_routing, 154 | param_routing=self.param_routing) 155 | if self.soft_routing: # soft routing, always diffrentiable (if no detach) 156 | mul_weight = 'soft' 157 | elif self.diff_routing: # hard differentiable routing 158 | mul_weight = 'hard' 159 | else: # hard non-differentiable routing 160 | mul_weight = 'none' 161 | self.kv_gather = KVGather(mul_weight=mul_weight) 162 | 163 | # qkv mapping (shared by both global routing and local attention) 164 | self.param_attention = param_attention 165 | if self.param_attention == 'qkvo': 166 | self.qkv = QKVLinear(self.dim, self.qk_dim) 167 | self.wo = nn.Linear(dim, dim) 168 | elif self.param_attention == 'qkv': 169 | self.qkv = QKVLinear(self.dim, self.qk_dim) 170 | self.wo = nn.Identity() 171 | else: 172 | raise ValueError(f'param_attention mode {self.param_attention} is not surpported!') 173 | 174 | self.kv_downsample_mode = kv_downsample_mode 175 | self.kv_per_win = kv_per_win 176 | self.kv_downsample_ratio = kv_downsample_ratio 177 | self.kv_downsample_kenel = kv_downsample_kernel 178 | if self.kv_downsample_mode == 'ada_avgpool': 179 | assert self.kv_per_win is not None 180 | self.kv_down = nn.AdaptiveAvgPool2d(self.kv_per_win) 181 | elif self.kv_downsample_mode == 'ada_maxpool': 182 | assert self.kv_per_win is not None 183 | self.kv_down = nn.AdaptiveMaxPool2d(self.kv_per_win) 184 | elif self.kv_downsample_mode == 'maxpool': 185 | assert self.kv_downsample_ratio is not None 186 | self.kv_down = nn.MaxPool2d(self.kv_downsample_ratio) if self.kv_downsample_ratio > 1 else nn.Identity() 187 | elif self.kv_downsample_mode == 'avgpool': 188 | assert self.kv_downsample_ratio is not None 189 | self.kv_down = nn.AvgPool2d(self.kv_downsample_ratio) if self.kv_downsample_ratio > 1 else nn.Identity() 190 | elif self.kv_downsample_mode == 'identity': # no kv downsampling 191 | self.kv_down = nn.Identity() 192 | elif self.kv_downsample_mode == 'fracpool': 193 | # assert self.kv_downsample_ratio is not None 194 | # assert self.kv_downsample_kenel is not None 195 | # TODO: fracpool 196 | # 1. kernel size should be input size dependent 197 | # 2. there is a random factor, need to avoid independent sampling for k and v 198 | raise NotImplementedError('fracpool policy is not implemented yet!') 199 | elif kv_downsample_mode == 'conv': 200 | # TODO: need to consider the case where k != v so that need two downsample modules 201 | raise NotImplementedError('conv policy is not implemented yet!') 202 | else: 203 | raise ValueError(f'kv_down_sample_mode {self.kv_downsaple_mode} is not surpported!') 204 | 205 | # softmax for local attention 206 | self.attn_act = nn.Softmax(dim=-1) 207 | 208 | self.auto_pad = auto_pad 209 | 210 | def forward(self, x, ret_attn_mask=False): 211 | """ 212 | x: NHWC tensor 213 | 214 | Return: 215 | NHWC tensor 216 | """ 217 | # NOTE: use padding for semantic segmentation 218 | ################################################### 219 | if self.auto_pad: 220 | N, H_in, W_in, C = x.size() 221 | 222 | pad_l = pad_t = 0 223 | pad_r = (self.n_win - W_in % self.n_win) % self.n_win 224 | pad_b = (self.n_win - H_in % self.n_win) % self.n_win 225 | x = F.pad(x, (0, 0, # dim=-1 226 | pad_l, pad_r, # dim=-2 227 | pad_t, pad_b)) # dim=-3 228 | _, H, W, _ = x.size() # padded size 229 | else: 230 | N, H, W, C = x.size() 231 | assert H % self.n_win == 0 and W % self.n_win == 0 # 232 | ################################################### 233 | 234 | # patchify, (n, p^2, w, w, c), keep 2d window as we need 2d pooling to reduce kv size 235 | x = rearrange(x, "n (j h) (i w) c -> n (j i) h w c", j=self.n_win, i=self.n_win) 236 | 237 | #################qkv projection################### 238 | # q: (n, p^2, w, w, c_qk) 239 | # kv: (n, p^2, w, w, c_qk+c_v) 240 | # NOTE: separte kv if there were memory leak issue caused by gather 241 | q, kv = self.qkv(x) 242 | 243 | # pixel-wise qkv 244 | # q_pix: (n, p^2, w^2, c_qk) 245 | # kv_pix: (n, p^2, h_kv*w_kv, c_qk+c_v) 246 | q_pix = rearrange(q, 'n p2 h w c -> n p2 (h w) c') 247 | kv_pix = self.kv_down(rearrange(kv, 'n p2 h w c -> (n p2) c h w')) 248 | kv_pix = rearrange(kv_pix, '(n j i) c h w -> n (j i) (h w) c', j=self.n_win, i=self.n_win) 249 | 250 | q_win, k_win = q.mean([2, 3]), kv[..., 0:self.qk_dim].mean( 251 | [2, 3]) # window-wise qk, (n, p^2, c_qk), (n, p^2, c_qk) 252 | 253 | ##################side_dwconv(lepe)################## 254 | # NOTE: call contiguous to avoid gradient warning when using ddp 255 | lepe = self.lepe(rearrange(kv[..., self.qk_dim:], 'n (j i) h w c -> n c (j h) (i w)', j=self.n_win, 256 | i=self.n_win).contiguous()) 257 | lepe = rearrange(lepe, 'n c (j h) (i w) -> n (j h) (i w) c', j=self.n_win, i=self.n_win) 258 | 259 | ############ gather q dependent k/v ################# 260 | 261 | r_weight, r_idx = self.router(q_win, k_win) # both are (n, p^2, topk) tensors 262 | 263 | kv_pix_sel = self.kv_gather(r_idx=r_idx, r_weight=r_weight, kv=kv_pix) # (n, p^2, topk, h_kv*w_kv, c_qk+c_v) 264 | k_pix_sel, v_pix_sel = kv_pix_sel.split([self.qk_dim, self.dim], dim=-1) 265 | # kv_pix_sel: (n, p^2, topk, h_kv*w_kv, c_qk) 266 | # v_pix_sel: (n, p^2, topk, h_kv*w_kv, c_v) 267 | 268 | ######### do attention as normal #################### 269 | k_pix_sel = rearrange(k_pix_sel, 'n p2 k w2 (m c) -> (n p2) m c (k w2)', 270 | m=self.num_heads) # flatten to BMLC, (n*p^2, m, topk*h_kv*w_kv, c_kq//m) transpose here? 271 | v_pix_sel = rearrange(v_pix_sel, 'n p2 k w2 (m c) -> (n p2) m (k w2) c', 272 | m=self.num_heads) # flatten to BMLC, (n*p^2, m, topk*h_kv*w_kv, c_v//m) 273 | q_pix = rearrange(q_pix, 'n p2 w2 (m c) -> (n p2) m w2 c', 274 | m=self.num_heads) # to BMLC tensor (n*p^2, m, w^2, c_qk//m) 275 | 276 | # param-free multihead attention 277 | attn_weight = ( 278 | q_pix * self.scale) @ k_pix_sel # (n*p^2, m, w^2, c) @ (n*p^2, m, c, topk*h_kv*w_kv) -> (n*p^2, m, w^2, topk*h_kv*w_kv) 279 | attn_weight = self.attn_act(attn_weight) 280 | out = attn_weight @ v_pix_sel # (n*p^2, m, w^2, topk*h_kv*w_kv) @ (n*p^2, m, topk*h_kv*w_kv, c) -> (n*p^2, m, w^2, c) 281 | out = rearrange(out, '(n j i) m (h w) c -> n (j h) (i w) (m c)', j=self.n_win, i=self.n_win, 282 | h=H // self.n_win, w=W // self.n_win) 283 | 284 | out = out + lepe 285 | # output linear 286 | out = self.wo(out) 287 | 288 | # NOTE: use padding for semantic segmentation 289 | # crop padded region 290 | if self.auto_pad and (pad_r > 0 or pad_b > 0): 291 | out = out[:, :H_in, :W_in, :].contiguous() 292 | 293 | if ret_attn_mask: 294 | return out, r_weight, r_idx, attn_weight 295 | else: 296 | return out 297 | -------------------------------------------------------------------------------- /new_model/encoders/Transformer/BiFormer/dual_biformer.py: -------------------------------------------------------------------------------- 1 | import math 2 | import time 3 | from collections import OrderedDict 4 | from functools import partial 5 | from typing import Optional, Union 6 | 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | from einops import rearrange 11 | from einops.layers.torch import Rearrange 12 | from fairscale.nn.checkpoint import checkpoint_wrapper 13 | from timm.models import register_model 14 | from timm.models.layers import DropPath, to_2tuple, trunc_normal_ 15 | from timm.models.vision_transformer import _cfg 16 | from .bra_legacy import BiLevelRoutingAttention 17 | from .modules import Attention, AttentionLePE, DWConv 18 | 19 | from models.new_model.modules import FeatureFusion as FFM 20 | from models.new_model.modules import FeatureCorrection_s2c as FCM 21 | from thop import clever_format, profile 22 | 23 | 24 | def get_pe_layer(emb_dim, pe_dim=None, name='none'): 25 | if name == 'none': 26 | return nn.Identity() 27 | # if name == 'sum': 28 | # return Summer(PositionalEncodingPermute2D(emb_dim)) 29 | # elif name == 'npe.sin': 30 | # return NeuralPE(emb_dim=emb_dim, pe_dim=pe_dim, mode='sin') 31 | # elif name == 'npe.coord': 32 | # return NeuralPE(emb_dim=emb_dim, pe_dim=pe_dim, mode='coord') 33 | # elif name == 'hpe.conv': 34 | # return HybridPE(emb_dim=emb_dim, pe_dim=pe_dim, mode='conv', res_shortcut=True) 35 | # elif name == 'hpe.dsconv': 36 | # return HybridPE(emb_dim=emb_dim, pe_dim=pe_dim, mode='dsconv', res_shortcut=True) 37 | # elif name == 'hpe.pointconv': 38 | # return HybridPE(emb_dim=emb_dim, pe_dim=pe_dim, mode='pointconv', res_shortcut=True) 39 | else: 40 | raise ValueError(f'PE name {name} is not surpported!') 41 | 42 | 43 | class Block(nn.Module): 44 | def __init__(self, dim, drop_path=0., layer_scale_init_value=-1, 45 | num_heads=8, n_win=7, qk_dim=None, qk_scale=None, 46 | kv_per_win=4, kv_downsample_ratio=4, kv_downsample_kernel=None, kv_downsample_mode='ada_avgpool', 47 | topk=4, param_attention="qkvo", param_routing=False, diff_routing=False, soft_routing=False, 48 | mlp_ratio=4, mlp_dwconv=False, 49 | side_dwconv=5, before_attn_dwconv=3, pre_norm=True, auto_pad=False): 50 | super().__init__() 51 | qk_dim = qk_dim or dim 52 | 53 | # modules 54 | if before_attn_dwconv > 0: 55 | self.pos_embed = nn.Conv2d(dim, dim, kernel_size=before_attn_dwconv, padding=1, groups=dim) 56 | else: 57 | self.pos_embed = lambda x: 0 58 | self.norm1 = nn.LayerNorm(dim, eps=1e-6) # important to avoid attention collapsing 59 | if topk > 0: 60 | self.attn = BiLevelRoutingAttention(dim=dim, num_heads=num_heads, n_win=n_win, qk_dim=qk_dim, 61 | qk_scale=qk_scale, kv_per_win=kv_per_win, 62 | kv_downsample_ratio=kv_downsample_ratio, 63 | kv_downsample_kernel=kv_downsample_kernel, 64 | kv_downsample_mode=kv_downsample_mode, 65 | topk=topk, param_attention=param_attention, param_routing=param_routing, 66 | diff_routing=diff_routing, soft_routing=soft_routing, 67 | side_dwconv=side_dwconv, 68 | auto_pad=auto_pad) 69 | elif topk == -1: 70 | self.attn = Attention(dim=dim) 71 | elif topk == -2: 72 | self.attn = AttentionLePE(dim=dim, side_dwconv=side_dwconv) 73 | elif topk == 0: 74 | self.attn = nn.Sequential(Rearrange('n h w c -> n c h w'), # compatiability 75 | nn.Conv2d(dim, dim, 1), # pseudo qkv linear 76 | nn.Conv2d(dim, dim, 5, padding=2, groups=dim), # pseudo attention 77 | nn.Conv2d(dim, dim, 1), # pseudo out linear 78 | Rearrange('n c h w -> n h w c') 79 | ) 80 | self.norm2 = nn.LayerNorm(dim, eps=1e-6) 81 | self.mlp = nn.Sequential(nn.Linear(dim, int(mlp_ratio * dim)), 82 | DWConv(int(mlp_ratio * dim)) if mlp_dwconv else nn.Identity(), 83 | nn.GELU(), 84 | nn.Linear(int(mlp_ratio * dim), dim) 85 | ) 86 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() 87 | 88 | # tricks: layer scale & pre_norm/post_norm 89 | if layer_scale_init_value > 0: 90 | self.use_layer_scale = True 91 | self.gamma1 = nn.Parameter(layer_scale_init_value * torch.ones((dim)), requires_grad=True) 92 | self.gamma2 = nn.Parameter(layer_scale_init_value * torch.ones((dim)), requires_grad=True) 93 | else: 94 | self.use_layer_scale = False 95 | self.pre_norm = pre_norm 96 | 97 | def forward(self, x): 98 | """ 99 | x: NCHW tensor 100 | """ 101 | # conv pos embedding 102 | x = x + self.pos_embed(x) 103 | # permute to NHWC tensor for attention & mlp 104 | x = x.permute(0, 2, 3, 1) # (N, C, H, W) -> (N, H, W, C) 105 | 106 | # attention & mlp 107 | if self.pre_norm: 108 | if self.use_layer_scale: 109 | x = x + self.drop_path(self.gamma1 * self.attn(self.norm1(x))) # (N, H, W, C) 110 | x = x + self.drop_path(self.gamma2 * self.mlp(self.norm2(x))) # (N, H, W, C) 111 | else: 112 | x = x + self.drop_path(self.attn(self.norm1(x))) # (N, H, W, C) 113 | x = x + self.drop_path(self.mlp(self.norm2(x))) # (N, H, W, C) 114 | else: # https://kexue.fm/archives/9009 115 | if self.use_layer_scale: 116 | x = self.norm1(x + self.drop_path(self.gamma1 * self.attn(x))) # (N, H, W, C) 117 | x = self.norm2(x + self.drop_path(self.gamma2 * self.mlp(x))) # (N, H, W, C) 118 | else: 119 | x = self.norm1(x + self.drop_path(self.attn(x))) # (N, H, W, C) 120 | x = self.norm2(x + self.drop_path(self.mlp(x))) # (N, H, W, C) 121 | 122 | # permute back 123 | x = x.permute(0, 3, 1, 2) # (N, H, W, C) -> (N, C, H, W) 124 | return x 125 | 126 | 127 | class DualBiFormer(nn.Module): 128 | def __init__(self, depth=[3, 4, 8, 3], in_chans=3, embed_dim=[64, 128, 320, 512], 129 | head_dim=64, qk_scale=None, 130 | drop_path_rate=0., drop_rate=0., 131 | use_checkpoint_stages=[], 132 | ######## 133 | n_win=7, 134 | kv_downsample_mode='ada_avgpool', 135 | kv_per_wins=[2, 2, -1, -1], 136 | topks=[8, 8, -1, -1], 137 | side_dwconv=5, 138 | layer_scale_init_value=-1, 139 | qk_dims=[None, None, None, None], 140 | param_routing=False, diff_routing=False, soft_routing=False, 141 | pre_norm=True, 142 | pe=None, 143 | pe_stages=[0], 144 | before_attn_dwconv=3, 145 | auto_pad=False, 146 | # ----------------------- 147 | kv_downsample_kernels=[4, 2, 1, 1], 148 | kv_downsample_ratios=[4, 2, 1, 1], # -> kv_per_win = [2, 2, 2, 1] 149 | mlp_ratios=[4, 4, 4, 4], 150 | sr_ratios=[8, 4, 2, 1], 151 | norm_fuse=nn.BatchNorm2d, 152 | param_attention='qkvo', 153 | mlp_dwconv=False): 154 | """ 155 | Args: 156 | depth (list): depth of each stage 157 | img_size (int, tuple): input image size 158 | in_chans (int): number of input channels 159 | embed_dim (list): embedding dimension of each stage 160 | head_dim (int): head dimension 161 | mlp_ratio (int): ratio of mlp hidden dim to embedding dim 162 | qkv_bias (bool): enable bias for qkv if True 163 | qk_scale (float): override default qk scale of head_dim ** -0.5 if set 164 | representation_size (Optional[int]): enable and set representation layer (pre-logits) to this value if set 165 | drop_rate (float): dropout rate 166 | attn_drop_rate (float): attention dropout rate 167 | drop_path_rate (float): stochastic depth rate 168 | norm_layer (nn.Module): normalization layer 169 | conv_stem (bool): whether use overlapped patch stem 170 | """ 171 | super().__init__() 172 | self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models 173 | 174 | self.downsample_layers = nn.ModuleList() 175 | self.aux_downsample_layers = nn.ModuleList() 176 | # NOTE: uniformer uses two 3*3 conv, while in many other transformers this is one 7*7 conv 177 | stem = nn.Sequential( 178 | nn.Conv2d(in_chans, embed_dim[0] // 2, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1)), 179 | nn.BatchNorm2d(embed_dim[0] // 2), 180 | nn.GELU(), 181 | nn.Conv2d(embed_dim[0] // 2, embed_dim[0], kernel_size=(3, 3), stride=(2, 2), padding=(1, 1)), 182 | nn.BatchNorm2d(embed_dim[0]), 183 | ) 184 | aux_stem = nn.Sequential( 185 | nn.Conv2d(in_chans, embed_dim[0] // 2, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1)), 186 | nn.BatchNorm2d(embed_dim[0] // 2), 187 | nn.GELU(), 188 | nn.Conv2d(embed_dim[0] // 2, embed_dim[0], kernel_size=(3, 3), stride=(2, 2), padding=(1, 1)), 189 | nn.BatchNorm2d(embed_dim[0]), 190 | ) 191 | if (pe is not None) and 0 in pe_stages: 192 | stem.append(get_pe_layer(emb_dim=embed_dim[0], name=pe)) 193 | aux_stem.append(get_pe_layer(emb_dim=embed_dim[0], name=pe)) 194 | if use_checkpoint_stages: 195 | stem = checkpoint_wrapper(stem) 196 | aux_stem = checkpoint_wrapper(aux_stem) 197 | self.downsample_layers.append(stem) 198 | self.aux_downsample_layers.append(aux_stem) 199 | 200 | for i in range(3): 201 | downsample_layer = nn.Sequential( 202 | nn.Conv2d(embed_dim[i], embed_dim[i + 1], kernel_size=(3, 3), stride=(2, 2), padding=(1, 1)), 203 | nn.BatchNorm2d(embed_dim[i + 1]) 204 | ) 205 | aux_downsample_layer = nn.Sequential( 206 | nn.Conv2d(embed_dim[i], embed_dim[i + 1], kernel_size=(3, 3), stride=(2, 2), padding=(1, 1)), 207 | nn.BatchNorm2d(embed_dim[i + 1]) 208 | ) 209 | if (pe is not None) and i + 1 in pe_stages: 210 | downsample_layer.append(get_pe_layer(emb_dim=embed_dim[i + 1], name=pe)) 211 | aux_downsample_layer.append(get_pe_layer(emb_dim=embed_dim[i + 1], name=pe)) 212 | if use_checkpoint_stages: 213 | downsample_layer = checkpoint_wrapper(downsample_layer) 214 | aux_downsample_layer = checkpoint_wrapper(aux_downsample_layer) 215 | self.downsample_layers.append(downsample_layer) 216 | self.aux_downsample_layers.append(aux_downsample_layer) 217 | 218 | self.stages = nn.ModuleList() # 4 feature resolution stages, each consisting of multiple residual blocks 219 | self.aux_stages = nn.ModuleList() 220 | self.norm = nn.ModuleList() 221 | self.aux_norm = nn.ModuleList() 222 | nheads = [dim // head_dim for dim in qk_dims] 223 | dp_rates = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depth))] 224 | cur = 0 225 | for i in range(4): 226 | stage = nn.Sequential( 227 | *[Block(dim=embed_dim[i], drop_path=dp_rates[cur + j], 228 | layer_scale_init_value=layer_scale_init_value, 229 | topk=topks[i], 230 | num_heads=nheads[i], 231 | n_win=n_win, 232 | qk_dim=qk_dims[i], 233 | qk_scale=qk_scale, 234 | kv_per_win=kv_per_wins[i], 235 | kv_downsample_ratio=kv_downsample_ratios[i], 236 | kv_downsample_kernel=kv_downsample_kernels[i], 237 | kv_downsample_mode=kv_downsample_mode, 238 | param_attention=param_attention, 239 | param_routing=param_routing, 240 | diff_routing=diff_routing, 241 | soft_routing=soft_routing, 242 | mlp_ratio=mlp_ratios[i], 243 | mlp_dwconv=mlp_dwconv, 244 | side_dwconv=side_dwconv, 245 | before_attn_dwconv=before_attn_dwconv, 246 | pre_norm=pre_norm, 247 | auto_pad=auto_pad) for j in range(depth[i])], 248 | ) 249 | self.norm.append(nn.LayerNorm(embed_dim[i])) 250 | aux_stage = nn.Sequential( 251 | *[Block(dim=embed_dim[i], drop_path=dp_rates[cur + j], 252 | layer_scale_init_value=layer_scale_init_value, 253 | topk=topks[i], 254 | num_heads=nheads[i], 255 | n_win=n_win, 256 | qk_dim=qk_dims[i], 257 | qk_scale=qk_scale, 258 | kv_per_win=kv_per_wins[i], 259 | kv_downsample_ratio=kv_downsample_ratios[i], 260 | kv_downsample_kernel=kv_downsample_kernels[i], 261 | kv_downsample_mode=kv_downsample_mode, 262 | param_attention=param_attention, 263 | param_routing=param_routing, 264 | diff_routing=diff_routing, 265 | soft_routing=soft_routing, 266 | mlp_ratio=mlp_ratios[i], 267 | mlp_dwconv=mlp_dwconv, 268 | side_dwconv=side_dwconv, 269 | before_attn_dwconv=before_attn_dwconv, 270 | pre_norm=pre_norm, 271 | auto_pad=auto_pad) for j in range(depth[i])], 272 | ) 273 | self.aux_norm.append(nn.LayerNorm(embed_dim[i])) 274 | if i in use_checkpoint_stages: 275 | stage = checkpoint_wrapper(stage) 276 | aux_stage = checkpoint_wrapper(aux_stage) 277 | self.stages.append(stage) 278 | self.aux_stages.append(aux_stage) 279 | cur += depth[i] 280 | 281 | self.FCMs = nn.ModuleList([ 282 | FCM(dim=embed_dim[0], reduction=1), 283 | FCM(dim=embed_dim[1], reduction=1), 284 | FCM(dim=embed_dim[2], reduction=1), 285 | FCM(dim=embed_dim[3], reduction=1)]) 286 | 287 | self.FFMs = nn.ModuleList([ 288 | FFM(dim=embed_dim[0], reduction=1, num_heads=nheads[0], norm_layer=norm_fuse, sr_ratio=sr_ratios[0]), 289 | FFM(dim=embed_dim[1], reduction=1, num_heads=nheads[1], norm_layer=norm_fuse, sr_ratio=sr_ratios[1]), 290 | FFM(dim=embed_dim[2], reduction=1, num_heads=nheads[2], norm_layer=norm_fuse, sr_ratio=sr_ratios[2]), 291 | FFM(dim=embed_dim[3], reduction=1, num_heads=nheads[3], norm_layer=norm_fuse, sr_ratio=sr_ratios[3])]) 292 | 293 | self.apply(self._init_weights) 294 | 295 | def _init_weights(self, m): 296 | if isinstance(m, nn.Linear): 297 | trunc_normal_(m.weight, std=.02) 298 | if isinstance(m, nn.Linear) and m.bias is not None: 299 | nn.init.constant_(m.bias, 0) 300 | elif isinstance(m, nn.LayerNorm): 301 | nn.init.constant_(m.bias, 0) 302 | nn.init.constant_(m.weight, 1.0) 303 | elif isinstance(m, nn.Conv2d): 304 | fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 305 | fan_out //= m.groups 306 | m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) 307 | if m.bias is not None: 308 | m.bias.data.zero_() 309 | 310 | def init_weights(self, pretrained=None): 311 | 312 | if isinstance(pretrained, str): 313 | load_dualpath_model(self, pretrained) 314 | else: 315 | raise TypeError('pretrained must be a str or None') 316 | 317 | def forward_features(self, x1, x2): 318 | 319 | outs = [] 320 | for i in range(4): 321 | x1 = self.downsample_layers[i](x1) # res = (56, 28, 14, 7), wins = (64, 16, 4, 1) 322 | x1 = self.stages[i](x1) 323 | 324 | x2 = self.aux_downsample_layers[i](x2) # res = (56, 28, 14, 7), wins = (64, 16, 4, 1) 325 | x2 = self.aux_stages[i](x2) 326 | 327 | x1, x2 = self.FCMs[i](x1, x2) 328 | 329 | x1_1 = self.norm[i](x1.permute(0, 2, 3, 1).contiguous()) 330 | x2_1 = self.aux_norm[i](x2.permute(0, 2, 3, 1).contiguous()) 331 | 332 | fuse = self.FFMs[i](x1_1.permute(0, 3, 1, 2).contiguous(), x2_1.permute(0, 3, 1, 2).contiguous()) 333 | 334 | outs.append(fuse) 335 | 336 | return tuple(outs) 337 | 338 | def forward(self, x1, x2): 339 | x = self.forward_features(x1, x2) 340 | 341 | return x 342 | 343 | 344 | def load_dualpath_model(model, model_file, is_restore=False): 345 | # load raw state_dict 346 | t_start = time.time() 347 | if isinstance(model_file, str): 348 | raw_state_dict = torch.load(model_file, map_location=torch.device('cpu')) 349 | # raw_state_dict = torch.load(model_file) 350 | if 'model' in raw_state_dict.keys(): 351 | raw_state_dict = raw_state_dict['model'] 352 | else: 353 | raw_state_dict = model_file 354 | 355 | state_dict = {} 356 | for k, v in raw_state_dict.items(): 357 | if k.find('downsample_layers') >= 0: 358 | state_dict[k] = v 359 | state_dict[k.replace('downsample_layers', 'aux_downsample_layers')] = v 360 | elif k.find('stages') >= 0: 361 | state_dict[k] = v 362 | state_dict[k.replace('stages', 'aux_stages')] = v 363 | elif k.find('norm') >= 0: 364 | state_dict[k] = v 365 | state_dict[k.replace('norm', 'aux_norm')] = v 366 | 367 | t_ioend = time.time() 368 | 369 | if is_restore: 370 | new_state_dict = OrderedDict() 371 | for k, v in state_dict.items(): 372 | name = 'module.' + k 373 | new_state_dict[name] = v 374 | state_dict = new_state_dict 375 | 376 | model.load_state_dict(state_dict, strict=False) 377 | 378 | del state_dict 379 | t_end = time.time() 380 | print("Load model, Time usage:\n\tIO: {}, initialize parameters: {}".format( 381 | t_ioend - t_start, t_end - t_ioend)) 382 | 383 | 384 | class biformer_t(DualBiFormer): 385 | # "biformer_tiny_in1k": "https://api.onedrive.com/v1.0/shares/s!AkBbczdRlZvChHEOoGkgwgQzEDlM/root/content" 386 | def __init__(self, **kwargs): 387 | super(biformer_t, self).__init__(depth=[2, 2, 8, 2], 388 | embed_dim=[64, 128, 256, 512], mlp_ratios=[3, 3, 3, 3], 389 | n_win=8, 390 | kv_downsample_mode='identity', 391 | kv_per_wins=[-1, -1, -1, -1], 392 | topks=[1, 4, 16, -2], 393 | side_dwconv=5, 394 | before_attn_dwconv=3, 395 | layer_scale_init_value=-1, 396 | qk_dims=[64, 128, 256, 512], 397 | head_dim=32, 398 | param_routing=False, diff_routing=False, soft_routing=False, 399 | pre_norm=True, 400 | pe=None, **kwargs) 401 | 402 | 403 | class biformer_s(DualBiFormer): 404 | # "biformer_small_in1k": "https://api.onedrive.com/v1.0/shares/s!AkBbczdRlZvChHDyM-x9KWRBZ832/root/content" 405 | def __init__(self, **kwargs): 406 | super(biformer_s, self).__init__(depth=[4, 4, 18, 4], 407 | embed_dim=[64, 128, 256, 512], mlp_ratios=[3, 3, 3, 3], 408 | # ------------------------------ 409 | n_win=8, 410 | kv_downsample_mode='identity', 411 | kv_per_wins=[-1, -1, -1, -1], 412 | topks=[1, 4, 16, -2], 413 | side_dwconv=5, 414 | before_attn_dwconv=3, 415 | layer_scale_init_value=-1, 416 | qk_dims=[64, 128, 256, 512], 417 | head_dim=32, 418 | param_routing=False, diff_routing=False, soft_routing=False, 419 | pre_norm=True, 420 | pe=None, **kwargs) 421 | 422 | 423 | class biformer_b(DualBiFormer): 424 | # "biformer_base_in1k": "https://api.onedrive.com/v1.0/shares/s!AkBbczdRlZvChHI_XPhoadjaNxtO/root/content" 425 | def __init__(self, **kwargs): 426 | super(biformer_b, self).__init__(depth=[4, 4, 18, 4], 427 | embed_dim=[96, 192, 384, 768], mlp_ratios=[3, 3, 3, 3], 428 | # use_checkpoint_stages=[0, 1, 2, 3], 429 | use_checkpoint_stages=[], 430 | # ------------------------------ 431 | n_win=8, 432 | kv_downsample_mode='identity', 433 | kv_per_wins=[-1, -1, -1, -1], 434 | topks=[1, 4, 16, -2], 435 | side_dwconv=5, 436 | before_attn_dwconv=3, 437 | layer_scale_init_value=-1, 438 | qk_dims=[96, 192, 384, 768], 439 | head_dim=32, 440 | param_routing=False, diff_routing=False, soft_routing=False, 441 | pre_norm=True, 442 | pe=None, **kwargs) 443 | 444 | 445 | if __name__ == '__main__': 446 | model = biformer_t().cuda() 447 | # print(model) 448 | left = torch.randn(1, 3, 256, 256).cuda() 449 | right = torch.randn(1, 3, 256, 256).cuda() 450 | 451 | # summary(model, [(4, 256, 256), (1, 256, 256)]) 452 | flops, params = profile(model, (left, right), verbose=False) 453 | 454 | flops = flops * 2 455 | flops, params = clever_format([flops, params], "%.3f") 456 | print('Total GFLOPS: %s' % flops) 457 | print('Total params: %s' % params) 458 | -------------------------------------------------------------------------------- /new_model/encoders/Transformer/BiFormer/modules.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from einops import rearrange 4 | 5 | 6 | class DWConv(nn.Module): 7 | def __init__(self, dim=768): 8 | super(DWConv, self).__init__() 9 | self.dwconv = nn.Conv2d(dim, dim, 3, 1, 1, bias=True, groups=dim) 10 | 11 | def forward(self, x): 12 | """ 13 | x: NHWC tensor 14 | """ 15 | x = x.permute(0, 3, 1, 2) # NCHW 16 | x = self.dwconv(x) 17 | x = x.permute(0, 2, 3, 1) # NHWC 18 | 19 | return x 20 | 21 | 22 | class Attention(nn.Module): 23 | """ 24 | vanilla attention 25 | """ 26 | 27 | def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.): 28 | super().__init__() 29 | self.num_heads = num_heads 30 | head_dim = dim // num_heads 31 | # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights 32 | self.scale = qk_scale or head_dim ** -0.5 33 | 34 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) 35 | self.attn_drop = nn.Dropout(attn_drop) 36 | self.proj = nn.Linear(dim, dim) 37 | self.proj_drop = nn.Dropout(proj_drop) 38 | 39 | def forward(self, x): 40 | """ 41 | args: 42 | x: NHWC tensor 43 | return: 44 | NHWC tensor 45 | """ 46 | _, H, W, _ = x.size() 47 | x = rearrange(x, 'n h w c -> n (h w) c') 48 | 49 | ####################################### 50 | B, N, C = x.shape 51 | qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) 52 | q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) 53 | 54 | attn = (q @ k.transpose(-2, -1)) * self.scale 55 | attn = attn.softmax(dim=-1) 56 | attn = self.attn_drop(attn) 57 | 58 | x = (attn @ v).transpose(1, 2).reshape(B, N, C) 59 | x = self.proj(x) 60 | x = self.proj_drop(x) 61 | ####################################### 62 | 63 | x = rearrange(x, 'n (h w) c -> n h w c', h=H, w=W) 64 | return x 65 | 66 | 67 | class AttentionLePE(nn.Module): 68 | """ 69 | vanilla attention 70 | """ 71 | 72 | def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0., side_dwconv=5): 73 | super().__init__() 74 | self.num_heads = num_heads 75 | head_dim = dim // num_heads 76 | # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights 77 | self.scale = qk_scale or head_dim ** -0.5 78 | 79 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) 80 | self.attn_drop = nn.Dropout(attn_drop) 81 | self.proj = nn.Linear(dim, dim) 82 | self.proj_drop = nn.Dropout(proj_drop) 83 | self.lepe = nn.Conv2d(dim, dim, kernel_size=side_dwconv, stride=1, padding=side_dwconv // 2, 84 | groups=dim) if side_dwconv > 0 else \ 85 | lambda x: torch.zeros_like(x) 86 | 87 | def forward(self, x): 88 | """ 89 | args: 90 | x: NHWC tensor 91 | return: 92 | NHWC tensor 93 | """ 94 | _, H, W, _ = x.size() 95 | x = rearrange(x, 'n h w c -> n (h w) c') 96 | 97 | ####################################### 98 | B, N, C = x.shape 99 | qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) 100 | q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) 101 | 102 | lepe = self.lepe(rearrange(x, 'n (h w) c -> n c h w', h=H, w=W)) 103 | lepe = rearrange(lepe, 'n c h w -> n (h w) c') 104 | 105 | attn = (q @ k.transpose(-2, -1)) * self.scale 106 | attn = attn.softmax(dim=-1) 107 | attn = self.attn_drop(attn) 108 | 109 | x = (attn @ v).transpose(1, 2).reshape(B, N, C) 110 | x = x + lepe 111 | 112 | x = self.proj(x) 113 | x = self.proj_drop(x) 114 | ####################################### 115 | 116 | x = rearrange(x, 'n (h w) c -> n h w c', h=H, w=W) 117 | return x 118 | 119 | 120 | class nchwAttentionLePE(nn.Module): 121 | """ 122 | Attention with LePE, takes nchw input 123 | """ 124 | 125 | def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0., side_dwconv=5): 126 | super().__init__() 127 | self.num_heads = num_heads 128 | self.head_dim = dim // num_heads 129 | self.scale = qk_scale or self.head_dim ** -0.5 130 | 131 | self.qkv = nn.Conv2d(dim, dim * 3, kernel_size=1, bias=qkv_bias) 132 | self.attn_drop = nn.Dropout(attn_drop) 133 | self.proj = nn.Conv2d(dim, dim, kernel_size=1) 134 | self.proj_drop = nn.Dropout(proj_drop) 135 | self.lepe = nn.Conv2d(dim, dim, kernel_size=side_dwconv, stride=1, padding=side_dwconv // 2, 136 | groups=dim) if side_dwconv > 0 else \ 137 | lambda x: torch.zeros_like(x) 138 | 139 | def forward(self, x: torch.Tensor): 140 | """ 141 | args: 142 | x: NCHW tensor 143 | return: 144 | NCHW tensor 145 | """ 146 | B, C, H, W = x.size() 147 | q, k, v = self.qkv.forward(x).chunk(3, dim=1) # B, C, H, W 148 | 149 | attn = q.view(B, self.num_heads, self.head_dim, H * W).transpose(-1, -2) @ \ 150 | k.view(B, self.num_heads, self.head_dim, H * W) 151 | attn = torch.softmax(attn * self.scale, dim=-1) 152 | attn = self.attn_drop(attn) 153 | 154 | # (B, nhead, HW, HW) @ (B, nhead, HW, head_dim) -> (B, nhead, HW, head_dim) 155 | output: torch.Tensor = attn @ v.view(B, self.num_heads, self.head_dim, H * W).transpose(-1, -2) 156 | output = output.permute(0, 1, 3, 2).reshape(B, C, H, W) 157 | output = output + self.lepe(v) 158 | 159 | output = self.proj_drop(self.proj(output)) 160 | 161 | return output 162 | -------------------------------------------------------------------------------- /new_model/encoders/Transformer/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/masurq/CFFormer/6be4de96f134c162253933dbf4271485d2286a87/new_model/encoders/Transformer/__init__.py -------------------------------------------------------------------------------- /new_model/encoders/Transformer/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/masurq/CFFormer/6be4de96f134c162253933dbf4271485d2286a87/new_model/encoders/Transformer/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /new_model/encoders/Transformer/__pycache__/dual_cswin.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/masurq/CFFormer/6be4de96f134c162253933dbf4271485d2286a87/new_model/encoders/Transformer/__pycache__/dual_cswin.cpython-38.pyc -------------------------------------------------------------------------------- /new_model/encoders/Transformer/__pycache__/dual_segformer.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/masurq/CFFormer/6be4de96f134c162253933dbf4271485d2286a87/new_model/encoders/Transformer/__pycache__/dual_segformer.cpython-38.pyc -------------------------------------------------------------------------------- /new_model/encoders/Transformer/__pycache__/dual_swin.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/masurq/CFFormer/6be4de96f134c162253933dbf4271485d2286a87/new_model/encoders/Transformer/__pycache__/dual_swin.cpython-38.pyc -------------------------------------------------------------------------------- /new_model/encoders/Transformer/__pycache__/dual_uniformer.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/masurq/CFFormer/6be4de96f134c162253933dbf4271485d2286a87/new_model/encoders/Transformer/__pycache__/dual_uniformer.cpython-38.pyc -------------------------------------------------------------------------------- /new_model/encoders/Transformer/dual_cswin.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from functools import partial 5 | 6 | from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD 7 | from timm.models.helpers import load_pretrained 8 | from timm.models.layers import DropPath, to_2tuple, trunc_normal_ 9 | from timm.models.resnet import resnet26d, resnet50d 10 | from timm.models.registry import register_model 11 | from einops.layers.torch import Rearrange 12 | import numpy as np 13 | 14 | from fvcore.nn import FlopCountAnalysis 15 | from fvcore.nn import flop_count_table 16 | from thop import clever_format, profile 17 | import time 18 | import math 19 | from collections import OrderedDict 20 | 21 | from models.new_model.modules import FeatureFusion as FFM 22 | from models.new_model.modules import FeatureCorrection_s2c as FCM 23 | 24 | 25 | class Mlp(nn.Module): 26 | def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): 27 | super().__init__() 28 | out_features = out_features or in_features 29 | hidden_features = hidden_features or in_features 30 | self.fc1 = nn.Linear(in_features, hidden_features) 31 | self.act = act_layer() 32 | self.fc2 = nn.Linear(hidden_features, out_features) 33 | self.drop = nn.Dropout(drop) 34 | 35 | def forward(self, x): 36 | x = self.fc1(x) 37 | x = self.act(x) 38 | x = self.drop(x) 39 | x = self.fc2(x) 40 | x = self.drop(x) 41 | return x 42 | 43 | 44 | class LePEAttention(nn.Module): 45 | def __init__(self, dim, resolution, idx, split_size=7, dim_out=None, num_heads=8, qkv_bias=False, qk_scale=None, 46 | attn_drop=0., proj_drop=0.): 47 | """Not supported now, since we have cls_tokens now..... 48 | """ 49 | super().__init__() 50 | self.dim = dim 51 | self.dim_out = dim_out or dim 52 | self.resolution = resolution 53 | self.split_size = split_size 54 | self.num_heads = num_heads 55 | head_dim = dim // num_heads 56 | # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights 57 | self.scale = qk_scale or head_dim ** -0.5 58 | self.idx = idx 59 | if idx == -1: 60 | H_sp, W_sp = self.resolution, self.resolution 61 | elif idx == 0: 62 | H_sp, W_sp = self.resolution, self.split_size 63 | elif idx == 1: 64 | W_sp, H_sp = self.resolution, self.split_size 65 | else: 66 | print("ERROR MODE", idx) 67 | exit(0) 68 | self.H_sp = H_sp 69 | self.W_sp = W_sp 70 | 71 | self.H_sp_ = self.H_sp 72 | self.W_sp_ = self.W_sp 73 | 74 | stride = 1 75 | self.get_v = nn.Conv2d(dim, dim, kernel_size=3, stride=1, padding=1, groups=dim) 76 | 77 | self.attn_drop = nn.Dropout(attn_drop) 78 | 79 | def im2cswin(self, x): 80 | B, C, H, W = x.shape 81 | x = img2windows(x, self.H_sp, self.W_sp) 82 | x = x.reshape(-1, self.H_sp * self.W_sp, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3).contiguous() 83 | return x 84 | 85 | def get_rpe(self, x, func): 86 | B, C, H, W = x.shape 87 | H_sp, W_sp = self.H_sp, self.W_sp 88 | x = x.view(B, C, H // H_sp, H_sp, W // W_sp, W_sp) 89 | x = x.permute(0, 2, 4, 1, 3, 5).contiguous().reshape(-1, C, H_sp, W_sp) ### B', C, H', W' 90 | 91 | rpe = func(x) ### B', C, H', W' 92 | rpe = rpe.reshape(-1, self.num_heads, C // self.num_heads, H_sp * W_sp).permute(0, 1, 3, 2).contiguous() 93 | 94 | x = x.reshape(-1, self.num_heads, C // self.num_heads, self.H_sp * self.W_sp).permute(0, 1, 3, 2).contiguous() 95 | return x, rpe 96 | 97 | def forward(self, temp): 98 | """ 99 | x: B N C 100 | mask: B N N 101 | """ 102 | B, _, C, H, W = temp.shape 103 | 104 | idx = self.idx 105 | if idx == -1: 106 | H_sp, W_sp = H, W 107 | elif idx == 0: 108 | H_sp, W_sp = H, self.split_size 109 | elif idx == 1: 110 | H_sp, W_sp = self.split_size, W 111 | else: 112 | print("ERROR MODE in forward", idx) 113 | exit(0) 114 | self.H_sp = H_sp 115 | self.W_sp = W_sp 116 | 117 | ### padding for split window 118 | H_pad = (self.H_sp - H % self.H_sp) % self.H_sp 119 | W_pad = (self.W_sp - W % self.W_sp) % self.W_sp 120 | top_pad = H_pad // 2 121 | down_pad = H_pad - top_pad 122 | left_pad = W_pad // 2 123 | right_pad = W_pad - left_pad 124 | H_ = H + H_pad 125 | W_ = W + W_pad 126 | 127 | qkv = F.pad(temp, (left_pad, right_pad, top_pad, down_pad)) ### B,3,C,H',W' 128 | qkv = qkv.permute(1, 0, 2, 3, 4) 129 | q, k, v = qkv[0], qkv[1], qkv[2] 130 | 131 | q = self.im2cswin(q) 132 | k = self.im2cswin(k) 133 | v, rpe = self.get_rpe(v, self.get_v) 134 | 135 | ### Local attention 136 | q = q * self.scale 137 | attn = (q @ k.transpose(-2, -1)) # B head N C @ B head C N --> B head N N 138 | 139 | attn = nn.functional.softmax(attn, dim=-1, dtype=attn.dtype) 140 | 141 | attn = self.attn_drop(attn) 142 | 143 | x = (attn @ v) + rpe 144 | x = x.transpose(1, 2).reshape(-1, self.H_sp * self.W_sp, C) # B head N N @ B head N C 145 | 146 | ### Window2Img 147 | x = windows2img(x, self.H_sp, self.W_sp, H_, W_) # B H_ W_ C 148 | x = x[:, top_pad:H + top_pad, left_pad:W + left_pad, :] 149 | x = x.reshape(B, -1, C) 150 | 151 | return x 152 | 153 | 154 | class CSWinBlock(nn.Module): 155 | 156 | def __init__(self, dim, patches_resolution, num_heads, 157 | split_size=7, mlp_ratio=4., qkv_bias=False, qk_scale=None, 158 | drop=0., attn_drop=0., drop_path=0., 159 | act_layer=nn.GELU, norm_layer=nn.LayerNorm, 160 | last_stage=False): 161 | super().__init__() 162 | self.dim = dim 163 | self.num_heads = num_heads 164 | self.patches_resolution = patches_resolution 165 | self.split_size = split_size 166 | self.mlp_ratio = mlp_ratio 167 | self.qkv = nn.Linear(dim, dim * 3, bias=True) 168 | self.norm1 = norm_layer(dim) 169 | 170 | if last_stage: 171 | self.branch_num = 1 172 | else: 173 | self.branch_num = 2 174 | self.proj = nn.Linear(dim, dim) 175 | self.proj_drop = nn.Dropout(drop) 176 | 177 | if last_stage: 178 | self.attns = nn.ModuleList([ 179 | LePEAttention( 180 | dim, resolution=self.patches_resolution, idx=-1, 181 | split_size=split_size, num_heads=num_heads, dim_out=dim, 182 | qkv_bias=qkv_bias, qk_scale=qk_scale, 183 | attn_drop=attn_drop, proj_drop=drop) 184 | for i in range(self.branch_num)]) 185 | else: 186 | self.attns = nn.ModuleList([ 187 | LePEAttention( 188 | dim // 2, resolution=self.patches_resolution, idx=i, 189 | split_size=split_size, num_heads=num_heads // 2, dim_out=dim // 2, 190 | qkv_bias=qkv_bias, qk_scale=qk_scale, 191 | attn_drop=attn_drop, proj_drop=drop) 192 | for i in range(self.branch_num)]) 193 | mlp_hidden_dim = int(dim * mlp_ratio) 194 | 195 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() 196 | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, out_features=dim, act_layer=act_layer, 197 | drop=drop) 198 | self.norm2 = norm_layer(dim) 199 | 200 | atten_mask_matrix = None 201 | 202 | self.register_buffer("atten_mask_matrix", atten_mask_matrix) 203 | self.H = None 204 | self.W = None 205 | 206 | def forward(self, x): 207 | """ 208 | x: B, H*W, C 209 | """ 210 | B, L, C = x.shape 211 | H = self.H 212 | W = self.W 213 | assert L == H * W, "flatten img_tokens has wrong size" 214 | img = self.norm1(x) 215 | temp = self.qkv(img).reshape(B, H, W, 3, C).permute(0, 3, 4, 1, 2) 216 | 217 | if self.branch_num == 2: 218 | x1 = self.attns[0](temp[:, :, :C // 2, :, :]) 219 | x2 = self.attns[1](temp[:, :, C // 2:, :, :]) 220 | attened_x = torch.cat([x1, x2], dim=2) 221 | else: 222 | attened_x = self.attns[0](temp) 223 | attened_x = self.proj(attened_x) 224 | x = x + self.drop_path(attened_x) 225 | x = x + self.drop_path(self.mlp(self.norm2(x))) 226 | 227 | return x 228 | 229 | 230 | def img2windows(img, H_sp, W_sp): 231 | """ 232 | img: B C H W 233 | """ 234 | B, C, H, W = img.shape 235 | img_reshape = img.view(B, C, H // H_sp, H_sp, W // W_sp, W_sp) 236 | img_perm = img_reshape.permute(0, 2, 4, 3, 5, 1).contiguous().reshape(-1, H_sp * W_sp, C) 237 | return img_perm 238 | 239 | 240 | def windows2img(img_splits_hw, H_sp, W_sp, H, W): 241 | """ 242 | img_splits_hw: B' H W C 243 | """ 244 | B = int(img_splits_hw.shape[0] / (H * W / H_sp / W_sp)) 245 | 246 | img = img_splits_hw.view(B, H // H_sp, W // W_sp, H_sp, W_sp, -1) 247 | img = img.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) 248 | return img 249 | 250 | 251 | class Merge_Block(nn.Module): 252 | def __init__(self, dim, dim_out, norm_layer=nn.LayerNorm): 253 | super().__init__() 254 | self.conv = nn.Conv2d(dim, dim_out, 3, 2, 1) 255 | self.norm = norm_layer(dim_out) 256 | 257 | def forward(self, x, H, W): 258 | B, new_HW, C = x.shape 259 | x = x.transpose(-2, -1).contiguous().view(B, C, H, W) 260 | x = self.conv(x) 261 | B, C, H, W = x.shape 262 | x = x.view(B, C, -1).transpose(-2, -1).contiguous() 263 | x = self.norm(x) 264 | 265 | return x, H, W 266 | 267 | 268 | class DualCSWin(nn.Module): 269 | """ Vision Transformer with support for patch or hybrid CNN input stage 270 | """ 271 | 272 | def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=64, depth=[1, 2, 21, 1], split_size=7, 273 | num_heads=[1, 2, 4, 8], mlp_ratio=4., qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0., 274 | drop_path_rate=0., norm_layer=nn.LayerNorm, norm_fuse=nn.BatchNorm2d, sr_ratios=[8, 4, 2, 1]): 275 | super().__init__() 276 | self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models 277 | 278 | heads = num_heads 279 | 280 | self.stage1_conv_embed = nn.Sequential( 281 | nn.Conv2d(in_chans, embed_dim, 7, 4, 2), 282 | Rearrange('b c h w -> b (h w) c', h=img_size // 4, w=img_size // 4), 283 | nn.LayerNorm(embed_dim) 284 | ) 285 | 286 | self.aux_stage1_conv_embed = nn.Sequential( 287 | nn.Conv2d(in_chans, embed_dim, 7, 4, 2), 288 | Rearrange('b c h w -> b (h w) c', h=img_size // 4, w=img_size // 4), 289 | nn.LayerNorm(embed_dim) 290 | ) 291 | 292 | self.norm1 = nn.LayerNorm(embed_dim) 293 | self.aux_norm1 = nn.LayerNorm(embed_dim) 294 | 295 | curr_dim = embed_dim 296 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, np.sum(depth))] # stochastic depth decay rule 297 | self.stage1 = nn.ModuleList([ 298 | CSWinBlock( 299 | dim=curr_dim, num_heads=heads[0], patches_resolution=224 // 4, mlp_ratio=mlp_ratio, 300 | qkv_bias=qkv_bias, qk_scale=qk_scale, split_size=split_size[0], 301 | drop=drop_rate, attn_drop=attn_drop_rate, 302 | drop_path=dpr[i], norm_layer=norm_layer) 303 | for i in range(depth[0])]) 304 | self.aux_stage1 = nn.ModuleList([ 305 | CSWinBlock( 306 | dim=curr_dim, num_heads=heads[0], patches_resolution=224 // 4, mlp_ratio=mlp_ratio, 307 | qkv_bias=qkv_bias, qk_scale=qk_scale, split_size=split_size[0], 308 | drop=drop_rate, attn_drop=attn_drop_rate, 309 | drop_path=dpr[i], norm_layer=norm_layer) 310 | for i in range(depth[0])]) 311 | 312 | self.merge1 = Merge_Block(curr_dim, curr_dim * (heads[1] // heads[0])) 313 | self.aux_merge1 = Merge_Block(curr_dim, curr_dim * (heads[1] // heads[0])) 314 | curr_dim = curr_dim * (heads[1] // heads[0]) 315 | self.norm2 = nn.LayerNorm(curr_dim) 316 | self.aux_norm2 = nn.LayerNorm(curr_dim) 317 | self.stage2 = nn.ModuleList( 318 | [CSWinBlock( 319 | dim=curr_dim, num_heads=heads[1], patches_resolution=224 // 8, mlp_ratio=mlp_ratio, 320 | qkv_bias=qkv_bias, qk_scale=qk_scale, split_size=split_size[1], 321 | drop=drop_rate, attn_drop=attn_drop_rate, 322 | drop_path=dpr[np.sum(depth[:1]) + i], norm_layer=norm_layer) 323 | for i in range(depth[1])]) 324 | self.aux_stage2 = nn.ModuleList( 325 | [CSWinBlock( 326 | dim=curr_dim, num_heads=heads[1], patches_resolution=224 // 8, mlp_ratio=mlp_ratio, 327 | qkv_bias=qkv_bias, qk_scale=qk_scale, split_size=split_size[1], 328 | drop=drop_rate, attn_drop=attn_drop_rate, 329 | drop_path=dpr[np.sum(depth[:1]) + i], norm_layer=norm_layer) 330 | for i in range(depth[1])]) 331 | 332 | self.merge2 = Merge_Block(curr_dim, curr_dim * (heads[2] // heads[1])) 333 | self.aux_merge2 = Merge_Block(curr_dim, curr_dim * (heads[2] // heads[1])) 334 | curr_dim = curr_dim * (heads[2] // heads[1]) 335 | self.norm3 = nn.LayerNorm(curr_dim) 336 | self.aux_norm3 = nn.LayerNorm(curr_dim) 337 | temp_stage3 = [] 338 | aux_temp_stage3 = [] 339 | temp_stage3.extend( 340 | [CSWinBlock( 341 | dim=curr_dim, num_heads=heads[2], patches_resolution=224 // 16, mlp_ratio=mlp_ratio, 342 | qkv_bias=qkv_bias, qk_scale=qk_scale, split_size=split_size[2], 343 | drop=drop_rate, attn_drop=attn_drop_rate, 344 | drop_path=dpr[np.sum(depth[:2]) + i], norm_layer=norm_layer) 345 | for i in range(depth[2])]) 346 | aux_temp_stage3.extend( 347 | [CSWinBlock( 348 | dim=curr_dim, num_heads=heads[2], patches_resolution=224 // 16, mlp_ratio=mlp_ratio, 349 | qkv_bias=qkv_bias, qk_scale=qk_scale, split_size=split_size[2], 350 | drop=drop_rate, attn_drop=attn_drop_rate, 351 | drop_path=dpr[np.sum(depth[:2]) + i], norm_layer=norm_layer) 352 | for i in range(depth[2])]) 353 | self.stage3 = nn.ModuleList(temp_stage3) 354 | self.aux_stage3 = nn.ModuleList(aux_temp_stage3) 355 | 356 | self.merge3 = Merge_Block(curr_dim, curr_dim * (heads[3] // heads[2])) 357 | self.aux_merge3 = Merge_Block(curr_dim, curr_dim * (heads[3] // heads[2])) 358 | curr_dim = curr_dim * (heads[3] // heads[2]) 359 | self.stage4 = nn.ModuleList( 360 | [CSWinBlock( 361 | dim=curr_dim, num_heads=heads[3], patches_resolution=224 // 32, mlp_ratio=mlp_ratio, 362 | qkv_bias=qkv_bias, qk_scale=qk_scale, split_size=split_size[-1], 363 | drop=drop_rate, attn_drop=attn_drop_rate, 364 | drop_path=dpr[np.sum(depth[:-1]) + i], norm_layer=norm_layer, last_stage=True) 365 | for i in range(depth[-1])]) 366 | self.aux_stage4 = nn.ModuleList( 367 | [CSWinBlock( 368 | dim=curr_dim, num_heads=heads[3], patches_resolution=224 // 32, mlp_ratio=mlp_ratio, 369 | qkv_bias=qkv_bias, qk_scale=qk_scale, split_size=split_size[-1], 370 | drop=drop_rate, attn_drop=attn_drop_rate, 371 | drop_path=dpr[np.sum(depth[:-1]) + i], norm_layer=norm_layer, last_stage=True) 372 | for i in range(depth[-1])]) 373 | self.norm4 = norm_layer(curr_dim) 374 | self.aux_norm4 = norm_layer(curr_dim) 375 | 376 | self.FCMs = nn.ModuleList([ 377 | FCM(dim=embed_dim, reduction=1), 378 | FCM(dim=embed_dim * 2, reduction=1), 379 | FCM(dim=embed_dim * 4, reduction=1), 380 | FCM(dim=embed_dim * 8, reduction=1)]) 381 | 382 | self.FFMs = nn.ModuleList([ 383 | FFM(dim=embed_dim, reduction=1, num_heads=num_heads[0], norm_layer=norm_fuse, sr_ratio=sr_ratios[0]), 384 | FFM(dim=embed_dim * 2, reduction=1, num_heads=num_heads[1], norm_layer=norm_fuse, sr_ratio=sr_ratios[1]), 385 | FFM(dim=embed_dim * 4, reduction=1, num_heads=num_heads[2], norm_layer=norm_fuse, sr_ratio=sr_ratios[2]), 386 | FFM(dim=embed_dim * 8, reduction=1, num_heads=num_heads[3], norm_layer=norm_fuse, sr_ratio=sr_ratios[3])]) 387 | 388 | self.apply(self._init_weights) 389 | 390 | def _init_weights(self, m): 391 | if isinstance(m, nn.Linear): 392 | trunc_normal_(m.weight, std=.02) 393 | if isinstance(m, nn.Linear) and m.bias is not None: 394 | nn.init.constant_(m.bias, 0) 395 | elif isinstance(m, nn.LayerNorm): 396 | nn.init.constant_(m.bias, 0) 397 | nn.init.constant_(m.weight, 1.0) 398 | elif isinstance(m, nn.Conv2d): 399 | fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 400 | fan_out //= m.groups 401 | m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) 402 | if m.bias is not None: 403 | m.bias.data.zero_() 404 | 405 | def init_weights(self, pretrained=None): 406 | 407 | if isinstance(pretrained, str): 408 | load_dualpath_model(self, pretrained) 409 | else: 410 | raise TypeError('pretrained must be a str or None') 411 | 412 | def save_out(self, x, norm, H, W): 413 | x = norm(x) 414 | B, N, C = x.shape 415 | x = x.view(B, H, W, C).permute(0, 3, 1, 2).contiguous() 416 | return x 417 | 418 | def forward_features(self, x1, x2): 419 | 420 | outs = [] 421 | x1 = self.stage1_conv_embed[0](x1) ### B, C, H, W 422 | x2 = self.aux_stage1_conv_embed[0](x2) 423 | B, C, H1, W1 = x1.size() 424 | B, C, H2, W2 = x2.size() 425 | x1 = x1.reshape(B, C, -1).transpose(-1, -2).contiguous() 426 | x2 = x2.reshape(B, C, -1).transpose(-1, -2).contiguous() 427 | x1 = self.stage1_conv_embed[2](x1) 428 | x2 = self.aux_stage1_conv_embed[2](x2) 429 | 430 | for blk in self.stage1: 431 | blk.H = H1 432 | blk.W = W1 433 | x1 = blk(x1) 434 | for blk in self.aux_stage1: 435 | blk.H = H2 436 | blk.W = W2 437 | x2 = blk(x2) 438 | 439 | x1, x2 = self.FCMs[0](self.save_out(x1, self.norm1, H1, W1), self.save_out(x2, self.aux_norm1, H2, W2)) 440 | fuse = self.FFMs[0](x1, x2) 441 | outs.append(fuse) 442 | 443 | for pre, aux_pre, blocks, aux_blocks, norm, aux_norm, fcm, ffm in zip([self.merge1, self.merge2, self.merge3], 444 | [self.aux_merge1, self.aux_merge2, 445 | self.aux_merge3], 446 | [self.stage2, self.stage3, self.stage4], 447 | [self.aux_stage2, self.aux_stage3, 448 | self.aux_stage4], 449 | [self.norm2, self.norm3, self.norm4], 450 | [self.aux_norm2, self.aux_norm3, 451 | self.aux_norm4], 452 | [self.FCMs[1], self.FCMs[2], 453 | self.FCMs[3]], 454 | [self.FFMs[1], self.FFMs[2], 455 | self.FFMs[3]]): 456 | x1 = x1.flatten(2).transpose(1, 2) 457 | x2 = x2.flatten(2).transpose(1, 2) 458 | x1, H1, W1 = pre(x1, H1, W1) 459 | x2, H2, W2 = aux_pre(x2, H2, W2) 460 | for blk in blocks: 461 | blk.H = H1 462 | blk.W = W1 463 | x1 = blk(x1) 464 | for blk in aux_blocks: 465 | blk.H = H2 466 | blk.W = W2 467 | x2 = blk(x2) 468 | 469 | x1, x2 = fcm(self.save_out(x1, norm, H1, W1), self.save_out(x2, aux_norm, H2, W2)) 470 | fuse = ffm(x1, x2) 471 | outs.append(fuse) 472 | 473 | return tuple(outs) 474 | 475 | def forward(self, x1, x2): 476 | x = self.forward_features(x1, x2) 477 | return x 478 | 479 | 480 | def load_dualpath_model(model, model_file, is_restore=False): 481 | # load raw state_dict 482 | t_start = time.time() 483 | if isinstance(model_file, str): 484 | raw_state_dict = torch.load(model_file, map_location=torch.device('cpu')) 485 | # raw_state_dict = torch.load(model_file) 486 | if 'model' in raw_state_dict.keys(): 487 | raw_state_dict = raw_state_dict['model'] 488 | else: 489 | raw_state_dict = model_file 490 | 491 | state_dict = {} 492 | for k, v in raw_state_dict['state_dict_ema'].items(): 493 | if k.find('stage1_conv_embed') >= 0: 494 | state_dict[k] = v 495 | state_dict[k.replace('stage1_conv_embed', 'aux_stage1_conv_embed')] = v 496 | elif k.find('merge') >= 0: 497 | state_dict[k] = v 498 | state_dict[k.replace('merge', 'aux_merge')] = v 499 | elif k.find('stage') >= 0: 500 | state_dict[k] = v 501 | state_dict[k.replace('stage', 'aux_stage')] = v 502 | elif k.find('norm') >= 0: 503 | state_dict[k] = v 504 | state_dict[k.replace('norm', 'aux_norm')] = v 505 | 506 | t_ioend = time.time() 507 | 508 | if is_restore: 509 | new_state_dict = OrderedDict() 510 | for k, v in state_dict.items(): 511 | name = 'module.' + k 512 | new_state_dict[name] = v 513 | state_dict = new_state_dict 514 | 515 | model.load_state_dict(state_dict, strict=False) 516 | 517 | del state_dict 518 | 519 | t_end = time.time() 520 | print("Load model, Time usage:\n\tIO: {}, initialize parameters: {}".format( 521 | t_ioend - t_start, t_end - t_ioend)) 522 | 523 | 524 | def _conv_filter(state_dict, patch_size=16): 525 | """ convert patch embedding weight from manual patchify + linear proj to conv""" 526 | out_dict = {} 527 | for k, v in state_dict.items(): 528 | if 'patch_embed.proj.weight' in k: 529 | v = v.reshape((v.shape[0], 3, patch_size, patch_size)) 530 | out_dict[k] = v 531 | return out_dict 532 | 533 | 534 | class cswin_t(DualCSWin): 535 | def __init__(self, **kwargs): 536 | super(cswin_t, self).__init__( 537 | embed_dim=64, depth=[1, 2, 21, 1], 538 | split_size=[1, 2, 7, 7], num_heads=[2, 4, 8, 16], mlp_ratio=4., drop_path_rate=0.3, **kwargs) 539 | 540 | 541 | class cswin_s(DualCSWin): 542 | def __init__(self, **kwargs): 543 | super(cswin_s, self).__init__( 544 | embed_dim=64, depth=[2, 4, 32, 2], 545 | split_size=[1, 2, 7, 7], num_heads=[2, 4, 8, 16], mlp_ratio=4., drop_path_rate=0.4, **kwargs) 546 | 547 | 548 | class cswin_b(DualCSWin): 549 | def __init__(self, **kwargs): 550 | super(cswin_b, self).__init__( 551 | embed_dim=96, depth=[2, 4, 32, 2], 552 | split_size=[1, 2, 7, 7], num_heads=[4, 8, 16, 32], mlp_ratio=4., drop_path_rate=0.6, **kwargs) 553 | 554 | 555 | if __name__ == '__main__': 556 | model = cswin_t().cuda() 557 | # print(model) 558 | left = torch.randn(1, 3, 256, 256).cuda() 559 | right = torch.randn(1, 3, 256, 256).cuda() 560 | 561 | flops = FlopCountAnalysis(model, (left, right)) 562 | print(flop_count_table(flops)) 563 | # summary(model, [(4, 256, 256), (1, 256, 256)]) 564 | flops, params = profile(model, (left, right), verbose=False) 565 | 566 | flops = flops * 2 567 | flops, params = clever_format([flops, params], "%.3f") 568 | print('Total GFLOPS: %s' % flops) 569 | print('Total params: %s' % params) 570 | -------------------------------------------------------------------------------- /new_model/encoders/Transformer/dual_segformer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from functools import partial 4 | from timm.models.layers import DropPath, to_2tuple, trunc_normal_ 5 | 6 | # from models.new_model.modules import FeatureFusion as FFM 7 | # from models.new_model.modules import FeatureCorrection_s2c as FCM 8 | 9 | from models.new_model.net_utils import FeatureFusionModule as FFM 10 | from models.new_model.net_utils import FeatureRectifyModule as FCM 11 | 12 | import math 13 | import time 14 | 15 | from collections import OrderedDict 16 | 17 | 18 | class DWConv(nn.Module): 19 | """ 20 | Depthwise convolution bloc: input: x with size(B N C); output size (B N C) 21 | """ 22 | 23 | def __init__(self, dim=768): 24 | super(DWConv, self).__init__() 25 | self.dwconv = nn.Conv2d(dim, dim, kernel_size=3, stride=1, padding=1, bias=True, groups=dim) 26 | 27 | def forward(self, x, H, W): 28 | B, N, C = x.shape 29 | x = x.permute(0, 2, 1).reshape(B, C, H, W).contiguous() # B N C -> B C N -> B C H W 30 | x = self.dwconv(x) 31 | x = x.flatten(2).transpose(1, 2) # B C H W -> B N C 32 | 33 | return x 34 | 35 | 36 | class Mlp(nn.Module): 37 | def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): 38 | super().__init__() 39 | """ 40 | MLP Block: 41 | """ 42 | out_features = out_features or in_features 43 | hidden_features = hidden_features or in_features 44 | self.fc1 = nn.Linear(in_features, hidden_features) 45 | self.dwconv = DWConv(hidden_features) 46 | self.act = act_layer() 47 | self.fc2 = nn.Linear(hidden_features, out_features) 48 | self.drop = nn.Dropout(drop) 49 | 50 | self.apply(self._init_weights) 51 | 52 | def _init_weights(self, m): 53 | if isinstance(m, nn.Linear): 54 | trunc_normal_(m.weight, std=.02) 55 | if isinstance(m, nn.Linear) and m.bias is not None: 56 | nn.init.constant_(m.bias, 0) 57 | elif isinstance(m, nn.LayerNorm): 58 | nn.init.constant_(m.bias, 0) 59 | nn.init.constant_(m.weight, 1.0) 60 | elif isinstance(m, nn.Conv2d): 61 | fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 62 | fan_out //= m.groups 63 | m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) 64 | if m.bias is not None: 65 | m.bias.data.zero_() 66 | 67 | def forward(self, x, H, W): 68 | x = self.fc1(x) 69 | x = self.dwconv(x, H, W) 70 | x = self.act(x) 71 | x = self.drop(x) 72 | x = self.fc2(x) 73 | x = self.drop(x) 74 | return x 75 | 76 | 77 | class Attention(nn.Module): 78 | def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0., sr_ratio=1): 79 | super().__init__() 80 | assert dim % num_heads == 0, f"dim {dim} should be divided by num_heads {num_heads}." 81 | 82 | self.dim = dim 83 | self.num_heads = num_heads 84 | head_dim = dim // num_heads 85 | self.scale = qk_scale or head_dim ** -0.5 86 | 87 | # Linear embedding 88 | self.q = nn.Linear(dim, dim, bias=qkv_bias) 89 | self.kv = nn.Linear(dim, dim * 2, bias=qkv_bias) 90 | self.attn_drop = nn.Dropout(attn_drop) 91 | self.proj = nn.Linear(dim, dim) 92 | self.proj_drop = nn.Dropout(proj_drop) 93 | 94 | self.sr_ratio = sr_ratio 95 | if sr_ratio > 1: 96 | self.sr = nn.Conv2d(dim, dim, kernel_size=sr_ratio, stride=sr_ratio) 97 | self.norm = nn.LayerNorm(dim) 98 | 99 | self.apply(self._init_weights) 100 | 101 | def _init_weights(self, m): 102 | if isinstance(m, nn.Linear): 103 | trunc_normal_(m.weight, std=.02) 104 | if isinstance(m, nn.Linear) and m.bias is not None: 105 | nn.init.constant_(m.bias, 0) 106 | elif isinstance(m, nn.LayerNorm): 107 | nn.init.constant_(m.bias, 0) 108 | nn.init.constant_(m.weight, 1.0) 109 | elif isinstance(m, nn.Conv2d): 110 | fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 111 | fan_out //= m.groups 112 | m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) 113 | if m.bias is not None: 114 | m.bias.data.zero_() 115 | 116 | def forward(self, x, H, W): 117 | B, N, C = x.shape 118 | # B N C -> B N num_heads C//num_heads -> B num_heads N C//num_heads 119 | q = self.q(x).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3) 120 | 121 | if self.sr_ratio > 1: 122 | # B C//num_head N num_heads -> B N C//num_heads num_heads -> B C H W 123 | x_ = x.permute(0, 2, 1).reshape(B, C, H, W) 124 | # B C H W -> B C H/R W/R -> B C HW/R² -> B HW/R² C 125 | x_ = self.sr(x_).reshape(B, C, -1).permute(0, 2, 1) 126 | x_ = self.norm(x_) 127 | # B HW/R² C -> B HW/R² 2C -> B HW/R² 2 num_heads C//num_heads -> 2 B num_heads HW/R² C//num_heads 128 | kv = self.kv(x_).reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) 129 | else: 130 | kv = self.kv(x).reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) 131 | k, v = kv[0], kv[1] 132 | 133 | # B num_heads N HW/R² 134 | attn = (q @ k.transpose(-2, -1)) * self.scale 135 | attn = attn.softmax(dim=-1) 136 | attn = self.attn_drop(attn) 137 | 138 | # B num_heads N C//num_heads -> B N num_heads C//num_heads -> B N C 139 | x = (attn @ v).transpose(1, 2).reshape(B, N, C) 140 | x = self.proj(x) 141 | x = self.proj_drop(x) 142 | 143 | return x 144 | 145 | 146 | class Block(nn.Module): 147 | """ 148 | Transformer Block: Self-Attention -> Mix FFN -> OverLap Patch Merging 149 | """ 150 | 151 | def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., 152 | drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, sr_ratio=1): 153 | super().__init__() 154 | self.norm1 = norm_layer(dim) 155 | self.attn = Attention( 156 | dim, 157 | num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, 158 | attn_drop=attn_drop, proj_drop=drop, sr_ratio=sr_ratio) 159 | # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here 160 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() 161 | self.norm2 = norm_layer(dim) 162 | mlp_hidden_dim = int(dim * mlp_ratio) 163 | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) 164 | 165 | self.apply(self._init_weights) 166 | 167 | def _init_weights(self, m): 168 | if isinstance(m, nn.Linear): 169 | trunc_normal_(m.weight, std=.02) 170 | if isinstance(m, nn.Linear) and m.bias is not None: 171 | nn.init.constant_(m.bias, 0) 172 | elif isinstance(m, nn.LayerNorm): 173 | nn.init.constant_(m.bias, 0) 174 | nn.init.constant_(m.weight, 1.0) 175 | elif isinstance(m, nn.Conv2d): 176 | fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 177 | fan_out //= m.groups 178 | m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) 179 | if m.bias is not None: 180 | m.bias.data.zero_() 181 | 182 | def forward(self, x, H, W): 183 | x = x + self.drop_path(self.attn(self.norm1(x), H, W)) 184 | x = x + self.drop_path(self.mlp(self.norm2(x), H, W)) 185 | 186 | return x 187 | 188 | 189 | class OverlapPatchEmbed(nn.Module): 190 | """ Image to Patch Embedding 191 | """ 192 | 193 | def __init__(self, img_size=224, patch_size=7, stride=4, in_chans=3, embed_dim=768): 194 | super().__init__() 195 | img_size = to_2tuple(img_size) 196 | patch_size = to_2tuple(patch_size) 197 | 198 | self.img_size = img_size 199 | self.patch_size = patch_size 200 | self.H, self.W = img_size[0] // patch_size[0], img_size[1] // patch_size[1] 201 | self.num_patches = self.H * self.W 202 | self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=stride, 203 | padding=(patch_size[0] // 2, patch_size[1] // 2)) 204 | self.norm = nn.LayerNorm(embed_dim) 205 | 206 | self.apply(self._init_weights) 207 | 208 | def _init_weights(self, m): 209 | if isinstance(m, nn.Linear): 210 | trunc_normal_(m.weight, std=.02) 211 | if isinstance(m, nn.Linear) and m.bias is not None: 212 | nn.init.constant_(m.bias, 0) 213 | elif isinstance(m, nn.LayerNorm): 214 | nn.init.constant_(m.bias, 0) 215 | nn.init.constant_(m.weight, 1.0) 216 | elif isinstance(m, nn.Conv2d): 217 | fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 218 | fan_out //= m.groups 219 | m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) 220 | if m.bias is not None: 221 | m.bias.data.zero_() 222 | 223 | def forward(self, x): 224 | # B C H W -> B C H/4 W/4 225 | x = self.proj(x) 226 | _, _, H, W = x.shape 227 | x = x.flatten(2).transpose(1, 2) 228 | # B H*W/16 C 229 | x = self.norm(x) 230 | 231 | return x, H, W 232 | 233 | 234 | class DualSegFormer(nn.Module): 235 | def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dims=[64, 128, 256, 512], 236 | num_heads=[1, 2, 4, 8], mlp_ratios=[4, 4, 4, 4], qkv_bias=False, qk_scale=None, drop_rate=0., 237 | attn_drop_rate=0., drop_path_rate=0., norm_layer=nn.LayerNorm, norm_fuse=nn.BatchNorm2d, 238 | depths=[3, 4, 6, 3], sr_ratios=[8, 4, 2, 1]): 239 | super().__init__() 240 | self.num_classes = num_classes 241 | self.depths = depths 242 | 243 | # patch_embed 244 | self.patch_embed1 = OverlapPatchEmbed(img_size=img_size, patch_size=7, stride=4, in_chans=in_chans, 245 | embed_dim=embed_dims[0]) 246 | self.patch_embed2 = OverlapPatchEmbed(img_size=img_size // 4, patch_size=3, stride=2, in_chans=embed_dims[0], 247 | embed_dim=embed_dims[1]) 248 | self.patch_embed3 = OverlapPatchEmbed(img_size=img_size // 8, patch_size=3, stride=2, in_chans=embed_dims[1], 249 | embed_dim=embed_dims[2]) 250 | self.patch_embed4 = OverlapPatchEmbed(img_size=img_size // 16, patch_size=3, stride=2, in_chans=embed_dims[2], 251 | embed_dim=embed_dims[3]) 252 | 253 | self.extra_patch_embed1 = OverlapPatchEmbed(img_size=img_size, patch_size=7, stride=4, in_chans=in_chans, 254 | embed_dim=embed_dims[0]) 255 | self.extra_patch_embed2 = OverlapPatchEmbed(img_size=img_size // 4, patch_size=3, stride=2, 256 | in_chans=embed_dims[0], 257 | embed_dim=embed_dims[1]) 258 | self.extra_patch_embed3 = OverlapPatchEmbed(img_size=img_size // 8, patch_size=3, stride=2, 259 | in_chans=embed_dims[1], 260 | embed_dim=embed_dims[2]) 261 | self.extra_patch_embed4 = OverlapPatchEmbed(img_size=img_size // 16, patch_size=3, stride=2, 262 | in_chans=embed_dims[2], 263 | embed_dim=embed_dims[3]) 264 | 265 | # transformer encoder 266 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule 267 | cur = 0 268 | 269 | self.block1 = nn.ModuleList([Block( 270 | dim=embed_dims[0], num_heads=num_heads[0], mlp_ratio=mlp_ratios[0], qkv_bias=qkv_bias, qk_scale=qk_scale, 271 | drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[cur + i], norm_layer=norm_layer, 272 | sr_ratio=sr_ratios[0]) 273 | for i in range(depths[0])]) 274 | self.norm1 = norm_layer(embed_dims[0]) 275 | 276 | self.extra_block1 = nn.ModuleList([Block( 277 | dim=embed_dims[0], num_heads=num_heads[0], mlp_ratio=mlp_ratios[0], qkv_bias=qkv_bias, qk_scale=qk_scale, 278 | drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[cur + i], norm_layer=norm_layer, 279 | sr_ratio=sr_ratios[0]) 280 | for i in range(depths[0])]) 281 | self.extra_norm1 = norm_layer(embed_dims[0]) 282 | cur += depths[0] 283 | 284 | self.block2 = nn.ModuleList([Block( 285 | dim=embed_dims[1], num_heads=num_heads[1], mlp_ratio=mlp_ratios[1], qkv_bias=qkv_bias, qk_scale=qk_scale, 286 | drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[cur], norm_layer=norm_layer, 287 | sr_ratio=sr_ratios[1]) 288 | for i in range(depths[1])]) 289 | self.norm2 = norm_layer(embed_dims[1]) 290 | 291 | self.extra_block2 = nn.ModuleList([Block( 292 | dim=embed_dims[1], num_heads=num_heads[1], mlp_ratio=mlp_ratios[1], qkv_bias=qkv_bias, qk_scale=qk_scale, 293 | drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[cur + 1], norm_layer=norm_layer, 294 | sr_ratio=sr_ratios[1]) 295 | for i in range(depths[1])]) 296 | self.extra_norm2 = norm_layer(embed_dims[1]) 297 | 298 | cur += depths[1] 299 | 300 | self.block3 = nn.ModuleList([Block( 301 | dim=embed_dims[2], num_heads=num_heads[2], mlp_ratio=mlp_ratios[2], qkv_bias=qkv_bias, qk_scale=qk_scale, 302 | drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[cur + i], norm_layer=norm_layer, 303 | sr_ratio=sr_ratios[2]) 304 | for i in range(depths[2])]) 305 | self.norm3 = norm_layer(embed_dims[2]) 306 | 307 | self.extra_block3 = nn.ModuleList([Block( 308 | dim=embed_dims[2], num_heads=num_heads[2], mlp_ratio=mlp_ratios[2], qkv_bias=qkv_bias, qk_scale=qk_scale, 309 | drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[cur + i], norm_layer=norm_layer, 310 | sr_ratio=sr_ratios[2]) 311 | for i in range(depths[2])]) 312 | self.extra_norm3 = norm_layer(embed_dims[2]) 313 | 314 | cur += depths[2] 315 | 316 | self.block4 = nn.ModuleList([Block( 317 | dim=embed_dims[3], num_heads=num_heads[3], mlp_ratio=mlp_ratios[3], qkv_bias=qkv_bias, qk_scale=qk_scale, 318 | drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[cur + i], norm_layer=norm_layer, 319 | sr_ratio=sr_ratios[3]) 320 | for i in range(depths[3])]) 321 | self.norm4 = norm_layer(embed_dims[3]) 322 | 323 | self.extra_block4 = nn.ModuleList([Block( 324 | dim=embed_dims[3], num_heads=num_heads[3], mlp_ratio=mlp_ratios[3], qkv_bias=qkv_bias, qk_scale=qk_scale, 325 | drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[cur + i], norm_layer=norm_layer, 326 | sr_ratio=sr_ratios[3]) 327 | for i in range(depths[3])]) 328 | self.extra_norm4 = norm_layer(embed_dims[3]) 329 | 330 | cur += depths[3] 331 | 332 | self.FCMs = nn.ModuleList([ 333 | FCM(dim=embed_dims[0], reduction=1), 334 | FCM(dim=embed_dims[1], reduction=1), 335 | FCM(dim=embed_dims[2], reduction=1), 336 | FCM(dim=embed_dims[3], reduction=1)]) 337 | 338 | # self.FFMs = nn.ModuleList([ 339 | # FFM(dim=embed_dims[0], reduction=1, num_heads=num_heads[0], norm_layer=norm_fuse, sr_ratio=sr_ratios[0]), 340 | # FFM(dim=embed_dims[1], reduction=1, num_heads=num_heads[1], norm_layer=norm_fuse, sr_ratio=sr_ratios[1]), 341 | # FFM(dim=embed_dims[2], reduction=1, num_heads=num_heads[2], norm_layer=norm_fuse, sr_ratio=sr_ratios[2]), 342 | # FFM(dim=embed_dims[3], reduction=1, num_heads=num_heads[3], norm_layer=norm_fuse, sr_ratio=sr_ratios[3])]) 343 | 344 | self.FFMs = nn.ModuleList([ 345 | FFM(dim=embed_dims[0], reduction=1, num_heads=num_heads[0], norm_layer=norm_fuse), 346 | FFM(dim=embed_dims[1], reduction=1, num_heads=num_heads[1], norm_layer=norm_fuse), 347 | FFM(dim=embed_dims[2], reduction=1, num_heads=num_heads[2], norm_layer=norm_fuse), 348 | FFM(dim=embed_dims[3], reduction=1, num_heads=num_heads[3], norm_layer=norm_fuse)]) 349 | 350 | self.apply(self._init_weights) 351 | 352 | def _init_weights(self, m): 353 | if isinstance(m, nn.Linear): 354 | trunc_normal_(m.weight, std=.02) 355 | if isinstance(m, nn.Linear) and m.bias is not None: 356 | nn.init.constant_(m.bias, 0) 357 | elif isinstance(m, nn.LayerNorm): 358 | nn.init.constant_(m.bias, 0) 359 | nn.init.constant_(m.weight, 1.0) 360 | elif isinstance(m, nn.Conv2d): 361 | fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 362 | fan_out //= m.groups 363 | m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) 364 | if m.bias is not None: 365 | m.bias.data.zero_() 366 | 367 | def init_weights(self, pretrained=None): 368 | 369 | if isinstance(pretrained, str): 370 | load_dualpath_model(self, pretrained) 371 | else: 372 | raise TypeError('pretrained must be a str or None') 373 | 374 | def forward_features(self, x1, x2): 375 | """ 376 | x1: B x N x H x W 377 | """ 378 | B = x1.shape[0] 379 | outs = [] 380 | 381 | # stage 1 382 | x1, H, W = self.patch_embed1(x1) 383 | # B H*W/16 C 384 | x2, _, _ = self.extra_patch_embed1(x2) 385 | for i, blk in enumerate(self.block1): 386 | x1 = blk(x1, H, W) 387 | for i, blk in enumerate(self.extra_block1): 388 | x2 = blk(x2, H, W) 389 | x1 = self.norm1(x1) 390 | x2 = self.extra_norm1(x2) 391 | 392 | x1 = x1.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous() 393 | x2 = x2.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous() 394 | x1, x2 = self.FCMs[0](x1, x2) 395 | x_fused = self.FFMs[0](x1, x2) 396 | outs.append(x_fused) 397 | 398 | # stage 2 399 | x1, H, W = self.patch_embed2(x1) 400 | x2, _, _ = self.extra_patch_embed2(x2) 401 | for i, blk in enumerate(self.block2): 402 | x1 = blk(x1, H, W) 403 | for i, blk in enumerate(self.extra_block2): 404 | x2 = blk(x2, H, W) 405 | x1 = self.norm2(x1) 406 | x2 = self.extra_norm2(x2) 407 | 408 | x1 = x1.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous() 409 | x2 = x2.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous() 410 | x1, x2 = self.FCMs[1](x1, x2) 411 | x_fused = self.FFMs[1](x1, x2) 412 | outs.append(x_fused) 413 | 414 | # stage 3 415 | x1, H, W = self.patch_embed3(x1) 416 | x2, _, _ = self.extra_patch_embed3(x2) 417 | for i, blk in enumerate(self.block3): 418 | x1 = blk(x1, H, W) 419 | for i, blk in enumerate(self.extra_block3): 420 | x2 = blk(x2, H, W) 421 | x1 = self.norm3(x1) 422 | x2 = self.extra_norm3(x2) 423 | 424 | x1 = x1.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous() 425 | x2 = x2.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous() 426 | x1, x2 = self.FCMs[2](x1, x2) 427 | x_fused = self.FFMs[2](x1, x2) 428 | outs.append(x_fused) 429 | 430 | # stage 4 431 | x1, H, W = self.patch_embed4(x1) 432 | x2, _, _ = self.extra_patch_embed4(x2) 433 | for i, blk in enumerate(self.block4): 434 | x1 = blk(x1, H, W) 435 | for i, blk in enumerate(self.extra_block4): 436 | x2 = blk(x2, H, W) 437 | x1 = self.norm4(x1) 438 | x2 = self.extra_norm4(x2) 439 | 440 | x1 = x1.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous() 441 | x2 = x2.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous() 442 | x1, x2 = self.FCMs[3](x1, x2) 443 | x_fused = self.FFMs[3](x1, x2) 444 | outs.append(x_fused) 445 | 446 | return tuple(outs) 447 | 448 | def forward(self, x1, x2): 449 | out = self.forward_features(x1, x2) 450 | return out 451 | 452 | 453 | def load_dualpath_model(model, model_file, is_restore=False): 454 | # load raw state_dict 455 | t_start = time.time() 456 | if isinstance(model_file, str): 457 | raw_state_dict = torch.load(model_file, map_location=torch.device('cpu')) 458 | # raw_state_dict = torch.load(model_file) 459 | if 'model' in raw_state_dict.keys(): 460 | raw_state_dict = raw_state_dict['model'] 461 | else: 462 | raw_state_dict = model_file 463 | 464 | state_dict = {} 465 | for k, v in raw_state_dict.items(): 466 | if k.find('patch_embed') >= 0: 467 | state_dict[k] = v 468 | state_dict[k.replace('patch_embed', 'extra_patch_embed')] = v 469 | elif k.find('block') >= 0: 470 | state_dict[k] = v 471 | state_dict[k.replace('block', 'extra_block')] = v 472 | elif k.find('norm') >= 0: 473 | state_dict[k] = v 474 | state_dict[k.replace('norm', 'extra_norm')] = v 475 | 476 | t_ioend = time.time() 477 | 478 | if is_restore: 479 | new_state_dict = OrderedDict() 480 | for k, v in state_dict.items(): 481 | name = 'module.' + k 482 | new_state_dict[name] = v 483 | state_dict = new_state_dict 484 | 485 | model.load_state_dict(state_dict, strict=False) 486 | 487 | del state_dict 488 | 489 | t_end = time.time() 490 | print("Load model, Time usage:\n\tIO: {}, initialize parameters: {}".format( 491 | t_ioend - t_start, t_end - t_ioend)) 492 | 493 | 494 | class mit_b0(DualSegFormer): 495 | def __init__(self, fuse_cfg=None, **kwargs): 496 | super(mit_b0, self).__init__( 497 | patch_size=4, embed_dims=[32, 64, 160, 256], num_heads=[1, 2, 5, 8], mlp_ratios=[4, 4, 4, 4], 498 | qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[2, 2, 2, 2], sr_ratios=[8, 4, 2, 1], 499 | drop_rate=0.0, drop_path_rate=0.1, **kwargs) 500 | 501 | 502 | class mit_b1(DualSegFormer): 503 | def __init__(self, fuse_cfg=None, **kwargs): 504 | super(mit_b1, self).__init__( 505 | patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[4, 4, 4, 4], 506 | qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[2, 2, 2, 2], sr_ratios=[8, 4, 2, 1], 507 | drop_rate=0.0, drop_path_rate=0.1, **kwargs) 508 | 509 | 510 | class mit_b2(DualSegFormer): 511 | def __init__(self, fuse_cfg=None, **kwargs): 512 | super(mit_b2, self).__init__( 513 | patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[4, 4, 4, 4], 514 | qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[3, 4, 6, 3], sr_ratios=[8, 4, 2, 1], 515 | drop_rate=0.0, drop_path_rate=0.1, **kwargs) 516 | 517 | 518 | class mit_b3(DualSegFormer): 519 | def __init__(self, fuse_cfg=None, **kwargs): 520 | super(mit_b3, self).__init__( 521 | patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[4, 4, 4, 4], 522 | qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[3, 4, 18, 3], sr_ratios=[8, 4, 2, 1], 523 | drop_rate=0.0, drop_path_rate=0.1, **kwargs) 524 | 525 | 526 | class mit_b4(DualSegFormer): 527 | def __init__(self, fuse_cfg=None, **kwargs): 528 | super(mit_b4, self).__init__( 529 | patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[4, 4, 4, 4], 530 | qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[3, 8, 27, 3], sr_ratios=[8, 4, 2, 1], 531 | drop_rate=0.0, drop_path_rate=0.1, **kwargs) 532 | 533 | 534 | class mit_b5(DualSegFormer): 535 | def __init__(self, fuse_cfg=None, **kwargs): 536 | super(mit_b5, self).__init__( 537 | patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[4, 4, 4, 4], 538 | qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[3, 6, 40, 3], sr_ratios=[8, 4, 2, 1], 539 | drop_rate=0.0, drop_path_rate=0.1, **kwargs) 540 | -------------------------------------------------------------------------------- /new_model/encoders/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/masurq/CFFormer/6be4de96f134c162253933dbf4271485d2286a87/new_model/encoders/__init__.py -------------------------------------------------------------------------------- /new_model/encoders/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/masurq/CFFormer/6be4de96f134c162253933dbf4271485d2286a87/new_model/encoders/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /new_model/init_func.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | def __init_weight(feature, conv_init, norm_layer, bn_eps, bn_momentum, **kwargs): 6 | for name, m in feature.named_modules(): 7 | if isinstance(m, (nn.Conv1d, nn.Conv2d, nn.Conv3d)): 8 | conv_init(m.weight, **kwargs) 9 | 10 | elif isinstance(m, norm_layer): 11 | m.eps = bn_eps 12 | m.momentum = bn_momentum 13 | nn.init.constant_(m.weight, 1) 14 | nn.init.constant_(m.bias, 0) 15 | 16 | 17 | def init_weight(module_list, conv_init, norm_layer, bn_eps, bn_momentum, **kwargs): 18 | if isinstance(module_list, list): 19 | for feature in module_list: 20 | __init_weight(feature, conv_init, norm_layer, bn_eps, bn_momentum, 21 | **kwargs) 22 | else: 23 | __init_weight(module_list, conv_init, norm_layer, bn_eps, bn_momentum, 24 | **kwargs) 25 | 26 | 27 | def group_weight(weight_group, module, norm_layer, lr): 28 | group_decay = [] 29 | group_no_decay = [] 30 | for m in module.modules(): 31 | if isinstance(m, nn.Linear): 32 | group_decay.append(m.weight) 33 | if m.bias is not None: 34 | group_no_decay.append(m.bias) 35 | 36 | elif isinstance(m, (nn.Conv1d, nn.Conv2d, nn.Conv3d, nn.ConvTranspose2d, nn.ConvTranspose3d)): 37 | group_decay.append(m.weight) 38 | if m.bias is not None: 39 | group_no_decay.append(m.bias) 40 | 41 | elif isinstance(m, norm_layer) or isinstance(m, nn.BatchNorm1d) or isinstance(m, nn.BatchNorm2d) \ 42 | or isinstance(m, nn.BatchNorm3d) or isinstance(m, nn.GroupNorm): 43 | if m.weight is not None: 44 | group_no_decay.append(m.weight) 45 | if m.bias is not None: 46 | group_no_decay.append(m.bias) 47 | 48 | elif isinstance(m, nn.Parameter): 49 | group_decay.append(m) 50 | 51 | assert len(list(module.parameters())) >= len(group_decay) + len(group_no_decay) 52 | weight_group.append(dict(params=group_decay, lr=lr)) 53 | weight_group.append(dict(params=group_no_decay, weight_decay=.0, lr=lr)) 54 | return weight_group 55 | 56 | 57 | def nostride_dilate(m, dilate): 58 | if isinstance(m, nn.Conv2d): 59 | if m.stride == (2, 2): 60 | m.stride = (1, 1) 61 | if m.kernel_size == (3, 3): 62 | m.dilation = (dilate, dilate) 63 | m.padding = (dilate, dilate) 64 | else: 65 | if m.kernel_size == (3, 3): 66 | m.dilation = (dilate, dilate) 67 | m.padding = (dilate, dilate) 68 | 69 | 70 | def patch_first_conv_single_biformer(model, in_channel1, in_channel2): 71 | """Change first convolution layer input channels. 72 | In case: 73 | in_channels == 1 or in_channels == 2 -> reuse original weights 74 | in_channels > 3 -> make random kaiming normal initialization 75 | """ 76 | 77 | conv1_found = False 78 | 79 | for name, module in model.named_modules(): 80 | if not conv1_found and isinstance(module, nn.Conv2d) and "downsample_layers" in name: 81 | conv1_found = True 82 | in_channel = in_channel1 + in_channel2 83 | module.in_channels = in_channel 84 | weight = module.weight.detach() 85 | reset = False 86 | 87 | if in_channel == 1: 88 | weight = weight.sum(1, keepdim=True) 89 | elif in_channel == 2: 90 | weight[:, 0] = weight[:, 0] + 0.5 * weight[:, 1] 91 | weight[:, 1] = weight[:, 2] + 0.5 * weight[:, 1] 92 | weight = weight[:, :2] 93 | else: 94 | for i in range(3, in_channel): 95 | weight = torch.cat((weight, weight[:, (i % 3):(i % 3 + 1)]), dim=1) 96 | weight *= (3 / in_channel) 97 | 98 | module.weight = nn.parameter.Parameter(weight) 99 | 100 | if conv1_found: 101 | break 102 | 103 | 104 | def patch_first_conv_mit(model, in_channel1, in_channel2): 105 | """Change first convolution layer input channels. 106 | In case: 107 | in_channels == 1 or in_channels == 2 -> reuse original weights 108 | in_channels > 3 -> make random kaiming normal initialization 109 | """ 110 | 111 | conv1_found = False 112 | conv2_found = False 113 | 114 | for name, module in model.named_modules(): 115 | if not conv1_found and isinstance(module, nn.Conv2d) and "patch_embed1" in name: 116 | conv1_found = True 117 | module.in_channels = in_channel1 118 | weight = module.weight.detach() 119 | reset = False 120 | 121 | if in_channel1 == 1: 122 | weight = weight.sum(1, keepdim=True) 123 | elif in_channel1 == 2: 124 | weight[:, 0] = weight[:, 0] + 0.5 * weight[:, 1] 125 | weight[:, 1] = weight[:, 2] + 0.5 * weight[:, 1] 126 | weight = weight[:, :2] 127 | else: 128 | for i in range(3, in_channel1): 129 | weight = torch.cat((weight, weight[:, (i % 3):(i % 3 + 1)]), dim=1) 130 | weight *= (3 / in_channel1) 131 | 132 | module.weight = nn.parameter.Parameter(weight) 133 | 134 | if not conv2_found and isinstance(module, nn.Conv2d) and "extra_patch_embed1" in name: 135 | conv2_found = True 136 | module.in_channels = in_channel2 137 | weight = module.weight.detach() 138 | reset = False 139 | 140 | if in_channel2 == 1: 141 | weight = weight.sum(1, keepdim=True) 142 | elif in_channel2 == 2: 143 | weight[:, 0] = weight[:, 0] + 0.5 * weight[:, 1] 144 | weight[:, 1] = weight[:, 2] + 0.5 * weight[:, 1] 145 | weight = weight[:, :2] 146 | else: 147 | for i in range(3, in_channel2): 148 | weight = torch.cat((weight, weight[:, (i % 3):(i % 3 + 1)]), dim=1) 149 | weight *= (3 / in_channel2) 150 | 151 | module.weight = nn.parameter.Parameter(weight) 152 | 153 | if conv1_found and conv2_found: 154 | break 155 | 156 | 157 | def patch_first_conv_DilateFormer(model, in_channel1, in_channel2): 158 | """Change first convolution layer input channels. 159 | In case: 160 | in_channels == 1 or in_channels == 2 -> reuse original weights 161 | in_channels > 3 -> make random kaiming normal initialization 162 | """ 163 | 164 | conv1_found = False 165 | conv2_found = False 166 | 167 | for name, module in model.named_modules(): 168 | if not conv1_found and isinstance(module, nn.Conv2d) and "patch_embed" in name: 169 | conv1_found = True 170 | module.in_channels = in_channel1 171 | weight = module.weight.detach() 172 | reset = False 173 | 174 | if in_channel1 == 1: 175 | weight = weight.sum(1, keepdim=True) 176 | elif in_channel1 == 2: 177 | weight[:, 0] = weight[:, 0] + 0.5 * weight[:, 1] 178 | weight[:, 1] = weight[:, 2] + 0.5 * weight[:, 1] 179 | weight = weight[:, :2] 180 | else: 181 | for i in range(3, in_channel1): 182 | weight = torch.cat((weight, weight[:, (i % 3):(i % 3 + 1)]), dim=1) 183 | weight *= (3 / in_channel1) 184 | 185 | module.weight = nn.parameter.Parameter(weight) 186 | 187 | if not conv2_found and isinstance(module, nn.Conv2d) and "aux_patch_embed" in name: 188 | conv2_found = True 189 | module.in_channels = in_channel2 190 | weight = module.weight.detach() 191 | reset = False 192 | 193 | if in_channel2 == 1: 194 | weight = weight.sum(1, keepdim=True) 195 | elif in_channel2 == 2: 196 | weight[:, 0] = weight[:, 0] + 0.5 * weight[:, 1] 197 | weight[:, 1] = weight[:, 2] + 0.5 * weight[:, 1] 198 | weight = weight[:, :2] 199 | else: 200 | for i in range(3, in_channel2): 201 | weight = torch.cat((weight, weight[:, (i % 3):(i % 3 + 1)]), dim=1) 202 | weight *= (3 / in_channel2) 203 | 204 | module.weight = nn.parameter.Parameter(weight) 205 | 206 | if conv1_found and conv2_found: 207 | break 208 | 209 | 210 | def patch_first_conv_swin(model, in_channel1, in_channel2): 211 | """Change first convolution layer input channels. 212 | In case: 213 | in_channels == 1 or in_channels == 2 -> reuse original weights 214 | in_channels > 3 -> make random kaiming normal initialization 215 | """ 216 | 217 | conv1_found = False 218 | conv2_found = False 219 | 220 | for name, module in model.named_modules(): 221 | if not conv1_found and isinstance(module, nn.Conv2d) and "patch_embed" in name: 222 | conv1_found = True 223 | module.in_channels = in_channel1 224 | weight = module.weight.detach() 225 | reset = False 226 | 227 | if in_channel1 == 1: 228 | weight = weight.sum(1, keepdim=True) 229 | elif in_channel1 == 2: 230 | weight[:, 0] = weight[:, 0] + 0.5 * weight[:, 1] 231 | weight[:, 1] = weight[:, 2] + 0.5 * weight[:, 1] 232 | weight = weight[:, :2] 233 | else: 234 | for i in range(3, in_channel1): 235 | weight = torch.cat((weight, weight[:, (i % 3):(i % 3 + 1)]), dim=1) 236 | weight *= (3 / in_channel1) 237 | 238 | module.weight = nn.parameter.Parameter(weight) 239 | 240 | if not conv2_found and isinstance(module, nn.Conv2d) and "patch_embed_d" in name: 241 | conv2_found = True 242 | module.in_channels = in_channel2 243 | weight = module.weight.detach() 244 | reset = False 245 | 246 | if in_channel2 == 1: 247 | weight = weight.sum(1, keepdim=True) 248 | elif in_channel2 == 2: 249 | weight[:, 0] = weight[:, 0] + 0.5 * weight[:, 1] 250 | weight[:, 1] = weight[:, 2] + 0.5 * weight[:, 1] 251 | weight = weight[:, :2] 252 | else: 253 | for i in range(3, in_channel2): 254 | weight = torch.cat((weight, weight[:, (i % 3):(i % 3 + 1)]), dim=1) 255 | weight *= (3 / in_channel2) 256 | 257 | module.weight = nn.parameter.Parameter(weight) 258 | 259 | if conv1_found and conv2_found: 260 | break 261 | 262 | 263 | def patch_first_conv_cswin(model, in_channel1, in_channel2): 264 | """Change first convolution layer input channels. 265 | In case: 266 | in_channels == 1 or in_channels == 2 -> reuse original weights 267 | in_channels > 3 -> make random kaiming normal initialization 268 | """ 269 | 270 | conv1_found = False 271 | conv2_found = False 272 | 273 | for name, module in model.named_modules(): 274 | if not conv1_found and isinstance(module, nn.Conv2d) and "stage1_conv_embed" in name: 275 | conv1_found = True 276 | module.in_channels = in_channel1 277 | weight = module.weight.detach() 278 | reset = False 279 | 280 | if in_channel1 == 1: 281 | weight = weight.sum(1, keepdim=True) 282 | elif in_channel1 == 2: 283 | weight[:, 0] = weight[:, 0] + 0.5 * weight[:, 1] 284 | weight[:, 1] = weight[:, 2] + 0.5 * weight[:, 1] 285 | weight = weight[:, :2] 286 | else: 287 | for i in range(3, in_channel1): 288 | weight = torch.cat((weight, weight[:, (i % 3):(i % 3 + 1)]), dim=1) 289 | weight *= (3 / in_channel1) 290 | 291 | module.weight = nn.parameter.Parameter(weight) 292 | 293 | if not conv2_found and isinstance(module, nn.Conv2d) and "aux_stage1_conv_embed" in name: 294 | conv2_found = True 295 | module.in_channels = in_channel2 296 | weight = module.weight.detach() 297 | reset = False 298 | 299 | if in_channel2 == 1: 300 | weight = weight.sum(1, keepdim=True) 301 | elif in_channel2 == 2: 302 | weight[:, 0] = weight[:, 0] + 0.5 * weight[:, 1] 303 | weight[:, 1] = weight[:, 2] + 0.5 * weight[:, 1] 304 | weight = weight[:, :2] 305 | else: 306 | for i in range(3, in_channel2): 307 | weight = torch.cat((weight, weight[:, (i % 3):(i % 3 + 1)]), dim=1) 308 | weight *= (3 / in_channel2) 309 | 310 | module.weight = nn.parameter.Parameter(weight) 311 | 312 | if conv1_found and conv2_found: 313 | break 314 | 315 | 316 | def patch_first_conv_biformer(model, in_channel1, in_channel2): 317 | """Change first convolution layer input channels. 318 | In case: 319 | in_channels == 1 or in_channels == 2 -> reuse original weights 320 | in_channels > 3 -> make random kaiming normal initialization 321 | """ 322 | 323 | conv1_found = False 324 | conv2_found = False 325 | 326 | for name, module in model.named_modules(): 327 | if not conv1_found and isinstance(module, nn.Conv2d) and "downsample_layers" in name: 328 | conv1_found = True 329 | module.in_channels = in_channel1 330 | weight = module.weight.detach() 331 | reset = False 332 | 333 | if in_channel1 == 1: 334 | weight = weight.sum(1, keepdim=True) 335 | elif in_channel1 == 2: 336 | weight[:, 0] = weight[:, 0] + 0.5 * weight[:, 1] 337 | weight[:, 1] = weight[:, 2] + 0.5 * weight[:, 1] 338 | weight = weight[:, :2] 339 | else: 340 | for i in range(3, in_channel1): 341 | weight = torch.cat((weight, weight[:, (i % 3):(i % 3 + 1)]), dim=1) 342 | weight *= (3 / in_channel1) 343 | 344 | module.weight = nn.parameter.Parameter(weight) 345 | 346 | if not conv2_found and isinstance(module, nn.Conv2d) and "aux_downsample_layers" in name: 347 | conv2_found = True 348 | module.in_channels = in_channel2 349 | weight = module.weight.detach() 350 | reset = False 351 | 352 | if in_channel2 == 1: 353 | weight = weight.sum(1, keepdim=True) 354 | elif in_channel2 == 2: 355 | weight[:, 0] = weight[:, 0] + 0.5 * weight[:, 1] 356 | weight[:, 1] = weight[:, 2] + 0.5 * weight[:, 1] 357 | weight = weight[:, :2] 358 | else: 359 | for i in range(3, in_channel2): 360 | weight = torch.cat((weight, weight[:, (i % 3):(i % 3 + 1)]), dim=1) 361 | weight *= (3 / in_channel2) 362 | 363 | module.weight = nn.parameter.Parameter(weight) 364 | 365 | if conv1_found and conv2_found: 366 | break 367 | 368 | 369 | def patch_first_conv_vitae(model, in_channel1, in_channel2): 370 | """Change first convolution layer input channels. 371 | In case: 372 | in_channels == 1 or in_channels == 2 -> reuse original weights 373 | in_channels > 3 -> make random kaiming normal initialization 374 | """ 375 | 376 | conv1_found = False 377 | conv2_found = False 378 | conv3_found = False 379 | 380 | for name, module in model.named_modules(): 381 | if not conv1_found and isinstance(module, nn.Conv2d) and "layers.0.RC.PRM" in name: 382 | print(name) 383 | conv1_found = True 384 | module.in_channels = in_channel1 385 | weight = module.weight.detach() 386 | reset = False 387 | 388 | if in_channel1 == 1: 389 | weight = weight.sum(1, keepdim=True) 390 | elif in_channel1 == 2: 391 | weight[:, 0] = weight[:, 0] + 0.5 * weight[:, 1] 392 | weight[:, 1] = weight[:, 2] + 0.5 * weight[:, 1] 393 | weight = weight[:, :2] 394 | else: 395 | for i in range(3, in_channel1): 396 | weight = torch.cat((weight, weight[:, (i % 3):(i % 3 + 1)]), dim=1) 397 | weight *= (3 / in_channel1) 398 | 399 | module.weight = nn.parameter.Parameter(weight) 400 | if not conv3_found and isinstance(module, nn.Conv2d) and "layers.0.RC.PCM" in name: 401 | print(name) 402 | conv3_found = True 403 | module.in_channels = in_channel1 404 | weight = module.weight.detach() 405 | reset = False 406 | 407 | if in_channel1 == 1: 408 | weight = weight.sum(1, keepdim=True) 409 | elif in_channel1 == 2: 410 | weight[:, 0] = weight[:, 0] + 0.5 * weight[:, 1] 411 | weight[:, 1] = weight[:, 2] + 0.5 * weight[:, 1] 412 | weight = weight[:, :2] 413 | else: 414 | for i in range(3, in_channel1): 415 | weight = torch.cat((weight, weight[:, (i % 3):(i % 3 + 1)]), dim=1) 416 | weight *= (3 / in_channel1) 417 | 418 | module.weight = nn.parameter.Parameter(weight) 419 | if not conv2_found and isinstance(module, nn.Conv2d) and "aux_layers.0.RC.PRM" in name: 420 | print(name) 421 | conv2_found = True 422 | module.in_channels = in_channel2 423 | weight = module.weight.detach() 424 | reset = False 425 | 426 | if in_channel2 == 1: 427 | weight = weight.sum(1, keepdim=True) 428 | elif in_channel2 == 2: 429 | weight[:, 0] = weight[:, 0] + 0.5 * weight[:, 1] 430 | weight[:, 1] = weight[:, 2] + 0.5 * weight[:, 1] 431 | weight = weight[:, :2] 432 | else: 433 | for i in range(3, in_channel2): 434 | weight = torch.cat((weight, weight[:, (i % 3):(i % 3 + 1)]), dim=1) 435 | weight *= (3 / in_channel2) 436 | 437 | module.weight = nn.parameter.Parameter(weight) 438 | 439 | if conv1_found and conv2_found and conv3_found: 440 | break 441 | 442 | 443 | def patch_first_conv_SMT(model, in_channel1, in_channel2): 444 | """Change first convolution layer input channels. 445 | In case: 446 | in_channels == 1 or in_channels == 2 -> reuse original weights 447 | in_channels > 3 -> make random kaiming normal initialization 448 | """ 449 | 450 | conv1_found = False 451 | conv2_found = False 452 | 453 | for name, module in model.named_modules(): 454 | if not conv1_found and isinstance(module, nn.Conv2d) and "patch_embed1" in name: 455 | conv1_found = True 456 | module.in_channels = in_channel1 457 | weight = module.weight.detach() 458 | reset = False 459 | 460 | if in_channel1 == 1: 461 | weight = weight.sum(1, keepdim=True) 462 | elif in_channel1 == 2: 463 | weight[:, 0] = weight[:, 0] + 0.5 * weight[:, 1] 464 | weight[:, 1] = weight[:, 2] + 0.5 * weight[:, 1] 465 | weight = weight[:, :2] 466 | else: 467 | for i in range(3, in_channel1): 468 | weight = torch.cat((weight, weight[:, (i % 3):(i % 3 + 1)]), dim=1) 469 | weight *= (3 / in_channel1) 470 | 471 | module.weight = nn.parameter.Parameter(weight) 472 | 473 | if not conv2_found and isinstance(module, nn.Conv2d) and "aux_patch_embed1" in name: 474 | conv2_found = True 475 | module.in_channels = in_channel2 476 | weight = module.weight.detach() 477 | reset = False 478 | 479 | if in_channel2 == 1: 480 | weight = weight.sum(1, keepdim=True) 481 | elif in_channel2 == 2: 482 | weight[:, 0] = weight[:, 0] + 0.5 * weight[:, 1] 483 | weight[:, 1] = weight[:, 2] + 0.5 * weight[:, 1] 484 | weight = weight[:, :2] 485 | else: 486 | for i in range(3, in_channel2): 487 | weight = torch.cat((weight, weight[:, (i % 3):(i % 3 + 1)]), dim=1) 488 | weight *= (3 / in_channel2) 489 | 490 | module.weight = nn.parameter.Parameter(weight) 491 | 492 | if conv1_found and conv2_found: 493 | break 494 | 495 | 496 | def patch_first_conv_EMO(model, in_channel1, in_channel2): 497 | """Change first convolution layer input channels. 498 | In case: 499 | in_channels == 1 or in_channels == 2 -> reuse original weights 500 | in_channels > 3 -> make random kaiming normal initialization 501 | """ 502 | 503 | conv1_found = False 504 | conv2_found = False 505 | 506 | for name, module in model.named_modules(): 507 | if not conv1_found and isinstance(module, nn.Conv2d) and "stage0" in name: 508 | conv1_found = True 509 | module.in_channels = in_channel1 510 | weight = module.weight.detach() 511 | reset = False 512 | 513 | if in_channel1 == 1: 514 | weight = weight.sum(1, keepdim=True) 515 | elif in_channel1 == 2: 516 | weight[:, 0] = weight[:, 0] + 0.5 * weight[:, 1] 517 | weight[:, 1] = weight[:, 2] + 0.5 * weight[:, 1] 518 | weight = weight[:, :2] 519 | else: 520 | for i in range(3, in_channel1): 521 | weight = torch.cat((weight, weight[:, (i % 3):(i % 3 + 1)]), dim=1) 522 | weight *= (3 / in_channel1) 523 | 524 | module.weight = nn.parameter.Parameter(weight) 525 | 526 | if not conv2_found and isinstance(module, nn.Conv2d) and "aux_stage0" in name: 527 | conv2_found = True 528 | module.in_channels = in_channel2 529 | weight = module.weight.detach() 530 | reset = False 531 | 532 | if in_channel2 == 1: 533 | weight = weight.sum(1, keepdim=True) 534 | elif in_channel2 == 2: 535 | weight[:, 0] = weight[:, 0] + 0.5 * weight[:, 1] 536 | weight[:, 1] = weight[:, 2] + 0.5 * weight[:, 1] 537 | weight = weight[:, :2] 538 | else: 539 | for i in range(3, in_channel2): 540 | weight = torch.cat((weight, weight[:, (i % 3):(i % 3 + 1)]), dim=1) 541 | weight *= (3 / in_channel2) 542 | 543 | module.weight = nn.parameter.Parameter(weight) 544 | 545 | if conv1_found and conv2_found: 546 | break 547 | 548 | 549 | def patch_first_conv_resnet(model, in_channel1, in_channel2): 550 | """Change first convolution layer input channels. 551 | In case: 552 | in_channels == 1 or in_channels == 2 -> reuse original weights 553 | in_channels > 3 -> make random kaiming normal initialization 554 | """ 555 | 556 | conv1_found = False 557 | conv2_found = False 558 | 559 | for name, module in model.named_modules(): 560 | if not conv1_found and isinstance(module, nn.Conv2d) and "conv1" in name: 561 | conv1_found = True 562 | module.in_channels = in_channel1 563 | weight = module.weight.detach() 564 | reset = False 565 | 566 | if in_channel1 == 1: 567 | weight = weight.sum(1, keepdim=True) 568 | elif in_channel1 == 2: 569 | weight[:, 0] = weight[:, 0] + 0.5 * weight[:, 1] 570 | weight[:, 1] = weight[:, 2] + 0.5 * weight[:, 1] 571 | weight = weight[:, :2] 572 | else: 573 | for i in range(3, in_channel1): 574 | weight = torch.cat((weight, weight[:, (i % 3):(i % 3 + 1)]), dim=1) 575 | weight *= (3 / in_channel1) 576 | 577 | module.weight = nn.parameter.Parameter(weight) 578 | 579 | if not conv2_found and isinstance(module, nn.Conv2d) and "extra_conv1" in name: 580 | conv2_found = True 581 | module.in_channels = in_channel2 582 | weight = module.weight.detach() 583 | reset = False 584 | 585 | if in_channel2 == 1: 586 | weight = weight.sum(1, keepdim=True) 587 | elif in_channel2 == 2: 588 | weight[:, 0] = weight[:, 0] + 0.5 * weight[:, 1] 589 | weight[:, 1] = weight[:, 2] + 0.5 * weight[:, 1] 590 | weight = weight[:, :2] 591 | else: 592 | for i in range(3, in_channel2): 593 | weight = torch.cat((weight, weight[:, (i % 3):(i % 3 + 1)]), dim=1) 594 | weight *= (3 / in_channel2) 595 | 596 | module.weight = nn.parameter.Parameter(weight) 597 | 598 | if conv1_found and conv2_found: 599 | break 600 | 601 | 602 | def patch_first_conv_mobilenet(model, in_channel1, in_channel2): 603 | """Change first convolution layer input channels. 604 | In case: 605 | in_channels == 1 or in_channels == 2 -> reuse original weights 606 | in_channels > 3 -> make random kaiming normal initialization 607 | """ 608 | 609 | conv1_found = False 610 | conv2_found = False 611 | 612 | for name, module in model.named_modules(): 613 | if not conv1_found and isinstance(module, nn.Conv2d) and "features" in name: 614 | conv1_found = True 615 | module.in_channels = in_channel1 616 | weight = module.weight.detach() 617 | reset = False 618 | 619 | if in_channel1 == 1: 620 | weight = weight.sum(1, keepdim=True) 621 | elif in_channel1 == 2: 622 | weight[:, 0] = weight[:, 0] + 0.5 * weight[:, 1] 623 | weight[:, 1] = weight[:, 2] + 0.5 * weight[:, 1] 624 | weight = weight[:, :2] 625 | else: 626 | for i in range(3, in_channel1): 627 | weight = torch.cat((weight, weight[:, (i % 3):(i % 3 + 1)]), dim=1) 628 | weight *= (3 / in_channel1) 629 | 630 | module.weight = nn.parameter.Parameter(weight) 631 | 632 | if not conv2_found and isinstance(module, nn.Conv2d) and "aux_features" in name: 633 | conv2_found = True 634 | module.in_channels = in_channel2 635 | weight = module.weight.detach() 636 | reset = False 637 | 638 | if in_channel2 == 1: 639 | weight = weight.sum(1, keepdim=True) 640 | elif in_channel2 == 2: 641 | weight[:, 0] = weight[:, 0] + 0.5 * weight[:, 1] 642 | weight[:, 1] = weight[:, 2] + 0.5 * weight[:, 1] 643 | weight = weight[:, :2] 644 | else: 645 | for i in range(3, in_channel2): 646 | weight = torch.cat((weight, weight[:, (i % 3):(i % 3 + 1)]), dim=1) 647 | weight *= (3 / in_channel2) 648 | 649 | module.weight = nn.parameter.Parameter(weight) 650 | 651 | if conv1_found and conv2_found: 652 | break 653 | -------------------------------------------------------------------------------- /new_model/modules.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from timm.models.layers import trunc_normal_ 4 | import math 5 | from thop import clever_format, profile 6 | 7 | 8 | # Feature Rectify Module 9 | class ChannelWeights(nn.Module): 10 | def __init__(self, dim, reduction=1): 11 | super(ChannelWeights, self).__init__() 12 | self.dim = dim 13 | self.avg_pool = nn.AdaptiveAvgPool2d(1) 14 | self.max_pool = nn.AdaptiveMaxPool2d(1) 15 | self.mlp = nn.Sequential( 16 | nn.Linear(self.dim * 6, self.dim * 6 // reduction), 17 | nn.ReLU(inplace=True), 18 | nn.Linear(self.dim * 6 // reduction, self.dim * 2), 19 | nn.Sigmoid()) 20 | 21 | def forward(self, x1, x2): 22 | B, _, H, W = x1.shape 23 | x = torch.cat((x1, x2), dim=1) 24 | avg = self.avg_pool(x).view(B, self.dim * 2) 25 | std = torch.std(x, dim=(2, 3), keepdim=True).view(B, self.dim * 2) 26 | max = self.max_pool(x).view(B, self.dim * 2) 27 | y = torch.cat((avg, std, max), dim=1) # B 6C 28 | y = self.mlp(y).view(B, self.dim * 2, 1) 29 | channel_weights = y.reshape(B, 2, self.dim, 1, 1).permute(1, 0, 2, 3, 4) # 2 B C 1 1 30 | return channel_weights 31 | 32 | 33 | class SpatialWeights(nn.Module): 34 | def __init__(self, dim, reduction=1): 35 | super(SpatialWeights, self).__init__() 36 | self.dim = dim 37 | self.mlp = nn.Sequential( 38 | nn.Conv2d(self.dim * 2, self.dim // reduction, kernel_size=1), 39 | nn.ReLU(inplace=True), 40 | nn.Conv2d(self.dim // reduction, 2, kernel_size=1), 41 | nn.Sigmoid()) 42 | 43 | def forward(self, x1, x2): 44 | B, _, H, W = x1.shape 45 | x = torch.cat((x1, x2), dim=1) # B 2C H W 46 | spatial_weights = self.mlp(x).reshape(B, 2, 1, H, W).permute(1, 0, 2, 3, 4) # 2 B 1 H W 47 | return spatial_weights 48 | 49 | 50 | 51 | 52 | 53 | # 先空间校正再通道校正 54 | class FeatureCorrection_s2c(nn.Module): 55 | def __init__(self, dim, reduction=1, eps=1e-8): 56 | super(FeatureCorrection_s2c, self).__init__() 57 | # 自定义可训练权重参数 58 | self.weights = nn.Parameter(torch.ones(2, dtype=torch.float32), requires_grad=True) 59 | self.eps = eps 60 | self.spatial_weights = SpatialWeights(dim=dim, reduction=reduction) 61 | self.channel_weights = ChannelWeights(dim=dim, reduction=reduction) 62 | 63 | self.apply(self._init_weights) 64 | 65 | @classmethod 66 | def _init_weights(cls, m): 67 | if isinstance(m, nn.Linear): 68 | trunc_normal_(m.weight, std=.02) 69 | if isinstance(m, nn.Linear) and m.bias is not None: 70 | nn.init.constant_(m.bias, 0) 71 | elif isinstance(m, nn.LayerNorm): 72 | nn.init.constant_(m.bias, 0) 73 | nn.init.constant_(m.weight, 1.0) 74 | elif isinstance(m, nn.Conv2d): 75 | fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 76 | fan_out //= m.groups 77 | m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) 78 | if m.bias is not None: 79 | m.bias.data.zero_() 80 | 81 | def forward(self, x1, x2): 82 | weights = nn.ReLU()(self.weights) 83 | fuse_weights = weights / (torch.sum(weights, dim=0) + self.eps) 84 | 85 | spatial_weights = self.spatial_weights(x1, x2) 86 | x1_1 = x1 + fuse_weights[0] * spatial_weights[1] * x2 87 | x2_1 = x2 + fuse_weights[0] * spatial_weights[0] * x1 88 | 89 | channel_weights = self.channel_weights(x1_1, x2_1) 90 | 91 | main_out = x1_1 + fuse_weights[1] * channel_weights[1] * x2_1 92 | aux_out = x2_1 + fuse_weights[1] * channel_weights[0] * x1_1 93 | return main_out, aux_out 94 | 95 | 96 | class CrossAttention(nn.Module): 97 | def __init__(self, dim, num_heads=8, sr_ratio=1, qkv_bias=False, qk_scale=None): 98 | super(CrossAttention, self).__init__() 99 | assert dim % num_heads == 0, f"dim {dim} should be divided by num_heads {num_heads}." 100 | 101 | self.dim = dim 102 | self.num_heads = num_heads 103 | head_dim = dim // num_heads 104 | self.scale = qk_scale or head_dim ** -0.5 105 | 106 | self.q1 = nn.Linear(dim, dim, bias=qkv_bias) 107 | self.kv1 = nn.Linear(dim, dim * 2, bias=qkv_bias) 108 | 109 | self.q2 = nn.Linear(dim, dim, bias=qkv_bias) 110 | self.kv2 = nn.Linear(dim, dim * 2, bias=qkv_bias) 111 | 112 | self.sr_ratio = sr_ratio 113 | if sr_ratio > 1: 114 | self.sr1 = nn.Conv2d(dim, dim, kernel_size=sr_ratio + 1, stride=sr_ratio, padding=sr_ratio // 2, groups=dim) 115 | self.norm1 = nn.LayerNorm(dim) 116 | 117 | self.sr2 = nn.Conv2d(dim, dim, kernel_size=sr_ratio + 1, stride=sr_ratio, padding=sr_ratio // 2, groups=dim) 118 | self.norm2 = nn.LayerNorm(dim) 119 | 120 | def forward(self, x1, x2, H, W): 121 | B, N, C = x1.shape 122 | # B num_heads N C//num_heads 123 | q1 = self.q1(x1).reshape(B, -1, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3).contiguous() 124 | q2 = self.q2(x2).reshape(B, -1, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3).contiguous() 125 | 126 | if self.sr_ratio > 1: 127 | # B C//num_head N num_heads -> B N C//num_heads num_heads -> B C H W 128 | x_1 = x1.permute(0, 2, 1).reshape(B, C, H, W) 129 | # B C H W -> B C H/R W/R -> B C HW/R² -> B HW/R² C 130 | x_1 = self.sr1(x_1).reshape(B, C, -1).permute(0, 2, 1) 131 | x_1 = self.norm1(x_1) 132 | # B HW/R² C -> B HW/R² 2C -> B HW/R² 2 num_heads C//num_heads -> 2 B num_heads HW/R² C//num_heads 133 | kv1 = self.kv1(x_1).reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) 134 | 135 | x_2 = x2.permute(0, 2, 1).reshape(B, C, H, W) 136 | x_2 = self.sr2(x_2).reshape(B, C, -1).permute(0, 2, 1) 137 | x_2 = self.norm2(x_2) 138 | kv2 = self.kv2(x_2).reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) 139 | else: 140 | kv1 = self.kv1(x1).reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) 141 | 142 | kv2 = self.kv2(x2).reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) 143 | 144 | # B num_heads HW/R² C//num_heads 145 | k1, v1 = kv1[0], kv1[1] 146 | k2, v2 = kv2[0], kv2[1] 147 | 148 | # B num_heads N HW/R² 149 | attn1 = (q1 @ k2.transpose(-2, -1)) * self.scale 150 | attn1 = attn1.softmax(dim=-1) 151 | 152 | attn2 = (q2 @ k1.transpose(-2, -1)) * self.scale 153 | attn2 = attn2.softmax(dim=-1) 154 | 155 | # B num_heads N C//num_heads -> B N num_heads C//num_heads -> B N C 156 | main_out = (attn1 @ v2).transpose(1, 2).reshape(B, N, C) 157 | aux_out = (attn2 @ v1).transpose(1, 2).reshape(B, N, C) 158 | 159 | return main_out, aux_out 160 | 161 | 162 | class FeatureInteraction(nn.Module): 163 | def __init__(self, dim, reduction=1, num_heads=None, sr_ratio=None, norm_layer=nn.LayerNorm): 164 | super().__init__() 165 | self.channel_proj1 = nn.Linear(dim, dim // reduction * 2) 166 | self.channel_proj2 = nn.Linear(dim, dim // reduction * 2) 167 | self.act1 = nn.ReLU(inplace=True) 168 | self.act2 = nn.ReLU(inplace=True) 169 | self.cross_attn = CrossAttention(dim // reduction, num_heads=num_heads, sr_ratio=sr_ratio) 170 | self.end_proj1 = nn.Linear(dim // reduction * 2, dim) 171 | self.end_proj2 = nn.Linear(dim // reduction * 2, dim) 172 | self.norm1 = norm_layer(dim) 173 | self.norm2 = norm_layer(dim) 174 | 175 | def forward(self, x1, x2, H, W): 176 | y1, z1 = self.act1(self.channel_proj1(x1)).chunk(2, dim=-1) 177 | y2, z2 = self.act2(self.channel_proj2(x2)).chunk(2, dim=-1) 178 | c1, c2 = self.cross_attn(z1, z2, H, W) 179 | y1 = torch.cat((y1, c1), dim=-1) 180 | y2 = torch.cat((y2, c2), dim=-1) 181 | main_out = self.norm1(x1 + self.end_proj1(y1)) 182 | aux_out = self.norm2(x2 + self.end_proj2(y2)) 183 | 184 | return main_out, aux_out 185 | 186 | 187 | class ChannelEmbed(nn.Module): 188 | def __init__(self, in_channels, out_channels, reduction=1, norm_layer=nn.BatchNorm2d): 189 | super(ChannelEmbed, self).__init__() 190 | self.out_channels = out_channels 191 | self.residual = nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False) 192 | self.channel_embed = nn.Sequential( 193 | nn.Conv2d(in_channels, out_channels // reduction, kernel_size=1, bias=True), 194 | nn.Conv2d(out_channels // reduction, out_channels // reduction, kernel_size=3, stride=1, padding=1, 195 | bias=True, groups=out_channels // reduction), 196 | nn.ReLU(inplace=True), 197 | nn.Conv2d(out_channels // reduction, out_channels, kernel_size=1, bias=True), 198 | norm_layer(out_channels) 199 | ) 200 | self.norm = norm_layer(out_channels) 201 | 202 | def forward(self, x, H, W): 203 | B, N, _C = x.shape 204 | x = x.permute(0, 2, 1).reshape(B, _C, H, W).contiguous() 205 | residual = self.residual(x) 206 | x = self.channel_embed(x) 207 | out = self.norm(residual + x) 208 | 209 | return out 210 | 211 | 212 | class ChannelEmbed2(nn.Module): 213 | def __init__(self, in_channels, out_channels, norm_layer=nn.BatchNorm2d): 214 | super(ChannelEmbed2, self).__init__() 215 | self.out_channels = out_channels 216 | self.residual = nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=True) 217 | 218 | self.channel_embed1 = nn.Sequential( 219 | nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False, groups=out_channels), 220 | norm_layer(out_channels), 221 | nn.ReLU(inplace=True) 222 | ) 223 | self.channel_embed2 = nn.Sequential( 224 | nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False, groups=out_channels), 225 | norm_layer(out_channels), 226 | nn.ReLU(inplace=True) 227 | ) 228 | self.norm = norm_layer(out_channels) 229 | 230 | def forward(self, x1, x2, H, W): 231 | B, N, _C = x1.shape 232 | 233 | x1 = x1.permute(0, 2, 1).reshape(B, _C, H, W).contiguous() 234 | x2 = x2.permute(0, 2, 1).reshape(B, _C, H, W).contiguous() 235 | residual = self.residual(x1) 236 | 237 | x1 = self.channel_embed1(x1) 238 | x2 = self.channel_embed2(x2) 239 | fuse = self.norm(x1 * x2) 240 | 241 | out = fuse + residual 242 | 243 | return out 244 | 245 | 246 | class FeatureFusion(nn.Module): 247 | def __init__(self, dim, reduction=1, sr_ratio=1, num_heads=None, norm_layer=nn.BatchNorm2d): 248 | super().__init__() 249 | self.cross = FeatureInteraction(dim=dim, reduction=reduction, num_heads=num_heads, sr_ratio=sr_ratio) 250 | self.channel_emb = ChannelEmbed(in_channels=dim * 2, out_channels=dim, reduction=reduction, 251 | norm_layer=norm_layer) 252 | self.apply(self._init_weights) 253 | 254 | @classmethod 255 | def _init_weights(cls, m): 256 | if isinstance(m, nn.Linear): 257 | trunc_normal_(m.weight, std=.02) 258 | if isinstance(m, nn.Linear) and m.bias is not None: 259 | nn.init.constant_(m.bias, 0) 260 | elif isinstance(m, nn.LayerNorm): 261 | nn.init.constant_(m.bias, 0) 262 | nn.init.constant_(m.weight, 1.0) 263 | elif isinstance(m, nn.Conv2d): 264 | fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 265 | fan_out //= m.groups 266 | m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) 267 | if m.bias is not None: 268 | m.bias.data.zero_() 269 | 270 | def forward(self, x1, x2): 271 | B, C, H, W = x1.shape 272 | # B C (HW)->B N(HW) C 273 | x1 = x1.flatten(2).transpose(1, 2) 274 | x2 = x2.flatten(2).transpose(1, 2) 275 | # B N(HW) C 276 | x1, x2 = self.cross(x1, x2, H, W) 277 | # B N(HW) 2C 278 | fuse = torch.cat((x1, x2), dim=-1) 279 | # B C H W 280 | fuse = self.channel_emb(fuse, H, W) 281 | 282 | return fuse 283 | 284 | 285 | class FeatureFusion2(nn.Module): 286 | def __init__(self, dim, reduction=1, sr_ratio=1, num_heads=None, norm_layer=nn.BatchNorm2d): 287 | super().__init__() 288 | self.cross = FeatureInteraction(dim=dim, reduction=reduction, num_heads=num_heads, sr_ratio=sr_ratio) 289 | self.channel_emb = ChannelEmbed2(in_channels=dim, out_channels=dim, norm_layer=norm_layer) 290 | self.apply(self._init_weights) 291 | 292 | @classmethod 293 | def _init_weights(cls, m): 294 | if isinstance(m, nn.Linear): 295 | trunc_normal_(m.weight, std=.02) 296 | if isinstance(m, nn.Linear) and m.bias is not None: 297 | nn.init.constant_(m.bias, 0) 298 | elif isinstance(m, nn.LayerNorm): 299 | nn.init.constant_(m.bias, 0) 300 | nn.init.constant_(m.weight, 1.0) 301 | elif isinstance(m, nn.Conv2d): 302 | fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 303 | fan_out //= m.groups 304 | m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) 305 | if m.bias is not None: 306 | m.bias.data.zero_() 307 | 308 | def forward(self, x1, x2): 309 | B, C, H, W = x1.shape 310 | # B C (HW)->B N(HW) C 311 | x1 = x1.flatten(2).transpose(1, 2) 312 | x2 = x2.flatten(2).transpose(1, 2) 313 | # B N(HW) C 314 | x1, x2 = self.cross(x1, x2, H, W) 315 | # B C H W 316 | fuse = self.channel_emb(x1, x2, H, W) 317 | 318 | return fuse 319 | 320 | 321 | if __name__ == '__main__': 322 | model = FeatureFusion(64, num_heads=1, sr_ratio=4).cuda() 323 | 324 | left = torch.randn(1, 64, 64, 64).cuda() 325 | right = torch.randn(1, 64, 64, 64).cuda() 326 | out = model(left, right) 327 | 328 | # summary(model, [(4, 256, 256), (1, 256, 256)]) 329 | flops, params = profile(model, (left, right), verbose=False) 330 | 331 | flops = flops * 2 332 | flops, params = clever_format([flops, params], "%.3f") 333 | print('Total GFLOPS: %s' % flops) 334 | print('Total params: %s' % params) 335 | --------------------------------------------------------------------------------