├── README.md ├── build_sam_feat_seg_model.py ├── images └── overal.png ├── models ├── EVC.py ├── SamFeatSeg.py ├── __init__.py ├── __pycache__ │ ├── EVC.cpython-310.pyc │ ├── EVC.cpython-37.pyc │ ├── SamFeatSeg.cpython-310.pyc │ ├── SamFeatSeg.cpython-37.pyc │ ├── __init__.cpython-310.pyc │ └── __init__.cpython-37.pyc ├── mamba_ssm │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-310.pyc │ │ └── __init__.cpython-37.pyc │ ├── models │ │ ├── __init__.py │ │ ├── __pycache__ │ │ │ ├── __init__.cpython-310.pyc │ │ │ └── mixer_seq_simple.cpython-310.pyc │ │ └── mixer_seq_simple.py │ ├── modules │ │ ├── __init__.py │ │ ├── __pycache__ │ │ │ ├── __init__.cpython-310.pyc │ │ │ └── mamba_simple.cpython-310.pyc │ │ └── mamba_simple.py │ ├── ops │ │ ├── __init__.py │ │ ├── __pycache__ │ │ │ ├── __init__.cpython-310.pyc │ │ │ ├── __init__.cpython-37.pyc │ │ │ ├── selective_scan_interface.cpython-310.pyc │ │ │ └── selective_scan_interface.cpython-37.pyc │ │ ├── selective_scan_interface.py │ │ └── triton │ │ │ ├── __init__.py │ │ │ ├── layernorm.py │ │ │ └── selective_state_update.py │ └── utils │ │ ├── __init__.py │ │ ├── generation.py │ │ └── hf.py └── ops_dcnv3 │ ├── DCNv3.egg-info │ ├── PKG-INFO │ ├── SOURCES.txt │ ├── dependency_links.txt │ └── top_level.txt │ ├── build │ ├── lib.linux-x86_64-3.10 │ │ ├── DCNv3.cpython-310-x86_64-linux-gnu.so │ │ ├── functions │ │ │ ├── __init__.py │ │ │ └── dcnv3_func.py │ │ └── modules │ │ │ ├── __init__.py │ │ │ └── dcnv3.py │ ├── lib.linux-x86_64-3.7 │ │ ├── DCNv3.cpython-37m-x86_64-linux-gnu.so │ │ ├── functions │ │ │ ├── __init__.py │ │ │ └── dcnv3_func.py │ │ └── modules │ │ │ ├── __init__.py │ │ │ └── dcnv3.py │ ├── temp.linux-x86_64-3.10 │ │ ├── build.ninja │ │ └── home │ │ │ └── zbf │ │ │ └── Desktop │ │ │ └── code │ │ │ └── zbfbackbone │ │ │ └── models │ │ │ └── ops_dcnv3 │ │ │ └── src │ │ │ ├── cpu │ │ │ └── dcnv3_cpu.o │ │ │ ├── cuda │ │ │ └── dcnv3_cuda.o │ │ │ └── vision.o │ └── temp.linux-x86_64-3.7 │ │ ├── build.ninja │ │ └── home │ │ └── zbf │ │ └── lab │ │ └── remote_all │ │ └── InternImage-master │ │ └── segmentation │ │ └── ops_dcnv3 │ │ └── src │ │ ├── cpu │ │ └── dcnv3_cpu.o │ │ ├── cuda │ │ └── dcnv3_cuda.o │ │ └── vision.o │ ├── dist │ ├── DCNv3-1.0-py3.10-linux-x86_64.egg │ └── DCNv3-1.0-py3.7-linux-x86_64.egg │ ├── functions │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-310.pyc │ │ ├── __init__.cpython-37.pyc │ │ ├── dcnv3_func.cpython-310.pyc │ │ └── dcnv3_func.cpython-37.pyc │ └── dcnv3_func.py │ ├── make.sh │ ├── modules │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-310.pyc │ │ ├── __init__.cpython-37.pyc │ │ ├── dcnv3.cpython-310.pyc │ │ └── dcnv3.cpython-37.pyc │ └── dcnv3.py │ ├── setup.py │ ├── src │ ├── cpu │ │ ├── dcnv3_cpu.cpp │ │ └── dcnv3_cpu.h │ ├── cuda │ │ ├── dcnv3_cuda.cu │ │ ├── dcnv3_cuda.h │ │ └── dcnv3_im2col_cuda.cuh │ ├── dcnv3.h │ └── vision.cpp │ └── test.py └── requirements.txt /README.md: -------------------------------------------------------------------------------- 1 | 2 | # DeMambaNet: Deformable Convolution and Mamba Integration Network for High-Precision Segmentation of Ambiguously Defined Dental Radicular Boundaries 3 | 4 | # Methods 5 |
6 | 7 |
8 |

9 | Figure 1: Structure of the DeMambaNet. 10 |

11 | 12 | # Install 13 | - Compile CUDA operators 14 | ```bash 15 | cd ./ops_dcnv3 16 | sh ./make.sh 17 | # unit test (should see all checking is True) 18 | python test.py 19 | ``` 20 | - You can also install the operator using .whl files 21 | [DCNv3-1.0-whl](https://github.com/OpenGVLab/InternImage/releases/tag/whl_files) 22 | 23 | - For [mamba](https://github.com/state-spaces/mamba): MAMBA-SSM and causal conv1d need to be installed, you can view the original github to install. 24 | 25 | - This code uses versions of torch and cuda 26 | ```bash 27 | pip install -r requirements.txt 28 | ``` 29 | # test DeMambaNet 30 | ```bash 31 | python build_sam_feat_seg_model.py 32 | ``` 33 | -------------------------------------------------------------------------------- /build_sam_feat_seg_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from models.SamFeatSeg import SamFeatSeg, InternImage, MambaEncoder,SegDecoderCNN_add 4 | 5 | def _build_feat_seg_model(num_classes=2): 6 | DrcM = SamFeatSeg( 7 | 8 | first=InternImage(channels=48), 9 | second=MambaEncoder(), 10 | seg_decoder_bl=SegDecoderCNN_add(), 11 | num_classes=num_classes, 12 | ) 13 | return DrcM 14 | 15 | 16 | 17 | 18 | model = _build_feat_seg_model( num_classes=2 ) 19 | model = model.cuda() 20 | inputs = torch.randn(2, 3, 320, 320).cuda() 21 | print(inputs.shape) 22 | output = model(inputs) 23 | print(output.shape) 24 | -------------------------------------------------------------------------------- /images/overal.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IMOP-lab/DeMambaNet/4576667c8e388fa79e1e200468ae3669b34e8901/images/overal.png -------------------------------------------------------------------------------- /models/EVC.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- encoding: utf-8 -*- 3 | # Copyright (c) 2014-2021 Megvii Inc. All rights reserved. 4 | 5 | import torch 6 | from torch import nn 7 | from torch.nn import functional as F 8 | 9 | from functools import partial 10 | 11 | from timm.models.layers import DropPath, trunc_normal_ 12 | 13 | 14 | # LVC 15 | class Encoding(nn.Module): 16 | def __init__(self, in_channels, num_codes): 17 | super(Encoding, self).__init__() 18 | # init codewords and smoothing factor 19 | self.in_channels, self.num_codes = in_channels, num_codes 20 | num_codes = 64 21 | std = 1. / ((num_codes * in_channels)**0.5) 22 | # [num_codes, channels] 23 | self.codewords = nn.Parameter( 24 | torch.empty(num_codes, in_channels, dtype=torch.float).uniform_(-std, std), requires_grad=True) 25 | # [num_codes] 26 | self.scale = nn.Parameter(torch.empty(num_codes, dtype=torch.float).uniform_(-1, 0), requires_grad=True) 27 | 28 | @staticmethod 29 | def scaled_l2(x, codewords, scale): 30 | num_codes, in_channels = codewords.size() 31 | b = x.size(0) 32 | expanded_x = x.unsqueeze(2).expand((b, x.size(1), num_codes, in_channels)) 33 | 34 | # ---处理codebook (num_code, c1) 35 | reshaped_codewords = codewords.view((1, 1, num_codes, in_channels)) 36 | 37 | # 把scale从1, num_code变成 batch, c2, N, num_codes 38 | reshaped_scale = scale.view((1, 1, num_codes)) # N, num_codes 39 | 40 | # ---计算rik = z1 - d # b, N, num_codes 41 | scaled_l2_norm = reshaped_scale * (expanded_x - reshaped_codewords).pow(2).sum(dim=3) 42 | return scaled_l2_norm 43 | 44 | @staticmethod 45 | def aggregate(assignment_weights, x, codewords): 46 | num_codes, in_channels = codewords.size() 47 | 48 | # ---处理codebook 49 | reshaped_codewords = codewords.view((1, 1, num_codes, in_channels)) 50 | b = x.size(0) 51 | 52 | # ---处理特征向量x b, c1, N 53 | expanded_x = x.unsqueeze(2).expand((b, x.size(1), num_codes, in_channels)) 54 | 55 | #变换rei b, N, num_codes,- 56 | assignment_weights = assignment_weights.unsqueeze(3) # b, N, num_codes, 57 | 58 | # ---开始计算eik,必须在Rei计算完之后 59 | encoded_feat = (assignment_weights * (expanded_x - reshaped_codewords)).sum(1) 60 | return encoded_feat 61 | 62 | def forward(self, x): 63 | assert x.dim() == 4 and x.size(1) == self.in_channels 64 | b, in_channels, w, h = x.size() 65 | 66 | # [batch_size, height x width, channels] 67 | x = x.view(b, self.in_channels, -1).transpose(1, 2).contiguous() 68 | 69 | # assignment_weights: [batch_size, channels, num_codes] 70 | assignment_weights = F.softmax(self.scaled_l2(x, self.codewords, self.scale), dim=2) 71 | 72 | # aggregate 73 | encoded_feat = self.aggregate(assignment_weights, x, self.codewords) 74 | return encoded_feat 75 | 76 | 77 | # 1*1 3*3 1*1 78 | class ConvBlock(nn.Module): 79 | def __init__(self, in_channels, out_channels, stride=1, res_conv=False, act_layer=nn.ReLU, groups=1, 80 | norm_layer=partial(nn.BatchNorm2d, eps=1e-6), drop_block=None, drop_path=None): 81 | super(ConvBlock, self).__init__() 82 | self.in_channels = in_channels 83 | expansion = 4 84 | c = out_channels // expansion 85 | 86 | self.conv1 = nn.Conv2d(in_channels, c, kernel_size=1, stride=1, padding=0, bias=False) # [64, 256, 1, 1] 87 | self.bn1 = norm_layer(c) 88 | self.act1 = act_layer(inplace=True) 89 | 90 | self.conv2 = nn.Conv2d(c, c, kernel_size=3, stride=stride, groups=groups, padding=1, bias=False) 91 | self.bn2 = norm_layer(c) 92 | self.act2 = act_layer(inplace=True) 93 | 94 | self.conv3 = nn.Conv2d(c, out_channels, kernel_size=1, stride=1, padding=0, bias=False) 95 | self.bn3 = norm_layer(out_channels) 96 | self.act3 = act_layer(inplace=True) 97 | 98 | if res_conv: 99 | self.residual_conv = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0, bias=False) 100 | self.residual_bn = norm_layer(out_channels) 101 | 102 | self.res_conv = res_conv 103 | self.drop_block = drop_block 104 | self.drop_path = drop_path 105 | 106 | def zero_init_last_bn(self): 107 | nn.init.zeros_(self.bn3.weight) 108 | 109 | def forward(self, x, return_x_2=True): 110 | residual = x 111 | 112 | x = self.conv1(x) 113 | x = self.bn1(x) 114 | if self.drop_block is not None: 115 | x = self.drop_block(x) 116 | x = self.act1(x) 117 | 118 | x = self.conv2(x) #if x_t_r is None else self.conv2(x + x_t_r) 119 | x = self.bn2(x) 120 | if self.drop_block is not None: 121 | x = self.drop_block(x) 122 | x2 = self.act2(x) 123 | 124 | x = self.conv3(x2) 125 | x = self.bn3(x) 126 | if self.drop_block is not None: 127 | x = self.drop_block(x) 128 | 129 | if self.drop_path is not None: 130 | x = self.drop_path(x) 131 | 132 | if self.res_conv: 133 | residual = self.residual_conv(residual) 134 | residual = self.residual_bn(residual) 135 | 136 | x += residual 137 | x = self.act3(x) 138 | 139 | if return_x_2: 140 | return x, x2 141 | else: 142 | return x 143 | 144 | 145 | class Mean(nn.Module): 146 | def __init__(self, dim, keep_dim=False): 147 | super(Mean, self).__init__() 148 | self.dim = dim 149 | self.keep_dim = keep_dim 150 | 151 | def forward(self, input): 152 | return input.mean(self.dim, self.keep_dim) 153 | 154 | 155 | class Mlp(nn.Module): 156 | """ 157 | Implementation of MLP with 1*1 convolutions. Input: tensor with shape [B, C, H, W] 158 | """ 159 | def __init__(self, in_features, hidden_features=None, 160 | out_features=None, act_layer=nn.GELU, drop=0.): 161 | super().__init__() 162 | out_features = out_features or in_features 163 | hidden_features = hidden_features or in_features 164 | self.fc1 = nn.Conv2d(in_features, hidden_features, 1) 165 | self.act = act_layer() 166 | self.fc2 = nn.Conv2d(hidden_features, out_features, 1) 167 | self.drop = nn.Dropout(drop) 168 | self.apply(self._init_weights) 169 | 170 | def _init_weights(self, m): 171 | if isinstance(m, nn.Conv2d): 172 | trunc_normal_(m.weight, std=.02) 173 | if m.bias is not None: 174 | nn.init.constant_(m.bias, 0) 175 | 176 | def forward(self, x): 177 | x = self.fc1(x) 178 | x = self.act(x) 179 | x = self.drop(x) 180 | x = self.fc2(x) 181 | x = self.drop(x) 182 | return x 183 | 184 | 185 | 186 | 187 | class GroupNorm(nn.GroupNorm): 188 | """ 189 | Group Normalization with 1 group. 190 | Input: tensor in shape [B, C, H, W] 191 | """ 192 | def __init__(self, num_channels, **kwargs): 193 | super().__init__(1, num_channels, **kwargs) 194 | 195 | 196 | def get_activation(name="silu", inplace=True): 197 | if name == "silu": 198 | module = nn.SiLU(inplace=inplace) 199 | elif name == "relu": 200 | module = nn.ReLU(inplace=inplace) 201 | elif name == "lrelu": 202 | module = nn.LeakyReLU(0.1, inplace=inplace) 203 | else: 204 | raise AttributeError("Unsupported act type: {}".format(name)) 205 | return module 206 | 207 | 208 | class BaseConv(nn.Module): 209 | """A Conv2d -> Batchnorm -> silu/leaky relu block""" # CBL 210 | 211 | def __init__( 212 | self, in_channels, out_channels, ksize, stride, groups=1, bias=False, act="silu" 213 | ): 214 | super().__init__() 215 | # same padding 216 | pad = (ksize - 1) // 2 217 | self.conv = nn.Conv2d( 218 | in_channels, 219 | out_channels, 220 | kernel_size=ksize, 221 | stride=stride, 222 | padding=pad, 223 | groups=groups, 224 | bias=bias, 225 | ) 226 | self.bn = nn.BatchNorm2d(out_channels) 227 | self.act = get_activation(act, inplace=True) 228 | 229 | def forward(self, x): 230 | return self.act(self.bn(self.conv(x))) 231 | 232 | def fuseforward(self, x): 233 | return self.act(self.conv(x)) 234 | 235 | 236 | class DWConv(nn.Module): 237 | """Depthwise Conv + Conv""" 238 | def __init__(self, in_channels, out_channels, ksize, stride=1, act="silu"): 239 | super().__init__() 240 | self.dconv = BaseConv( 241 | in_channels, 242 | in_channels, 243 | ksize=ksize, 244 | stride=stride, 245 | groups=in_channels, 246 | act=act, 247 | ) 248 | self.pconv = BaseConv( 249 | in_channels, out_channels, ksize=1, stride=1, groups=1, act=act 250 | ) 251 | 252 | def forward(self, x): 253 | x = self.dconv(x) 254 | return self.pconv(x) 255 | 256 | 257 | 258 | class LVCBlock(nn.Module): 259 | def __init__(self, in_channels, out_channels, num_codes, channel_ratio=0.25, base_channel=64): 260 | super(LVCBlock, self).__init__() 261 | self.out_channels = out_channels 262 | self.num_codes = num_codes 263 | num_codes = 64 264 | 265 | self.conv_1 = ConvBlock(in_channels=in_channels, out_channels=in_channels, res_conv=True, stride=1) 266 | 267 | self.LVC = nn.Sequential( 268 | nn.Conv2d(in_channels, in_channels, 1, bias=False), 269 | nn.BatchNorm2d(in_channels), 270 | nn.ReLU(inplace=True), 271 | Encoding(in_channels=in_channels, num_codes=num_codes), 272 | nn.BatchNorm1d(num_codes), 273 | nn.ReLU(inplace=True), 274 | Mean(dim=1)) 275 | self.fc = nn.Sequential(nn.Linear(in_channels, in_channels), nn.Sigmoid()) 276 | 277 | def forward(self, x): 278 | x = self.conv_1(x, return_x_2=False) 279 | en = self.LVC(x) 280 | gam = self.fc(en) 281 | b, in_channels, _, _ = x.size() 282 | y = gam.view(b, in_channels, 1, 1) 283 | x = F.relu_(x + x * y) 284 | return x 285 | 286 | 287 | # LightMLPBlock 288 | class LightMLPBlock(nn.Module): 289 | def __init__(self, in_channels, out_channels, ksize=1, stride=1, act="silu", 290 | mlp_ratio=4., drop=0., act_layer=nn.GELU, 291 | use_layer_scale=True, layer_scale_init_value=1e-5, drop_path=0., norm_layer=GroupNorm): # act_layer=nn.GELU, 292 | super().__init__() 293 | self.dw = DWConv(in_channels, out_channels, ksize=1, stride=1, act="silu") 294 | self.linear = nn.Linear(out_channels, out_channels) # learnable position embedding 295 | self.out_channels = out_channels 296 | 297 | self.norm1 = norm_layer(in_channels) 298 | self.norm2 = norm_layer(in_channels) 299 | 300 | mlp_hidden_dim = int(in_channels * mlp_ratio) 301 | self.mlp = Mlp(in_features=in_channels, hidden_features=mlp_hidden_dim, act_layer=nn.GELU, 302 | drop=drop) 303 | 304 | self.drop_path = DropPath(drop_path) if drop_path > 0. \ 305 | else nn.Identity() 306 | 307 | self.use_layer_scale = use_layer_scale 308 | if use_layer_scale: 309 | self.layer_scale_1 = nn.Parameter( 310 | layer_scale_init_value * torch.ones((out_channels)), requires_grad=True) 311 | self.layer_scale_2 = nn.Parameter( 312 | layer_scale_init_value * torch.ones((out_channels)), requires_grad=True) 313 | 314 | def forward(self, x): 315 | if self.use_layer_scale: 316 | x = x + self.drop_path(self.layer_scale_1.unsqueeze(-1).unsqueeze(-1) * self.dw(self.norm1(x))) 317 | x = x + self.drop_path(self.layer_scale_2.unsqueeze(-1).unsqueeze(-1) * self.mlp(self.norm2(x))) 318 | else: 319 | x = x + self.drop_path(self.dw(self.norm1(x))) 320 | x = x + self.drop_path(self.mlp(self.norm2(x))) 321 | return x 322 | 323 | 324 | # EVCBlock 325 | class EVCBlock(nn.Module): 326 | def __init__(self, in_channels, out_channels, channel_ratio=4, base_channel=16): 327 | super().__init__() 328 | expansion = 2 329 | ch = in_channels * expansion 330 | # Stem stage: get the feature maps by conv block (copied form resnet.py) 进入conformer框架之前的处理 331 | self.conv1 = nn.Conv2d(in_channels, in_channels, kernel_size=7, stride=1, padding=3, bias=False) # 1 / 2 [112, 112] 332 | self.bn1 = nn.BatchNorm2d(in_channels) 333 | self.act1 = nn.ReLU(inplace=True) 334 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=1, padding=1) # 1 / 4 [56, 56] 335 | 336 | # LVC 337 | self.lvc = LVCBlock(in_channels=in_channels, out_channels=in_channels, num_codes=64) # c1值暂时未定 338 | # LightMLPBlock 339 | self.l_MLP = LightMLPBlock(in_channels, in_channels, ksize=1, stride=1, act="silu", act_layer=nn.GELU, mlp_ratio=4., drop=0., 340 | use_layer_scale=True, layer_scale_init_value=1e-5, drop_path=0., norm_layer=GroupNorm) 341 | self.cnv1 = nn.Conv2d(ch, out_channels, kernel_size=1, stride=1, padding=0) 342 | 343 | def forward(self, x): 344 | x1 = self.maxpool(self.act1(self.bn1(self.conv1(x)))) 345 | # LVCBlock 346 | x_lvc = self.lvc(x1) 347 | # LightMLPBlock 348 | x_lmlp = self.l_MLP(x1) 349 | # concat 350 | x = torch.cat((x_lvc, x_lmlp), dim=1) 351 | x = self.cnv1(x) 352 | return x 353 | 354 | 355 | if __name__ == '__main__': 356 | block = EVCBlock(256, 128,) 357 | input = torch.rand(1, 256, 64, 64) 358 | output = block(input) 359 | print(input.size(), output.size()) 360 | 361 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | from .SamFeatSeg import SamFeatSeg 2 | # from .build_sam_feat_seg_model import sam_feat_seg_model_registry 3 | -------------------------------------------------------------------------------- /models/__pycache__/EVC.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IMOP-lab/DeMambaNet/4576667c8e388fa79e1e200468ae3669b34e8901/models/__pycache__/EVC.cpython-310.pyc -------------------------------------------------------------------------------- /models/__pycache__/EVC.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IMOP-lab/DeMambaNet/4576667c8e388fa79e1e200468ae3669b34e8901/models/__pycache__/EVC.cpython-37.pyc -------------------------------------------------------------------------------- /models/__pycache__/SamFeatSeg.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IMOP-lab/DeMambaNet/4576667c8e388fa79e1e200468ae3669b34e8901/models/__pycache__/SamFeatSeg.cpython-310.pyc -------------------------------------------------------------------------------- /models/__pycache__/SamFeatSeg.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IMOP-lab/DeMambaNet/4576667c8e388fa79e1e200468ae3669b34e8901/models/__pycache__/SamFeatSeg.cpython-37.pyc -------------------------------------------------------------------------------- /models/__pycache__/__init__.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IMOP-lab/DeMambaNet/4576667c8e388fa79e1e200468ae3669b34e8901/models/__pycache__/__init__.cpython-310.pyc -------------------------------------------------------------------------------- /models/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IMOP-lab/DeMambaNet/4576667c8e388fa79e1e200468ae3669b34e8901/models/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /models/mamba_ssm/__init__.py: -------------------------------------------------------------------------------- 1 | __version__ = "1.0.1" 2 | 3 | from .ops.selective_scan_interface import selective_scan_fn, mamba_inner_fn, bimamba_inner_fn 4 | from .modules.mamba_simple import Mamba 5 | from .models.mixer_seq_simple import MambaLMHeadModel 6 | -------------------------------------------------------------------------------- /models/mamba_ssm/__pycache__/__init__.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IMOP-lab/DeMambaNet/4576667c8e388fa79e1e200468ae3669b34e8901/models/mamba_ssm/__pycache__/__init__.cpython-310.pyc -------------------------------------------------------------------------------- /models/mamba_ssm/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IMOP-lab/DeMambaNet/4576667c8e388fa79e1e200468ae3669b34e8901/models/mamba_ssm/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /models/mamba_ssm/models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IMOP-lab/DeMambaNet/4576667c8e388fa79e1e200468ae3669b34e8901/models/mamba_ssm/models/__init__.py -------------------------------------------------------------------------------- /models/mamba_ssm/models/__pycache__/__init__.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IMOP-lab/DeMambaNet/4576667c8e388fa79e1e200468ae3669b34e8901/models/mamba_ssm/models/__pycache__/__init__.cpython-310.pyc -------------------------------------------------------------------------------- /models/mamba_ssm/models/__pycache__/mixer_seq_simple.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IMOP-lab/DeMambaNet/4576667c8e388fa79e1e200468ae3669b34e8901/models/mamba_ssm/models/__pycache__/mixer_seq_simple.cpython-310.pyc -------------------------------------------------------------------------------- /models/mamba_ssm/models/mixer_seq_simple.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023, Albert Gu, Tri Dao. 2 | 3 | import math 4 | from functools import partial 5 | 6 | from collections import namedtuple 7 | 8 | import torch 9 | import torch.nn as nn 10 | 11 | from mamba_ssm.modules.mamba_simple import Mamba, Block 12 | from mamba_ssm.utils.generation import GenerationMixin 13 | from mamba_ssm.utils.hf import load_config_hf, load_state_dict_hf 14 | 15 | try: 16 | from mamba_ssm.ops.triton.layernorm import RMSNorm, layer_norm_fn, rms_norm_fn 17 | except ImportError: 18 | RMSNorm, layer_norm_fn, rms_norm_fn = None, None, None 19 | 20 | 21 | def create_block( 22 | d_model, 23 | ssm_cfg=None, 24 | norm_epsilon=1e-5, 25 | rms_norm=False, 26 | residual_in_fp32=False, 27 | fused_add_norm=False, 28 | layer_idx=None, 29 | device=None, 30 | dtype=None, 31 | ): 32 | if ssm_cfg is None: 33 | ssm_cfg = {} 34 | factory_kwargs = {"device": device, "dtype": dtype} 35 | mixer_cls = partial(Mamba, layer_idx=layer_idx, **ssm_cfg, **factory_kwargs) 36 | norm_cls = partial( 37 | nn.LayerNorm if not rms_norm else RMSNorm, eps=norm_epsilon, **factory_kwargs 38 | ) 39 | block = Block( 40 | d_model, 41 | mixer_cls, 42 | norm_cls=norm_cls, 43 | fused_add_norm=fused_add_norm, 44 | residual_in_fp32=residual_in_fp32, 45 | ) 46 | block.layer_idx = layer_idx 47 | return block 48 | 49 | 50 | # https://github.com/huggingface/transformers/blob/c28d04e9e252a1a099944e325685f14d242ecdcd/src/transformers/models/gpt2/modeling_gpt2.py#L454 51 | def _init_weights( 52 | module, 53 | n_layer, 54 | initializer_range=0.02, # Now only used for embedding layer. 55 | rescale_prenorm_residual=True, 56 | n_residuals_per_layer=1, # Change to 2 if we have MLP 57 | ): 58 | if isinstance(module, nn.Linear): 59 | if module.bias is not None: 60 | if not getattr(module.bias, "_no_reinit", False): 61 | nn.init.zeros_(module.bias) 62 | elif isinstance(module, nn.Embedding): 63 | nn.init.normal_(module.weight, std=initializer_range) 64 | 65 | if rescale_prenorm_residual: 66 | # Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme: 67 | # > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale 68 | # > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers. 69 | # > -- GPT-2 :: https://openai.com/blog/better-language-models/ 70 | # 71 | # Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py 72 | for name, p in module.named_parameters(): 73 | if name in ["out_proj.weight", "fc2.weight"]: 74 | # Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block 75 | # Following Pytorch init, except scale by 1/sqrt(2 * n_layer) 76 | # We need to reinit p since this code could be called multiple times 77 | # Having just p *= scale would repeatedly scale it down 78 | nn.init.kaiming_uniform_(p, a=math.sqrt(5)) 79 | with torch.no_grad(): 80 | p /= math.sqrt(n_residuals_per_layer * n_layer) 81 | 82 | 83 | class MixerModel(nn.Module): 84 | def __init__( 85 | self, 86 | d_model: int, 87 | n_layer: int, 88 | vocab_size: int, 89 | ssm_cfg=None, 90 | norm_epsilon: float = 1e-5, 91 | rms_norm: bool = False, 92 | initializer_cfg=None, 93 | fused_add_norm=False, 94 | residual_in_fp32=False, 95 | device=None, 96 | dtype=None, 97 | ) -> None: 98 | factory_kwargs = {"device": device, "dtype": dtype} 99 | super().__init__() 100 | self.residual_in_fp32 = residual_in_fp32 101 | 102 | self.embedding = nn.Embedding(vocab_size, d_model, **factory_kwargs) 103 | 104 | # We change the order of residual and layer norm: 105 | # Instead of LN -> Attn / MLP -> Add, we do: 106 | # Add -> LN -> Attn / MLP / Mixer, returning both the residual branch (output of Add) and 107 | # the main branch (output of MLP / Mixer). The model definition is unchanged. 108 | # This is for performance reason: we can fuse add + layer_norm. 109 | self.fused_add_norm = fused_add_norm 110 | if self.fused_add_norm: 111 | if layer_norm_fn is None or rms_norm_fn is None: 112 | raise ImportError("Failed to import Triton LayerNorm / RMSNorm kernels") 113 | 114 | self.layers = nn.ModuleList( 115 | [ 116 | create_block( 117 | d_model, 118 | ssm_cfg=ssm_cfg, 119 | norm_epsilon=norm_epsilon, 120 | rms_norm=rms_norm, 121 | residual_in_fp32=residual_in_fp32, 122 | fused_add_norm=fused_add_norm, 123 | layer_idx=i, 124 | **factory_kwargs, 125 | ) 126 | for i in range(n_layer) 127 | ] 128 | ) 129 | 130 | self.norm_f = (nn.LayerNorm if not rms_norm else RMSNorm)( 131 | d_model, eps=norm_epsilon, **factory_kwargs 132 | ) 133 | 134 | self.apply( 135 | partial( 136 | _init_weights, 137 | n_layer=n_layer, 138 | **(initializer_cfg if initializer_cfg is not None else {}), 139 | ) 140 | ) 141 | 142 | def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs): 143 | return { 144 | i: layer.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype, **kwargs) 145 | for i, layer in enumerate(self.layers) 146 | } 147 | 148 | def forward(self, input_ids, inference_params=None): 149 | hidden_states = self.embedding(input_ids) 150 | residual = None 151 | for layer in self.layers: 152 | hidden_states, residual = layer( 153 | hidden_states, residual, inference_params=inference_params 154 | ) 155 | if not self.fused_add_norm: 156 | residual = (hidden_states + residual) if residual is not None else hidden_states 157 | hidden_states = self.norm_f(residual.to(dtype=self.norm_f.weight.dtype)) 158 | else: 159 | # Set prenorm=False here since we don't need the residual 160 | fused_add_norm_fn = rms_norm_fn if isinstance(self.norm_f, RMSNorm) else layer_norm_fn 161 | hidden_states = fused_add_norm_fn( 162 | hidden_states, 163 | self.norm_f.weight, 164 | self.norm_f.bias, 165 | eps=self.norm_f.eps, 166 | residual=residual, 167 | prenorm=False, 168 | residual_in_fp32=self.residual_in_fp32, 169 | ) 170 | return hidden_states 171 | 172 | 173 | class MambaLMHeadModel(nn.Module, GenerationMixin): 174 | 175 | def __init__( 176 | self, 177 | d_model: int, 178 | n_layer: int, 179 | vocab_size: int, 180 | initializer_cfg=None, 181 | pad_vocab_size_multiple: int = 1, 182 | device=None, 183 | dtype=None, 184 | **backbone_kwargs, 185 | ) -> None: 186 | factory_kwargs = {"device": device, "dtype": dtype} 187 | super().__init__() 188 | if vocab_size % pad_vocab_size_multiple != 0: 189 | vocab_size += pad_vocab_size_multiple - (vocab_size % pad_vocab_size_multiple) 190 | self.backbone = MixerModel( 191 | d_model=d_model, 192 | n_layer=n_layer, 193 | vocab_size=vocab_size, 194 | initializer_cfg=initializer_cfg, 195 | **backbone_kwargs, 196 | **factory_kwargs, 197 | ) 198 | self.lm_head = nn.Linear(d_model, vocab_size, bias=False, **factory_kwargs) 199 | 200 | # Initialize weights and apply final processing 201 | self.apply( 202 | partial( 203 | _init_weights, 204 | n_layer=n_layer, 205 | **(initializer_cfg if initializer_cfg is not None else {}), 206 | ) 207 | ) 208 | self.tie_weights() 209 | 210 | def tie_weights(self): 211 | self.lm_head.weight = self.backbone.embedding.weight 212 | 213 | def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs): 214 | return self.backbone.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype, **kwargs) 215 | 216 | def forward(self, input_ids, position_ids=None, inference_params=None, num_last_tokens=0): 217 | """ 218 | "position_ids" is just to be compatible with Transformer generation. We don't use it. 219 | num_last_tokens: if > 0, only return the logits for the last n tokens 220 | """ 221 | hidden_states = self.backbone(input_ids, inference_params=inference_params) 222 | if num_last_tokens > 0: 223 | hidden_states = hidden_states[:, -num_last_tokens:] 224 | lm_logits = self.lm_head(hidden_states) 225 | CausalLMOutput = namedtuple("CausalLMOutput", ["logits"]) 226 | return CausalLMOutput(logits=lm_logits) 227 | 228 | @classmethod 229 | def from_pretrained(cls, pretrained_model_name, device=None, dtype=None, **kwargs): 230 | config = load_config_hf(pretrained_model_name) 231 | model = cls(**config, device=device, dtype=dtype, **kwargs) 232 | model.load_state_dict(load_state_dict_hf(pretrained_model_name, device=device, dtype=dtype)) 233 | return model 234 | -------------------------------------------------------------------------------- /models/mamba_ssm/modules/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IMOP-lab/DeMambaNet/4576667c8e388fa79e1e200468ae3669b34e8901/models/mamba_ssm/modules/__init__.py -------------------------------------------------------------------------------- /models/mamba_ssm/modules/__pycache__/__init__.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IMOP-lab/DeMambaNet/4576667c8e388fa79e1e200468ae3669b34e8901/models/mamba_ssm/modules/__pycache__/__init__.cpython-310.pyc -------------------------------------------------------------------------------- /models/mamba_ssm/modules/__pycache__/mamba_simple.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IMOP-lab/DeMambaNet/4576667c8e388fa79e1e200468ae3669b34e8901/models/mamba_ssm/modules/__pycache__/mamba_simple.cpython-310.pyc -------------------------------------------------------------------------------- /models/mamba_ssm/modules/mamba_simple.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023, Tri Dao, Albert Gu. 2 | 3 | import math 4 | from typing import Optional 5 | 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | from torch import Tensor 10 | 11 | from einops import rearrange, repeat 12 | 13 | try: 14 | from causal_conv1d import causal_conv1d_fn, causal_conv1d_update 15 | except ImportError: 16 | causal_conv1d_fn, causal_conv1d_update = None 17 | 18 | # try: 19 | # from mamba_ssm.ops.selective_scan_interface import selective_scan_fn, mamba_inner_fn, bimamba_inner_fn, mamba_inner_fn_no_out_proj 20 | # except ImportError: 21 | # selective_scan_fn, mamba_inner_fn, bimamba_inner_fn, mamba_inner_fn_no_out_proj = None, None, None, None, None 22 | from mamba_ssm.ops.selective_scan_interface import selective_scan_fn, mamba_inner_fn, mamba_inner_fn_no_out_proj 23 | 24 | 25 | try: 26 | from mamba_ssm.ops.triton.selective_state_update import selective_state_update 27 | except ImportError: 28 | selective_state_update = None 29 | 30 | try: 31 | from mamba_ssm.ops.triton.layernorm import RMSNorm, layer_norm_fn, rms_norm_fn 32 | except ImportError: 33 | RMSNorm, layer_norm_fn, rms_norm_fn = None, None, None 34 | 35 | 36 | class Mamba(nn.Module): 37 | def __init__( 38 | self, 39 | d_model, 40 | d_state=16, 41 | d_conv=4, 42 | expand=2, 43 | dt_rank="auto", 44 | dt_min=0.001, 45 | dt_max=0.1, 46 | dt_init="random", 47 | dt_scale=1.0, 48 | dt_init_floor=1e-4, 49 | conv_bias=True, 50 | bias=False, 51 | use_fast_path=True, # Fused kernel options 52 | layer_idx=None, 53 | device=None, 54 | dtype=None, 55 | bimamba_type="none", 56 | nslices=5 57 | ): 58 | factory_kwargs = {"device": device, "dtype": dtype} 59 | super().__init__() 60 | self.d_model = d_model 61 | self.d_state = d_state 62 | self.d_conv = d_conv 63 | self.expand = expand 64 | self.d_inner = int(self.expand * self.d_model) 65 | self.dt_rank = math.ceil(self.d_model / 16) if dt_rank == "auto" else dt_rank 66 | self.use_fast_path = use_fast_path 67 | self.layer_idx = layer_idx 68 | self.bimamba_type = bimamba_type 69 | self.nslices = nslices 70 | 71 | self.in_proj = nn.Linear(self.d_model, self.d_inner * 2, bias=bias, **factory_kwargs) 72 | 73 | self.conv1d = nn.Conv1d( 74 | in_channels=self.d_inner, 75 | out_channels=self.d_inner, 76 | bias=conv_bias, 77 | kernel_size=d_conv, 78 | groups=self.d_inner, 79 | padding=d_conv - 1, 80 | **factory_kwargs, 81 | ) 82 | 83 | self.activation = "silu" 84 | self.act = nn.SiLU() 85 | 86 | self.x_proj = nn.Linear( 87 | self.d_inner, self.dt_rank + self.d_state * 2, bias=False, **factory_kwargs 88 | ) 89 | self.dt_proj = nn.Linear(self.dt_rank, self.d_inner, bias=True, **factory_kwargs) 90 | 91 | # Initialize special dt projection to preserve variance at initialization 92 | dt_init_std = self.dt_rank**-0.5 * dt_scale 93 | if dt_init == "constant": 94 | nn.init.constant_(self.dt_proj.weight, dt_init_std) 95 | elif dt_init == "random": 96 | nn.init.uniform_(self.dt_proj.weight, -dt_init_std, dt_init_std) 97 | else: 98 | raise NotImplementedError 99 | 100 | # Initialize dt bias so that F.softplus(dt_bias) is between dt_min and dt_max 101 | dt = torch.exp( 102 | torch.rand(self.d_inner, **factory_kwargs) * (math.log(dt_max) - math.log(dt_min)) 103 | + math.log(dt_min) 104 | ).clamp(min=dt_init_floor) 105 | # Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759 106 | inv_dt = dt + torch.log(-torch.expm1(-dt)) 107 | with torch.no_grad(): 108 | self.dt_proj.bias.copy_(inv_dt) 109 | # Our initialization would set all Linear.bias to zero, need to mark this one as _no_reinit 110 | self.dt_proj.bias._no_reinit = True 111 | 112 | # S4D real initialization 113 | A = repeat( 114 | torch.arange(1, self.d_state + 1, dtype=torch.float32, device=device), 115 | "n -> d n", 116 | d=self.d_inner, 117 | ).contiguous() 118 | A_log = torch.log(A) # Keep A_log in fp32 119 | self.A_log = nn.Parameter(A_log) 120 | self.A_log._no_weight_decay = True 121 | 122 | # D "skip" parameter 123 | self.D = nn.Parameter(torch.ones(self.d_inner, device=device)) # Keep in fp32 124 | self.D._no_weight_decay = True 125 | 126 | # bidirectional 127 | assert bimamba_type == "v3" 128 | 129 | A_b = repeat( 130 | torch.arange(1, self.d_state + 1, dtype=torch.float32, device=device), 131 | "n -> d n", 132 | d=self.d_inner, 133 | ).contiguous() 134 | A_b_log = torch.log(A_b) # Keep A_b_log in fp32 135 | self.A_b_log = nn.Parameter(A_b_log) 136 | self.A_b_log._no_weight_decay = True 137 | 138 | self.conv1d_b = nn.Conv1d( 139 | in_channels=self.d_inner, 140 | out_channels=self.d_inner, 141 | bias=conv_bias, 142 | kernel_size=d_conv, 143 | groups=self.d_inner, 144 | padding=d_conv - 1, 145 | **factory_kwargs, 146 | ) 147 | 148 | self.x_proj_b = nn.Linear( 149 | self.d_inner, self.dt_rank + self.d_state * 2, bias=False, **factory_kwargs 150 | ) 151 | self.dt_proj_b = nn.Linear(self.dt_rank, self.d_inner, bias=True, **factory_kwargs) 152 | 153 | self.D_b = nn.Parameter(torch.ones(self.d_inner, device=device)) # Keep in fp32 154 | self.D_b._no_weight_decay = True 155 | 156 | # assert bimamba_type == "v3" 157 | # spatial 158 | A_s = repeat( 159 | torch.arange(1, self.d_state + 1, dtype=torch.float32, device=device), 160 | "n -> d n", 161 | d=self.d_inner, 162 | ).contiguous() 163 | A_s_log = torch.log(A_s) # Keep A_b_log in fp32 164 | self.A_s_log = nn.Parameter(A_s_log) 165 | self.A_s_log._no_weight_decay = True 166 | 167 | self.conv1d_s = nn.Conv1d( 168 | in_channels=self.d_inner, 169 | out_channels=self.d_inner, 170 | bias=conv_bias, 171 | kernel_size=d_conv, 172 | groups=self.d_inner, 173 | padding=d_conv - 1, 174 | **factory_kwargs, 175 | ) 176 | 177 | self.x_proj_s = nn.Linear( 178 | self.d_inner, self.dt_rank + self.d_state * 2, bias=False, **factory_kwargs 179 | ) 180 | self.dt_proj_s = nn.Linear(self.dt_rank, self.d_inner, bias=True, **factory_kwargs) 181 | 182 | self.D_s = nn.Parameter(torch.ones(self.d_inner, device=device)) # Keep in fp32 183 | self.D_s._no_weight_decay = True 184 | 185 | 186 | 187 | 188 | self.out_proj = nn.Linear(self.d_inner, self.d_model, bias=bias, **factory_kwargs) 189 | 190 | def forward(self, hidden_states, inference_params=None): 191 | """ 192 | hidden_states: (B, L, D) 193 | Returns: same shape as hidden_states 194 | """ 195 | batch, seqlen, dim = hidden_states.shape 196 | 197 | conv_state, ssm_state = None, None 198 | if inference_params is not None: 199 | conv_state, ssm_state = self._get_states_from_cache(inference_params, batch) 200 | if inference_params.seqlen_offset > 0: 201 | # The states are updated inplace 202 | out, _, _ = self.step(hidden_states, conv_state, ssm_state) 203 | return out 204 | 205 | # We do matmul and transpose BLH -> HBL at the same time 206 | xz = rearrange( 207 | self.in_proj.weight @ rearrange(hidden_states, "b l d -> d (b l)"), 208 | "d (b l) -> b d l", 209 | l=seqlen, 210 | ) 211 | if self.in_proj.bias is not None: 212 | xz = xz + rearrange(self.in_proj.bias.to(dtype=xz.dtype), "d -> d 1") 213 | 214 | A = -torch.exp(self.A_log.float()) # (d_inner, d_state) 215 | # In the backward pass we write dx and dz next to each other to avoid torch.cat 216 | if self.use_fast_path and inference_params is None: # Doesn't support outputting the states 217 | if self.bimamba_type == "v3": 218 | A_b = -torch.exp(self.A_b_log.float()) 219 | out = mamba_inner_fn_no_out_proj( 220 | xz, 221 | self.conv1d.weight, 222 | self.conv1d.bias, 223 | self.x_proj.weight, 224 | self.dt_proj.weight, 225 | A, 226 | None, # input-dependent B 227 | None, # input-dependent C 228 | self.D.float(), 229 | delta_bias=self.dt_proj.bias.float(), 230 | delta_softplus=True, 231 | ) 232 | out_b = mamba_inner_fn_no_out_proj( 233 | xz.flip([-1]), 234 | self.conv1d_b.weight, 235 | self.conv1d_b.bias, 236 | self.x_proj_b.weight, 237 | self.dt_proj_b.weight, 238 | A_b, 239 | None, 240 | None, 241 | self.D_b.float(), 242 | delta_bias=self.dt_proj_b.bias.float(), 243 | delta_softplus=True, 244 | ) 245 | A_s = -torch.exp(self.A_s_log.float()) 246 | 247 | xz_s = xz.chunk(self.nslices, dim=-1) 248 | xz_s = torch.stack(xz_s,dim=-1) 249 | xz_s = xz_s.flatten(-2) 250 | out_s = mamba_inner_fn_no_out_proj( 251 | xz_s, 252 | self.conv1d_s.weight, 253 | self.conv1d_s.bias, 254 | self.x_proj_s.weight, 255 | self.dt_proj_s.weight, 256 | A_s, 257 | None, 258 | None, 259 | self.D_s.float(), 260 | delta_bias=self.dt_proj_s.bias.float(), 261 | delta_softplus=True, 262 | ) 263 | out_s = out_s.reshape(batch,self.d_inner,seqlen//self.nslices,self.nslices).permute(0,1,3,2).flatten(-2) 264 | 265 | # F.linear(rearrange(out_z, "b d l -> b l d"), out_proj_weight, out_proj_bias) 266 | out = F.linear(rearrange(out + out_b.flip([-1]) + out_s, "b d l -> b l d"), self.out_proj.weight, self.out_proj.bias) 267 | elif self.bimamba_type == "v2": 268 | A_b = -torch.exp(self.A_b_log.float()) 269 | out = mamba_inner_fn_no_out_proj( 270 | xz, 271 | self.conv1d.weight, 272 | self.conv1d.bias, 273 | self.x_proj.weight, 274 | self.dt_proj.weight, 275 | A, 276 | None, # input-dependent B 277 | None, # input-dependent C 278 | self.D.float(), 279 | delta_bias=self.dt_proj.bias.float(), 280 | delta_softplus=True, 281 | ) 282 | out_b = mamba_inner_fn_no_out_proj( 283 | xz.flip([-1]), 284 | self.conv1d_b.weight, 285 | self.conv1d_b.bias, 286 | self.x_proj_b.weight, 287 | self.dt_proj_b.weight, 288 | A_b, 289 | None, 290 | None, 291 | self.D_b.float(), 292 | delta_bias=self.dt_proj_b.bias.float(), 293 | delta_softplus=True, 294 | ) 295 | # F.linear(rearrange(out_z, "b d l -> b l d"), out_proj_weight, out_proj_bias) 296 | out = F.linear(rearrange(out + out_b.flip([-1]), "b d l -> b l d"), self.out_proj.weight, self.out_proj.bias) 297 | else: 298 | out = mamba_inner_fn( 299 | xz, 300 | self.conv1d.weight, 301 | self.conv1d.bias, 302 | self.x_proj.weight, 303 | self.dt_proj.weight, 304 | self.out_proj.weight, 305 | self.out_proj.bias, 306 | A, 307 | None, # input-dependent B 308 | None, # input-dependent C 309 | self.D.float(), 310 | delta_bias=self.dt_proj.bias.float(), 311 | delta_softplus=True, 312 | ) 313 | else: 314 | x, z = xz.chunk(2, dim=1) 315 | # Compute short convolution 316 | if conv_state is not None: 317 | conv_state.copy_(x[:, :, -self.d_conv :]) # Update state (B D W) 318 | if causal_conv1d_fn is None: 319 | x = self.act(self.conv1d(x)[..., :seqlen]) 320 | else: 321 | assert self.activation in ["silu", "swish"] 322 | x = causal_conv1d_fn( 323 | x, 324 | rearrange(self.conv1d.weight, "d 1 w -> d w"), 325 | self.conv1d.bias, 326 | self.activation, 327 | ) 328 | 329 | # We're careful here about the layout, to avoid extra transposes. 330 | # We want dt to have d as the slowest moving dimension 331 | # and L as the fastest moving dimension, since those are what the ssm_scan kernel expects. 332 | x_dbl = self.x_proj(rearrange(x, "b d l -> (b l) d")) # (bl d) 333 | dt, B, C = torch.split(x_dbl, [self.dt_rank, self.d_state, self.d_state], dim=-1) 334 | dt = self.dt_proj.weight @ dt.t() 335 | dt = rearrange(dt, "d (b l) -> b d l", l=seqlen) 336 | B = rearrange(B, "(b l) dstate -> b dstate l", l=seqlen).contiguous() 337 | C = rearrange(C, "(b l) dstate -> b dstate l", l=seqlen).contiguous() 338 | assert self.activation in ["silu", "swish"] 339 | y = selective_scan_fn( 340 | x, 341 | dt, 342 | A, 343 | B, 344 | C, 345 | self.D.float(), 346 | z=z, 347 | delta_bias=self.dt_proj.bias.float(), 348 | delta_softplus=True, 349 | return_last_state=ssm_state is not None, 350 | ) 351 | if ssm_state is not None: 352 | y, last_state = y 353 | ssm_state.copy_(last_state) 354 | y = rearrange(y, "b d l -> b l d") 355 | out = self.out_proj(y) 356 | return out 357 | 358 | def step(self, hidden_states, conv_state, ssm_state): 359 | dtype = hidden_states.dtype 360 | assert hidden_states.shape[1] == 1, "Only support decoding with 1 token at a time for now" 361 | xz = self.in_proj(hidden_states.squeeze(1)) # (B 2D) 362 | x, z = xz.chunk(2, dim=-1) # (B D) 363 | 364 | # Conv step 365 | if causal_conv1d_update is None: 366 | conv_state.copy_(torch.roll(conv_state, shifts=-1, dims=-1)) # Update state (B D W) 367 | conv_state[:, :, -1] = x 368 | x = torch.sum(conv_state * rearrange(self.conv1d.weight, "d 1 w -> d w"), dim=-1) # (B D) 369 | if self.conv1d.bias is not None: 370 | x = x + self.conv1d.bias 371 | x = self.act(x).to(dtype=dtype) 372 | else: 373 | x = causal_conv1d_update( 374 | x, 375 | conv_state, 376 | rearrange(self.conv1d.weight, "d 1 w -> d w"), 377 | self.conv1d.bias, 378 | self.activation, 379 | ) 380 | 381 | x_db = self.x_proj(x) # (B dt_rank+2*d_state) 382 | dt, B, C = torch.split(x_db, [self.dt_rank, self.d_state, self.d_state], dim=-1) 383 | # Don't add dt_bias here 384 | dt = F.linear(dt, self.dt_proj.weight) # (B d_inner) 385 | A = -torch.exp(self.A_log.float()) # (d_inner, d_state) 386 | 387 | # SSM step 388 | if selective_state_update is None: 389 | # Discretize A and B 390 | dt = F.softplus(dt + self.dt_proj.bias.to(dtype=dt.dtype)) 391 | dA = torch.exp(torch.einsum("bd,dn->bdn", dt, A)) 392 | dB = torch.einsum("bd,bn->bdn", dt, B) 393 | ssm_state.copy_(ssm_state * dA + rearrange(x, "b d -> b d 1") * dB) 394 | y = torch.einsum("bdn,bn->bd", ssm_state.to(dtype), C) 395 | y = y + self.D.to(dtype) * x 396 | y = y * self.act(z) # (B D) 397 | else: 398 | y = selective_state_update( 399 | ssm_state, x, dt, A, B, C, self.D, z=z, dt_bias=self.dt_proj.bias, dt_softplus=True 400 | ) 401 | 402 | out = self.out_proj(y) 403 | return out.unsqueeze(1), conv_state, ssm_state 404 | 405 | def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs): 406 | device = self.out_proj.weight.device 407 | conv_dtype = self.conv1d.weight.dtype if dtype is None else dtype 408 | conv_state = torch.zeros( 409 | batch_size, self.d_model * self.expand, self.d_conv, device=device, dtype=conv_dtype 410 | ) 411 | ssm_dtype = self.dt_proj.weight.dtype if dtype is None else dtype 412 | # ssm_dtype = torch.float32 413 | ssm_state = torch.zeros( 414 | batch_size, self.d_model * self.expand, self.d_state, device=device, dtype=ssm_dtype 415 | ) 416 | return conv_state, ssm_state 417 | 418 | def _get_states_from_cache(self, inference_params, batch_size, initialize_states=False): 419 | assert self.layer_idx is not None 420 | if self.layer_idx not in inference_params.key_value_memory_dict: 421 | batch_shape = (batch_size,) 422 | conv_state = torch.zeros( 423 | batch_size, 424 | self.d_model * self.expand, 425 | self.d_conv, 426 | device=self.conv1d.weight.device, 427 | dtype=self.conv1d.weight.dtype, 428 | ) 429 | ssm_state = torch.zeros( 430 | batch_size, 431 | self.d_model * self.expand, 432 | self.d_state, 433 | device=self.dt_proj.weight.device, 434 | dtype=self.dt_proj.weight.dtype, 435 | # dtype=torch.float32, 436 | ) 437 | inference_params.key_value_memory_dict[self.layer_idx] = (conv_state, ssm_state) 438 | else: 439 | conv_state, ssm_state = inference_params.key_value_memory_dict[self.layer_idx] 440 | # TODO: What if batch size changes between generation, and we reuse the same states? 441 | if initialize_states: 442 | conv_state.zero_() 443 | ssm_state.zero_() 444 | return conv_state, ssm_state 445 | 446 | 447 | class Block(nn.Module): 448 | def __init__( 449 | self, dim, mixer_cls, norm_cls=nn.LayerNorm, fused_add_norm=False, residual_in_fp32=False 450 | ): 451 | """ 452 | Simple block wrapping a mixer class with LayerNorm/RMSNorm and residual connection" 453 | 454 | This Block has a slightly different structure compared to a regular 455 | prenorm Transformer block. 456 | The standard block is: LN -> MHA/MLP -> Add. 457 | [Ref: https://arxiv.org/abs/2002.04745] 458 | Here we have: Add -> LN -> Mixer, returning both 459 | the hidden_states (output of the mixer) and the residual. 460 | This is purely for performance reasons, as we can fuse add and LayerNorm. 461 | The residual needs to be provided (except for the very first block). 462 | """ 463 | super().__init__() 464 | self.residual_in_fp32 = residual_in_fp32 465 | self.fused_add_norm = fused_add_norm 466 | self.mixer = mixer_cls(dim) 467 | self.norm = norm_cls(dim) 468 | if self.fused_add_norm: 469 | assert RMSNorm is not None, "RMSNorm import fails" 470 | assert isinstance( 471 | self.norm, (nn.LayerNorm, RMSNorm) 472 | ), "Only LayerNorm and RMSNorm are supported for fused_add_norm" 473 | 474 | def forward( 475 | self, hidden_states: Tensor, residual: Optional[Tensor] = None, inference_params=None 476 | ): 477 | r"""Pass the input through the encoder layer. 478 | 479 | Args: 480 | hidden_states: the sequence to the encoder layer (required). 481 | residual: hidden_states = Mixer(LN(residual)) 482 | """ 483 | if not self.fused_add_norm: 484 | residual = (hidden_states + residual) if residual is not None else hidden_states 485 | hidden_states = self.norm(residual.to(dtype=self.norm.weight.dtype)) 486 | if self.residual_in_fp32: 487 | residual = residual.to(torch.float32) 488 | else: 489 | fused_add_norm_fn = rms_norm_fn if isinstance(self.norm, RMSNorm) else layer_norm_fn 490 | hidden_states, residual = fused_add_norm_fn( 491 | hidden_states, 492 | self.norm.weight, 493 | self.norm.bias, 494 | residual=residual, 495 | prenorm=True, 496 | residual_in_fp32=self.residual_in_fp32, 497 | eps=self.norm.eps, 498 | ) 499 | hidden_states = self.mixer(hidden_states, inference_params=inference_params) 500 | return hidden_states, residual 501 | 502 | def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs): 503 | return self.mixer.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype, **kwargs) 504 | -------------------------------------------------------------------------------- /models/mamba_ssm/ops/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IMOP-lab/DeMambaNet/4576667c8e388fa79e1e200468ae3669b34e8901/models/mamba_ssm/ops/__init__.py -------------------------------------------------------------------------------- /models/mamba_ssm/ops/__pycache__/__init__.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IMOP-lab/DeMambaNet/4576667c8e388fa79e1e200468ae3669b34e8901/models/mamba_ssm/ops/__pycache__/__init__.cpython-310.pyc -------------------------------------------------------------------------------- /models/mamba_ssm/ops/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IMOP-lab/DeMambaNet/4576667c8e388fa79e1e200468ae3669b34e8901/models/mamba_ssm/ops/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /models/mamba_ssm/ops/__pycache__/selective_scan_interface.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IMOP-lab/DeMambaNet/4576667c8e388fa79e1e200468ae3669b34e8901/models/mamba_ssm/ops/__pycache__/selective_scan_interface.cpython-310.pyc -------------------------------------------------------------------------------- /models/mamba_ssm/ops/__pycache__/selective_scan_interface.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IMOP-lab/DeMambaNet/4576667c8e388fa79e1e200468ae3669b34e8901/models/mamba_ssm/ops/__pycache__/selective_scan_interface.cpython-37.pyc -------------------------------------------------------------------------------- /models/mamba_ssm/ops/triton/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IMOP-lab/DeMambaNet/4576667c8e388fa79e1e200468ae3669b34e8901/models/mamba_ssm/ops/triton/__init__.py -------------------------------------------------------------------------------- /models/mamba_ssm/ops/triton/layernorm.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023, Tri Dao. 2 | # Implement residual + layer_norm / rms_norm. 3 | 4 | # Based on the Triton LayerNorm tutorial: https://triton-lang.org/main/getting-started/tutorials/05-layer-norm.html 5 | # For the backward pass, we keep weight_grad and bias_grad in registers and accumulate. 6 | # This is faster for dimensions up to 8k, but after that it's much slower due to register spilling. 7 | # The models we train have hidden dim up to 8k anyway (e.g. Llama 70B), so this is fine. 8 | 9 | import math 10 | 11 | import torch 12 | import torch.nn.functional as F 13 | from torch.cuda.amp import custom_fwd, custom_bwd 14 | 15 | import triton 16 | import triton.language as tl 17 | 18 | 19 | def layer_norm_ref(x, weight, bias, residual=None, eps=1e-6, prenorm=False, upcast=False): 20 | dtype = x.dtype 21 | if upcast: 22 | weight = weight.float() 23 | bias = bias.float() if bias is not None else None 24 | if upcast: 25 | x = x.float() 26 | residual = residual.float() if residual is not None else residual 27 | if residual is not None: 28 | x = (x + residual).to(x.dtype) 29 | out = F.layer_norm(x.to(weight.dtype), x.shape[-1:], weight=weight, bias=bias, eps=eps).to( 30 | dtype 31 | ) 32 | return out if not prenorm else (out, x) 33 | 34 | 35 | def rms_norm_ref(x, weight, bias, residual=None, eps=1e-6, prenorm=False, upcast=False): 36 | dtype = x.dtype 37 | if upcast: 38 | weight = weight.float() 39 | bias = bias.float() if bias is not None else None 40 | if upcast: 41 | x = x.float() 42 | residual = residual.float() if residual is not None else residual 43 | if residual is not None: 44 | x = (x + residual).to(x.dtype) 45 | rstd = 1 / torch.sqrt((x.square()).mean(dim=-1, keepdim=True) + eps) 46 | out = (x * rstd * weight) + bias if bias is not None else (x * rstd * weight) 47 | out = out.to(dtype) 48 | return out if not prenorm else (out, x) 49 | 50 | 51 | @triton.autotune( 52 | configs=[ 53 | triton.Config({}, num_warps=1), 54 | triton.Config({}, num_warps=2), 55 | triton.Config({}, num_warps=4), 56 | triton.Config({}, num_warps=8), 57 | triton.Config({}, num_warps=16), 58 | triton.Config({}, num_warps=32), 59 | ], 60 | key=["N", "HAS_RESIDUAL", "STORE_RESIDUAL_OUT", "IS_RMS_NORM", "HAS_BIAS"], 61 | ) 62 | # @triton.heuristics({"HAS_BIAS": lambda args: args["B"] is not None}) 63 | # @triton.heuristics({"HAS_RESIDUAL": lambda args: args["RESIDUAL"] is not None}) 64 | @triton.jit 65 | def _layer_norm_fwd_1pass_kernel( 66 | X, # pointer to the input 67 | Y, # pointer to the output 68 | W, # pointer to the weights 69 | B, # pointer to the biases 70 | RESIDUAL, # pointer to the residual 71 | RESIDUAL_OUT, # pointer to the residual 72 | Mean, # pointer to the mean 73 | Rstd, # pointer to the 1/std 74 | stride_x_row, # how much to increase the pointer when moving by 1 row 75 | stride_y_row, 76 | stride_res_row, 77 | stride_res_out_row, 78 | N, # number of columns in X 79 | eps, # epsilon to avoid division by zero 80 | IS_RMS_NORM: tl.constexpr, 81 | BLOCK_N: tl.constexpr, 82 | HAS_RESIDUAL: tl.constexpr, 83 | STORE_RESIDUAL_OUT: tl.constexpr, 84 | HAS_BIAS: tl.constexpr, 85 | ): 86 | # Map the program id to the row of X and Y it should compute. 87 | row = tl.program_id(0) 88 | X += row * stride_x_row 89 | Y += row * stride_y_row 90 | if HAS_RESIDUAL: 91 | RESIDUAL += row * stride_res_row 92 | if STORE_RESIDUAL_OUT: 93 | RESIDUAL_OUT += row * stride_res_out_row 94 | # Compute mean and variance 95 | cols = tl.arange(0, BLOCK_N) 96 | x = tl.load(X + cols, mask=cols < N, other=0.0).to(tl.float32) 97 | if HAS_RESIDUAL: 98 | residual = tl.load(RESIDUAL + cols, mask=cols < N, other=0.0).to(tl.float32) 99 | x += residual 100 | if STORE_RESIDUAL_OUT: 101 | tl.store(RESIDUAL_OUT + cols, x, mask=cols < N) 102 | if not IS_RMS_NORM: 103 | mean = tl.sum(x, axis=0) / N 104 | tl.store(Mean + row, mean) 105 | xbar = tl.where(cols < N, x - mean, 0.0) 106 | var = tl.sum(xbar * xbar, axis=0) / N 107 | else: 108 | xbar = tl.where(cols < N, x, 0.0) 109 | var = tl.sum(xbar * xbar, axis=0) / N 110 | rstd = 1 / tl.sqrt(var + eps) 111 | tl.store(Rstd + row, rstd) 112 | # Normalize and apply linear transformation 113 | mask = cols < N 114 | w = tl.load(W + cols, mask=mask).to(tl.float32) 115 | if HAS_BIAS: 116 | b = tl.load(B + cols, mask=mask).to(tl.float32) 117 | x_hat = (x - mean) * rstd if not IS_RMS_NORM else x * rstd 118 | y = x_hat * w + b if HAS_BIAS else x_hat * w 119 | # Write output 120 | tl.store(Y + cols, y, mask=mask) 121 | 122 | 123 | def _layer_norm_fwd( 124 | x, weight, bias, eps, residual=None, out_dtype=None, residual_dtype=None, is_rms_norm=False 125 | ): 126 | if residual is not None: 127 | residual_dtype = residual.dtype 128 | M, N = x.shape 129 | assert x.stride(-1) == 1 130 | if residual is not None: 131 | assert residual.stride(-1) == 1 132 | assert residual.shape == (M, N) 133 | assert weight.shape == (N,) 134 | assert weight.stride(-1) == 1 135 | if bias is not None: 136 | assert bias.stride(-1) == 1 137 | assert bias.shape == (N,) 138 | # allocate output 139 | y = torch.empty_like(x, dtype=x.dtype if out_dtype is None else out_dtype) 140 | assert y.stride(-1) == 1 141 | if residual is not None or (residual_dtype is not None and residual_dtype != x.dtype): 142 | residual_out = torch.empty(M, N, device=x.device, dtype=residual_dtype) 143 | assert residual_out.stride(-1) == 1 144 | else: 145 | residual_out = None 146 | mean = torch.empty((M,), dtype=torch.float32, device="cuda") if not is_rms_norm else None 147 | rstd = torch.empty((M,), dtype=torch.float32, device="cuda") 148 | # Less than 64KB per feature: enqueue fused kernel 149 | MAX_FUSED_SIZE = 65536 // x.element_size() 150 | BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(N)) 151 | if N > BLOCK_N: 152 | raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.") 153 | # heuristics for number of warps 154 | with torch.cuda.device(x.device.index): 155 | _layer_norm_fwd_1pass_kernel[(M,)]( 156 | x, 157 | y, 158 | weight, 159 | bias, 160 | residual, 161 | residual_out, 162 | mean, 163 | rstd, 164 | x.stride(0), 165 | y.stride(0), 166 | residual.stride(0) if residual is not None else 0, 167 | residual_out.stride(0) if residual_out is not None else 0, 168 | N, 169 | eps, 170 | is_rms_norm, 171 | BLOCK_N, 172 | residual is not None, 173 | residual_out is not None, 174 | bias is not None, 175 | ) 176 | # residual_out is None if residual is None and residual_dtype == input_dtype 177 | return y, mean, rstd, residual_out if residual_out is not None else x 178 | 179 | 180 | @triton.autotune( 181 | configs=[ 182 | triton.Config({}, num_warps=1), 183 | triton.Config({}, num_warps=2), 184 | triton.Config({}, num_warps=4), 185 | triton.Config({}, num_warps=8), 186 | triton.Config({}, num_warps=16), 187 | triton.Config({}, num_warps=32), 188 | ], 189 | key=["N", "HAS_DRESIDUAL", "STORE_DRESIDUAL", "IS_RMS_NORM", "HAS_BIAS"], 190 | ) 191 | # @triton.heuristics({"HAS_BIAS": lambda args: args["B"] is not None}) 192 | # @triton.heuristics({"HAS_DRESIDUAL": lambda args: args["DRESIDUAL"] is not None}) 193 | # @triton.heuristics({"STORE_DRESIDUAL": lambda args: args["DRESIDUAL_IN"] is not None}) 194 | @triton.heuristics({"RECOMPUTE_OUTPUT": lambda args: args["Y"] is not None}) 195 | @triton.jit 196 | def _layer_norm_bwd_kernel( 197 | X, # pointer to the input 198 | W, # pointer to the weights 199 | B, # pointer to the biases 200 | Y, # pointer to the output to be recomputed 201 | DY, # pointer to the output gradient 202 | DX, # pointer to the input gradient 203 | DW, # pointer to the partial sum of weights gradient 204 | DB, # pointer to the partial sum of biases gradient 205 | DRESIDUAL, 206 | DRESIDUAL_IN, 207 | Mean, # pointer to the mean 208 | Rstd, # pointer to the 1/std 209 | stride_x_row, # how much to increase the pointer when moving by 1 row 210 | stride_y_row, 211 | stride_dy_row, 212 | stride_dx_row, 213 | stride_dres_row, 214 | stride_dres_in_row, 215 | M, # number of rows in X 216 | N, # number of columns in X 217 | eps, # epsilon to avoid division by zero 218 | rows_per_program, 219 | IS_RMS_NORM: tl.constexpr, 220 | BLOCK_N: tl.constexpr, 221 | HAS_DRESIDUAL: tl.constexpr, 222 | STORE_DRESIDUAL: tl.constexpr, 223 | HAS_BIAS: tl.constexpr, 224 | RECOMPUTE_OUTPUT: tl.constexpr, 225 | ): 226 | # Map the program id to the elements of X, DX, and DY it should compute. 227 | row_block_id = tl.program_id(0) 228 | row_start = row_block_id * rows_per_program 229 | cols = tl.arange(0, BLOCK_N) 230 | mask = cols < N 231 | X += row_start * stride_x_row 232 | if HAS_DRESIDUAL: 233 | DRESIDUAL += row_start * stride_dres_row 234 | if STORE_DRESIDUAL: 235 | DRESIDUAL_IN += row_start * stride_dres_in_row 236 | DY += row_start * stride_dy_row 237 | DX += row_start * stride_dx_row 238 | if RECOMPUTE_OUTPUT: 239 | Y += row_start * stride_y_row 240 | w = tl.load(W + cols, mask=mask).to(tl.float32) 241 | if RECOMPUTE_OUTPUT and HAS_BIAS: 242 | b = tl.load(B + cols, mask=mask, other=0.0).to(tl.float32) 243 | dw = tl.zeros((BLOCK_N,), dtype=tl.float32) 244 | if HAS_BIAS: 245 | db = tl.zeros((BLOCK_N,), dtype=tl.float32) 246 | row_end = min((row_block_id + 1) * rows_per_program, M) 247 | for row in range(row_start, row_end): 248 | # Load data to SRAM 249 | x = tl.load(X + cols, mask=mask, other=0).to(tl.float32) 250 | dy = tl.load(DY + cols, mask=mask, other=0).to(tl.float32) 251 | if not IS_RMS_NORM: 252 | mean = tl.load(Mean + row) 253 | rstd = tl.load(Rstd + row) 254 | # Compute dx 255 | xhat = (x - mean) * rstd if not IS_RMS_NORM else x * rstd 256 | xhat = tl.where(mask, xhat, 0.0) 257 | if RECOMPUTE_OUTPUT: 258 | y = xhat * w + b if HAS_BIAS else xhat * w 259 | tl.store(Y + cols, y, mask=mask) 260 | wdy = w * dy 261 | dw += dy * xhat 262 | if HAS_BIAS: 263 | db += dy 264 | if not IS_RMS_NORM: 265 | c1 = tl.sum(xhat * wdy, axis=0) / N 266 | c2 = tl.sum(wdy, axis=0) / N 267 | dx = (wdy - (xhat * c1 + c2)) * rstd 268 | else: 269 | c1 = tl.sum(xhat * wdy, axis=0) / N 270 | dx = (wdy - xhat * c1) * rstd 271 | if HAS_DRESIDUAL: 272 | dres = tl.load(DRESIDUAL + cols, mask=mask, other=0).to(tl.float32) 273 | dx += dres 274 | # Write dx 275 | if STORE_DRESIDUAL: 276 | tl.store(DRESIDUAL_IN + cols, dx, mask=mask) 277 | tl.store(DX + cols, dx, mask=mask) 278 | 279 | X += stride_x_row 280 | if HAS_DRESIDUAL: 281 | DRESIDUAL += stride_dres_row 282 | if STORE_DRESIDUAL: 283 | DRESIDUAL_IN += stride_dres_in_row 284 | if RECOMPUTE_OUTPUT: 285 | Y += stride_y_row 286 | DY += stride_dy_row 287 | DX += stride_dx_row 288 | tl.store(DW + row_block_id * N + cols, dw, mask=mask) 289 | if HAS_BIAS: 290 | tl.store(DB + row_block_id * N + cols, db, mask=mask) 291 | 292 | 293 | def _layer_norm_bwd( 294 | dy, 295 | x, 296 | weight, 297 | bias, 298 | eps, 299 | mean, 300 | rstd, 301 | dresidual=None, 302 | has_residual=False, 303 | is_rms_norm=False, 304 | x_dtype=None, 305 | recompute_output=False, 306 | ): 307 | M, N = x.shape 308 | assert x.stride(-1) == 1 309 | assert dy.stride(-1) == 1 310 | assert dy.shape == (M, N) 311 | if dresidual is not None: 312 | assert dresidual.stride(-1) == 1 313 | assert dresidual.shape == (M, N) 314 | assert weight.shape == (N,) 315 | assert weight.stride(-1) == 1 316 | if bias is not None: 317 | assert bias.stride(-1) == 1 318 | assert bias.shape == (N,) 319 | # allocate output 320 | dx = ( 321 | torch.empty_like(x) 322 | if x_dtype is None 323 | else torch.empty(M, N, dtype=x_dtype, device=x.device) 324 | ) 325 | dresidual_in = torch.empty_like(x) if has_residual and dx.dtype != x.dtype else None 326 | y = torch.empty(M, N, dtype=dy.dtype, device=dy.device) if recompute_output else None 327 | 328 | # Less than 64KB per feature: enqueue fused kernel 329 | MAX_FUSED_SIZE = 65536 // x.element_size() 330 | BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(N)) 331 | if N > BLOCK_N: 332 | raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.") 333 | sm_count = torch.cuda.get_device_properties(x.device).multi_processor_count 334 | _dw = torch.empty((sm_count, N), dtype=torch.float32, device=weight.device) 335 | _db = ( 336 | torch.empty((sm_count, N), dtype=torch.float32, device=bias.device) 337 | if bias is not None 338 | else None 339 | ) 340 | rows_per_program = math.ceil(M / sm_count) 341 | grid = (sm_count,) 342 | with torch.cuda.device(x.device.index): 343 | _layer_norm_bwd_kernel[grid]( 344 | x, 345 | weight, 346 | bias, 347 | y, 348 | dy, 349 | dx, 350 | _dw, 351 | _db, 352 | dresidual, 353 | dresidual_in, 354 | mean, 355 | rstd, 356 | x.stride(0), 357 | 0 if not recompute_output else y.stride(0), 358 | dy.stride(0), 359 | dx.stride(0), 360 | dresidual.stride(0) if dresidual is not None else 0, 361 | dresidual_in.stride(0) if dresidual_in is not None else 0, 362 | M, 363 | N, 364 | eps, 365 | rows_per_program, 366 | is_rms_norm, 367 | BLOCK_N, 368 | dresidual is not None, 369 | dresidual_in is not None, 370 | bias is not None, 371 | ) 372 | dw = _dw.sum(0).to(weight.dtype) 373 | db = _db.sum(0).to(bias.dtype) if bias is not None else None 374 | # Don't need to compute dresidual_in separately in this case 375 | if has_residual and dx.dtype == x.dtype: 376 | dresidual_in = dx 377 | return (dx, dw, db, dresidual_in) if not recompute_output else (dx, dw, db, dresidual_in, y) 378 | 379 | 380 | class LayerNormFn(torch.autograd.Function): 381 | @staticmethod 382 | def forward( 383 | ctx, 384 | x, 385 | weight, 386 | bias, 387 | residual=None, 388 | eps=1e-6, 389 | prenorm=False, 390 | residual_in_fp32=False, 391 | is_rms_norm=False, 392 | ): 393 | x_shape_og = x.shape 394 | # reshape input data into 2D tensor 395 | x = x.reshape(-1, x.shape[-1]) 396 | if x.stride(-1) != 1: 397 | x = x.contiguous() 398 | if residual is not None: 399 | assert residual.shape == x_shape_og 400 | residual = residual.reshape(-1, residual.shape[-1]) 401 | if residual.stride(-1) != 1: 402 | residual = residual.contiguous() 403 | weight = weight.contiguous() 404 | if bias is not None: 405 | bias = bias.contiguous() 406 | residual_dtype = ( 407 | residual.dtype 408 | if residual is not None 409 | else (torch.float32 if residual_in_fp32 else None) 410 | ) 411 | y, mean, rstd, residual_out = _layer_norm_fwd( 412 | x, weight, bias, eps, residual, residual_dtype=residual_dtype, is_rms_norm=is_rms_norm 413 | ) 414 | ctx.save_for_backward(residual_out, weight, bias, mean, rstd) 415 | ctx.x_shape_og = x_shape_og 416 | ctx.eps = eps 417 | ctx.is_rms_norm = is_rms_norm 418 | ctx.has_residual = residual is not None 419 | ctx.prenorm = prenorm 420 | ctx.x_dtype = x.dtype 421 | y = y.reshape(x_shape_og) 422 | return y if not prenorm else (y, residual_out.reshape(x_shape_og)) 423 | 424 | @staticmethod 425 | def backward(ctx, dy, *args): 426 | x, weight, bias, mean, rstd = ctx.saved_tensors 427 | dy = dy.reshape(-1, dy.shape[-1]) 428 | if dy.stride(-1) != 1: 429 | dy = dy.contiguous() 430 | assert dy.shape == x.shape 431 | if ctx.prenorm: 432 | dresidual = args[0] 433 | dresidual = dresidual.reshape(-1, dresidual.shape[-1]) 434 | if dresidual.stride(-1) != 1: 435 | dresidual = dresidual.contiguous() 436 | assert dresidual.shape == x.shape 437 | else: 438 | dresidual = None 439 | dx, dw, db, dresidual_in = _layer_norm_bwd( 440 | dy, 441 | x, 442 | weight, 443 | bias, 444 | ctx.eps, 445 | mean, 446 | rstd, 447 | dresidual, 448 | ctx.has_residual, 449 | ctx.is_rms_norm, 450 | x_dtype=ctx.x_dtype, 451 | ) 452 | return ( 453 | dx.reshape(ctx.x_shape_og), 454 | dw, 455 | db, 456 | dresidual_in.reshape(ctx.x_shape_og) if ctx.has_residual else None, 457 | None, 458 | None, 459 | None, 460 | None, 461 | ) 462 | 463 | 464 | def layer_norm_fn( 465 | x, 466 | weight, 467 | bias, 468 | residual=None, 469 | eps=1e-6, 470 | prenorm=False, 471 | residual_in_fp32=False, 472 | is_rms_norm=False, 473 | ): 474 | return LayerNormFn.apply(x, weight, bias, residual, eps, prenorm, residual_in_fp32, is_rms_norm) 475 | 476 | 477 | def rms_norm_fn(x, weight, bias, residual=None, prenorm=False, residual_in_fp32=False, eps=1e-6): 478 | return LayerNormFn.apply(x, weight, bias, residual, eps, prenorm, residual_in_fp32, True) 479 | 480 | 481 | class RMSNorm(torch.nn.Module): 482 | def __init__(self, hidden_size, eps=1e-5, device=None, dtype=None): 483 | factory_kwargs = {"device": device, "dtype": dtype} 484 | super().__init__() 485 | self.eps = eps 486 | self.weight = torch.nn.Parameter(torch.empty(hidden_size, **factory_kwargs)) 487 | self.register_parameter("bias", None) 488 | self.reset_parameters() 489 | 490 | def reset_parameters(self): 491 | torch.nn.init.ones_(self.weight) 492 | 493 | def forward(self, x, residual=None, prenorm=False, residual_in_fp32=False): 494 | return rms_norm_fn( 495 | x, 496 | self.weight, 497 | self.bias, 498 | residual=residual, 499 | eps=self.eps, 500 | prenorm=prenorm, 501 | residual_in_fp32=residual_in_fp32, 502 | is_rms_norm=True, 503 | ) 504 | 505 | 506 | class LayerNormLinearFn(torch.autograd.Function): 507 | @staticmethod 508 | @custom_fwd 509 | def forward( 510 | ctx, 511 | x, 512 | norm_weight, 513 | norm_bias, 514 | linear_weight, 515 | linear_bias, 516 | residual=None, 517 | eps=1e-6, 518 | prenorm=False, 519 | residual_in_fp32=False, 520 | is_rms_norm=False, 521 | ): 522 | x_shape_og = x.shape 523 | # reshape input data into 2D tensor 524 | x = x.reshape(-1, x.shape[-1]) 525 | if x.stride(-1) != 1: 526 | x = x.contiguous() 527 | if residual is not None: 528 | assert residual.shape == x_shape_og 529 | residual = residual.reshape(-1, residual.shape[-1]) 530 | if residual.stride(-1) != 1: 531 | residual = residual.contiguous() 532 | norm_weight = norm_weight.contiguous() 533 | if norm_bias is not None: 534 | norm_bias = norm_bias.contiguous() 535 | residual_dtype = ( 536 | residual.dtype 537 | if residual is not None 538 | else (torch.float32 if residual_in_fp32 else None) 539 | ) 540 | y, mean, rstd, residual_out = _layer_norm_fwd( 541 | x, 542 | norm_weight, 543 | norm_bias, 544 | eps, 545 | residual, 546 | out_dtype=None if not torch.is_autocast_enabled() else torch.get_autocast_gpu_dtype(), 547 | residual_dtype=residual_dtype, 548 | is_rms_norm=is_rms_norm, 549 | ) 550 | y = y.reshape(x_shape_og) 551 | dtype = torch.get_autocast_gpu_dtype() if torch.is_autocast_enabled() else y.dtype 552 | linear_weight = linear_weight.to(dtype) 553 | linear_bias = linear_bias.to(dtype) if linear_bias is not None else None 554 | out = F.linear(y.to(linear_weight.dtype), linear_weight, linear_bias) 555 | # We don't store y, will be recomputed in the backward pass to save memory 556 | ctx.save_for_backward(residual_out, norm_weight, norm_bias, linear_weight, mean, rstd) 557 | ctx.x_shape_og = x_shape_og 558 | ctx.eps = eps 559 | ctx.is_rms_norm = is_rms_norm 560 | ctx.has_residual = residual is not None 561 | ctx.prenorm = prenorm 562 | ctx.x_dtype = x.dtype 563 | ctx.linear_bias_is_none = linear_bias is None 564 | return out if not prenorm else (out, residual_out.reshape(x_shape_og)) 565 | 566 | @staticmethod 567 | @custom_bwd 568 | def backward(ctx, dout, *args): 569 | x, norm_weight, norm_bias, linear_weight, mean, rstd = ctx.saved_tensors 570 | dout = dout.reshape(-1, dout.shape[-1]) 571 | dy = F.linear(dout, linear_weight.t()) 572 | dlinear_bias = None if ctx.linear_bias_is_none else dout.sum(0) 573 | if dy.stride(-1) != 1: 574 | dy = dy.contiguous() 575 | assert dy.shape == x.shape 576 | if ctx.prenorm: 577 | dresidual = args[0] 578 | dresidual = dresidual.reshape(-1, dresidual.shape[-1]) 579 | if dresidual.stride(-1) != 1: 580 | dresidual = dresidual.contiguous() 581 | assert dresidual.shape == x.shape 582 | else: 583 | dresidual = None 584 | dx, dnorm_weight, dnorm_bias, dresidual_in, y = _layer_norm_bwd( 585 | dy, 586 | x, 587 | norm_weight, 588 | norm_bias, 589 | ctx.eps, 590 | mean, 591 | rstd, 592 | dresidual, 593 | ctx.has_residual, 594 | ctx.is_rms_norm, 595 | x_dtype=ctx.x_dtype, 596 | recompute_output=True, 597 | ) 598 | dlinear_weight = torch.einsum("bo,bi->oi", dout, y) 599 | return ( 600 | dx.reshape(ctx.x_shape_og), 601 | dnorm_weight, 602 | dnorm_bias, 603 | dlinear_weight, 604 | dlinear_bias, 605 | dresidual_in.reshape(ctx.x_shape_og) if ctx.has_residual else None, 606 | None, 607 | None, 608 | None, 609 | None, 610 | ) 611 | 612 | 613 | def layer_norm_linear_fn( 614 | x, 615 | norm_weight, 616 | norm_bias, 617 | linear_weight, 618 | linear_bias, 619 | residual=None, 620 | eps=1e-6, 621 | prenorm=False, 622 | residual_in_fp32=False, 623 | is_rms_norm=False, 624 | ): 625 | return LayerNormLinearFn.apply( 626 | x, 627 | norm_weight, 628 | norm_bias, 629 | linear_weight, 630 | linear_bias, 631 | residual, 632 | eps, 633 | prenorm, 634 | residual_in_fp32, 635 | is_rms_norm, 636 | ) 637 | -------------------------------------------------------------------------------- /models/mamba_ssm/ops/triton/selective_state_update.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023, Tri Dao. 2 | 3 | """We want triton==2.1.0 for this 4 | """ 5 | 6 | import math 7 | import torch 8 | import torch.nn.functional as F 9 | 10 | import triton 11 | import triton.language as tl 12 | 13 | from einops import rearrange, repeat 14 | 15 | 16 | @triton.heuristics({"HAS_DT_BIAS": lambda args: args["dt_bias_ptr"] is not None}) 17 | @triton.heuristics({"HAS_D": lambda args: args["D_ptr"] is not None}) 18 | @triton.heuristics({"HAS_Z": lambda args: args["z_ptr"] is not None}) 19 | @triton.heuristics({"BLOCK_SIZE_DSTATE": lambda args: triton.next_power_of_2(args["dstate"])}) 20 | @triton.jit 21 | def _selective_scan_update_kernel( 22 | # Pointers to matrices 23 | state_ptr, x_ptr, dt_ptr, dt_bias_ptr, A_ptr, B_ptr, C_ptr, D_ptr, z_ptr, out_ptr, 24 | # Matrix dimensions 25 | batch, dim, dstate, 26 | # Strides 27 | stride_state_batch, stride_state_dim, stride_state_dstate, 28 | stride_x_batch, stride_x_dim, 29 | stride_dt_batch, stride_dt_dim, 30 | stride_dt_bias_dim, 31 | stride_A_dim, stride_A_dstate, 32 | stride_B_batch, stride_B_dstate, 33 | stride_C_batch, stride_C_dstate, 34 | stride_D_dim, 35 | stride_z_batch, stride_z_dim, 36 | stride_out_batch, stride_out_dim, 37 | # Meta-parameters 38 | DT_SOFTPLUS: tl.constexpr, 39 | BLOCK_SIZE_M: tl.constexpr, 40 | HAS_DT_BIAS: tl.constexpr, 41 | HAS_D: tl.constexpr, 42 | HAS_Z: tl.constexpr, 43 | BLOCK_SIZE_DSTATE: tl.constexpr, 44 | ): 45 | pid_m = tl.program_id(axis=0) 46 | pid_b = tl.program_id(axis=1) 47 | state_ptr += pid_b * stride_state_batch 48 | x_ptr += pid_b * stride_x_batch 49 | dt_ptr += pid_b * stride_dt_batch 50 | B_ptr += pid_b * stride_B_batch 51 | C_ptr += pid_b * stride_C_batch 52 | if HAS_Z: 53 | z_ptr += pid_b * stride_z_batch 54 | out_ptr += pid_b * stride_out_batch 55 | 56 | offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) 57 | offs_n = tl.arange(0, BLOCK_SIZE_DSTATE) 58 | state_ptrs = state_ptr + (offs_m[:, None] * stride_state_dim + offs_n[None, :] * stride_state_dstate) 59 | x_ptrs = x_ptr + offs_m * stride_x_dim 60 | dt_ptrs = dt_ptr + offs_m * stride_dt_dim 61 | if HAS_DT_BIAS: 62 | dt_bias_ptrs = dt_bias_ptr + offs_m * stride_dt_bias_dim 63 | A_ptrs = A_ptr + (offs_m[:, None] * stride_A_dim + offs_n[None, :] * stride_A_dstate) 64 | B_ptrs = B_ptr + offs_n * stride_B_dstate 65 | C_ptrs = C_ptr + offs_n * stride_C_dstate 66 | if HAS_D: 67 | D_ptrs = D_ptr + offs_m * stride_D_dim 68 | if HAS_Z: 69 | z_ptrs = z_ptr + offs_m * stride_z_dim 70 | out_ptrs = out_ptr + offs_m * stride_out_dim 71 | 72 | state = tl.load(state_ptrs, mask=(offs_m[:, None] < dim) & (offs_n[None, :] < dstate), other=0.0) 73 | x = tl.load(x_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32) 74 | dt = tl.load(dt_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32) 75 | if HAS_DT_BIAS: 76 | dt += tl.load(dt_bias_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32) 77 | if DT_SOFTPLUS: 78 | dt = tl.log(1.0 + tl.exp(dt)) 79 | A = tl.load(A_ptrs, mask=(offs_m[:, None] < dim) & (offs_n[None, :] < dstate), other=0.0).to(tl.float32) 80 | dA = tl.exp(A * dt[:, None]) 81 | B = tl.load(B_ptrs, mask=offs_n < dstate, other=0.0).to(tl.float32) 82 | C = tl.load(C_ptrs, mask=offs_n < dstate, other=0.0).to(tl.float32) 83 | if HAS_D: 84 | D = tl.load(D_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32) 85 | if HAS_Z: 86 | z = tl.load(z_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32) 87 | 88 | dB = B[None, :] * dt[:, None] 89 | state = state * dA + dB * x[:, None] 90 | tl.store(state_ptrs, state, mask=(offs_m[:, None] < dim) & (offs_n[None, :] < dstate)) 91 | out = tl.sum(state * C[None, :], axis=1) 92 | if HAS_D: 93 | out += x * D 94 | if HAS_Z: 95 | out *= z * tl.sigmoid(z) 96 | tl.store(out_ptrs, out, mask=offs_m < dim) 97 | 98 | 99 | def selective_state_update(state, x, dt, A, B, C, D=None, z=None, dt_bias=None, dt_softplus=False): 100 | """ 101 | Argument: 102 | state: (batch, dim, dstate) 103 | x: (batch, dim) 104 | dt: (batch, dim) 105 | A: (dim, dstate) 106 | B: (batch, dstate) 107 | C: (batch, dstate) 108 | D: (dim,) 109 | z: (batch, dim) 110 | dt_bias: (dim,) 111 | Return: 112 | out: (batch, dim) 113 | """ 114 | batch, dim, dstate = state.shape 115 | assert x.shape == (batch, dim) 116 | assert dt.shape == x.shape 117 | assert A.shape == (dim, dstate) 118 | assert B.shape == (batch, dstate) 119 | assert C.shape == B.shape 120 | if D is not None: 121 | assert D.shape == (dim,) 122 | if z is not None: 123 | assert z.shape == x.shape 124 | if dt_bias is not None: 125 | assert dt_bias.shape == (dim,) 126 | out = torch.empty_like(x) 127 | grid = lambda META: (triton.cdiv(dim, META['BLOCK_SIZE_M']), batch) 128 | z_strides = ((z.stride(0), z.stride(1)) if z is not None else (0, 0)) 129 | # We don't want autotune since it will overwrite the state 130 | # We instead tune by hand. 131 | BLOCK_SIZE_M, num_warps = ((32, 4) if dstate <= 16 132 | else ((16, 4) if dstate <= 32 else 133 | ((8, 4) if dstate <= 64 else 134 | ((4, 4) if dstate <= 128 else 135 | ((4, 8)))))) 136 | with torch.cuda.device(x.device.index): 137 | _selective_scan_update_kernel[grid]( 138 | state, x, dt, dt_bias, A, B, C, D, z, out, 139 | batch, dim, dstate, 140 | state.stride(0), state.stride(1), state.stride(2), 141 | x.stride(0), x.stride(1), 142 | dt.stride(0), dt.stride(1), 143 | dt_bias.stride(0) if dt_bias is not None else 0, 144 | A.stride(0), A.stride(1), 145 | B.stride(0), B.stride(1), 146 | C.stride(0), C.stride(1), 147 | D.stride(0) if D is not None else 0, 148 | z_strides[0], z_strides[1], 149 | out.stride(0), out.stride(1), 150 | dt_softplus, 151 | BLOCK_SIZE_M, 152 | num_warps=num_warps, 153 | ) 154 | return out 155 | 156 | 157 | def selective_state_update_ref(state, x, dt, A, B, C, D=None, z=None, dt_bias=None, dt_softplus=False): 158 | """ 159 | Argument: 160 | state: (batch, dim, dstate) 161 | x: (batch, dim) 162 | dt: (batch, dim) 163 | A: (dim, dstate) 164 | B: (batch, dstate) 165 | C: (batch, dstate) 166 | D: (dim,) 167 | z: (batch, dim) 168 | dt_bias: (dim,) 169 | Return: 170 | out: (batch, dim) 171 | """ 172 | batch, dim, dstate = state.shape 173 | assert x.shape == (batch, dim) 174 | assert dt.shape == x.shape 175 | assert A.shape == (dim, dstate) 176 | assert B.shape == (batch, dstate) 177 | assert C.shape == B.shape 178 | if D is not None: 179 | assert D.shape == (dim,) 180 | if z is not None: 181 | assert z.shape == x.shape 182 | if dt_bias is not None: 183 | assert dt_bias.shape == (dim,) 184 | dt = dt + dt_bias 185 | dt = F.softplus(dt) if dt_softplus else dt 186 | dA = torch.exp(rearrange(dt, "b d -> b d 1") * A) # (batch, dim, dstate) 187 | dB = rearrange(dt, "b d -> b d 1") * rearrange(B, "b n -> b 1 n") # (batch, dim, dstate) 188 | state.copy_(state * dA + dB * rearrange(x, "b d -> b d 1")) # (batch, dim, dstate 189 | out = torch.einsum("bdn,bn->bd", state.to(C.dtype), C) 190 | if D is not None: 191 | out += (x * D).to(out.dtype) 192 | return (out if z is None else out * F.silu(z)).to(x.dtype) 193 | -------------------------------------------------------------------------------- /models/mamba_ssm/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IMOP-lab/DeMambaNet/4576667c8e388fa79e1e200468ae3669b34e8901/models/mamba_ssm/utils/__init__.py -------------------------------------------------------------------------------- /models/mamba_ssm/utils/generation.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023, Albert Gu, Tri Dao. 2 | import gc 3 | import time 4 | from collections import namedtuple 5 | from dataclasses import dataclass, field 6 | from functools import partial 7 | from typing import Callable, Optional, Sequence, Union 8 | 9 | import torch 10 | import torch.nn.functional as F 11 | from einops import rearrange, repeat 12 | from torch import Tensor 13 | from torch.profiler import ProfilerActivity, profile, record_function 14 | from transformers.generation import GreedySearchDecoderOnlyOutput, SampleDecoderOnlyOutput 15 | 16 | 17 | @dataclass 18 | class InferenceParams: 19 | """Inference parameters that are passed to the main model in order 20 | to efficienly calculate and store the context during inference.""" 21 | 22 | max_seqlen: int 23 | max_batch_size: int 24 | seqlen_offset: int = 0 25 | batch_size_offset: int = 0 26 | key_value_memory_dict: dict = field(default_factory=dict) 27 | lengths_per_sample: Optional[Tensor] = None 28 | 29 | def reset(self, max_seqlen, max_batch_size): 30 | self.max_seqlen = max_seqlen 31 | self.max_batch_size = max_batch_size 32 | self.seqlen_offset = 0 33 | if self.lengths_per_sample is not None: 34 | self.lengths_per_sample.zero_() 35 | 36 | 37 | # https://github.com/NVIDIA/Megatron-LM/blob/0bb597b42c53355a567aba2a1357cc34b9d99ddd/megatron/text_generation/sampling.py 38 | # https://github.com/huggingface/transformers/blob/a44985b41cfa2de48a5e1de7f1f93b7483da25d1/src/transformers/generation/logits_process.py#L231 39 | def modify_logits_for_top_k_filtering(logits, top_k): 40 | """Set the logits for none top-k values to -inf. Done in-place.""" 41 | indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None] 42 | logits.masked_fill_(indices_to_remove, float("-Inf")) 43 | 44 | 45 | # https://github.com/NVIDIA/Megatron-LM/blob/0bb597b42c53355a567aba2a1357cc34b9d99ddd/megatron/text_generation/sampling.py 46 | # https://github.com/huggingface/transformers/blob/a44985b41cfa2de48a5e1de7f1f93b7483da25d1/src/transformers/generation/logits_process.py#L170 47 | def modify_logits_for_top_p_filtering(logits, top_p): 48 | """Set the logits for none top-p values to -inf. Done in-place.""" 49 | if top_p <= 0.0 or top_p >= 1.0: 50 | return 51 | # First sort and calculate cumulative sum of probabilities. 52 | sorted_logits, sorted_indices = torch.sort(logits, descending=False) 53 | cumulative_probs = sorted_logits.softmax(dim=-1).cumsum(dim=-1) 54 | # Remove tokens with cumulative top_p above the threshold (token with 0 are kept) 55 | sorted_indices_to_remove = cumulative_probs <= (1 - top_p) 56 | # scatter sorted tensors to original indexing 57 | indices_to_remove = sorted_indices_to_remove.scatter( 58 | 1, sorted_indices, sorted_indices_to_remove 59 | ) 60 | logits.masked_fill_(indices_to_remove, float("-inf")) 61 | 62 | 63 | def sample(logits, top_k=1, top_p=0.0, temperature=1.0): 64 | """Sample from top-k logits. 65 | Arguments: 66 | logits: Tensor of shape (batch_size, vocab_size) 67 | """ 68 | if top_k == 1: # Short-circuit for greedy decoding 69 | return logits.argmax(dim=-1) 70 | else: 71 | if top_p > 0.0: 72 | assert top_p <= 1.0, "top-p should be in (0, 1]." 73 | if top_k > 0: 74 | top_k = min(top_k, logits.size(-1)) # Safety check 75 | logits_top, indices = torch.topk(logits, top_k, dim=-1) 76 | if temperature != 1.0: 77 | logits_top /= temperature 78 | modify_logits_for_top_p_filtering(logits_top, top_p) 79 | return indices[ 80 | torch.arange(indices.shape[0], device=indices.device), 81 | torch.multinomial(torch.softmax(logits_top, dim=-1), num_samples=1).squeeze(dim=-1), 82 | ] 83 | else: 84 | # Clone so that when we modify for top_p we don't change the original logits 85 | logits_top = logits / temperature if temperature != 1.0 else logits.clone() 86 | modify_logits_for_top_p_filtering(logits_top, top_p) 87 | return torch.multinomial(torch.softmax(logits_top, dim=-1), num_samples=1).squeeze( 88 | dim=-1 89 | ) 90 | 91 | 92 | @torch.inference_mode() 93 | def decode( 94 | input_ids, 95 | model, 96 | max_length, 97 | top_k=1, 98 | top_p=0.0, 99 | temperature=1.0, 100 | eos_token_id=None, 101 | teacher_outputs=None, 102 | vocab_size=None, 103 | tensor_parallel=1, 104 | cg=False, 105 | enable_timing=False, 106 | ): 107 | """Decoding, either greedy or with top-k or top-p sampling. 108 | If top-k = 0, don't limit the number of candidates (pure sampling). 109 | Top-k and top-p can be used together. If top_k > 0 and top_p > 0, then top-k is applied first, 110 | then top-p. 111 | We assume that all sequences in the same batch have the same length. 112 | 113 | Arguments: 114 | input_ids: (batch, seq_len) 115 | max_length: int 116 | teacher_outputs (optional): (batch, seq_len). If provided, instead of sampling from the 117 | logits, the next token is taken from the teacher_outputs. Useful for testing. 118 | Returns: GreedySearchDecoderOnlyOutput or SampleDecoderOnlyOutput, with the following fields: 119 | sequences: (batch, max_length) 120 | scores: tuples of (batch, vocab_size) 121 | """ 122 | batch_size, seqlen_og = input_ids.shape 123 | teacher_output_len = teacher_outputs.shape[1] if teacher_outputs is not None else 0 124 | if cg: 125 | if not hasattr(model, "_decoding_cache"): 126 | model._decoding_cache = None 127 | model._decoding_cache = update_graph_cache( 128 | model, 129 | model._decoding_cache, 130 | batch_size, 131 | seqlen_og, 132 | max_length, 133 | tensor_parallel=tensor_parallel, 134 | ) 135 | inference_params = model._decoding_cache.inference_params 136 | inference_params.reset(max_length, batch_size) 137 | else: 138 | inference_params = InferenceParams(max_seqlen=max_length, max_batch_size=batch_size) 139 | 140 | def get_logits(input_ids, inference_params): 141 | decoding = inference_params.seqlen_offset > 0 142 | if decoding: 143 | position_ids = torch.full( 144 | (batch_size, 1), 145 | inference_params.seqlen_offset, 146 | dtype=torch.long, 147 | device=input_ids.device, 148 | ) 149 | else: 150 | position_ids = None 151 | if not cg or not decoding: 152 | logits = model( 153 | input_ids, 154 | position_ids=position_ids, 155 | inference_params=inference_params, 156 | num_last_tokens=1, 157 | ).logits.squeeze(dim=1) 158 | else: 159 | logits = model._decoding_cache.run( 160 | input_ids, position_ids, inference_params.seqlen_offset 161 | ).squeeze(dim=1) 162 | return logits[..., :vocab_size] if vocab_size is not None else logits 163 | 164 | def sample_tokens(logits, inference_params): 165 | if teacher_outputs is None or teacher_output_len <= inference_params.seqlen_offset: 166 | token = sample(logits, top_k=top_k, top_p=top_p, temperature=temperature) 167 | else: 168 | token = teacher_outputs[:, inference_params.seqlen_offset] 169 | # return rearrange(token, "b -> b 1") 170 | return token.unsqueeze(1) 171 | 172 | def should_stop(current_token, inference_params): 173 | if inference_params.seqlen_offset == 0: 174 | return False 175 | if eos_token_id is not None and (current_token == eos_token_id).all(): 176 | return True 177 | if inference_params.seqlen_offset >= max_length - 1: 178 | return True 179 | return False 180 | 181 | start = torch.cuda.Event(enable_timing=enable_timing) 182 | end = torch.cuda.Event(enable_timing=enable_timing) 183 | 184 | if enable_timing: 185 | if tensor_parallel > 1: 186 | torch.distributed.barrier() 187 | start.record() 188 | scores, sequences = [], [input_ids] 189 | while not should_stop(sequences[-1], inference_params): 190 | scores.append(get_logits(sequences[-1], inference_params)) 191 | inference_params.seqlen_offset += sequences[-1].shape[1] 192 | sequences.append(sample_tokens(scores[-1], inference_params)) 193 | if enable_timing: 194 | end.record() 195 | if tensor_parallel > 1: 196 | torch.distributed.barrier() 197 | torch.cuda.synchronize() 198 | print(f"Prompt processing + decoding time: {(start.elapsed_time(end)):.0f}ms") 199 | output_cls = GreedySearchDecoderOnlyOutput if top_k == 1 else SampleDecoderOnlyOutput 200 | return output_cls(sequences=torch.cat(sequences, dim=1), scores=tuple(scores)) 201 | 202 | 203 | class GenerationMixin: 204 | def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs): 205 | raise NotImplementedError 206 | 207 | def generate( 208 | self, 209 | input_ids, 210 | max_length, 211 | top_k=1, 212 | top_p=0.0, 213 | temperature=1.0, 214 | return_dict_in_generate=False, 215 | output_scores=False, 216 | **kwargs, 217 | ): 218 | output = decode( 219 | input_ids, self, max_length, top_k=top_k, top_p=top_p, temperature=temperature, **kwargs 220 | ) 221 | if not output_scores: 222 | output.scores = None 223 | return output if return_dict_in_generate else output.sequences 224 | 225 | 226 | def allocate_inference_cache( 227 | max_batch_size, 228 | max_seqlen, 229 | nheads, 230 | headdim, 231 | layers: Union[int, Sequence], 232 | device, 233 | dtype=torch.float16, 234 | ): 235 | assert dtype in [torch.float16, torch.bfloat16, torch.float32] 236 | kv_cache_shape = (max_batch_size, max_seqlen, 2, nheads, headdim) 237 | if isinstance(layers, int): 238 | layers = range(layers) 239 | return {i: torch.empty(kv_cache_shape, device=device, dtype=dtype) for i in layers} 240 | 241 | 242 | @dataclass 243 | class DecodingCGCache: 244 | max_batch_size: int = 0 245 | max_seqlen: int = 0 246 | device = None 247 | dtype = None 248 | callables: dict = field(default_factory=dict) 249 | mempool = None 250 | inference_params: Optional[InferenceParams] = None 251 | run: Optional[Callable] = None 252 | 253 | 254 | @torch.inference_mode() 255 | def update_graph_cache( 256 | model, 257 | cache, 258 | batch_size, 259 | seqlen_og, 260 | max_seqlen, 261 | decoding_seqlens=(1,), 262 | tensor_parallel=1, 263 | dtype=None, 264 | n_warmups=2, 265 | ): 266 | if cache is None: 267 | cache = DecodingCGCache() 268 | param_example = next(iter(model.parameters())) 269 | device = param_example.device 270 | if dtype is None: 271 | dtype = param_example.dtype 272 | if ( 273 | (device, dtype) != (cache.device, cache.dtype) 274 | or batch_size > cache.max_batch_size 275 | or max_seqlen > cache.max_seqlen 276 | ): # Invalidate the cache 277 | cache.callables = {} 278 | cache.mempool = None 279 | cache.inference_params = None 280 | gc.collect() 281 | cache.device, cache.dtype = device, dtype 282 | cache.max_batch_size, cache.max_seqlen = batch_size, max_seqlen 283 | if hasattr(model, "allocate_inference_cache"): 284 | inf_cache = model.allocate_inference_cache(batch_size, max_seqlen, dtype) 285 | else: 286 | headdim = getattr( 287 | model.config, 288 | "head_dim", 289 | model.config.hidden_size // model.config.num_attention_heads, 290 | ) 291 | inf_cache = allocate_inference_cache( 292 | batch_size, 293 | max_seqlen, 294 | model.config.num_attention_heads // tensor_parallel, 295 | headdim, 296 | model.config.num_hidden_layers, 297 | device, 298 | dtype, 299 | ) 300 | lengths_per_sample = torch.full((batch_size,), seqlen_og, dtype=torch.int32, device=device) 301 | cache.inference_params = InferenceParams( 302 | max_seqlen=max_seqlen, 303 | max_batch_size=batch_size, 304 | seqlen_offset=seqlen_og, 305 | key_value_memory_dict=inf_cache, 306 | lengths_per_sample=lengths_per_sample, 307 | ) 308 | cache.mempool = torch.cuda.graphs.graph_pool_handle() 309 | for decoding_seqlen in decoding_seqlens: 310 | if (batch_size, decoding_seqlen) not in cache.callables: 311 | cache.callables[batch_size, decoding_seqlen] = capture_graph( 312 | model, 313 | cache.inference_params, 314 | batch_size, 315 | max_seqlen, 316 | decoding_seqlen=decoding_seqlen, 317 | mempool=cache.mempool, 318 | n_warmups=n_warmups, 319 | ) 320 | 321 | def dispatch(input_ids, position_ids, seqlen): 322 | batch_size, decoding_seqlen = input_ids.shape[:2] 323 | return cache.callables[batch_size, decoding_seqlen](input_ids, position_ids, seqlen) 324 | 325 | cache.run = dispatch 326 | cache.inference_params.seqlen_offset = 0 # Reset so it's not confusing 327 | return cache 328 | 329 | 330 | def capture_graph( 331 | model, inference_params, batch_size, max_seqlen, decoding_seqlen=1, mempool=None, n_warmups=2 332 | ): 333 | device = next(iter(model.parameters())).device 334 | input_ids = torch.full((batch_size, decoding_seqlen), 0, dtype=torch.long, device=device) 335 | position_ids = torch.full((batch_size, decoding_seqlen), 0, dtype=torch.long, device=device) 336 | seqlen_offset_og = inference_params.seqlen_offset 337 | inference_params.seqlen_offset = max_seqlen - decoding_seqlen 338 | inference_params.lengths_per_sample[:] = inference_params.seqlen_offset 339 | 340 | # Warmup before capture 341 | s = torch.cuda.Stream() 342 | s.wait_stream(torch.cuda.current_stream()) 343 | with torch.cuda.stream(s): 344 | for _ in range(n_warmups): 345 | logits = model( 346 | input_ids, 347 | position_ids=position_ids, 348 | inference_params=inference_params, 349 | num_last_tokens=decoding_seqlen, 350 | ).logits 351 | s.synchronize() 352 | # This might be needed for correctness if we run with NCCL_GRAPH_MIXING_SUPPORT=0, 353 | # which requires that graph launch and non-captured launch to not overlap (I think, 354 | # that's how I interpret the documentation). I'm not sure if this is required. 355 | if torch.distributed.is_initialized(): 356 | torch.distributed.barrier() 357 | torch.cuda.current_stream().wait_stream(s) 358 | # Captures the graph 359 | # To allow capture, automatically sets a side stream as the current stream in the context 360 | graph = torch.cuda.CUDAGraph() 361 | with torch.cuda.graph(graph, pool=mempool): 362 | logits = model( 363 | input_ids, 364 | position_ids=position_ids, 365 | inference_params=inference_params, 366 | num_last_tokens=decoding_seqlen, 367 | ).logits 368 | 369 | def run(new_input_ids, new_position_ids, seqlen): 370 | inference_params.lengths_per_sample[:] = seqlen 371 | input_ids.copy_(new_input_ids) 372 | position_ids.copy_(new_position_ids) 373 | graph.replay() 374 | return logits.clone() 375 | 376 | inference_params.seqlen_offset = seqlen_offset_og 377 | return run 378 | -------------------------------------------------------------------------------- /models/mamba_ssm/utils/hf.py: -------------------------------------------------------------------------------- 1 | import json 2 | 3 | import torch 4 | 5 | from transformers.utils import WEIGHTS_NAME, CONFIG_NAME 6 | from transformers.utils.hub import cached_file 7 | 8 | 9 | def load_config_hf(model_name): 10 | resolved_archive_file = cached_file(model_name, CONFIG_NAME, _raise_exceptions_for_missing_entries=False) 11 | return json.load(open(resolved_archive_file)) 12 | 13 | 14 | def load_state_dict_hf(model_name, device=None, dtype=None): 15 | # If not fp32, then we don't want to load directly to the GPU 16 | mapped_device = "cpu" if dtype not in [torch.float32, None] else device 17 | resolved_archive_file = cached_file(model_name, WEIGHTS_NAME, _raise_exceptions_for_missing_entries=False) 18 | return torch.load(resolved_archive_file, map_location=mapped_device) 19 | # Convert dtype before moving to GPU to save memory 20 | if dtype is not None: 21 | state_dict = {k: v.to(dtype=dtype) for k, v in state_dict.items()} 22 | state_dict = {k: v.to(device=device) for k, v in state_dict.items()} 23 | return state_dict 24 | -------------------------------------------------------------------------------- /models/ops_dcnv3/DCNv3.egg-info/PKG-INFO: -------------------------------------------------------------------------------- 1 | Metadata-Version: 2.1 2 | Name: DCNv3 3 | Version: 1.0 4 | Summary: PyTorch Wrapper for CUDA Functions of DCNv3 5 | Home-page: https://github.com/OpenGVLab/InternImage 6 | Author: InternImage 7 | License: UNKNOWN 8 | Platform: UNKNOWN 9 | 10 | UNKNOWN 11 | 12 | -------------------------------------------------------------------------------- /models/ops_dcnv3/DCNv3.egg-info/SOURCES.txt: -------------------------------------------------------------------------------- 1 | setup.py 2 | /home/zbf/Desktop/code/zbfbackbone/models/ops_dcnv3/src/vision.cpp 3 | /home/zbf/Desktop/code/zbfbackbone/models/ops_dcnv3/src/cpu/dcnv3_cpu.cpp 4 | /home/zbf/Desktop/code/zbfbackbone/models/ops_dcnv3/src/cuda/dcnv3_cuda.cu 5 | /home/zbf/lab/remote_all/InternImage-master/segmentation/ops_dcnv3/src/vision.cpp 6 | /home/zbf/lab/remote_all/InternImage-master/segmentation/ops_dcnv3/src/cpu/dcnv3_cpu.cpp 7 | /home/zbf/lab/remote_all/InternImage-master/segmentation/ops_dcnv3/src/cuda/dcnv3_cuda.cu 8 | DCNv3.egg-info/PKG-INFO 9 | DCNv3.egg-info/SOURCES.txt 10 | DCNv3.egg-info/dependency_links.txt 11 | DCNv3.egg-info/top_level.txt 12 | functions/__init__.py 13 | functions/dcnv3_func.py 14 | modules/__init__.py 15 | modules/dcnv3.py -------------------------------------------------------------------------------- /models/ops_dcnv3/DCNv3.egg-info/dependency_links.txt: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /models/ops_dcnv3/DCNv3.egg-info/top_level.txt: -------------------------------------------------------------------------------- 1 | DCNv3 2 | functions 3 | modules 4 | -------------------------------------------------------------------------------- /models/ops_dcnv3/build/lib.linux-x86_64-3.10/DCNv3.cpython-310-x86_64-linux-gnu.so: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IMOP-lab/DeMambaNet/4576667c8e388fa79e1e200468ae3669b34e8901/models/ops_dcnv3/build/lib.linux-x86_64-3.10/DCNv3.cpython-310-x86_64-linux-gnu.so -------------------------------------------------------------------------------- /models/ops_dcnv3/build/lib.linux-x86_64-3.10/functions/__init__.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # InternImage 3 | # Copyright (c) 2022 OpenGVLab 4 | # Licensed under The MIT License [see LICENSE for details] 5 | # -------------------------------------------------------- 6 | 7 | from .dcnv3_func import DCNv3Function, dcnv3_core_pytorch 8 | -------------------------------------------------------------------------------- /models/ops_dcnv3/build/lib.linux-x86_64-3.10/functions/dcnv3_func.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # InternImage 3 | # Copyright (c) 2022 OpenGVLab 4 | # Licensed under The MIT License [see LICENSE for details] 5 | # -------------------------------------------------------- 6 | 7 | from __future__ import absolute_import 8 | from __future__ import print_function 9 | from __future__ import division 10 | 11 | import torch 12 | import torch.nn.functional as F 13 | from torch.autograd import Function 14 | from torch.autograd.function import once_differentiable 15 | from torch.cuda.amp import custom_bwd, custom_fwd 16 | import DCNv3 17 | 18 | 19 | class DCNv3Function(Function): 20 | @staticmethod 21 | @custom_fwd 22 | def forward( 23 | ctx, input, offset, mask, 24 | kernel_h, kernel_w, stride_h, stride_w, 25 | pad_h, pad_w, dilation_h, dilation_w, 26 | group, group_channels, offset_scale, im2col_step): 27 | ctx.kernel_h = kernel_h 28 | ctx.kernel_w = kernel_w 29 | ctx.stride_h = stride_h 30 | ctx.stride_w = stride_w 31 | ctx.pad_h = pad_h 32 | ctx.pad_w = pad_w 33 | ctx.dilation_h = dilation_h 34 | ctx.dilation_w = dilation_w 35 | ctx.group = group 36 | ctx.group_channels = group_channels 37 | ctx.offset_scale = offset_scale 38 | ctx.im2col_step = im2col_step 39 | output = DCNv3.dcnv3_forward( 40 | input, offset, mask, kernel_h, 41 | kernel_w, stride_h, stride_w, pad_h, 42 | pad_w, dilation_h, dilation_w, group, 43 | group_channels, offset_scale, ctx.im2col_step) 44 | ctx.save_for_backward(input, offset, mask) 45 | 46 | return output 47 | 48 | @staticmethod 49 | @once_differentiable 50 | @custom_bwd 51 | def backward(ctx, grad_output): 52 | input, offset, mask = ctx.saved_tensors 53 | grad_input, grad_offset, grad_mask = \ 54 | DCNv3.dcnv3_backward( 55 | input, offset, mask, ctx.kernel_h, 56 | ctx.kernel_w, ctx.stride_h, ctx.stride_w, ctx.pad_h, 57 | ctx.pad_w, ctx.dilation_h, ctx.dilation_w, ctx.group, 58 | ctx.group_channels, ctx.offset_scale, grad_output.contiguous(), ctx.im2col_step) 59 | 60 | return grad_input, grad_offset, grad_mask, \ 61 | None, None, None, None, None, None, None, None, None, None, None, None 62 | 63 | @staticmethod 64 | def symbolic(g, input, offset, mask, kernel_h, kernel_w, stride_h, 65 | stride_w, pad_h, pad_w, dilation_h, dilation_w, group, 66 | group_channels, offset_scale, im2col_step): 67 | """Symbolic function for mmdeploy::DCNv3. 68 | 69 | Returns: 70 | DCNv3 op for onnx. 71 | """ 72 | return g.op( 73 | 'mmdeploy::TRTDCNv3', 74 | input, 75 | offset, 76 | mask, 77 | kernel_h_i=int(kernel_h), 78 | kernel_w_i=int(kernel_w), 79 | stride_h_i=int(stride_h), 80 | stride_w_i=int(stride_w), 81 | pad_h_i=int(pad_h), 82 | pad_w_i=int(pad_w), 83 | dilation_h_i=int(dilation_h), 84 | dilation_w_i=int(dilation_w), 85 | group_i=int(group), 86 | group_channels_i=int(group_channels), 87 | offset_scale_f=float(offset_scale), 88 | im2col_step_i=int(im2col_step), 89 | ) 90 | 91 | def _get_reference_points(spatial_shapes, device, kernel_h, kernel_w, dilation_h, dilation_w, pad_h=0, pad_w=0, stride_h=1, stride_w=1): 92 | _, H_, W_, _ = spatial_shapes 93 | H_out = (H_ - (dilation_h * (kernel_h - 1) + 1)) // stride_h + 1 94 | W_out = (W_ - (dilation_w * (kernel_w - 1) + 1)) // stride_w + 1 95 | 96 | ref_y, ref_x = torch.meshgrid( 97 | torch.linspace( 98 | # pad_h + 0.5, 99 | # H_ - pad_h - 0.5, 100 | (dilation_h * (kernel_h - 1)) // 2 + 0.5, 101 | (dilation_h * (kernel_h - 1)) // 2 + 0.5 + (H_out - 1) * stride_h, 102 | H_out, 103 | dtype=torch.float32, 104 | device=device), 105 | torch.linspace( 106 | # pad_w + 0.5, 107 | # W_ - pad_w - 0.5, 108 | (dilation_w * (kernel_w - 1)) // 2 + 0.5, 109 | (dilation_w * (kernel_w - 1)) // 2 + 0.5 + (W_out - 1) * stride_w, 110 | W_out, 111 | dtype=torch.float32, 112 | device=device)) 113 | ref_y = ref_y.reshape(-1)[None] / H_ 114 | ref_x = ref_x.reshape(-1)[None] / W_ 115 | 116 | ref = torch.stack((ref_x, ref_y), -1).reshape( 117 | 1, H_out, W_out, 1, 2) 118 | 119 | return ref 120 | 121 | 122 | def _generate_dilation_grids(spatial_shapes, kernel_h, kernel_w, dilation_h, dilation_w, group, device): 123 | _, H_, W_, _ = spatial_shapes 124 | points_list = [] 125 | x, y = torch.meshgrid( 126 | torch.linspace( 127 | -((dilation_w * (kernel_w - 1)) // 2), 128 | -((dilation_w * (kernel_w - 1)) // 2) + 129 | (kernel_w - 1) * dilation_w, kernel_w, 130 | dtype=torch.float32, 131 | device=device), 132 | torch.linspace( 133 | -((dilation_h * (kernel_h - 1)) // 2), 134 | -((dilation_h * (kernel_h - 1)) // 2) + 135 | (kernel_h - 1) * dilation_h, kernel_h, 136 | dtype=torch.float32, 137 | device=device)) 138 | 139 | points_list.extend([x / W_, y / H_]) 140 | grid = torch.stack(points_list, -1).reshape(-1, 1, 2).\ 141 | repeat(1, group, 1).permute(1, 0, 2) 142 | grid = grid.reshape(1, 1, 1, group * kernel_h * kernel_w, 2) 143 | 144 | return grid 145 | 146 | 147 | def dcnv3_core_pytorch( 148 | input, offset, mask, kernel_h, 149 | kernel_w, stride_h, stride_w, pad_h, 150 | pad_w, dilation_h, dilation_w, group, 151 | group_channels, offset_scale): 152 | # for debug and test only, 153 | # need to use cuda version instead 154 | input = F.pad( 155 | input, 156 | [0, 0, pad_h, pad_h, pad_w, pad_w]) 157 | N_, H_in, W_in, _ = input.shape 158 | _, H_out, W_out, _ = offset.shape 159 | 160 | ref = _get_reference_points( 161 | input.shape, input.device, kernel_h, kernel_w, dilation_h, dilation_w, pad_h, pad_w, stride_h, stride_w) 162 | grid = _generate_dilation_grids( 163 | input.shape, kernel_h, kernel_w, dilation_h, dilation_w, group, input.device) 164 | spatial_norm = torch.tensor([W_in, H_in]).reshape(1, 1, 1, 2).\ 165 | repeat(1, 1, 1, group*kernel_h*kernel_w).to(input.device) 166 | 167 | sampling_locations = (ref + grid * offset_scale).repeat(N_, 1, 1, 1, 1).flatten(3, 4) + \ 168 | offset * offset_scale / spatial_norm 169 | 170 | P_ = kernel_h * kernel_w 171 | sampling_grids = 2 * sampling_locations - 1 172 | # N_, H_in, W_in, group*group_channels -> N_, H_in*W_in, group*group_channels -> N_, group*group_channels, H_in*W_in -> N_*group, group_channels, H_in, W_in 173 | input_ = input.view(N_, H_in*W_in, group*group_channels).transpose(1, 2).\ 174 | reshape(N_*group, group_channels, H_in, W_in) 175 | # N_, H_out, W_out, group*P_*2 -> N_, H_out*W_out, group, P_, 2 -> N_, group, H_out*W_out, P_, 2 -> N_*group, H_out*W_out, P_, 2 176 | sampling_grid_ = sampling_grids.view(N_, H_out*W_out, group, P_, 2).transpose(1, 2).\ 177 | flatten(0, 1) 178 | # N_*group, group_channels, H_out*W_out, P_ 179 | sampling_input_ = F.grid_sample( 180 | input_, sampling_grid_, mode='bilinear', padding_mode='zeros', align_corners=False) 181 | 182 | # (N_, H_out, W_out, group*P_) -> N_, H_out*W_out, group, P_ -> (N_, group, H_out*W_out, P_) -> (N_*group, 1, H_out*W_out, P_) 183 | mask = mask.view(N_, H_out*W_out, group, P_).transpose(1, 2).\ 184 | reshape(N_*group, 1, H_out*W_out, P_) 185 | output = (sampling_input_ * mask).sum(-1).view(N_, 186 | group*group_channels, H_out*W_out) 187 | 188 | return output.transpose(1, 2).reshape(N_, H_out, W_out, -1).contiguous() 189 | -------------------------------------------------------------------------------- /models/ops_dcnv3/build/lib.linux-x86_64-3.10/modules/__init__.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # InternImage 3 | # Copyright (c) 2022 OpenGVLab 4 | # Licensed under The MIT License [see LICENSE for details] 5 | # -------------------------------------------------------- 6 | 7 | from .dcnv3 import DCNv3, DCNv3_pytorch -------------------------------------------------------------------------------- /models/ops_dcnv3/build/lib.linux-x86_64-3.10/modules/dcnv3.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # InternImage 3 | # Copyright (c) 2022 OpenGVLab 4 | # Licensed under The MIT License [see LICENSE for details] 5 | # -------------------------------------------------------- 6 | 7 | from __future__ import absolute_import 8 | from __future__ import print_function 9 | from __future__ import division 10 | 11 | import warnings 12 | import torch 13 | from torch import nn 14 | import torch.nn.functional as F 15 | from torch.nn.init import xavier_uniform_, constant_ 16 | from ..functions import DCNv3Function, dcnv3_core_pytorch 17 | 18 | 19 | class to_channels_first(nn.Module): 20 | 21 | def __init__(self): 22 | super().__init__() 23 | 24 | def forward(self, x): 25 | return x.permute(0, 3, 1, 2) 26 | 27 | 28 | class to_channels_last(nn.Module): 29 | 30 | def __init__(self): 31 | super().__init__() 32 | 33 | def forward(self, x): 34 | return x.permute(0, 2, 3, 1) 35 | 36 | 37 | def build_norm_layer(dim, 38 | norm_layer, 39 | in_format='channels_last', 40 | out_format='channels_last', 41 | eps=1e-6): 42 | layers = [] 43 | if norm_layer == 'BN': 44 | if in_format == 'channels_last': 45 | layers.append(to_channels_first()) 46 | layers.append(nn.BatchNorm2d(dim)) 47 | if out_format == 'channels_last': 48 | layers.append(to_channels_last()) 49 | elif norm_layer == 'LN': 50 | if in_format == 'channels_first': 51 | layers.append(to_channels_last()) 52 | layers.append(nn.LayerNorm(dim, eps=eps)) 53 | if out_format == 'channels_first': 54 | layers.append(to_channels_first()) 55 | else: 56 | raise NotImplementedError( 57 | f'build_norm_layer does not support {norm_layer}') 58 | return nn.Sequential(*layers) 59 | 60 | 61 | def build_act_layer(act_layer): 62 | if act_layer == 'ReLU': 63 | return nn.ReLU(inplace=True) 64 | elif act_layer == 'SiLU': 65 | return nn.SiLU(inplace=True) 66 | elif act_layer == 'GELU': 67 | return nn.GELU() 68 | 69 | raise NotImplementedError(f'build_act_layer does not support {act_layer}') 70 | 71 | 72 | def _is_power_of_2(n): 73 | if (not isinstance(n, int)) or (n < 0): 74 | raise ValueError( 75 | "invalid input for _is_power_of_2: {} (type: {})".format(n, type(n))) 76 | 77 | return (n & (n - 1) == 0) and n != 0 78 | 79 | 80 | class CenterFeatureScaleModule(nn.Module): 81 | def forward(self, 82 | query, 83 | center_feature_scale_proj_weight, 84 | center_feature_scale_proj_bias): 85 | center_feature_scale = F.linear(query, 86 | weight=center_feature_scale_proj_weight, 87 | bias=center_feature_scale_proj_bias).sigmoid() 88 | return center_feature_scale 89 | 90 | 91 | class DCNv3_pytorch(nn.Module): 92 | def __init__( 93 | self, 94 | channels=64, 95 | kernel_size=3, 96 | dw_kernel_size=None, 97 | stride=1, 98 | pad=1, 99 | dilation=1, 100 | group=4, 101 | offset_scale=1.0, 102 | act_layer='GELU', 103 | norm_layer='LN', 104 | center_feature_scale=False): 105 | """ 106 | DCNv3 Module 107 | :param channels 108 | :param kernel_size 109 | :param stride 110 | :param pad 111 | :param dilation 112 | :param group 113 | :param offset_scale 114 | :param act_layer 115 | :param norm_layer 116 | """ 117 | super().__init__() 118 | if channels % group != 0: 119 | raise ValueError( 120 | f'channels must be divisible by group, but got {channels} and {group}') 121 | _d_per_group = channels // group 122 | dw_kernel_size = dw_kernel_size if dw_kernel_size is not None else kernel_size 123 | # you'd better set _d_per_group to a power of 2 which is more efficient in our CUDA implementation 124 | if not _is_power_of_2(_d_per_group): 125 | warnings.warn( 126 | "You'd better set channels in DCNv3 to make the dimension of each attention head a power of 2 " 127 | "which is more efficient in our CUDA implementation.") 128 | 129 | self.offset_scale = offset_scale 130 | self.channels = channels 131 | self.kernel_size = kernel_size 132 | self.dw_kernel_size = dw_kernel_size 133 | self.stride = stride 134 | self.dilation = dilation 135 | self.pad = pad 136 | self.group = group 137 | self.group_channels = channels // group 138 | self.offset_scale = offset_scale 139 | self.center_feature_scale = center_feature_scale 140 | 141 | self.dw_conv = nn.Sequential( 142 | nn.Conv2d( 143 | channels, 144 | channels, 145 | kernel_size=dw_kernel_size, 146 | stride=1, 147 | padding=(dw_kernel_size - 1) // 2, 148 | groups=channels), 149 | build_norm_layer( 150 | channels, 151 | norm_layer, 152 | 'channels_first', 153 | 'channels_last'), 154 | build_act_layer(act_layer)) 155 | self.offset = nn.Linear( 156 | channels, 157 | group * kernel_size * kernel_size * 2) 158 | self.mask = nn.Linear( 159 | channels, 160 | group * kernel_size * kernel_size) 161 | self.input_proj = nn.Linear(channels, channels) 162 | self.output_proj = nn.Linear(channels, channels) 163 | self._reset_parameters() 164 | 165 | if center_feature_scale: 166 | self.center_feature_scale_proj_weight = nn.Parameter( 167 | torch.zeros((group, channels), dtype=torch.float)) 168 | self.center_feature_scale_proj_bias = nn.Parameter( 169 | torch.tensor(0.0, dtype=torch.float).view((1,)).repeat(group, )) 170 | self.center_feature_scale_module = CenterFeatureScaleModule() 171 | 172 | def _reset_parameters(self): 173 | constant_(self.offset.weight.data, 0.) 174 | constant_(self.offset.bias.data, 0.) 175 | constant_(self.mask.weight.data, 0.) 176 | constant_(self.mask.bias.data, 0.) 177 | xavier_uniform_(self.input_proj.weight.data) 178 | constant_(self.input_proj.bias.data, 0.) 179 | xavier_uniform_(self.output_proj.weight.data) 180 | constant_(self.output_proj.bias.data, 0.) 181 | 182 | def forward(self, input): 183 | """ 184 | :param query (N, H, W, C) 185 | :return output (N, H, W, C) 186 | """ 187 | N, H, W, _ = input.shape 188 | 189 | x = self.input_proj(input) 190 | x_proj = x 191 | 192 | x1 = input.permute(0, 3, 1, 2) 193 | x1 = self.dw_conv(x1) 194 | offset = self.offset(x1) 195 | mask = self.mask(x1).reshape(N, H, W, self.group, -1) 196 | mask = F.softmax(mask, -1).reshape(N, H, W, -1) 197 | 198 | x = dcnv3_core_pytorch( 199 | x, offset, mask, 200 | self.kernel_size, self.kernel_size, 201 | self.stride, self.stride, 202 | self.pad, self.pad, 203 | self.dilation, self.dilation, 204 | self.group, self.group_channels, 205 | self.offset_scale) 206 | if self.center_feature_scale: 207 | center_feature_scale = self.center_feature_scale_module( 208 | x1, self.center_feature_scale_proj_weight, self.center_feature_scale_proj_bias) 209 | # N, H, W, groups -> N, H, W, groups, 1 -> N, H, W, groups, _d_per_group -> N, H, W, channels 210 | center_feature_scale = center_feature_scale[..., None].repeat( 211 | 1, 1, 1, 1, self.channels // self.group).flatten(-2) 212 | x = x * (1 - center_feature_scale) + x_proj * center_feature_scale 213 | x = self.output_proj(x) 214 | 215 | return x 216 | 217 | 218 | class DCNv3(nn.Module): 219 | def __init__( 220 | self, 221 | channels=64, 222 | kernel_size=3, 223 | dw_kernel_size=None, 224 | stride=1, 225 | pad=1, 226 | dilation=1, 227 | group=4, 228 | offset_scale=1.0, 229 | act_layer='GELU', 230 | norm_layer='LN', 231 | center_feature_scale=False): 232 | """ 233 | DCNv3 Module 234 | :param channels 235 | :param kernel_size 236 | :param stride 237 | :param pad 238 | :param dilation 239 | :param group 240 | :param offset_scale 241 | :param act_layer 242 | :param norm_layer 243 | """ 244 | super().__init__() 245 | 246 | # if channels % group != 0: 247 | # raise ValueError( 248 | # f'channels must be divisible by group, but got {channels} and {group}') 249 | 250 | _d_per_group = channels // group 251 | dw_kernel_size = dw_kernel_size if dw_kernel_size is not None else kernel_size 252 | # you'd better set _d_per_group to a power of 2 which is more efficient in our CUDA implementation 253 | if not _is_power_of_2(_d_per_group): 254 | warnings.warn( 255 | "You'd better set channels in DCNv3 to make the dimension of each attention head a power of 2 " 256 | "which is more efficient in our CUDA implementation.") 257 | 258 | self.offset_scale = offset_scale 259 | self.channels = channels 260 | self.kernel_size = kernel_size 261 | self.dw_kernel_size = dw_kernel_size 262 | self.stride = stride 263 | self.dilation = dilation 264 | self.pad = pad 265 | self.group = group 266 | self.group_channels = channels // group 267 | self.offset_scale = offset_scale 268 | self.center_feature_scale = center_feature_scale 269 | 270 | self.dw_conv = nn.Sequential( 271 | nn.Conv2d( 272 | channels, 273 | channels, 274 | kernel_size=dw_kernel_size, 275 | stride=1, 276 | padding=(dw_kernel_size - 1) // 2, 277 | groups=channels), 278 | build_norm_layer( 279 | channels, 280 | norm_layer, 281 | 'channels_first', 282 | 'channels_last'), 283 | build_act_layer(act_layer)) 284 | self.offset = nn.Linear( 285 | channels, 286 | group * kernel_size * kernel_size * 2) 287 | self.mask = nn.Linear( 288 | channels, 289 | group * kernel_size * kernel_size) 290 | self.input_proj = nn.Linear(channels, channels) 291 | self.output_proj = nn.Linear(channels, channels) 292 | self._reset_parameters() 293 | 294 | if center_feature_scale: 295 | self.center_feature_scale_proj_weight = nn.Parameter( 296 | torch.zeros((group, channels), dtype=torch.float)) 297 | self.center_feature_scale_proj_bias = nn.Parameter( 298 | torch.tensor(0.0, dtype=torch.float).view((1,)).repeat(group, )) 299 | self.center_feature_scale_module = CenterFeatureScaleModule() 300 | 301 | def _reset_parameters(self): 302 | constant_(self.offset.weight.data, 0.) 303 | constant_(self.offset.bias.data, 0.) 304 | constant_(self.mask.weight.data, 0.) 305 | constant_(self.mask.bias.data, 0.) 306 | xavier_uniform_(self.input_proj.weight.data) 307 | constant_(self.input_proj.bias.data, 0.) 308 | xavier_uniform_(self.output_proj.weight.data) 309 | constant_(self.output_proj.bias.data, 0.) 310 | 311 | def forward(self, input): 312 | """ 313 | :param query (N, H, W, C) 314 | :return output (N, H, W, C) 315 | """ 316 | N, H, W, _ = input.shape 317 | 318 | x = self.input_proj(input) 319 | x_proj = x 320 | dtype = x.dtype 321 | 322 | x1 = input.permute(0, 3, 1, 2) 323 | x1 = self.dw_conv(x1) 324 | offset = self.offset(x1) 325 | mask = self.mask(x1).reshape(N, H, W, self.group, -1) 326 | mask = F.softmax(mask, -1).reshape(N, H, W, -1).type(dtype) 327 | 328 | x = DCNv3Function.apply( 329 | x, offset, mask, 330 | self.kernel_size, self.kernel_size, 331 | self.stride, self.stride, 332 | self.pad, self.pad, 333 | self.dilation, self.dilation, 334 | self.group, self.group_channels, 335 | self.offset_scale, 336 | 256) 337 | 338 | if self.center_feature_scale: 339 | center_feature_scale = self.center_feature_scale_module( 340 | x1, self.center_feature_scale_proj_weight, self.center_feature_scale_proj_bias) 341 | # N, H, W, groups -> N, H, W, groups, 1 -> N, H, W, groups, _d_per_group -> N, H, W, channels 342 | center_feature_scale = center_feature_scale[..., None].repeat( 343 | 1, 1, 1, 1, self.channels // self.group).flatten(-2) 344 | x = x * (1 - center_feature_scale) + x_proj * center_feature_scale 345 | x = self.output_proj(x) 346 | 347 | return x 348 | -------------------------------------------------------------------------------- /models/ops_dcnv3/build/lib.linux-x86_64-3.7/DCNv3.cpython-37m-x86_64-linux-gnu.so: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IMOP-lab/DeMambaNet/4576667c8e388fa79e1e200468ae3669b34e8901/models/ops_dcnv3/build/lib.linux-x86_64-3.7/DCNv3.cpython-37m-x86_64-linux-gnu.so -------------------------------------------------------------------------------- /models/ops_dcnv3/build/lib.linux-x86_64-3.7/functions/__init__.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # InternImage 3 | # Copyright (c) 2022 OpenGVLab 4 | # Licensed under The MIT License [see LICENSE for details] 5 | # -------------------------------------------------------- 6 | 7 | from .dcnv3_func import DCNv3Function, dcnv3_core_pytorch 8 | -------------------------------------------------------------------------------- /models/ops_dcnv3/build/lib.linux-x86_64-3.7/functions/dcnv3_func.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # InternImage 3 | # Copyright (c) 2022 OpenGVLab 4 | # Licensed under The MIT License [see LICENSE for details] 5 | # -------------------------------------------------------- 6 | 7 | from __future__ import absolute_import 8 | from __future__ import print_function 9 | from __future__ import division 10 | 11 | import torch 12 | import torch.nn.functional as F 13 | from torch.autograd import Function 14 | from torch.autograd.function import once_differentiable 15 | from torch.cuda.amp import custom_bwd, custom_fwd 16 | import DCNv3 17 | 18 | 19 | class DCNv3Function(Function): 20 | @staticmethod 21 | @custom_fwd 22 | def forward( 23 | ctx, input, offset, mask, 24 | kernel_h, kernel_w, stride_h, stride_w, 25 | pad_h, pad_w, dilation_h, dilation_w, 26 | group, group_channels, offset_scale, im2col_step): 27 | ctx.kernel_h = kernel_h 28 | ctx.kernel_w = kernel_w 29 | ctx.stride_h = stride_h 30 | ctx.stride_w = stride_w 31 | ctx.pad_h = pad_h 32 | ctx.pad_w = pad_w 33 | ctx.dilation_h = dilation_h 34 | ctx.dilation_w = dilation_w 35 | ctx.group = group 36 | ctx.group_channels = group_channels 37 | ctx.offset_scale = offset_scale 38 | ctx.im2col_step = im2col_step 39 | output = DCNv3.dcnv3_forward( 40 | input, offset, mask, kernel_h, 41 | kernel_w, stride_h, stride_w, pad_h, 42 | pad_w, dilation_h, dilation_w, group, 43 | group_channels, offset_scale, ctx.im2col_step) 44 | ctx.save_for_backward(input, offset, mask) 45 | 46 | return output 47 | 48 | @staticmethod 49 | @once_differentiable 50 | @custom_bwd 51 | def backward(ctx, grad_output): 52 | input, offset, mask = ctx.saved_tensors 53 | grad_input, grad_offset, grad_mask = \ 54 | DCNv3.dcnv3_backward( 55 | input, offset, mask, ctx.kernel_h, 56 | ctx.kernel_w, ctx.stride_h, ctx.stride_w, ctx.pad_h, 57 | ctx.pad_w, ctx.dilation_h, ctx.dilation_w, ctx.group, 58 | ctx.group_channels, ctx.offset_scale, grad_output.contiguous(), ctx.im2col_step) 59 | 60 | return grad_input, grad_offset, grad_mask, \ 61 | None, None, None, None, None, None, None, None, None, None, None, None 62 | 63 | @staticmethod 64 | def symbolic(g, input, offset, mask, kernel_h, kernel_w, stride_h, 65 | stride_w, pad_h, pad_w, dilation_h, dilation_w, group, 66 | group_channels, offset_scale, im2col_step): 67 | """Symbolic function for mmdeploy::DCNv3. 68 | 69 | Returns: 70 | DCNv3 op for onnx. 71 | """ 72 | return g.op( 73 | 'mmdeploy::TRTDCNv3', 74 | input, 75 | offset, 76 | mask, 77 | kernel_h_i=int(kernel_h), 78 | kernel_w_i=int(kernel_w), 79 | stride_h_i=int(stride_h), 80 | stride_w_i=int(stride_w), 81 | pad_h_i=int(pad_h), 82 | pad_w_i=int(pad_w), 83 | dilation_h_i=int(dilation_h), 84 | dilation_w_i=int(dilation_w), 85 | group_i=int(group), 86 | group_channels_i=int(group_channels), 87 | offset_scale_f=float(offset_scale), 88 | im2col_step_i=int(im2col_step), 89 | ) 90 | 91 | def _get_reference_points(spatial_shapes, device, kernel_h, kernel_w, dilation_h, dilation_w, pad_h=0, pad_w=0, stride_h=1, stride_w=1): 92 | _, H_, W_, _ = spatial_shapes 93 | H_out = (H_ - (dilation_h * (kernel_h - 1) + 1)) // stride_h + 1 94 | W_out = (W_ - (dilation_w * (kernel_w - 1) + 1)) // stride_w + 1 95 | 96 | ref_y, ref_x = torch.meshgrid( 97 | torch.linspace( 98 | # pad_h + 0.5, 99 | # H_ - pad_h - 0.5, 100 | (dilation_h * (kernel_h - 1)) // 2 + 0.5, 101 | (dilation_h * (kernel_h - 1)) // 2 + 0.5 + (H_out - 1) * stride_h, 102 | H_out, 103 | dtype=torch.float32, 104 | device=device), 105 | torch.linspace( 106 | # pad_w + 0.5, 107 | # W_ - pad_w - 0.5, 108 | (dilation_w * (kernel_w - 1)) // 2 + 0.5, 109 | (dilation_w * (kernel_w - 1)) // 2 + 0.5 + (W_out - 1) * stride_w, 110 | W_out, 111 | dtype=torch.float32, 112 | device=device)) 113 | ref_y = ref_y.reshape(-1)[None] / H_ 114 | ref_x = ref_x.reshape(-1)[None] / W_ 115 | 116 | ref = torch.stack((ref_x, ref_y), -1).reshape( 117 | 1, H_out, W_out, 1, 2) 118 | 119 | return ref 120 | 121 | 122 | def _generate_dilation_grids(spatial_shapes, kernel_h, kernel_w, dilation_h, dilation_w, group, device): 123 | _, H_, W_, _ = spatial_shapes 124 | points_list = [] 125 | x, y = torch.meshgrid( 126 | torch.linspace( 127 | -((dilation_w * (kernel_w - 1)) // 2), 128 | -((dilation_w * (kernel_w - 1)) // 2) + 129 | (kernel_w - 1) * dilation_w, kernel_w, 130 | dtype=torch.float32, 131 | device=device), 132 | torch.linspace( 133 | -((dilation_h * (kernel_h - 1)) // 2), 134 | -((dilation_h * (kernel_h - 1)) // 2) + 135 | (kernel_h - 1) * dilation_h, kernel_h, 136 | dtype=torch.float32, 137 | device=device)) 138 | 139 | points_list.extend([x / W_, y / H_]) 140 | grid = torch.stack(points_list, -1).reshape(-1, 1, 2).\ 141 | repeat(1, group, 1).permute(1, 0, 2) 142 | grid = grid.reshape(1, 1, 1, group * kernel_h * kernel_w, 2) 143 | 144 | return grid 145 | 146 | 147 | def dcnv3_core_pytorch( 148 | input, offset, mask, kernel_h, 149 | kernel_w, stride_h, stride_w, pad_h, 150 | pad_w, dilation_h, dilation_w, group, 151 | group_channels, offset_scale): 152 | # for debug and test only, 153 | # need to use cuda version instead 154 | input = F.pad( 155 | input, 156 | [0, 0, pad_h, pad_h, pad_w, pad_w]) 157 | N_, H_in, W_in, _ = input.shape 158 | _, H_out, W_out, _ = offset.shape 159 | 160 | ref = _get_reference_points( 161 | input.shape, input.device, kernel_h, kernel_w, dilation_h, dilation_w, pad_h, pad_w, stride_h, stride_w) 162 | grid = _generate_dilation_grids( 163 | input.shape, kernel_h, kernel_w, dilation_h, dilation_w, group, input.device) 164 | spatial_norm = torch.tensor([W_in, H_in]).reshape(1, 1, 1, 2).\ 165 | repeat(1, 1, 1, group*kernel_h*kernel_w).to(input.device) 166 | 167 | sampling_locations = (ref + grid * offset_scale).repeat(N_, 1, 1, 1, 1).flatten(3, 4) + \ 168 | offset * offset_scale / spatial_norm 169 | 170 | P_ = kernel_h * kernel_w 171 | sampling_grids = 2 * sampling_locations - 1 172 | # N_, H_in, W_in, group*group_channels -> N_, H_in*W_in, group*group_channels -> N_, group*group_channels, H_in*W_in -> N_*group, group_channels, H_in, W_in 173 | input_ = input.view(N_, H_in*W_in, group*group_channels).transpose(1, 2).\ 174 | reshape(N_*group, group_channels, H_in, W_in) 175 | # N_, H_out, W_out, group*P_*2 -> N_, H_out*W_out, group, P_, 2 -> N_, group, H_out*W_out, P_, 2 -> N_*group, H_out*W_out, P_, 2 176 | sampling_grid_ = sampling_grids.view(N_, H_out*W_out, group, P_, 2).transpose(1, 2).\ 177 | flatten(0, 1) 178 | # N_*group, group_channels, H_out*W_out, P_ 179 | sampling_input_ = F.grid_sample( 180 | input_, sampling_grid_, mode='bilinear', padding_mode='zeros', align_corners=False) 181 | 182 | # (N_, H_out, W_out, group*P_) -> N_, H_out*W_out, group, P_ -> (N_, group, H_out*W_out, P_) -> (N_*group, 1, H_out*W_out, P_) 183 | mask = mask.view(N_, H_out*W_out, group, P_).transpose(1, 2).\ 184 | reshape(N_*group, 1, H_out*W_out, P_) 185 | output = (sampling_input_ * mask).sum(-1).view(N_, 186 | group*group_channels, H_out*W_out) 187 | 188 | return output.transpose(1, 2).reshape(N_, H_out, W_out, -1).contiguous() 189 | -------------------------------------------------------------------------------- /models/ops_dcnv3/build/lib.linux-x86_64-3.7/modules/__init__.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # InternImage 3 | # Copyright (c) 2022 OpenGVLab 4 | # Licensed under The MIT License [see LICENSE for details] 5 | # -------------------------------------------------------- 6 | 7 | from .dcnv3 import DCNv3, DCNv3_pytorch -------------------------------------------------------------------------------- /models/ops_dcnv3/build/lib.linux-x86_64-3.7/modules/dcnv3.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # InternImage 3 | # Copyright (c) 2022 OpenGVLab 4 | # Licensed under The MIT License [see LICENSE for details] 5 | # -------------------------------------------------------- 6 | 7 | from __future__ import absolute_import 8 | from __future__ import print_function 9 | from __future__ import division 10 | 11 | import warnings 12 | import torch 13 | from torch import nn 14 | import torch.nn.functional as F 15 | from torch.nn.init import xavier_uniform_, constant_ 16 | from ..functions import DCNv3Function, dcnv3_core_pytorch 17 | 18 | 19 | class to_channels_first(nn.Module): 20 | 21 | def __init__(self): 22 | super().__init__() 23 | 24 | def forward(self, x): 25 | return x.permute(0, 3, 1, 2) 26 | 27 | 28 | class to_channels_last(nn.Module): 29 | 30 | def __init__(self): 31 | super().__init__() 32 | 33 | def forward(self, x): 34 | return x.permute(0, 2, 3, 1) 35 | 36 | 37 | def build_norm_layer(dim, 38 | norm_layer, 39 | in_format='channels_last', 40 | out_format='channels_last', 41 | eps=1e-6): 42 | layers = [] 43 | if norm_layer == 'BN': 44 | if in_format == 'channels_last': 45 | layers.append(to_channels_first()) 46 | layers.append(nn.BatchNorm2d(dim)) 47 | if out_format == 'channels_last': 48 | layers.append(to_channels_last()) 49 | elif norm_layer == 'LN': 50 | if in_format == 'channels_first': 51 | layers.append(to_channels_last()) 52 | layers.append(nn.LayerNorm(dim, eps=eps)) 53 | if out_format == 'channels_first': 54 | layers.append(to_channels_first()) 55 | else: 56 | raise NotImplementedError( 57 | f'build_norm_layer does not support {norm_layer}') 58 | return nn.Sequential(*layers) 59 | 60 | 61 | def build_act_layer(act_layer): 62 | if act_layer == 'ReLU': 63 | return nn.ReLU(inplace=True) 64 | elif act_layer == 'SiLU': 65 | return nn.SiLU(inplace=True) 66 | elif act_layer == 'GELU': 67 | return nn.GELU() 68 | 69 | raise NotImplementedError(f'build_act_layer does not support {act_layer}') 70 | 71 | 72 | def _is_power_of_2(n): 73 | if (not isinstance(n, int)) or (n < 0): 74 | raise ValueError( 75 | "invalid input for _is_power_of_2: {} (type: {})".format(n, type(n))) 76 | 77 | return (n & (n - 1) == 0) and n != 0 78 | 79 | 80 | class CenterFeatureScaleModule(nn.Module): 81 | def forward(self, 82 | query, 83 | center_feature_scale_proj_weight, 84 | center_feature_scale_proj_bias): 85 | center_feature_scale = F.linear(query, 86 | weight=center_feature_scale_proj_weight, 87 | bias=center_feature_scale_proj_bias).sigmoid() 88 | return center_feature_scale 89 | 90 | 91 | class DCNv3_pytorch(nn.Module): 92 | def __init__( 93 | self, 94 | channels=64, 95 | kernel_size=3, 96 | dw_kernel_size=None, 97 | stride=1, 98 | pad=1, 99 | dilation=1, 100 | group=4, 101 | offset_scale=1.0, 102 | act_layer='GELU', 103 | norm_layer='LN', 104 | center_feature_scale=False): 105 | """ 106 | DCNv3 Module 107 | :param channels 108 | :param kernel_size 109 | :param stride 110 | :param pad 111 | :param dilation 112 | :param group 113 | :param offset_scale 114 | :param act_layer 115 | :param norm_layer 116 | """ 117 | super().__init__() 118 | if channels % group != 0: 119 | raise ValueError( 120 | f'channels must be divisible by group, but got {channels} and {group}') 121 | _d_per_group = channels // group 122 | dw_kernel_size = dw_kernel_size if dw_kernel_size is not None else kernel_size 123 | # you'd better set _d_per_group to a power of 2 which is more efficient in our CUDA implementation 124 | if not _is_power_of_2(_d_per_group): 125 | warnings.warn( 126 | "You'd better set channels in DCNv3 to make the dimension of each attention head a power of 2 " 127 | "which is more efficient in our CUDA implementation.") 128 | 129 | self.offset_scale = offset_scale 130 | self.channels = channels 131 | self.kernel_size = kernel_size 132 | self.dw_kernel_size = dw_kernel_size 133 | self.stride = stride 134 | self.dilation = dilation 135 | self.pad = pad 136 | self.group = group 137 | self.group_channels = channels // group 138 | self.offset_scale = offset_scale 139 | self.center_feature_scale = center_feature_scale 140 | 141 | self.dw_conv = nn.Sequential( 142 | nn.Conv2d( 143 | channels, 144 | channels, 145 | kernel_size=dw_kernel_size, 146 | stride=1, 147 | padding=(dw_kernel_size - 1) // 2, 148 | groups=channels), 149 | build_norm_layer( 150 | channels, 151 | norm_layer, 152 | 'channels_first', 153 | 'channels_last'), 154 | build_act_layer(act_layer)) 155 | self.offset = nn.Linear( 156 | channels, 157 | group * kernel_size * kernel_size * 2) 158 | self.mask = nn.Linear( 159 | channels, 160 | group * kernel_size * kernel_size) 161 | self.input_proj = nn.Linear(channels, channels) 162 | self.output_proj = nn.Linear(channels, channels) 163 | self._reset_parameters() 164 | 165 | if center_feature_scale: 166 | self.center_feature_scale_proj_weight = nn.Parameter( 167 | torch.zeros((group, channels), dtype=torch.float)) 168 | self.center_feature_scale_proj_bias = nn.Parameter( 169 | torch.tensor(0.0, dtype=torch.float).view((1,)).repeat(group, )) 170 | self.center_feature_scale_module = CenterFeatureScaleModule() 171 | 172 | def _reset_parameters(self): 173 | constant_(self.offset.weight.data, 0.) 174 | constant_(self.offset.bias.data, 0.) 175 | constant_(self.mask.weight.data, 0.) 176 | constant_(self.mask.bias.data, 0.) 177 | xavier_uniform_(self.input_proj.weight.data) 178 | constant_(self.input_proj.bias.data, 0.) 179 | xavier_uniform_(self.output_proj.weight.data) 180 | constant_(self.output_proj.bias.data, 0.) 181 | 182 | def forward(self, input): 183 | """ 184 | :param query (N, H, W, C) 185 | :return output (N, H, W, C) 186 | """ 187 | N, H, W, _ = input.shape 188 | 189 | x = self.input_proj(input) 190 | x_proj = x 191 | 192 | x1 = input.permute(0, 3, 1, 2) 193 | x1 = self.dw_conv(x1) 194 | offset = self.offset(x1) 195 | mask = self.mask(x1).reshape(N, H, W, self.group, -1) 196 | mask = F.softmax(mask, -1).reshape(N, H, W, -1) 197 | 198 | x = dcnv3_core_pytorch( 199 | x, offset, mask, 200 | self.kernel_size, self.kernel_size, 201 | self.stride, self.stride, 202 | self.pad, self.pad, 203 | self.dilation, self.dilation, 204 | self.group, self.group_channels, 205 | self.offset_scale) 206 | if self.center_feature_scale: 207 | center_feature_scale = self.center_feature_scale_module( 208 | x1, self.center_feature_scale_proj_weight, self.center_feature_scale_proj_bias) 209 | # N, H, W, groups -> N, H, W, groups, 1 -> N, H, W, groups, _d_per_group -> N, H, W, channels 210 | center_feature_scale = center_feature_scale[..., None].repeat( 211 | 1, 1, 1, 1, self.channels // self.group).flatten(-2) 212 | x = x * (1 - center_feature_scale) + x_proj * center_feature_scale 213 | x = self.output_proj(x) 214 | 215 | return x 216 | 217 | 218 | class DCNv3(nn.Module): 219 | def __init__( 220 | self, 221 | channels=64, 222 | kernel_size=3, 223 | dw_kernel_size=None, 224 | stride=1, 225 | pad=1, 226 | dilation=1, 227 | group=4, 228 | offset_scale=1.0, 229 | act_layer='GELU', 230 | norm_layer='LN', 231 | center_feature_scale=False): 232 | """ 233 | DCNv3 Module 234 | :param channels 235 | :param kernel_size 236 | :param stride 237 | :param pad 238 | :param dilation 239 | :param group 240 | :param offset_scale 241 | :param act_layer 242 | :param norm_layer 243 | """ 244 | super().__init__() 245 | if channels % group != 0: 246 | raise ValueError( 247 | f'channels must be divisible by group, but got {channels} and {group}') 248 | _d_per_group = channels // group 249 | dw_kernel_size = dw_kernel_size if dw_kernel_size is not None else kernel_size 250 | # you'd better set _d_per_group to a power of 2 which is more efficient in our CUDA implementation 251 | if not _is_power_of_2(_d_per_group): 252 | warnings.warn( 253 | "You'd better set channels in DCNv3 to make the dimension of each attention head a power of 2 " 254 | "which is more efficient in our CUDA implementation.") 255 | 256 | self.offset_scale = offset_scale 257 | self.channels = channels 258 | self.kernel_size = kernel_size 259 | self.dw_kernel_size = dw_kernel_size 260 | self.stride = stride 261 | self.dilation = dilation 262 | self.pad = pad 263 | self.group = group 264 | self.group_channels = channels // group 265 | self.offset_scale = offset_scale 266 | self.center_feature_scale = center_feature_scale 267 | 268 | self.dw_conv = nn.Sequential( 269 | nn.Conv2d( 270 | channels, 271 | channels, 272 | kernel_size=dw_kernel_size, 273 | stride=1, 274 | padding=(dw_kernel_size - 1) // 2, 275 | groups=channels), 276 | build_norm_layer( 277 | channels, 278 | norm_layer, 279 | 'channels_first', 280 | 'channels_last'), 281 | build_act_layer(act_layer)) 282 | self.offset = nn.Linear( 283 | channels, 284 | group * kernel_size * kernel_size * 2) 285 | self.mask = nn.Linear( 286 | channels, 287 | group * kernel_size * kernel_size) 288 | self.input_proj = nn.Linear(channels, channels) 289 | self.output_proj = nn.Linear(channels, channels) 290 | self._reset_parameters() 291 | 292 | if center_feature_scale: 293 | self.center_feature_scale_proj_weight = nn.Parameter( 294 | torch.zeros((group, channels), dtype=torch.float)) 295 | self.center_feature_scale_proj_bias = nn.Parameter( 296 | torch.tensor(0.0, dtype=torch.float).view((1,)).repeat(group, )) 297 | self.center_feature_scale_module = CenterFeatureScaleModule() 298 | 299 | def _reset_parameters(self): 300 | constant_(self.offset.weight.data, 0.) 301 | constant_(self.offset.bias.data, 0.) 302 | constant_(self.mask.weight.data, 0.) 303 | constant_(self.mask.bias.data, 0.) 304 | xavier_uniform_(self.input_proj.weight.data) 305 | constant_(self.input_proj.bias.data, 0.) 306 | xavier_uniform_(self.output_proj.weight.data) 307 | constant_(self.output_proj.bias.data, 0.) 308 | 309 | def forward(self, input): 310 | """ 311 | :param query (N, H, W, C) 312 | :return output (N, H, W, C) 313 | """ 314 | N, H, W, _ = input.shape 315 | 316 | x = self.input_proj(input) 317 | x_proj = x 318 | dtype = x.dtype 319 | 320 | x1 = input.permute(0, 3, 1, 2) 321 | x1 = self.dw_conv(x1) 322 | offset = self.offset(x1) 323 | mask = self.mask(x1).reshape(N, H, W, self.group, -1) 324 | mask = F.softmax(mask, -1).reshape(N, H, W, -1).type(dtype) 325 | 326 | x = DCNv3Function.apply( 327 | x, offset, mask, 328 | self.kernel_size, self.kernel_size, 329 | self.stride, self.stride, 330 | self.pad, self.pad, 331 | self.dilation, self.dilation, 332 | self.group, self.group_channels, 333 | self.offset_scale, 334 | 256) 335 | 336 | if self.center_feature_scale: 337 | center_feature_scale = self.center_feature_scale_module( 338 | x1, self.center_feature_scale_proj_weight, self.center_feature_scale_proj_bias) 339 | # N, H, W, groups -> N, H, W, groups, 1 -> N, H, W, groups, _d_per_group -> N, H, W, channels 340 | center_feature_scale = center_feature_scale[..., None].repeat( 341 | 1, 1, 1, 1, self.channels // self.group).flatten(-2) 342 | x = x * (1 - center_feature_scale) + x_proj * center_feature_scale 343 | x = self.output_proj(x) 344 | 345 | return x 346 | -------------------------------------------------------------------------------- /models/ops_dcnv3/build/temp.linux-x86_64-3.10/build.ninja: -------------------------------------------------------------------------------- 1 | ninja_required_version = 1.3 2 | cxx = c++ 3 | nvcc = /usr/local/cuda-11.7/bin/nvcc 4 | 5 | cflags = -pthread -B /home/zbf/anaconda3/envs/zbf/compiler_compat -Wno-unused-result -Wsign-compare -DNDEBUG -fwrapv -O2 -Wall -fPIC -O2 -isystem /home/zbf/anaconda3/envs/zbf/include -fPIC -O2 -isystem /home/zbf/anaconda3/envs/zbf/include -fPIC -DWITH_CUDA -I/home/zbf/Desktop/code/zbfbackbone/models/ops_dcnv3/src -I/home/zbf/anaconda3/envs/zbf/lib/python3.10/site-packages/torch/include -I/home/zbf/anaconda3/envs/zbf/lib/python3.10/site-packages/torch/include/torch/csrc/api/include -I/home/zbf/anaconda3/envs/zbf/lib/python3.10/site-packages/torch/include/TH -I/home/zbf/anaconda3/envs/zbf/lib/python3.10/site-packages/torch/include/THC -I/usr/local/cuda-11.7/include -I/home/zbf/anaconda3/envs/zbf/include/python3.10 -c 6 | post_cflags = -DTORCH_API_INCLUDE_EXTENSION_H '-DPYBIND11_COMPILER_TYPE="_gcc"' '-DPYBIND11_STDLIB="_libstdcpp"' '-DPYBIND11_BUILD_ABI="_cxxabi1011"' -DTORCH_EXTENSION_NAME=DCNv3 -D_GLIBCXX_USE_CXX11_ABI=0 -std=c++17 7 | cuda_cflags = -DWITH_CUDA -I/home/zbf/Desktop/code/zbfbackbone/models/ops_dcnv3/src -I/home/zbf/anaconda3/envs/zbf/lib/python3.10/site-packages/torch/include -I/home/zbf/anaconda3/envs/zbf/lib/python3.10/site-packages/torch/include/torch/csrc/api/include -I/home/zbf/anaconda3/envs/zbf/lib/python3.10/site-packages/torch/include/TH -I/home/zbf/anaconda3/envs/zbf/lib/python3.10/site-packages/torch/include/THC -I/usr/local/cuda-11.7/include -I/home/zbf/anaconda3/envs/zbf/include/python3.10 -c 8 | cuda_post_cflags = -D__CUDA_NO_HALF_OPERATORS__ -D__CUDA_NO_HALF_CONVERSIONS__ -D__CUDA_NO_BFLOAT16_CONVERSIONS__ -D__CUDA_NO_HALF2_OPERATORS__ --expt-relaxed-constexpr --compiler-options ''"'"'-fPIC'"'"'' -DTORCH_API_INCLUDE_EXTENSION_H '-DPYBIND11_COMPILER_TYPE="_gcc"' '-DPYBIND11_STDLIB="_libstdcpp"' '-DPYBIND11_BUILD_ABI="_cxxabi1011"' -DTORCH_EXTENSION_NAME=DCNv3 -D_GLIBCXX_USE_CXX11_ABI=0 -gencode=arch=compute_86,code=compute_86 -gencode=arch=compute_86,code=sm_86 -std=c++17 9 | cuda_dlink_post_cflags = 10 | ldflags = 11 | 12 | rule compile 13 | command = $cxx -MMD -MF $out.d $cflags -c $in -o $out $post_cflags 14 | depfile = $out.d 15 | deps = gcc 16 | 17 | rule cuda_compile 18 | depfile = $out.d 19 | deps = gcc 20 | command = $nvcc $cuda_cflags -c $in -o $out $cuda_post_cflags 21 | 22 | 23 | 24 | 25 | 26 | build /home/zbf/Desktop/code/zbfbackbone/models/ops_dcnv3/build/temp.linux-x86_64-3.10/home/zbf/Desktop/code/zbfbackbone/models/ops_dcnv3/src/cpu/dcnv3_cpu.o: compile /home/zbf/Desktop/code/zbfbackbone/models/ops_dcnv3/src/cpu/dcnv3_cpu.cpp 27 | build /home/zbf/Desktop/code/zbfbackbone/models/ops_dcnv3/build/temp.linux-x86_64-3.10/home/zbf/Desktop/code/zbfbackbone/models/ops_dcnv3/src/cuda/dcnv3_cuda.o: cuda_compile /home/zbf/Desktop/code/zbfbackbone/models/ops_dcnv3/src/cuda/dcnv3_cuda.cu 28 | build /home/zbf/Desktop/code/zbfbackbone/models/ops_dcnv3/build/temp.linux-x86_64-3.10/home/zbf/Desktop/code/zbfbackbone/models/ops_dcnv3/src/vision.o: compile /home/zbf/Desktop/code/zbfbackbone/models/ops_dcnv3/src/vision.cpp 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | -------------------------------------------------------------------------------- /models/ops_dcnv3/build/temp.linux-x86_64-3.10/home/zbf/Desktop/code/zbfbackbone/models/ops_dcnv3/src/cpu/dcnv3_cpu.o: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IMOP-lab/DeMambaNet/4576667c8e388fa79e1e200468ae3669b34e8901/models/ops_dcnv3/build/temp.linux-x86_64-3.10/home/zbf/Desktop/code/zbfbackbone/models/ops_dcnv3/src/cpu/dcnv3_cpu.o -------------------------------------------------------------------------------- /models/ops_dcnv3/build/temp.linux-x86_64-3.10/home/zbf/Desktop/code/zbfbackbone/models/ops_dcnv3/src/cuda/dcnv3_cuda.o: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IMOP-lab/DeMambaNet/4576667c8e388fa79e1e200468ae3669b34e8901/models/ops_dcnv3/build/temp.linux-x86_64-3.10/home/zbf/Desktop/code/zbfbackbone/models/ops_dcnv3/src/cuda/dcnv3_cuda.o -------------------------------------------------------------------------------- /models/ops_dcnv3/build/temp.linux-x86_64-3.10/home/zbf/Desktop/code/zbfbackbone/models/ops_dcnv3/src/vision.o: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IMOP-lab/DeMambaNet/4576667c8e388fa79e1e200468ae3669b34e8901/models/ops_dcnv3/build/temp.linux-x86_64-3.10/home/zbf/Desktop/code/zbfbackbone/models/ops_dcnv3/src/vision.o -------------------------------------------------------------------------------- /models/ops_dcnv3/build/temp.linux-x86_64-3.7/build.ninja: -------------------------------------------------------------------------------- 1 | ninja_required_version = 1.3 2 | cxx = c++ 3 | nvcc = /usr/local/cuda/bin/nvcc 4 | 5 | cflags = -pthread -B /home/zbf/anaconda3/envs/internimage/compiler_compat -Wl,--sysroot=/ -Wsign-compare -DNDEBUG -g -fwrapv -O3 -Wall -Wstrict-prototypes -fPIC -DWITH_CUDA -I/home/zbf/lab/remote_all/InternImage-master/segmentation/ops_dcnv3/src -I/home/zbf/anaconda3/envs/internimage/lib/python3.7/site-packages/torch/include -I/home/zbf/anaconda3/envs/internimage/lib/python3.7/site-packages/torch/include/torch/csrc/api/include -I/home/zbf/anaconda3/envs/internimage/lib/python3.7/site-packages/torch/include/TH -I/home/zbf/anaconda3/envs/internimage/lib/python3.7/site-packages/torch/include/THC -I/usr/local/cuda/include -I/home/zbf/anaconda3/envs/internimage/include/python3.7m -c 6 | post_cflags = -DTORCH_API_INCLUDE_EXTENSION_H '-DPYBIND11_COMPILER_TYPE="_gcc"' '-DPYBIND11_STDLIB="_libstdcpp"' '-DPYBIND11_BUILD_ABI="_cxxabi1011"' -DTORCH_EXTENSION_NAME=DCNv3 -D_GLIBCXX_USE_CXX11_ABI=0 -std=c++14 7 | cuda_cflags = -DWITH_CUDA -I/home/zbf/lab/remote_all/InternImage-master/segmentation/ops_dcnv3/src -I/home/zbf/anaconda3/envs/internimage/lib/python3.7/site-packages/torch/include -I/home/zbf/anaconda3/envs/internimage/lib/python3.7/site-packages/torch/include/torch/csrc/api/include -I/home/zbf/anaconda3/envs/internimage/lib/python3.7/site-packages/torch/include/TH -I/home/zbf/anaconda3/envs/internimage/lib/python3.7/site-packages/torch/include/THC -I/usr/local/cuda/include -I/home/zbf/anaconda3/envs/internimage/include/python3.7m -c 8 | cuda_post_cflags = -D__CUDA_NO_HALF_OPERATORS__ -D__CUDA_NO_HALF_CONVERSIONS__ -D__CUDA_NO_BFLOAT16_CONVERSIONS__ -D__CUDA_NO_HALF2_OPERATORS__ --expt-relaxed-constexpr --compiler-options ''"'"'-fPIC'"'"'' -DTORCH_API_INCLUDE_EXTENSION_H '-DPYBIND11_COMPILER_TYPE="_gcc"' '-DPYBIND11_STDLIB="_libstdcpp"' '-DPYBIND11_BUILD_ABI="_cxxabi1011"' -DTORCH_EXTENSION_NAME=DCNv3 -D_GLIBCXX_USE_CXX11_ABI=0 -gencode=arch=compute_86,code=compute_86 -gencode=arch=compute_86,code=sm_86 -std=c++14 9 | ldflags = 10 | 11 | rule compile 12 | command = $cxx -MMD -MF $out.d $cflags -c $in -o $out $post_cflags 13 | depfile = $out.d 14 | deps = gcc 15 | 16 | rule cuda_compile 17 | depfile = $out.d 18 | deps = gcc 19 | command = $nvcc $cuda_cflags -c $in -o $out $cuda_post_cflags 20 | 21 | 22 | 23 | build /home/zbf/lab/remote_all/InternImage-master/segmentation/ops_dcnv3/build/temp.linux-x86_64-3.7/home/zbf/lab/remote_all/InternImage-master/segmentation/ops_dcnv3/src/cpu/dcnv3_cpu.o: compile /home/zbf/lab/remote_all/InternImage-master/segmentation/ops_dcnv3/src/cpu/dcnv3_cpu.cpp 24 | build /home/zbf/lab/remote_all/InternImage-master/segmentation/ops_dcnv3/build/temp.linux-x86_64-3.7/home/zbf/lab/remote_all/InternImage-master/segmentation/ops_dcnv3/src/cuda/dcnv3_cuda.o: cuda_compile /home/zbf/lab/remote_all/InternImage-master/segmentation/ops_dcnv3/src/cuda/dcnv3_cuda.cu 25 | build /home/zbf/lab/remote_all/InternImage-master/segmentation/ops_dcnv3/build/temp.linux-x86_64-3.7/home/zbf/lab/remote_all/InternImage-master/segmentation/ops_dcnv3/src/vision.o: compile /home/zbf/lab/remote_all/InternImage-master/segmentation/ops_dcnv3/src/vision.cpp 26 | 27 | 28 | 29 | 30 | 31 | -------------------------------------------------------------------------------- /models/ops_dcnv3/build/temp.linux-x86_64-3.7/home/zbf/lab/remote_all/InternImage-master/segmentation/ops_dcnv3/src/cpu/dcnv3_cpu.o: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IMOP-lab/DeMambaNet/4576667c8e388fa79e1e200468ae3669b34e8901/models/ops_dcnv3/build/temp.linux-x86_64-3.7/home/zbf/lab/remote_all/InternImage-master/segmentation/ops_dcnv3/src/cpu/dcnv3_cpu.o -------------------------------------------------------------------------------- /models/ops_dcnv3/build/temp.linux-x86_64-3.7/home/zbf/lab/remote_all/InternImage-master/segmentation/ops_dcnv3/src/cuda/dcnv3_cuda.o: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IMOP-lab/DeMambaNet/4576667c8e388fa79e1e200468ae3669b34e8901/models/ops_dcnv3/build/temp.linux-x86_64-3.7/home/zbf/lab/remote_all/InternImage-master/segmentation/ops_dcnv3/src/cuda/dcnv3_cuda.o -------------------------------------------------------------------------------- /models/ops_dcnv3/build/temp.linux-x86_64-3.7/home/zbf/lab/remote_all/InternImage-master/segmentation/ops_dcnv3/src/vision.o: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IMOP-lab/DeMambaNet/4576667c8e388fa79e1e200468ae3669b34e8901/models/ops_dcnv3/build/temp.linux-x86_64-3.7/home/zbf/lab/remote_all/InternImage-master/segmentation/ops_dcnv3/src/vision.o -------------------------------------------------------------------------------- /models/ops_dcnv3/dist/DCNv3-1.0-py3.10-linux-x86_64.egg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IMOP-lab/DeMambaNet/4576667c8e388fa79e1e200468ae3669b34e8901/models/ops_dcnv3/dist/DCNv3-1.0-py3.10-linux-x86_64.egg -------------------------------------------------------------------------------- /models/ops_dcnv3/dist/DCNv3-1.0-py3.7-linux-x86_64.egg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IMOP-lab/DeMambaNet/4576667c8e388fa79e1e200468ae3669b34e8901/models/ops_dcnv3/dist/DCNv3-1.0-py3.7-linux-x86_64.egg -------------------------------------------------------------------------------- /models/ops_dcnv3/functions/__init__.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # InternImage 3 | # Copyright (c) 2022 OpenGVLab 4 | # Licensed under The MIT License [see LICENSE for details] 5 | # -------------------------------------------------------- 6 | 7 | from .dcnv3_func import DCNv3Function, dcnv3_core_pytorch 8 | -------------------------------------------------------------------------------- /models/ops_dcnv3/functions/__pycache__/__init__.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IMOP-lab/DeMambaNet/4576667c8e388fa79e1e200468ae3669b34e8901/models/ops_dcnv3/functions/__pycache__/__init__.cpython-310.pyc -------------------------------------------------------------------------------- /models/ops_dcnv3/functions/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IMOP-lab/DeMambaNet/4576667c8e388fa79e1e200468ae3669b34e8901/models/ops_dcnv3/functions/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /models/ops_dcnv3/functions/__pycache__/dcnv3_func.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IMOP-lab/DeMambaNet/4576667c8e388fa79e1e200468ae3669b34e8901/models/ops_dcnv3/functions/__pycache__/dcnv3_func.cpython-310.pyc -------------------------------------------------------------------------------- /models/ops_dcnv3/functions/__pycache__/dcnv3_func.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IMOP-lab/DeMambaNet/4576667c8e388fa79e1e200468ae3669b34e8901/models/ops_dcnv3/functions/__pycache__/dcnv3_func.cpython-37.pyc -------------------------------------------------------------------------------- /models/ops_dcnv3/functions/dcnv3_func.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # InternImage 3 | # Copyright (c) 2022 OpenGVLab 4 | # Licensed under The MIT License [see LICENSE for details] 5 | # -------------------------------------------------------- 6 | 7 | from __future__ import absolute_import 8 | from __future__ import print_function 9 | from __future__ import division 10 | 11 | import torch 12 | import torch.nn.functional as F 13 | from torch.autograd import Function 14 | from torch.autograd.function import once_differentiable 15 | from torch.cuda.amp import custom_bwd, custom_fwd 16 | import DCNv3 17 | 18 | 19 | class DCNv3Function(Function): 20 | @staticmethod 21 | @custom_fwd 22 | def forward( 23 | ctx, input, offset, mask, 24 | kernel_h, kernel_w, stride_h, stride_w, 25 | pad_h, pad_w, dilation_h, dilation_w, 26 | group, group_channels, offset_scale, im2col_step): 27 | ctx.kernel_h = kernel_h 28 | ctx.kernel_w = kernel_w 29 | ctx.stride_h = stride_h 30 | ctx.stride_w = stride_w 31 | ctx.pad_h = pad_h 32 | ctx.pad_w = pad_w 33 | ctx.dilation_h = dilation_h 34 | ctx.dilation_w = dilation_w 35 | ctx.group = group 36 | ctx.group_channels = group_channels 37 | ctx.offset_scale = offset_scale 38 | ctx.im2col_step = im2col_step 39 | output = DCNv3.dcnv3_forward( 40 | input, offset, mask, kernel_h, 41 | kernel_w, stride_h, stride_w, pad_h, 42 | pad_w, dilation_h, dilation_w, group, 43 | group_channels, offset_scale, ctx.im2col_step) 44 | ctx.save_for_backward(input, offset, mask) 45 | 46 | return output 47 | 48 | @staticmethod 49 | @once_differentiable 50 | @custom_bwd 51 | def backward(ctx, grad_output): 52 | input, offset, mask = ctx.saved_tensors 53 | grad_input, grad_offset, grad_mask = \ 54 | DCNv3.dcnv3_backward( 55 | input, offset, mask, ctx.kernel_h, 56 | ctx.kernel_w, ctx.stride_h, ctx.stride_w, ctx.pad_h, 57 | ctx.pad_w, ctx.dilation_h, ctx.dilation_w, ctx.group, 58 | ctx.group_channels, ctx.offset_scale, grad_output.contiguous(), ctx.im2col_step) 59 | 60 | return grad_input, grad_offset, grad_mask, \ 61 | None, None, None, None, None, None, None, None, None, None, None, None 62 | 63 | @staticmethod 64 | def symbolic(g, input, offset, mask, kernel_h, kernel_w, stride_h, 65 | stride_w, pad_h, pad_w, dilation_h, dilation_w, group, 66 | group_channels, offset_scale, im2col_step): 67 | """Symbolic function for mmdeploy::DCNv3. 68 | 69 | Returns: 70 | DCNv3 op for onnx. 71 | """ 72 | return g.op( 73 | 'mmdeploy::TRTDCNv3', 74 | input, 75 | offset, 76 | mask, 77 | kernel_h_i=int(kernel_h), 78 | kernel_w_i=int(kernel_w), 79 | stride_h_i=int(stride_h), 80 | stride_w_i=int(stride_w), 81 | pad_h_i=int(pad_h), 82 | pad_w_i=int(pad_w), 83 | dilation_h_i=int(dilation_h), 84 | dilation_w_i=int(dilation_w), 85 | group_i=int(group), 86 | group_channels_i=int(group_channels), 87 | offset_scale_f=float(offset_scale), 88 | im2col_step_i=int(im2col_step), 89 | ) 90 | 91 | def _get_reference_points(spatial_shapes, device, kernel_h, kernel_w, dilation_h, dilation_w, pad_h=0, pad_w=0, stride_h=1, stride_w=1): 92 | _, H_, W_, _ = spatial_shapes 93 | H_out = (H_ - (dilation_h * (kernel_h - 1) + 1)) // stride_h + 1 94 | W_out = (W_ - (dilation_w * (kernel_w - 1) + 1)) // stride_w + 1 95 | 96 | ref_y, ref_x = torch.meshgrid( 97 | torch.linspace( 98 | # pad_h + 0.5, 99 | # H_ - pad_h - 0.5, 100 | (dilation_h * (kernel_h - 1)) // 2 + 0.5, 101 | (dilation_h * (kernel_h - 1)) // 2 + 0.5 + (H_out - 1) * stride_h, 102 | H_out, 103 | dtype=torch.float32, 104 | device=device), 105 | torch.linspace( 106 | # pad_w + 0.5, 107 | # W_ - pad_w - 0.5, 108 | (dilation_w * (kernel_w - 1)) // 2 + 0.5, 109 | (dilation_w * (kernel_w - 1)) // 2 + 0.5 + (W_out - 1) * stride_w, 110 | W_out, 111 | dtype=torch.float32, 112 | device=device)) 113 | ref_y = ref_y.reshape(-1)[None] / H_ 114 | ref_x = ref_x.reshape(-1)[None] / W_ 115 | 116 | ref = torch.stack((ref_x, ref_y), -1).reshape( 117 | 1, H_out, W_out, 1, 2) 118 | 119 | return ref 120 | 121 | 122 | def _generate_dilation_grids(spatial_shapes, kernel_h, kernel_w, dilation_h, dilation_w, group, device): 123 | _, H_, W_, _ = spatial_shapes 124 | points_list = [] 125 | x, y = torch.meshgrid( 126 | torch.linspace( 127 | -((dilation_w * (kernel_w - 1)) // 2), 128 | -((dilation_w * (kernel_w - 1)) // 2) + 129 | (kernel_w - 1) * dilation_w, kernel_w, 130 | dtype=torch.float32, 131 | device=device), 132 | torch.linspace( 133 | -((dilation_h * (kernel_h - 1)) // 2), 134 | -((dilation_h * (kernel_h - 1)) // 2) + 135 | (kernel_h - 1) * dilation_h, kernel_h, 136 | dtype=torch.float32, 137 | device=device)) 138 | 139 | points_list.extend([x / W_, y / H_]) 140 | grid = torch.stack(points_list, -1).reshape(-1, 1, 2).\ 141 | repeat(1, group, 1).permute(1, 0, 2) 142 | grid = grid.reshape(1, 1, 1, group * kernel_h * kernel_w, 2) 143 | 144 | return grid 145 | 146 | 147 | def dcnv3_core_pytorch( 148 | input, offset, mask, kernel_h, 149 | kernel_w, stride_h, stride_w, pad_h, 150 | pad_w, dilation_h, dilation_w, group, 151 | group_channels, offset_scale): 152 | # for debug and test only, 153 | # need to use cuda version instead 154 | input = F.pad( 155 | input, 156 | [0, 0, pad_h, pad_h, pad_w, pad_w]) 157 | N_, H_in, W_in, _ = input.shape 158 | _, H_out, W_out, _ = offset.shape 159 | 160 | ref = _get_reference_points( 161 | input.shape, input.device, kernel_h, kernel_w, dilation_h, dilation_w, pad_h, pad_w, stride_h, stride_w) 162 | grid = _generate_dilation_grids( 163 | input.shape, kernel_h, kernel_w, dilation_h, dilation_w, group, input.device) 164 | spatial_norm = torch.tensor([W_in, H_in]).reshape(1, 1, 1, 2).\ 165 | repeat(1, 1, 1, group*kernel_h*kernel_w).to(input.device) 166 | 167 | sampling_locations = (ref + grid * offset_scale).repeat(N_, 1, 1, 1, 1).flatten(3, 4) + \ 168 | offset * offset_scale / spatial_norm 169 | 170 | P_ = kernel_h * kernel_w 171 | sampling_grids = 2 * sampling_locations - 1 172 | # N_, H_in, W_in, group*group_channels -> N_, H_in*W_in, group*group_channels -> N_, group*group_channels, H_in*W_in -> N_*group, group_channels, H_in, W_in 173 | input_ = input.view(N_, H_in*W_in, group*group_channels).transpose(1, 2).\ 174 | reshape(N_*group, group_channels, H_in, W_in) 175 | # N_, H_out, W_out, group*P_*2 -> N_, H_out*W_out, group, P_, 2 -> N_, group, H_out*W_out, P_, 2 -> N_*group, H_out*W_out, P_, 2 176 | sampling_grid_ = sampling_grids.view(N_, H_out*W_out, group, P_, 2).transpose(1, 2).\ 177 | flatten(0, 1) 178 | # N_*group, group_channels, H_out*W_out, P_ 179 | sampling_input_ = F.grid_sample( 180 | input_, sampling_grid_, mode='bilinear', padding_mode='zeros', align_corners=False) 181 | 182 | # (N_, H_out, W_out, group*P_) -> N_, H_out*W_out, group, P_ -> (N_, group, H_out*W_out, P_) -> (N_*group, 1, H_out*W_out, P_) 183 | mask = mask.view(N_, H_out*W_out, group, P_).transpose(1, 2).\ 184 | reshape(N_*group, 1, H_out*W_out, P_) 185 | output = (sampling_input_ * mask).sum(-1).view(N_, 186 | group*group_channels, H_out*W_out) 187 | 188 | return output.transpose(1, 2).reshape(N_, H_out, W_out, -1).contiguous() 189 | -------------------------------------------------------------------------------- /models/ops_dcnv3/make.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | # -------------------------------------------------------- 3 | # InternImage 4 | # Copyright (c) 2022 OpenGVLab 5 | # Licensed under The MIT License [see LICENSE for details] 6 | # -------------------------------------------------------- 7 | 8 | python setup.py build install 9 | -------------------------------------------------------------------------------- /models/ops_dcnv3/modules/__init__.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # InternImage 3 | # Copyright (c) 2022 OpenGVLab 4 | # Licensed under The MIT License [see LICENSE for details] 5 | # -------------------------------------------------------- 6 | 7 | from .dcnv3 import DCNv3, DCNv3_pytorch -------------------------------------------------------------------------------- /models/ops_dcnv3/modules/__pycache__/__init__.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IMOP-lab/DeMambaNet/4576667c8e388fa79e1e200468ae3669b34e8901/models/ops_dcnv3/modules/__pycache__/__init__.cpython-310.pyc -------------------------------------------------------------------------------- /models/ops_dcnv3/modules/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IMOP-lab/DeMambaNet/4576667c8e388fa79e1e200468ae3669b34e8901/models/ops_dcnv3/modules/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /models/ops_dcnv3/modules/__pycache__/dcnv3.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IMOP-lab/DeMambaNet/4576667c8e388fa79e1e200468ae3669b34e8901/models/ops_dcnv3/modules/__pycache__/dcnv3.cpython-310.pyc -------------------------------------------------------------------------------- /models/ops_dcnv3/modules/__pycache__/dcnv3.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IMOP-lab/DeMambaNet/4576667c8e388fa79e1e200468ae3669b34e8901/models/ops_dcnv3/modules/__pycache__/dcnv3.cpython-37.pyc -------------------------------------------------------------------------------- /models/ops_dcnv3/modules/dcnv3.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # InternImage 3 | # Copyright (c) 2022 OpenGVLab 4 | # Licensed under The MIT License [see LICENSE for details] 5 | # -------------------------------------------------------- 6 | 7 | from __future__ import absolute_import 8 | from __future__ import print_function 9 | from __future__ import division 10 | 11 | import warnings 12 | import torch 13 | from torch import nn 14 | import torch.nn.functional as F 15 | from torch.nn.init import xavier_uniform_, constant_ 16 | from ..functions import DCNv3Function, dcnv3_core_pytorch 17 | 18 | 19 | class to_channels_first(nn.Module): 20 | 21 | def __init__(self): 22 | super().__init__() 23 | 24 | def forward(self, x): 25 | return x.permute(0, 3, 1, 2) 26 | 27 | 28 | class to_channels_last(nn.Module): 29 | 30 | def __init__(self): 31 | super().__init__() 32 | 33 | def forward(self, x): 34 | return x.permute(0, 2, 3, 1) 35 | 36 | 37 | def build_norm_layer(dim, 38 | norm_layer, 39 | in_format='channels_last', 40 | out_format='channels_last', 41 | eps=1e-6): 42 | layers = [] 43 | if norm_layer == 'BN': 44 | if in_format == 'channels_last': 45 | layers.append(to_channels_first()) 46 | layers.append(nn.BatchNorm2d(dim)) 47 | if out_format == 'channels_last': 48 | layers.append(to_channels_last()) 49 | elif norm_layer == 'LN': 50 | if in_format == 'channels_first': 51 | layers.append(to_channels_last()) 52 | layers.append(nn.LayerNorm(dim, eps=eps)) 53 | if out_format == 'channels_first': 54 | layers.append(to_channels_first()) 55 | else: 56 | raise NotImplementedError( 57 | f'build_norm_layer does not support {norm_layer}') 58 | return nn.Sequential(*layers) 59 | 60 | 61 | def build_act_layer(act_layer): 62 | if act_layer == 'ReLU': 63 | return nn.ReLU(inplace=True) 64 | elif act_layer == 'SiLU': 65 | return nn.SiLU(inplace=True) 66 | elif act_layer == 'GELU': 67 | return nn.GELU() 68 | 69 | raise NotImplementedError(f'build_act_layer does not support {act_layer}') 70 | 71 | 72 | def _is_power_of_2(n): 73 | if (not isinstance(n, int)) or (n < 0): 74 | raise ValueError( 75 | "invalid input for _is_power_of_2: {} (type: {})".format(n, type(n))) 76 | 77 | return (n & (n - 1) == 0) and n != 0 78 | 79 | 80 | class CenterFeatureScaleModule(nn.Module): 81 | def forward(self, 82 | query, 83 | center_feature_scale_proj_weight, 84 | center_feature_scale_proj_bias): 85 | center_feature_scale = F.linear(query, 86 | weight=center_feature_scale_proj_weight, 87 | bias=center_feature_scale_proj_bias).sigmoid() 88 | return center_feature_scale 89 | 90 | 91 | class DCNv3_pytorch(nn.Module): 92 | def __init__( 93 | self, 94 | channels=64, 95 | kernel_size=3, 96 | dw_kernel_size=None, 97 | stride=1, 98 | pad=1, 99 | dilation=1, 100 | group=4, 101 | offset_scale=1.0, 102 | act_layer='GELU', 103 | norm_layer='LN', 104 | center_feature_scale=False): 105 | """ 106 | DCNv3 Module 107 | :param channels 108 | :param kernel_size 109 | :param stride 110 | :param pad 111 | :param dilation 112 | :param group 113 | :param offset_scale 114 | :param act_layer 115 | :param norm_layer 116 | """ 117 | super().__init__() 118 | if channels % group != 0: 119 | raise ValueError( 120 | f'channels must be divisible by group, but got {channels} and {group}') 121 | _d_per_group = channels // group 122 | dw_kernel_size = dw_kernel_size if dw_kernel_size is not None else kernel_size 123 | # you'd better set _d_per_group to a power of 2 which is more efficient in our CUDA implementation 124 | if not _is_power_of_2(_d_per_group): 125 | warnings.warn( 126 | "You'd better set channels in DCNv3 to make the dimension of each attention head a power of 2 " 127 | "which is more efficient in our CUDA implementation.") 128 | 129 | self.offset_scale = offset_scale 130 | self.channels = channels 131 | self.kernel_size = kernel_size 132 | self.dw_kernel_size = dw_kernel_size 133 | self.stride = stride 134 | self.dilation = dilation 135 | self.pad = pad 136 | self.group = group 137 | self.group_channels = channels // group 138 | self.offset_scale = offset_scale 139 | self.center_feature_scale = center_feature_scale 140 | 141 | self.dw_conv = nn.Sequential( 142 | nn.Conv2d( 143 | channels, 144 | channels, 145 | kernel_size=dw_kernel_size, 146 | stride=1, 147 | padding=(dw_kernel_size - 1) // 2, 148 | groups=channels), 149 | build_norm_layer( 150 | channels, 151 | norm_layer, 152 | 'channels_first', 153 | 'channels_last'), 154 | build_act_layer(act_layer)) 155 | self.offset = nn.Linear( 156 | channels, 157 | group * kernel_size * kernel_size * 2) 158 | self.mask = nn.Linear( 159 | channels, 160 | group * kernel_size * kernel_size) 161 | self.input_proj = nn.Linear(channels, channels) 162 | self.output_proj = nn.Linear(channels, channels) 163 | self._reset_parameters() 164 | 165 | if center_feature_scale: 166 | self.center_feature_scale_proj_weight = nn.Parameter( 167 | torch.zeros((group, channels), dtype=torch.float)) 168 | self.center_feature_scale_proj_bias = nn.Parameter( 169 | torch.tensor(0.0, dtype=torch.float).view((1,)).repeat(group, )) 170 | self.center_feature_scale_module = CenterFeatureScaleModule() 171 | 172 | def _reset_parameters(self): 173 | constant_(self.offset.weight.data, 0.) 174 | constant_(self.offset.bias.data, 0.) 175 | constant_(self.mask.weight.data, 0.) 176 | constant_(self.mask.bias.data, 0.) 177 | xavier_uniform_(self.input_proj.weight.data) 178 | constant_(self.input_proj.bias.data, 0.) 179 | xavier_uniform_(self.output_proj.weight.data) 180 | constant_(self.output_proj.bias.data, 0.) 181 | 182 | def forward(self, input): 183 | """ 184 | :param query (N, H, W, C) 185 | :return output (N, H, W, C) 186 | """ 187 | N, H, W, _ = input.shape 188 | 189 | x = self.input_proj(input) 190 | x_proj = x 191 | 192 | x1 = input.permute(0, 3, 1, 2) 193 | x1 = self.dw_conv(x1) 194 | offset = self.offset(x1) 195 | mask = self.mask(x1).reshape(N, H, W, self.group, -1) 196 | mask = F.softmax(mask, -1).reshape(N, H, W, -1) 197 | 198 | x = dcnv3_core_pytorch( 199 | x, offset, mask, 200 | self.kernel_size, self.kernel_size, 201 | self.stride, self.stride, 202 | self.pad, self.pad, 203 | self.dilation, self.dilation, 204 | self.group, self.group_channels, 205 | self.offset_scale) 206 | if self.center_feature_scale: 207 | center_feature_scale = self.center_feature_scale_module( 208 | x1, self.center_feature_scale_proj_weight, self.center_feature_scale_proj_bias) 209 | # N, H, W, groups -> N, H, W, groups, 1 -> N, H, W, groups, _d_per_group -> N, H, W, channels 210 | center_feature_scale = center_feature_scale[..., None].repeat( 211 | 1, 1, 1, 1, self.channels // self.group).flatten(-2) 212 | x = x * (1 - center_feature_scale) + x_proj * center_feature_scale 213 | x = self.output_proj(x) 214 | 215 | return x 216 | 217 | 218 | class DCNv3(nn.Module): 219 | def __init__( 220 | self, 221 | channels=64, 222 | kernel_size=3, 223 | dw_kernel_size=None, 224 | stride=1, 225 | pad=1, 226 | dilation=1, 227 | group=4, 228 | offset_scale=1.0, 229 | act_layer='GELU', 230 | norm_layer='LN', 231 | center_feature_scale=False): 232 | """ 233 | DCNv3 Module 234 | :param channels 235 | :param kernel_size 236 | :param stride 237 | :param pad 238 | :param dilation 239 | :param group 240 | :param offset_scale 241 | :param act_layer 242 | :param norm_layer 243 | """ 244 | super().__init__() 245 | 246 | # if channels % group != 0: 247 | # raise ValueError( 248 | # f'channels must be divisible by group, but got {channels} and {group}') 249 | 250 | _d_per_group = channels // group 251 | dw_kernel_size = dw_kernel_size if dw_kernel_size is not None else kernel_size 252 | # you'd better set _d_per_group to a power of 2 which is more efficient in our CUDA implementation 253 | if not _is_power_of_2(_d_per_group): 254 | warnings.warn( 255 | "You'd better set channels in DCNv3 to make the dimension of each attention head a power of 2 " 256 | "which is more efficient in our CUDA implementation.") 257 | 258 | self.offset_scale = offset_scale 259 | self.channels = channels 260 | self.kernel_size = kernel_size 261 | self.dw_kernel_size = dw_kernel_size 262 | self.stride = stride 263 | self.dilation = dilation 264 | self.pad = pad 265 | self.group = group 266 | self.group_channels = channels // group 267 | self.offset_scale = offset_scale 268 | self.center_feature_scale = center_feature_scale 269 | 270 | self.dw_conv = nn.Sequential( 271 | nn.Conv2d( 272 | channels, 273 | channels, 274 | kernel_size=dw_kernel_size, 275 | stride=1, 276 | padding=(dw_kernel_size - 1) // 2, 277 | groups=channels), 278 | build_norm_layer( 279 | channels, 280 | norm_layer, 281 | 'channels_first', 282 | 'channels_last'), 283 | build_act_layer(act_layer)) 284 | self.offset = nn.Linear( 285 | channels, 286 | group * kernel_size * kernel_size * 2) 287 | self.mask = nn.Linear( 288 | channels, 289 | group * kernel_size * kernel_size) 290 | self.input_proj = nn.Linear(channels, channels) 291 | self.output_proj = nn.Linear(channels, channels) 292 | self._reset_parameters() 293 | 294 | if center_feature_scale: 295 | self.center_feature_scale_proj_weight = nn.Parameter( 296 | torch.zeros((group, channels), dtype=torch.float)) 297 | self.center_feature_scale_proj_bias = nn.Parameter( 298 | torch.tensor(0.0, dtype=torch.float).view((1,)).repeat(group, )) 299 | self.center_feature_scale_module = CenterFeatureScaleModule() 300 | 301 | def _reset_parameters(self): 302 | constant_(self.offset.weight.data, 0.) 303 | constant_(self.offset.bias.data, 0.) 304 | constant_(self.mask.weight.data, 0.) 305 | constant_(self.mask.bias.data, 0.) 306 | xavier_uniform_(self.input_proj.weight.data) 307 | constant_(self.input_proj.bias.data, 0.) 308 | xavier_uniform_(self.output_proj.weight.data) 309 | constant_(self.output_proj.bias.data, 0.) 310 | 311 | def forward(self, input): 312 | """ 313 | :param query (N, H, W, C) 314 | :return output (N, H, W, C) 315 | """ 316 | N, H, W, _ = input.shape 317 | 318 | x = self.input_proj(input) 319 | x_proj = x 320 | dtype = x.dtype 321 | 322 | x1 = input.permute(0, 3, 1, 2) 323 | x1 = self.dw_conv(x1) 324 | offset = self.offset(x1) 325 | mask = self.mask(x1).reshape(N, H, W, self.group, -1) 326 | mask = F.softmax(mask, -1).reshape(N, H, W, -1).type(dtype) 327 | 328 | x = DCNv3Function.apply( 329 | x, offset, mask, 330 | self.kernel_size, self.kernel_size, 331 | self.stride, self.stride, 332 | self.pad, self.pad, 333 | self.dilation, self.dilation, 334 | self.group, self.group_channels, 335 | self.offset_scale, 336 | 256) 337 | 338 | if self.center_feature_scale: 339 | center_feature_scale = self.center_feature_scale_module( 340 | x1, self.center_feature_scale_proj_weight, self.center_feature_scale_proj_bias) 341 | # N, H, W, groups -> N, H, W, groups, 1 -> N, H, W, groups, _d_per_group -> N, H, W, channels 342 | center_feature_scale = center_feature_scale[..., None].repeat( 343 | 1, 1, 1, 1, self.channels // self.group).flatten(-2) 344 | x = x * (1 - center_feature_scale) + x_proj * center_feature_scale 345 | x = self.output_proj(x) 346 | 347 | return x 348 | -------------------------------------------------------------------------------- /models/ops_dcnv3/setup.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # InternImage 3 | # Copyright (c) 2022 OpenGVLab 4 | # Licensed under The MIT License [see LICENSE for details] 5 | # -------------------------------------------------------- 6 | 7 | import os 8 | import glob 9 | 10 | import torch 11 | 12 | from torch.utils.cpp_extension import CUDA_HOME 13 | from torch.utils.cpp_extension import CppExtension 14 | from torch.utils.cpp_extension import CUDAExtension 15 | 16 | from setuptools import find_packages 17 | from setuptools import setup 18 | 19 | requirements = ["torch", "torchvision"] 20 | 21 | 22 | def get_extensions(): 23 | this_dir = os.path.dirname(os.path.abspath(__file__)) 24 | extensions_dir = os.path.join(this_dir, "src") 25 | 26 | main_file = glob.glob(os.path.join(extensions_dir, "*.cpp")) 27 | source_cpu = glob.glob(os.path.join(extensions_dir, "cpu", "*.cpp")) 28 | source_cuda = glob.glob(os.path.join(extensions_dir, "cuda", "*.cu")) 29 | 30 | sources = main_file + source_cpu 31 | extension = CppExtension 32 | extra_compile_args = {"cxx": []} 33 | define_macros = [] 34 | 35 | if torch.cuda.is_available() and CUDA_HOME is not None: 36 | extension = CUDAExtension 37 | sources += source_cuda 38 | define_macros += [("WITH_CUDA", None)] 39 | extra_compile_args["nvcc"] = [ 40 | # "-DCUDA_HAS_FP16=1", 41 | # "-D__CUDA_NO_HALF_OPERATORS__", 42 | # "-D__CUDA_NO_HALF_CONVERSIONS__", 43 | # "-D__CUDA_NO_HALF2_OPERATORS__", 44 | ] 45 | else: 46 | raise NotImplementedError('Cuda is not availabel') 47 | 48 | sources = [os.path.join(extensions_dir, s) for s in sources] 49 | include_dirs = [extensions_dir] 50 | ext_modules = [ 51 | extension( 52 | "DCNv3", 53 | sources, 54 | include_dirs=include_dirs, 55 | define_macros=define_macros, 56 | extra_compile_args=extra_compile_args, 57 | ) 58 | ] 59 | return ext_modules 60 | 61 | 62 | setup( 63 | name="DCNv3", 64 | version="1.0", 65 | author="InternImage", 66 | url="https://github.com/OpenGVLab/InternImage", 67 | description= 68 | "PyTorch Wrapper for CUDA Functions of DCNv3", 69 | packages=find_packages(exclude=( 70 | "configs", 71 | "tests", 72 | )), 73 | ext_modules=get_extensions(), 74 | cmdclass={"build_ext": torch.utils.cpp_extension.BuildExtension}, 75 | ) 76 | -------------------------------------------------------------------------------- /models/ops_dcnv3/src/cpu/dcnv3_cpu.cpp: -------------------------------------------------------------------------------- 1 | /*! 2 | ************************************************************************************************** 3 | * InternImage 4 | * Copyright (c) 2022 OpenGVLab 5 | * Licensed under The MIT License [see LICENSE for details] 6 | ************************************************************************************************** 7 | * Modified from 8 | *https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 9 | ************************************************************************************************** 10 | */ 11 | 12 | #include 13 | 14 | #include 15 | #include 16 | 17 | at::Tensor dcnv3_cpu_forward(const at::Tensor &input, const at::Tensor &offset, 18 | const at::Tensor &mask, const int kernel_h, 19 | const int kernel_w, const int stride_h, 20 | const int stride_w, const int pad_h, 21 | const int pad_w, const int dilation_h, 22 | const int dilation_w, const int group, 23 | const int group_channels, const float offset_scale, 24 | const int im2col_step) { 25 | AT_ERROR("Not implement on cpu"); 26 | } 27 | 28 | std::vector 29 | dcnv3_cpu_backward(const at::Tensor &input, const at::Tensor &offset, 30 | const at::Tensor &mask, const int kernel_h, 31 | const int kernel_w, const int stride_h, const int stride_w, 32 | const int pad_h, const int pad_w, const int dilation_h, 33 | const int dilation_w, const int group, 34 | const int group_channels, const float offset_scale, 35 | const at::Tensor &grad_output, const int im2col_step) { 36 | AT_ERROR("Not implement on cpu"); 37 | } 38 | -------------------------------------------------------------------------------- /models/ops_dcnv3/src/cpu/dcnv3_cpu.h: -------------------------------------------------------------------------------- 1 | /*! 2 | ************************************************************************************************** 3 | * InternImage 4 | * Copyright (c) 2022 OpenGVLab 5 | * Licensed under The MIT License [see LICENSE for details] 6 | ************************************************************************************************** 7 | * Modified from 8 | *https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 9 | ************************************************************************************************** 10 | */ 11 | 12 | #pragma once 13 | #include 14 | 15 | at::Tensor dcnv3_cpu_forward(const at::Tensor &input, const at::Tensor &offset, 16 | const at::Tensor &mask, const int kernel_h, 17 | const int kernel_w, const int stride_h, 18 | const int stride_w, const int pad_h, 19 | const int pad_w, const int dilation_h, 20 | const int dilation_w, const int group, 21 | const int group_channels, const float offset_scale, 22 | const int im2col_step); 23 | 24 | std::vector 25 | dcnv3_cpu_backward(const at::Tensor &input, const at::Tensor &offset, 26 | const at::Tensor &mask, const int kernel_h, 27 | const int kernel_w, const int stride_h, const int stride_w, 28 | const int pad_h, const int pad_w, const int dilation_h, 29 | const int dilation_w, const int group, 30 | const int group_channels, const float offset_scale, 31 | const at::Tensor &grad_output, const int im2col_step); 32 | -------------------------------------------------------------------------------- /models/ops_dcnv3/src/cuda/dcnv3_cuda.cu: -------------------------------------------------------------------------------- 1 | /*! 2 | ************************************************************************************************** 3 | * InternImage 4 | * Copyright (c) 2022 OpenGVLab 5 | * Licensed under The MIT License [see LICENSE for details] 6 | ************************************************************************************************** 7 | * Modified from 8 | *https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 9 | ************************************************************************************************** 10 | */ 11 | 12 | #include "cuda/dcnv3_im2col_cuda.cuh" 13 | #include 14 | 15 | #include 16 | #include 17 | #include 18 | #include 19 | #include 20 | 21 | at::Tensor dcnv3_cuda_forward(const at::Tensor &input, const at::Tensor &offset, 22 | const at::Tensor &mask, const int kernel_h, 23 | const int kernel_w, const int stride_h, 24 | const int stride_w, const int pad_h, 25 | const int pad_w, const int dilation_h, 26 | const int dilation_w, const int group, 27 | const int group_channels, 28 | const float offset_scale, const int im2col_step) { 29 | AT_ASSERTM(input.is_contiguous(), "input tensor has to be contiguous"); 30 | AT_ASSERTM(offset.is_contiguous(), "offset tensor has to be contiguous"); 31 | AT_ASSERTM(mask.is_contiguous(), "mask tensor has to be contiguous"); 32 | AT_ASSERTM(input.type().is_cuda(), "input must be a CUDA tensor"); 33 | AT_ASSERTM(offset.type().is_cuda(), "offset must be a CUDA tensor"); 34 | AT_ASSERTM(mask.type().is_cuda(), "mask must be a CUDA tensor"); 35 | 36 | const int batch = input.size(0); 37 | const int height_in = input.size(1); 38 | const int width_in = input.size(2); 39 | const int channels = input.size(3); 40 | const int height_out = 41 | (height_in + 2 * pad_h - (dilation_h * (kernel_h - 1) + 1)) / stride_h + 42 | 1; 43 | const int width_out = 44 | (width_in + 2 * pad_w - (dilation_w * (kernel_w - 1) + 1)) / stride_w + 45 | 1; 46 | const int im2col_step_ = std::min(batch, im2col_step); 47 | 48 | AT_ASSERTM(batch % im2col_step_ == 0, 49 | "batch(%d) must divide im2col_step(%d)", batch, im2col_step_); 50 | AT_ASSERTM( 51 | channels == (group * group_channels), 52 | "Input channels and group times group channels wont match: (%d vs %d).", 53 | channels, group * group_channels); 54 | 55 | auto output = 56 | at::zeros({batch, height_out, width_out, group * group_channels}, 57 | input.options()); 58 | 59 | const int batch_n = im2col_step_; 60 | auto output_n = output.view({batch / batch_n, batch_n, height_out, 61 | width_out, group * group_channels}); 62 | auto per_input_size = height_in * width_in * group * group_channels; 63 | auto per_offset_size = 64 | height_out * width_out * group * kernel_h * kernel_w * 2; 65 | auto per_mask_size = height_out * width_out * group * kernel_h * kernel_w; 66 | for (int n = 0; n < batch / im2col_step_; ++n) { 67 | auto columns = output_n.select(0, n); 68 | // AT_DISPATCH_FLOATING_TYPES( 69 | AT_DISPATCH_FLOATING_TYPES_AND_HALF( 70 | input.type(), "ms_deform_attn_forward_cuda", ([&] { 71 | dcnv3_im2col_cuda( 72 | at::cuda::getCurrentCUDAStream(), 73 | input.data() + n * im2col_step_ * per_input_size, 74 | offset.data() + 75 | n * im2col_step_ * per_offset_size, 76 | mask.data() + n * im2col_step_ * per_mask_size, 77 | columns.data(), kernel_h, kernel_w, stride_h, 78 | stride_w, pad_h, pad_w, dilation_h, dilation_w, group, 79 | group_channels, batch_n, height_in, width_in, height_out, 80 | width_out, offset_scale); 81 | })); 82 | } 83 | 84 | return output; 85 | } 86 | 87 | std::vector 88 | dcnv3_cuda_backward(const at::Tensor &input, const at::Tensor &offset, 89 | const at::Tensor &mask, const int kernel_h, 90 | const int kernel_w, const int stride_h, const int stride_w, 91 | const int pad_h, const int pad_w, const int dilation_h, 92 | const int dilation_w, const int group, 93 | const int group_channels, const float offset_scale, 94 | const at::Tensor &grad_output, const int im2col_step) { 95 | 96 | AT_ASSERTM(input.is_contiguous(), "input tensor has to be contiguous"); 97 | AT_ASSERTM(offset.is_contiguous(), "offset tensor has to be contiguous"); 98 | AT_ASSERTM(mask.is_contiguous(), "mask tensor has to be contiguous"); 99 | AT_ASSERTM(grad_output.is_contiguous(), 100 | "grad_output tensor has to be contiguous"); 101 | AT_ASSERTM(input.type().is_cuda(), "input must be a CUDA tensor"); 102 | AT_ASSERTM(offset.type().is_cuda(), "offset must be a CUDA tensor"); 103 | AT_ASSERTM(mask.type().is_cuda(), "mask must be a CUDA tensor"); 104 | AT_ASSERTM(grad_output.type().is_cuda(), 105 | "grad_output must be a CUDA tensor"); 106 | 107 | const int batch = input.size(0); 108 | const int height_in = input.size(1); 109 | const int width_in = input.size(2); 110 | const int channels = input.size(3); 111 | const int height_out = 112 | (height_in + 2 * pad_h - (dilation_h * (kernel_h - 1) + 1)) / stride_h + 113 | 1; 114 | const int width_out = 115 | (width_in + 2 * pad_w - (dilation_w * (kernel_w - 1) + 1)) / stride_w + 116 | 1; 117 | const int im2col_step_ = std::min(batch, im2col_step); 118 | 119 | AT_ASSERTM(batch % im2col_step_ == 0, 120 | "batch(%d) must divide im2col_step(%d)", batch, im2col_step_); 121 | AT_ASSERTM( 122 | channels == (group * group_channels), 123 | "Input channels and group times group channels wont match: (%d vs %d).", 124 | channels, group * group_channels); 125 | 126 | auto dtype = input.dtype(); 127 | if (dtype == at::kHalf) { 128 | dtype = at::kFloat; 129 | } 130 | 131 | auto grad_input = at::zeros_like(input, dtype); 132 | auto grad_offset = at::zeros_like(offset, dtype); 133 | auto grad_mask = at::zeros_like(mask, dtype); 134 | 135 | const int batch_n = im2col_step_; 136 | auto per_input_size = height_in * width_in * group * group_channels; 137 | auto per_offset_size = 138 | height_out * width_out * group * kernel_h * kernel_w * 2; 139 | auto per_mask_size = height_out * width_out * group * kernel_h * kernel_w; 140 | auto grad_output_n = 141 | grad_output.view({batch / im2col_step_, batch_n, height_out * width_out, 142 | group, group_channels}); 143 | 144 | for (int n = 0; n < batch / im2col_step_; ++n) { 145 | auto grad_output_g = grad_output_n.select(0, n); 146 | // AT_DISPATCH_FLOATING_TYPES( 147 | AT_DISPATCH_FLOATING_TYPES_AND_HALF( 148 | input.type(), "ms_deform_attn_backward_cuda", ([&] { 149 | dcnv3_col2im_cuda( 150 | at::cuda::getCurrentCUDAStream(), 151 | grad_output_g.data(), 152 | input.data() + n * im2col_step_ * per_input_size, 153 | offset.data() + 154 | n * im2col_step_ * per_offset_size, 155 | mask.data() + n * im2col_step_ * per_mask_size, 156 | kernel_h, kernel_w, stride_h, stride_w, pad_h, pad_w, 157 | dilation_h, dilation_w, group, group_channels, batch_n, 158 | height_in, width_in, height_out, width_out, offset_scale, 159 | grad_input.data() + 160 | n * im2col_step_ * per_input_size, 161 | grad_offset.data() + 162 | n * im2col_step_ * per_offset_size, 163 | grad_mask.data() + 164 | n * im2col_step_ * per_mask_size); 165 | })); 166 | } 167 | 168 | if (input.dtype() == torch::kHalf) { 169 | return {grad_input.to(torch::kHalf), grad_offset.to(torch::kHalf), 170 | grad_mask.to(torch::kHalf)}; 171 | } else { 172 | return {grad_input, grad_offset, grad_mask}; 173 | } 174 | } -------------------------------------------------------------------------------- /models/ops_dcnv3/src/cuda/dcnv3_cuda.h: -------------------------------------------------------------------------------- 1 | /*! 2 | ************************************************************************************************** 3 | * InternImage 4 | * Copyright (c) 2022 OpenGVLab 5 | * Licensed under The MIT License [see LICENSE for details] 6 | ************************************************************************************************** 7 | * Modified from 8 | *https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 9 | ************************************************************************************************** 10 | */ 11 | 12 | #pragma once 13 | #include 14 | 15 | at::Tensor dcnv3_cuda_forward(const at::Tensor &input, const at::Tensor &offset, 16 | const at::Tensor &mask, const int kernel_h, 17 | const int kernel_w, const int stride_h, 18 | const int stride_w, const int pad_h, 19 | const int pad_w, const int dilation_h, 20 | const int dilation_w, const int group, 21 | const int group_channels, 22 | const float offset_scale, const int im2col_step); 23 | 24 | std::vector 25 | dcnv3_cuda_backward(const at::Tensor &input, const at::Tensor &offset, 26 | const at::Tensor &mask, const int kernel_h, 27 | const int kernel_w, const int stride_h, const int stride_w, 28 | const int pad_h, const int pad_w, const int dilation_h, 29 | const int dilation_w, const int group, 30 | const int group_channels, const float offset_scale, 31 | const at::Tensor &grad_output, const int im2col_step); 32 | -------------------------------------------------------------------------------- /models/ops_dcnv3/src/dcnv3.h: -------------------------------------------------------------------------------- 1 | /*! 2 | ************************************************************************************************** 3 | * InternImage 4 | * Copyright (c) 2022 OpenGVLab 5 | * Licensed under The MIT License [see LICENSE for details] 6 | ************************************************************************************************** 7 | * Modified from 8 | *https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 9 | ************************************************************************************************** 10 | */ 11 | 12 | #pragma once 13 | 14 | #include "cpu/dcnv3_cpu.h" 15 | 16 | #ifdef WITH_CUDA 17 | #include "cuda/dcnv3_cuda.h" 18 | #endif 19 | 20 | at::Tensor dcnv3_forward(const at::Tensor &input, const at::Tensor &offset, 21 | const at::Tensor &mask, const int kernel_h, 22 | const int kernel_w, const int stride_h, 23 | const int stride_w, const int pad_h, const int pad_w, 24 | const int dilation_h, const int dilation_w, 25 | const int group, const int group_channels, 26 | const float offset_scale, const int im2col_step) { 27 | if (input.type().is_cuda()) { 28 | #ifdef WITH_CUDA 29 | return dcnv3_cuda_forward(input, offset, mask, kernel_h, kernel_w, 30 | stride_h, stride_w, pad_h, pad_w, dilation_h, 31 | dilation_w, group, group_channels, 32 | offset_scale, im2col_step); 33 | #else 34 | AT_ERROR("Not compiled with GPU support"); 35 | #endif 36 | } 37 | AT_ERROR("Not implemented on the CPU"); 38 | } 39 | 40 | std::vector 41 | dcnv3_backward(const at::Tensor &input, const at::Tensor &offset, 42 | const at::Tensor &mask, const int kernel_h, const int kernel_w, 43 | const int stride_h, const int stride_w, const int pad_h, 44 | const int pad_w, const int dilation_h, const int dilation_w, 45 | const int group, const int group_channels, 46 | const float offset_scale, const at::Tensor &grad_output, 47 | const int im2col_step) { 48 | if (input.type().is_cuda()) { 49 | #ifdef WITH_CUDA 50 | return dcnv3_cuda_backward(input, offset, mask, kernel_h, kernel_w, 51 | stride_h, stride_w, pad_h, pad_w, dilation_h, 52 | dilation_w, group, group_channels, 53 | offset_scale, grad_output, im2col_step); 54 | #else 55 | AT_ERROR("Not compiled with GPU support"); 56 | #endif 57 | } 58 | AT_ERROR("Not implemented on the CPU"); 59 | } 60 | -------------------------------------------------------------------------------- /models/ops_dcnv3/src/vision.cpp: -------------------------------------------------------------------------------- 1 | /*! 2 | ************************************************************************************************** 3 | * InternImage 4 | * Copyright (c) 2022 OpenGVLab 5 | * Licensed under The MIT License [see LICENSE for details] 6 | ************************************************************************************************** 7 | * Modified from 8 | *https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 9 | ************************************************************************************************** 10 | */ 11 | 12 | #include "dcnv3.h" 13 | 14 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 15 | m.def("dcnv3_forward", &dcnv3_forward, "dcnv3_forward"); 16 | m.def("dcnv3_backward", &dcnv3_backward, "dcnv3_backward"); 17 | } 18 | -------------------------------------------------------------------------------- /models/ops_dcnv3/test.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # InternImage 3 | # Copyright (c) 2022 OpenGVLab 4 | # Licensed under The MIT License [see LICENSE for details] 5 | # -------------------------------------------------------- 6 | 7 | from __future__ import absolute_import 8 | from __future__ import print_function 9 | from __future__ import division 10 | 11 | import time 12 | import torch 13 | import torch.nn as nn 14 | import math 15 | from torch.autograd import gradcheck 16 | 17 | from functions.dcnv3_func import DCNv3Function, dcnv3_core_pytorch 18 | 19 | H_in, W_in = 8, 8 20 | N, M, D = 2, 4, 16 21 | Kh, Kw = 3, 3 22 | P = Kh * Kw 23 | offset_scale = 2.0 24 | pad = 1 25 | dilation = 1 26 | stride = 1 27 | H_out = (H_in + 2 * pad - (dilation * (Kh - 1) + 1)) // stride + 1 28 | W_out = (W_in + 2 * pad - (dilation * (Kw - 1) + 1)) // stride + 1 29 | 30 | torch.manual_seed(3) 31 | 32 | 33 | @torch.no_grad() 34 | def check_forward_equal_with_pytorch_double(): 35 | input = torch.rand(N, H_in, W_in, M*D).cuda() * 0.01 36 | offset = torch.rand(N, H_out, W_out, M*P*2).cuda() * 10 37 | mask = torch.rand(N, H_out, W_out, M, P).cuda() + 1e-5 38 | mask /= mask.sum(-1, keepdim=True) 39 | mask = mask.reshape(N, H_out, W_out, M*P) 40 | 41 | output_pytorch = dcnv3_core_pytorch( 42 | input.double(), 43 | offset.double(), 44 | mask.double(), 45 | Kh, Kw, stride, stride, Kh // 2, Kw // 2, dilation, dilation, M, D, offset_scale).detach().cpu() 46 | 47 | im2col_step = 2 48 | output_cuda = DCNv3Function.apply( 49 | input.double(), 50 | offset.double(), 51 | mask.double(), 52 | Kh, Kw, stride, stride, Kh // 2, Kw // 2, dilation, dilation, M, D, offset_scale, 53 | im2col_step).detach().cpu() 54 | 55 | fwdok = torch.allclose(output_cuda, output_pytorch) 56 | max_abs_err = (output_cuda - output_pytorch).abs().max() 57 | max_rel_err = ((output_cuda - output_pytorch).abs() / 58 | output_pytorch.abs()).max() 59 | print('>>> forward double') 60 | print(f'* {fwdok} check_forward_equal_with_pytorch_double: max_abs_err {max_abs_err:.2e} max_rel_err {max_rel_err:.2e}') 61 | 62 | 63 | @torch.no_grad() 64 | def check_forward_equal_with_pytorch_float(): 65 | input = torch.rand(N, H_in, W_in, M*D).cuda() * 0.01 66 | offset = torch.rand(N, H_out, W_out, M*P*2).cuda() * 10 67 | mask = torch.rand(N, H_out, W_out, M, P).cuda() + 1e-5 68 | mask /= mask.sum(-1, keepdim=True) 69 | mask = mask.reshape(N, H_out, W_out, M*P) 70 | 71 | output_pytorch = dcnv3_core_pytorch( 72 | input, 73 | offset, 74 | mask, 75 | Kh, Kw, stride, stride, Kh // 2, Kw // 2, dilation, dilation, M, D, offset_scale).detach().cpu() 76 | 77 | im2col_step = 2 78 | output_cuda = DCNv3Function.apply( 79 | input, 80 | offset, 81 | mask, 82 | Kh, Kw, stride, stride, Kh // 2, Kw // 2, dilation, dilation, M, D, offset_scale, 83 | im2col_step).detach().cpu() 84 | 85 | fwdok = torch.allclose(output_cuda, output_pytorch, rtol=1e-2, atol=1e-3) 86 | max_abs_err = (output_cuda - output_pytorch).abs().max() 87 | max_rel_err = ((output_cuda - output_pytorch).abs() / 88 | output_pytorch.abs()).max() 89 | print('>>> forward float') 90 | print(f'* {fwdok} check_forward_equal_with_pytorch_float: max_abs_err {max_abs_err:.2e} max_rel_err {max_rel_err:.2e}') 91 | 92 | 93 | def check_backward_equal_with_pytorch_double(channels=4, grad_input=True, grad_offset=True, grad_mask=True): 94 | # H_in, W_in = 4, 4 95 | N = 2 96 | M = 2 97 | H_out = (H_in + 2 * pad - (dilation * (Kh - 1) + 1)) // stride + 1 98 | W_out = (W_in + 2 * pad - (dilation * (Kw - 1) + 1)) // stride + 1 99 | 100 | D = channels 101 | input0 = torch.rand(N, H_in, W_in, M*D).cuda() * 0.01 102 | offset0 = torch.rand(N, H_out, W_out, M*P*2).cuda() * 10 103 | mask0 = torch.rand(N, H_out, W_out, M, P).cuda() + 1e-5 104 | mask0 /= mask0.sum(-1, keepdim=True) 105 | mask0 = mask0.reshape(N, H_out, W_out, M*P) 106 | input0.requires_grad = grad_input 107 | offset0.requires_grad = grad_offset 108 | mask0.requires_grad = grad_mask 109 | 110 | output_pytorch = dcnv3_core_pytorch( 111 | input0.double(), 112 | offset0.double(), 113 | mask0.double(), 114 | Kh, Kw, stride, stride, Kh // 2, Kw // 2, dilation, dilation, M, D, offset_scale) 115 | output_pytorch.sum().backward() 116 | 117 | input1 = input0.detach() 118 | offset1 = offset0.detach() 119 | mask1 = mask0.detach() 120 | input1.requires_grad = grad_input 121 | offset1.requires_grad = grad_offset 122 | mask1.requires_grad = grad_mask 123 | 124 | im2col_step = 2 125 | output_cuda = DCNv3Function.apply( 126 | input1.double(), 127 | offset1.double(), 128 | mask1.double(), 129 | Kh, Kw, stride, stride, Kh // 2, Kw // 2, dilation, dilation, M, D, offset_scale, 130 | im2col_step) 131 | output_cuda.sum().backward() 132 | 133 | print(f'>>> backward double: channels {D}') 134 | bwdok = torch.allclose(input0.grad, input1.grad, rtol=1e-2, atol=1e-3) 135 | max_abs_err = (input0.grad - input1.grad).abs().max() 136 | max_rel_err = ((input0.grad - input1.grad).abs() / 137 | input0.grad.abs()).max() 138 | print( 139 | f'* {bwdok} input_grad check_backward_equal_with_pytorch_double: max_abs_err {max_abs_err:.2e} max_rel_err {max_rel_err:.2e}') 140 | 141 | bwdok = torch.allclose(offset0.grad, offset1.grad, rtol=1e-2, atol=1e-3) 142 | max_abs_err = (offset0.grad - offset1.grad).abs().max() 143 | max_rel_err = ((offset0.grad - offset1.grad).abs() / 144 | offset0.grad.abs()).max() 145 | print( 146 | f'* {bwdok} offset_grad check_backward_equal_with_pytorch_double: max_abs_err {max_abs_err:.2e} max_rel_err {max_rel_err:.2e}') 147 | 148 | bwdok = torch.allclose(mask0.grad, mask1.grad, rtol=1e-2, atol=1e-3) 149 | max_abs_err = (mask0.grad - mask1.grad).abs().max() 150 | max_rel_err = ((mask0.grad - mask1.grad).abs() / 151 | mask0.grad.abs()).max() 152 | print( 153 | f'* {bwdok} mask_grad check_backward_equal_with_pytorch_double: max_abs_err {max_abs_err:.2e} max_rel_err {max_rel_err:.2e}') 154 | 155 | 156 | def check_backward_equal_with_pytorch_float(channels=4, grad_input=True, grad_offset=True, grad_mask=True): 157 | # H_in, W_in = 4, 4 158 | N = 2 159 | M = 2 160 | H_out = (H_in + 2 * pad - (dilation * (Kh - 1) + 1)) // stride + 1 161 | W_out = (W_in + 2 * pad - (dilation * (Kw - 1) + 1)) // stride + 1 162 | 163 | D = channels 164 | input0 = torch.rand(N, H_in, W_in, M*D).cuda() * 0.01 165 | offset0 = torch.rand(N, H_out, W_out, M*P*2).cuda() * 10 166 | mask0 = torch.rand(N, H_out, W_out, M, P).cuda() + 1e-5 167 | mask0 /= mask0.sum(-1, keepdim=True) 168 | mask0 = mask0.reshape(N, H_out, W_out, M*P) 169 | input0.requires_grad = grad_input 170 | offset0.requires_grad = grad_offset 171 | mask0.requires_grad = grad_mask 172 | 173 | output_pytorch = dcnv3_core_pytorch( 174 | input0, 175 | offset0, 176 | mask0, 177 | Kh, Kw, stride, stride, Kh // 2, Kw // 2, dilation, dilation, M, D, offset_scale) 178 | output_pytorch.sum().backward() 179 | 180 | input1 = input0.detach() 181 | offset1 = offset0.detach() 182 | mask1 = mask0.detach() 183 | input1.requires_grad = grad_input 184 | offset1.requires_grad = grad_offset 185 | mask1.requires_grad = grad_mask 186 | 187 | im2col_step = 2 188 | output_cuda = DCNv3Function.apply( 189 | input1, 190 | offset1, 191 | mask1, 192 | Kh, Kw, stride, stride, Kh // 2, Kw // 2, dilation, dilation, M, D, offset_scale, 193 | im2col_step) 194 | output_cuda.sum().backward() 195 | 196 | print(f'>>> backward float: channels {D}') 197 | bwdok = torch.allclose(input0.grad, input1.grad, rtol=1e-2, atol=1e-3) 198 | max_abs_err = (input0.grad - input1.grad).abs().max() 199 | max_rel_err = ((input0.grad - input1.grad).abs() / 200 | input0.grad.abs()).max() 201 | print( 202 | f'* {bwdok} input_grad check_backward_equal_with_pytorch_float: max_abs_err {max_abs_err:.2e} max_rel_err {max_rel_err:.2e}') 203 | 204 | bwdok = torch.allclose(offset0.grad, offset1.grad, rtol=1e-2, atol=1e-3) 205 | max_abs_err = (offset0.grad - offset1.grad).abs().max() 206 | max_rel_err = ((offset0.grad - offset1.grad).abs() / 207 | offset0.grad.abs()).max() 208 | print( 209 | f'* {bwdok} offset_grad check_backward_equal_with_pytorch_float: max_abs_err {max_abs_err:.2e} max_rel_err {max_rel_err:.2e}') 210 | 211 | bwdok = torch.allclose(mask0.grad, mask1.grad, rtol=1e-2, atol=1e-3) 212 | max_abs_err = (mask0.grad - mask1.grad).abs().max() 213 | max_rel_err = ((mask0.grad - mask1.grad).abs() / 214 | mask0.grad.abs()).max() 215 | print( 216 | f'* {bwdok} mask_grad check_backward_equal_with_pytorch_float: max_abs_err {max_abs_err:.2e} max_rel_err {max_rel_err:.2e}') 217 | 218 | 219 | @torch.no_grad() 220 | def check_time_cost(im2col_step=128): 221 | N = 512 222 | H_in, W_in = 64, 64 223 | H_out = (H_in + 2 * pad - (dilation * (Kh - 1) + 1)) // stride + 1 224 | W_out = (W_in + 2 * pad - (dilation * (Kw - 1) + 1)) // stride + 1 225 | 226 | input = torch.rand(N, H_in, W_in, M*D).cuda() * 0.01 227 | offset = torch.rand(N, H_out, W_out, M*P*2).cuda() * 10 228 | mask = torch.rand(N, H_out, W_out, M, P).cuda() + 1e-5 229 | mask /= mask.sum(-1, keepdim=True) 230 | mask = mask.reshape(N, H_out, W_out, M*P) 231 | print( 232 | f'>>> time cost: im2col_step {im2col_step}; input {input.shape}; points {P} ') 233 | repeat = 100 234 | for i in range(repeat): 235 | output_cuda = DCNv3Function.apply( 236 | input, 237 | offset, 238 | mask, 239 | Kh, Kw, stride, stride, Kh // 2, Kw // 2, dilation, dilation, M, D, 1.0, 240 | im2col_step) 241 | torch.cuda.synchronize() 242 | start = time.time() 243 | for i in range(repeat): 244 | output_cuda = DCNv3Function.apply( 245 | input, 246 | offset, 247 | mask, 248 | Kh, Kw, stride, stride, Kh // 2, Kw // 2, dilation, dilation, M, D, 1.0, 249 | im2col_step) 250 | torch.cuda.synchronize() 251 | print(f'foward time cost: {(time.time() - start) / repeat}') 252 | 253 | 254 | if __name__ == '__main__': 255 | check_forward_equal_with_pytorch_double() 256 | check_forward_equal_with_pytorch_float() 257 | for channels in [1, 16, 30, 32, 64, 71, 1025]: 258 | check_backward_equal_with_pytorch_double(channels, True, True, True) 259 | for channels in [1, 16, 30, 32, 64, 71, 1025]: 260 | check_backward_equal_with_pytorch_float(channels, True, True, True) 261 | for i in range(3): 262 | im2col_step = 128 * (2 ** i) 263 | check_time_cost(im2col_step) 264 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IMOP-lab/DeMambaNet/4576667c8e388fa79e1e200468ae3669b34e8901/requirements.txt --------------------------------------------------------------------------------