├── .DS_Store ├── .gitignore ├── LICENSE ├── README.md ├── TransDoD ├── .DS_Store ├── MOTSDataset.py ├── TranDoD.png ├── __init__.py ├── __pycache__ │ └── MOTSDataset.cpython-37.pyc ├── models │ ├── .DS_Store │ ├── ResCNN2.py │ ├── TransDoDNet.py │ ├── __init__.py │ ├── __pycache__ │ │ ├── ResCNN2.cpython-37.pyc │ │ ├── TransDoDNet.cpython-37.pyc │ │ ├── __init__.cpython-37.pyc │ │ ├── deepnorm.cpython-37.pyc │ │ ├── deformable_transformer.cpython-37.pyc │ │ └── position_encoding.cpython-37.pyc │ ├── deepnorm.py │ ├── deformable_transformer.py │ ├── ops │ │ ├── .DS_Store │ │ ├── MultiScaleDeformableAttention.egg-info │ │ │ ├── PKG-INFO │ │ │ ├── SOURCES.txt │ │ │ ├── dependency_links.txt │ │ │ └── top_level.txt │ │ ├── build │ │ │ ├── lib.linux-x86_64-3.7 │ │ │ │ ├── MultiScaleDeformableAttention.cpython-37m-x86_64-linux-gnu.so │ │ │ │ ├── functions │ │ │ │ │ ├── __init__.py │ │ │ │ │ └── ms_deform_attn_func.py │ │ │ │ └── modules │ │ │ │ │ ├── __init__.py │ │ │ │ │ ├── ms_deform_attn (33).py │ │ │ │ │ └── ms_deform_attn.py │ │ │ └── temp.linux-x86_64-3.7 │ │ │ │ ├── .ninja_deps │ │ │ │ ├── .ninja_log │ │ │ │ ├── build.ninja │ │ │ │ └── media │ │ │ │ ├── userdisk0 │ │ │ │ ├── myproject-Det │ │ │ │ │ └── Deformable-DETR │ │ │ │ │ │ └── models │ │ │ │ │ │ └── ops │ │ │ │ │ │ └── src │ │ │ │ │ │ ├── cpu │ │ │ │ │ │ └── ms_deform_attn_cpu.o │ │ │ │ │ │ ├── cuda │ │ │ │ │ │ └── ms_deform_attn_cuda.o │ │ │ │ │ │ └── vision.o │ │ │ │ └── myproject-Seg │ │ │ │ │ └── MOTS-pro2 │ │ │ │ │ └── f_transformers │ │ │ │ │ └── models │ │ │ │ │ └── ops │ │ │ │ │ └── src │ │ │ │ │ ├── cpu │ │ │ │ │ └── ms_deform_attn_cpu.o │ │ │ │ │ ├── cuda │ │ │ │ │ └── ms_deform_attn_cuda.o │ │ │ │ │ └── vision.o │ │ │ │ └── userdisk1 │ │ │ │ └── jpzhang │ │ │ │ └── myproject-Seg │ │ │ │ └── MOTS-pro2 │ │ │ │ └── f_transformers │ │ │ │ └── models │ │ │ │ └── ops │ │ │ │ └── src │ │ │ │ ├── cpu │ │ │ │ └── ms_deform_attn_cpu.o │ │ │ │ ├── cuda │ │ │ │ └── ms_deform_attn_cuda.o │ │ │ │ └── vision.o │ │ ├── dist │ │ │ └── MultiScaleDeformableAttention-1.0-py3.7-linux-x86_64.egg │ │ ├── functions │ │ │ ├── __init__.py │ │ │ ├── __pycache__ │ │ │ │ ├── __init__.cpython-37.pyc │ │ │ │ └── ms_deform_attn_func.cpython-37.pyc │ │ │ └── ms_deform_attn_func.py │ │ ├── make.sh │ │ ├── modules │ │ │ ├── __init__.py │ │ │ ├── __pycache__ │ │ │ │ ├── __init__.cpython-37.pyc │ │ │ │ └── ms_deform_attn.cpython-37.pyc │ │ │ ├── ms_deform_attn (33).py │ │ │ └── ms_deform_attn.py │ │ ├── setup.py │ │ ├── src │ │ │ ├── cpu │ │ │ │ ├── ms_deform_attn_cpu.cpp │ │ │ │ └── ms_deform_attn_cpu.h │ │ │ ├── cuda │ │ │ │ ├── ms_deform_attn_cuda.cu │ │ │ │ ├── ms_deform_attn_cuda.h │ │ │ │ └── ms_deform_im2col_cuda.cuh │ │ │ ├── ms_deform_attn.h │ │ │ └── vision.cpp │ │ └── test.py │ └── position_encoding.py ├── test.py └── train.py ├── __init__.py ├── __pycache__ └── engine.cpython-37.pyc ├── a_DynConv ├── .train.py.swp ├── MOTSDataset.py ├── dodnet.png ├── evaluate.py ├── postp.py ├── train.py └── unet3D_DynConv882.py ├── data_list ├── .DS_Store └── MOTS │ ├── .DS_Store │ ├── MOTS_test.txt │ └── MOTS_train.txt ├── dataset ├── .DS_Store ├── list │ ├── .DS_Store │ └── MOTS │ │ ├── MOTS_test.txt │ │ └── MOTS_train.txt └── re_spacing.py ├── engine.py ├── loss_functions ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-37.pyc │ └── loss.cpython-37.pyc └── loss.py ├── run_script.sh └── utils ├── ParaFlop.py ├── __init__.py ├── __pycache__ ├── ParaFlop.cpython-37.pyc ├── __init__.cpython-37.pyc ├── logger.cpython-37.pyc ├── misc.cpython-37.pyc ├── my_utils.cpython-37.pyc └── pyt_utils.cpython-37.pyc ├── logger.py ├── misc.py ├── my_utils.py └── pyt_utils.py /.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jianpengz/DoDNet/f06ce1b02988b6894a85f925623c6bcc6012284f/.DS_Store -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # DoDNet 2 |

3 | 4 |

5 | 6 | 7 | This repo holds the pytorch implementation of DoDNet and TransDoDNet:
8 | 9 | **DoDNet: Learning to segment multi-organ and tumors from multiple partially labeled datasets** 10 | (https://arxiv.org/pdf/2011.10217.pdf) \ 11 | **Learning from partially labeled data for multi-organ and tumor segmentation** 12 | (https://arxiv.org/pdf/2211.06894.pdf) 13 | 14 | 22 | 23 | ## Usage 24 | 25 | 26 | 32 | 33 | ### 1. MOTS Dataset Preparation 34 | Before starting, MOTS should be re-built from the serveral medical organ and tumor segmentation datasets 35 | 36 | Partial-label task | Data source 37 | --- | :---: 38 | Liver | [data](https://competitions.codalab.org/competitions/17094) 39 | Kidney | [data](https://kits19.grand-challenge.org/data/) 40 | Hepatic Vessel | [data](http://medicaldecathlon.com/) 41 | Pancreas | [data](http://medicaldecathlon.com/) 42 | Colon | [data](http://medicaldecathlon.com/) 43 | Lung | [data](http://medicaldecathlon.com/) 44 | Spleen | [data](http://medicaldecathlon.com/) 45 | 46 | 63 | * Preprocessed data will be available soon. 64 | 65 | ### 2. Training/Testing/Evaluation 66 | sh run_script.sh 67 | 68 | 69 | 109 | 110 | 111 | ### 3. Citation 112 | If this code is helpful for your study, please cite: 113 | ``` 114 | @inproceedings{zhang2021dodnet, 115 | title={DoDNet: Learning to segment multi-organ and tumors from multiple partially labeled datasets}, 116 | author={Zhang, Jianpeng and Xie, Yutong and Xia, Yong and Shen, Chunhua}, 117 | booktitle={Proceedings of the IEEE conference on computer vision and pattern recognition}, 118 | pages={}, 119 | year={2021} 120 | } 121 | @article{xie2023learning, 122 | title={Learning from partially labeled data for multi-organ and tumor segmentation}, 123 | author={Xie, Yutong and Zhang, Jianpeng and Xia, Yong and Shen, Chunhua}, 124 | journal={IEEE Transactions on Pattern Analysis and Machine Intelligence}, 125 | year={2023} 126 | } 127 | ``` -------------------------------------------------------------------------------- /TransDoD/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jianpengz/DoDNet/f06ce1b02988b6894a85f925623c6bcc6012284f/TransDoD/.DS_Store -------------------------------------------------------------------------------- /TransDoD/TranDoD.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jianpengz/DoDNet/f06ce1b02988b6894a85f925623c6bcc6012284f/TransDoD/TranDoD.png -------------------------------------------------------------------------------- /TransDoD/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jianpengz/DoDNet/f06ce1b02988b6894a85f925623c6bcc6012284f/TransDoD/__init__.py -------------------------------------------------------------------------------- /TransDoD/__pycache__/MOTSDataset.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jianpengz/DoDNet/f06ce1b02988b6894a85f925623c6bcc6012284f/TransDoD/__pycache__/MOTSDataset.cpython-37.pyc -------------------------------------------------------------------------------- /TransDoD/models/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jianpengz/DoDNet/f06ce1b02988b6894a85f925623c6bcc6012284f/TransDoD/models/.DS_Store -------------------------------------------------------------------------------- /TransDoD/models/ResCNN2.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.autograd import Variable 5 | from functools import partial 6 | 7 | 8 | class Conv3d_wd(nn.Conv3d): 9 | 10 | def __init__(self, in_channels, out_channels, kernel_size, stride=(1,1,1), padding=(0,0,0), dilation=(1,1,1), groups=1, bias=False): 11 | super(Conv3d_wd, self).__init__(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias) 12 | 13 | def forward(self, x): 14 | weight = self.weight 15 | weight_mean = weight.mean(dim=1, keepdim=True).mean(dim=2, keepdim=True).mean(dim=3, keepdim=True).mean(dim=4, keepdim=True) 16 | weight = weight - weight_mean 17 | # std = weight.view(weight.size(0), -1).std(dim=1).view(-1, 1, 1, 1, 1) + 1e-5 18 | std = torch.sqrt(torch.var(weight.view(weight.size(0), -1), dim=1) + 1e-12).view(-1, 1, 1, 1, 1) 19 | weight = weight / std.expand_as(weight) 20 | return F.conv3d(x, weight, self.bias, self.stride, self.padding, self.dilation, self.groups) 21 | 22 | 23 | def conv3x3x3(in_planes, out_planes, kernel_size, stride=(1, 1, 1), padding=(0, 0, 0), dilation=(1, 1, 1), bias=False, weight_std=False): 24 | "3x3x3 convolution with padding" 25 | if weight_std: 26 | return Conv3d_wd(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, bias=bias) 27 | else: 28 | return nn.Conv3d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, bias=bias) 29 | 30 | 31 | def downsample_basic_block(x, planes, stride): 32 | out = F.avg_pool3d(x, kernel_size=1, stride=stride) 33 | zero_pads = torch.Tensor(out.size(0), planes - out.size(1), out.size(2), out.size(3),out.size(4)).zero_() 34 | if isinstance(out.data, torch.cuda.FloatTensor): 35 | zero_pads = zero_pads.cuda() 36 | 37 | out = Variable(torch.cat([out.data, zero_pads.cuda()], dim=1)) 38 | 39 | return out 40 | 41 | 42 | def Norm_layer(norm_cfg, inplanes): 43 | 44 | if norm_cfg == 'BN': 45 | out = nn.BatchNorm3d(inplanes) 46 | elif norm_cfg == 'SyncBN': 47 | out = nn.SyncBatchNorm(inplanes) 48 | elif norm_cfg == 'GN': 49 | out = nn.GroupNorm(16, inplanes) 50 | elif norm_cfg == 'IN': 51 | out = nn.InstanceNorm3d(inplanes,affine=True) 52 | 53 | return out 54 | 55 | 56 | def Activation_layer(activation_cfg, inplace=True): 57 | 58 | if activation_cfg == 'relu': 59 | out = nn.ReLU(inplace=inplace) 60 | elif activation_cfg == 'LeakyReLU': 61 | out = nn.LeakyReLU(negative_slope=1e-2, inplace=inplace) 62 | 63 | return out 64 | 65 | 66 | class BasicBlock(nn.Module): 67 | expansion = 1 68 | 69 | def __init__(self, inplanes, planes, norm_cfg, activation_cfg, stride=(1, 1, 1), downsample=None, weight_std=False): 70 | super(BasicBlock, self).__init__() 71 | self.conv1 = conv3x3x3(inplanes, planes, kernel_size=3, stride=stride, padding=1, bias=False, weight_std=weight_std) 72 | self.norm1 = Norm_layer(norm_cfg, planes) 73 | self.nonlin = Activation_layer(activation_cfg, inplace=True) 74 | self.conv2 = conv3x3x3(planes, planes, kernel_size=3, stride=(1, 1, 1), padding=1, bias=False, weight_std=weight_std) 75 | self.norm2 = Norm_layer(norm_cfg, planes) 76 | self.downsample = downsample 77 | self.stride = stride 78 | 79 | def forward(self, x): 80 | residual = x 81 | 82 | out = self.conv1(x) 83 | out = self.norm1(out) 84 | out = self.nonlin(out) 85 | 86 | out = self.conv2(out) 87 | out = self.norm2(out) 88 | 89 | if self.downsample is not None: 90 | residual = self.downsample(x) 91 | 92 | out += residual 93 | out = self.nonlin(out) 94 | 95 | return out 96 | 97 | 98 | class Bottleneck(nn.Module): 99 | expansion = 4 100 | 101 | def __init__(self, inplanes, planes, norm_cfg, activation_cfg, stride=(1, 1, 1), downsample=None, weight_std=False): 102 | super(Bottleneck, self).__init__() 103 | self.conv1 = conv3x3x3(inplanes, planes, kernel_size=1, bias=False, weight_std=weight_std) 104 | self.norm1 = Norm_layer(norm_cfg, planes) 105 | self.conv2 = conv3x3x3(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False, weight_std=weight_std) 106 | self.norm2 = Norm_layer(norm_cfg, planes) 107 | self.conv3 = conv3x3x3(planes, planes * 4, kernel_size=1, bias=False, weight_std=weight_std) 108 | self.norm3 = Norm_layer(norm_cfg, planes * 4) 109 | self.nonlin = Activation_layer(activation_cfg, inplace=True) 110 | self.downsample = downsample 111 | self.stride = stride 112 | 113 | def forward(self, x): 114 | residual = x 115 | 116 | out = self.conv1(x) 117 | out = self.norm1(out) 118 | out = self.nonlin(out) 119 | 120 | out = self.conv2(out) 121 | out = self.norm2(out) 122 | out = self.nonlin(out) 123 | 124 | out = self.conv3(out) 125 | out = self.norm3(out) 126 | 127 | if self.downsample is not None: 128 | residual = self.downsample(x) 129 | 130 | out += residual 131 | out = self.nonlin(out) 132 | 133 | return out 134 | 135 | 136 | class ResNet(nn.Module): 137 | 138 | arch_settings = { 139 | 10: (BasicBlock, (1, 1, 1, 1)), 140 | 18: (BasicBlock, (2, 2, 2, 2)), 141 | 34: (BasicBlock, (3, 4, 6, 3)), 142 | 50: (Bottleneck, (3, 4, 6, 3)), 143 | 101: (Bottleneck, (3, 4, 23, 3)), 144 | 152: (Bottleneck, (3, 8, 36, 3)), 145 | 200: (Bottleneck, (3, 24, 36, 3)) 146 | } 147 | 148 | def __init__(self, 149 | depth, 150 | in_channels=1, 151 | shortcut_type='B', 152 | norm_cfg='IN', 153 | activation_cfg='relu', 154 | weight_std=False): 155 | super(ResNet, self).__init__() 156 | 157 | if depth not in self.arch_settings: 158 | raise KeyError('invalid depth {} for resnet'.format(depth)) 159 | self.depth = depth 160 | block, layers = self.arch_settings[depth] 161 | self.inplanes = 32 162 | self.conv1 = conv3x3x3(in_channels, 32, kernel_size=3, stride=(1, 2, 2), padding=1, bias=False, weight_std=weight_std) 163 | self.norm1 = Norm_layer(norm_cfg, 32) 164 | self.nonlin1 = Activation_layer(activation_cfg, inplace=True) 165 | self.conv2 = conv3x3x3(32, 32, kernel_size=3, stride=(1, 1, 1), padding=1, bias=False, weight_std=weight_std) 166 | self.norm2 = Norm_layer(norm_cfg, 32) 167 | self.nonlin2 = Activation_layer(activation_cfg, inplace=True) 168 | self.maxpool = nn.MaxPool3d(kernel_size=(3, 3, 3), stride=(2, 2, 2), padding=1) 169 | self.layer1 = self._make_layer(block, 64, layers[0], shortcut_type, stride=(2, 2, 2), norm_cfg=norm_cfg, activation_cfg=activation_cfg, weight_std=weight_std) 170 | self.layer2 = self._make_layer(block, 128, layers[1], shortcut_type, stride=(2, 2, 2), norm_cfg=norm_cfg, activation_cfg=activation_cfg, weight_std=weight_std) 171 | self.layer3 = self._make_layer(block, 256, layers[2], shortcut_type, stride=(2, 2, 2), norm_cfg=norm_cfg, activation_cfg=activation_cfg, weight_std=weight_std) 172 | self.layer4 = self._make_layer(block, 320, layers[3], shortcut_type, stride=(2, 2, 2), norm_cfg=norm_cfg, activation_cfg=activation_cfg, weight_std=weight_std) 173 | self.layers = [] 174 | 175 | for m in self.modules(): 176 | if isinstance(m, (nn.Conv3d, Conv3d_wd)): 177 | m.weight = nn.init.kaiming_normal(m.weight, mode='fan_out') 178 | elif isinstance(m, (nn.BatchNorm3d, nn.GroupNorm, nn.InstanceNorm3d, nn.SyncBatchNorm)): 179 | m.weight.data.fill_(1) 180 | m.bias.data.zero_() 181 | 182 | def _make_layer(self, block, planes, blocks, shortcut_type, stride=(1, 1, 1), norm_cfg='BN', activation_cfg='relu', weight_std=False): 183 | downsample = None 184 | if stride != 1 or self.inplanes != planes * block.expansion: 185 | if shortcut_type == 'A': 186 | downsample = partial( 187 | downsample_basic_block, 188 | planes=planes * block.expansion, 189 | stride=stride) 190 | else: 191 | downsample = nn.Sequential( 192 | conv3x3x3( 193 | self.inplanes, 194 | planes * block.expansion, 195 | kernel_size=1, 196 | stride=stride, 197 | bias=False, weight_std=weight_std), 198 | Norm_layer(norm_cfg, planes * block.expansion)) 199 | 200 | layers = [] 201 | layers.append(block(self.inplanes, planes, norm_cfg, activation_cfg, stride=stride, downsample=downsample, weight_std=weight_std)) 202 | self.inplanes = planes * block.expansion 203 | for i in range(1, blocks): 204 | layers.append(block(self.inplanes, planes, norm_cfg, activation_cfg, weight_std=weight_std)) 205 | 206 | return nn.Sequential(*layers) 207 | 208 | def init_weights(self): 209 | for m in self.modules(): 210 | if isinstance(m, (nn.Conv3d, Conv3d_wd)): 211 | m.weight = nn.init.kaiming_normal_(m.weight, mode='fan_out') 212 | elif isinstance(m, (nn.BatchNorm3d, nn.GroupNorm, nn.InstanceNorm3d, nn.SyncBatchNorm)): 213 | if m.weight is not None: 214 | nn.init.constant_(m.weight, 1) 215 | if m.bias is not None: 216 | nn.init.constant_(m.bias, 0) 217 | 218 | def forward(self, x): 219 | self.layers = [] 220 | x = self.nonlin1(self.norm1(self.conv1(x))) 221 | x = self.nonlin2(self.norm2(self.conv2(x))) 222 | self.layers.append(x) 223 | 224 | x = self.layer1(x) 225 | self.layers.append(x) 226 | x = self.layer2(x) 227 | self.layers.append(x) 228 | x = self.layer3(x) 229 | self.layers.append(x) 230 | x = self.layer4(x) 231 | self.layers.append(x) 232 | 233 | return x 234 | 235 | def get_layers(self): 236 | return self.layers 237 | -------------------------------------------------------------------------------- /TransDoD/models/TransDoDNet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from models import ResCNN2 5 | from models.deformable_transformer import build_deformable_transformer 6 | from .position_encoding import build_position_encoding 7 | 8 | 9 | def _expand(tensor, length: int): 10 | return tensor.unsqueeze(1).repeat(1, int(length), 1, 1, 1, 1).flatten(0, 1) 11 | 12 | 13 | class Conv3d_wd(nn.Conv3d): 14 | 15 | def __init__(self, in_channels, out_channels, kernel_size, stride=(1, 1, 1), padding=(0, 0, 0), dilation=(1, 1, 1), 16 | groups=1, bias=True): 17 | super(Conv3d_wd, self).__init__(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias) 18 | 19 | def forward(self, x): 20 | weight = self.weight 21 | weight_mean = weight.mean(dim=1, keepdim=True).mean(dim=2, keepdim=True).mean(dim=3, keepdim=True).mean(dim=4, 22 | keepdim=True) 23 | weight = weight - weight_mean 24 | std = torch.sqrt(torch.var(weight.view(weight.size(0), -1), dim=1) + 1e-12).view(-1, 1, 1, 1, 1) 25 | weight = weight / std.expand_as(weight) 26 | return F.conv3d(x, weight, self.bias, self.stride, self.padding, self.dilation, self.groups) 27 | 28 | 29 | def conv3x3x3(in_planes, out_planes, kernel_size, stride=(1, 1, 1), padding=(0, 0, 0), dilation=(1, 1, 1), groups=1, 30 | bias=True, weight_std=False): 31 | "3x3x3 convolution with padding" 32 | if weight_std: 33 | return Conv3d_wd(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=padding, 34 | dilation=dilation, groups=groups, bias=bias) 35 | else: 36 | return nn.Conv3d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=padding, 37 | dilation=dilation, groups=groups, bias=bias) 38 | 39 | 40 | def Norm_layer(norm_cfg, inplanes): 41 | if norm_cfg == 'BN': 42 | out = nn.BatchNorm3d(inplanes) 43 | elif norm_cfg == 'SyncBN': 44 | out = nn.SyncBatchNorm(inplanes) 45 | elif norm_cfg == 'GN': 46 | out = nn.GroupNorm(16, inplanes) 47 | elif norm_cfg == 'IN': 48 | out = nn.InstanceNorm3d(inplanes, affine=True) 49 | 50 | return out 51 | 52 | 53 | def Activation_layer(activation_cfg, inplace=True): 54 | if activation_cfg == 'relu': 55 | out = nn.ReLU(inplace=inplace) 56 | elif activation_cfg == 'LeakyReLU': 57 | out = nn.LeakyReLU(negative_slope=1e-2, inplace=inplace) 58 | 59 | return out 60 | 61 | 62 | class ResCNN_DeformTR(nn.Module): 63 | def __init__(self, args, norm_cfg='IN', activation_cfg='relu', num_classes=None, weight_std=False, res_depth=None, dyn_head_dep_wid=[3,8]): 64 | super(ResCNN_DeformTR, self).__init__() 65 | 66 | self.args = args 67 | self.args.activation = activation_cfg 68 | self.num_classes = num_classes 69 | self.dyn_head_dep_wid = dyn_head_dep_wid 70 | if res_depth >= 50: 71 | expansion = 4 72 | else: 73 | expansion = 1 74 | 75 | # num dyn params 76 | num_dyn_params = (dyn_head_dep_wid[1]*dyn_head_dep_wid[1]+dyn_head_dep_wid[1])*(dyn_head_dep_wid[0]-1) + (dyn_head_dep_wid[1]*2+2)*1 77 | print(f"###Total dyn params {num_dyn_params}###") 78 | 79 | self.upsample = nn.Upsample(scale_factor=(1, 2, 2), mode='trilinear') 80 | 81 | if self.args.add_memory >= 1: 82 | if self.args.num_feature_levels >= 1: 83 | self.memory_conv3_l0 = nn.Sequential( 84 | conv3x3x3(args.hidden_dim, 256, kernel_size=1, bias=False, weight_std=weight_std), 85 | Norm_layer(norm_cfg, 256), 86 | Activation_layer(activation_cfg, inplace=True), 87 | ) 88 | 89 | self.cnn_bottle = nn.Sequential( 90 | conv3x3x3(320 * expansion, 256, kernel_size=1, bias=False, weight_std=weight_std), 91 | Norm_layer(norm_cfg, 256), 92 | Activation_layer(activation_cfg, inplace=True), 93 | ) 94 | 95 | self.shortcut_conv3 = nn.Sequential( 96 | conv3x3x3(256 * expansion, 256, kernel_size=1, bias=False, weight_std=weight_std), 97 | Norm_layer(norm_cfg, 256), 98 | Activation_layer(activation_cfg, inplace=True), 99 | ) 100 | 101 | self.shortcut_conv2 = nn.Sequential( 102 | conv3x3x3(128 * expansion, 128, kernel_size=1, bias=False, weight_std=weight_std), 103 | Norm_layer(norm_cfg, 128), 104 | Activation_layer(activation_cfg, inplace=True), 105 | ) 106 | 107 | self.shortcut_conv1 = nn.Sequential( 108 | conv3x3x3(64 * expansion, 64, kernel_size=1, bias=False, weight_std=weight_std), 109 | Norm_layer(norm_cfg, 64), 110 | Activation_layer(activation_cfg, inplace=True), 111 | ) 112 | 113 | self.shortcut_conv0 = nn.Sequential( 114 | conv3x3x3(32, 32, kernel_size=1, bias=False, weight_std=weight_std), 115 | Norm_layer(norm_cfg, 32), 116 | Activation_layer(activation_cfg, inplace=True), 117 | ) 118 | 119 | self.transposeconv_stage3 = nn.ConvTranspose3d(256, 256, kernel_size=(2, 2, 2), stride=(2, 2, 2), bias=False) 120 | self.transposeconv_stage2 = nn.ConvTranspose3d(256, 128, kernel_size=(2, 2, 2), stride=(2, 2, 2), bias=False) 121 | self.transposeconv_stage1 = nn.ConvTranspose3d(128, 64, kernel_size=(2, 2, 2), stride=(2, 2, 2), bias=False) 122 | self.transposeconv_stage0 = nn.ConvTranspose3d(64, 32, kernel_size=(2, 2, 2), stride=(2, 2, 2), bias=False) 123 | 124 | self.stage3_de = ResCNN2.BasicBlock(256, 256, norm_cfg, activation_cfg, weight_std=weight_std) 125 | self.stage2_de = ResCNN2.BasicBlock(128, 128, norm_cfg, activation_cfg, weight_std=weight_std) 126 | self.stage1_de = ResCNN2.BasicBlock(64, 64, norm_cfg, activation_cfg, weight_std=weight_std) 127 | self.stage0_de = ResCNN2.BasicBlock(32, 32, norm_cfg, activation_cfg, weight_std=weight_std) 128 | 129 | self.precls_conv = nn.Sequential( 130 | conv3x3x3(32, dyn_head_dep_wid[1], kernel_size=1, bias=False, weight_std=weight_std), 131 | Norm_layer(norm_cfg, dyn_head_dep_wid[1]), 132 | Activation_layer(activation_cfg, inplace=True), 133 | ) 134 | 135 | self.backbone = ResCNN2.ResNet(depth=res_depth, shortcut_type='B', norm_cfg=norm_cfg, 136 | activation_cfg=activation_cfg, 137 | weight_std=weight_std) 138 | 139 | self.backbone_layers = self.backbone.get_layers() 140 | 141 | # 142 | if self.args.using_transformer: 143 | self.transformer = build_deformable_transformer(args) 144 | self.position_embedding = build_position_encoding(args) 145 | self.controller = nn.Sequential( 146 | nn.Linear(args.hidden_dim, args.hidden_dim), 147 | Activation_layer(activation_cfg, inplace=True), 148 | nn.Linear(args.hidden_dim, num_dyn_params) 149 | ) 150 | 151 | input_proj_list = [] 152 | backbone_num_channels = [64 * expansion, 128 * expansion, 256 * expansion, 320 * expansion] 153 | backbone_num_channels = backbone_num_channels[-args.num_feature_levels:] 154 | for _ in backbone_num_channels: 155 | in_channels = _ 156 | input_proj_list.append(nn.Sequential( 157 | conv3x3x3(in_channels, args.hidden_dim, kernel_size=1, bias=True, weight_std=weight_std), 158 | Norm_layer(norm_cfg, args.hidden_dim), 159 | )) 160 | self.input_proj = nn.ModuleList(input_proj_list) 161 | self.query_embed = nn.Embedding(self.args.num_queries, args.hidden_dim * 2) 162 | 163 | for m in self.modules(): 164 | if isinstance(m, (nn.Conv3d, Conv3d_wd, nn.ConvTranspose3d)): 165 | m.weight = nn.init.kaiming_normal_(m.weight, mode='fan_out') 166 | elif isinstance(m, (nn.BatchNorm3d, nn.SyncBatchNorm, nn.InstanceNorm3d, nn.GroupNorm)): 167 | if m.weight is not None: 168 | nn.init.constant_(m.weight, 1) 169 | if m.bias is not None: 170 | nn.init.constant_(m.bias, 0) 171 | 172 | def parse_dynamic_params(self, params, channels, weight_nums, bias_nums): 173 | assert params.dim() == 2 174 | assert len(weight_nums) == len(bias_nums) 175 | assert params.size(1) == sum(weight_nums) + sum(bias_nums) 176 | 177 | num_insts = params.size(0) 178 | num_layers = len(weight_nums) 179 | 180 | params_splits = list(torch.split_with_sizes( 181 | params, weight_nums + bias_nums, dim=1 182 | )) 183 | 184 | weight_splits = params_splits[:num_layers] 185 | bias_splits = params_splits[num_layers:] 186 | 187 | for l in range(num_layers): 188 | if l < num_layers - 1: 189 | weight_splits[l] = weight_splits[l].reshape(num_insts * channels, -1, 1, 1, 1) 190 | bias_splits[l] = bias_splits[l].reshape(num_insts * channels) 191 | else: 192 | weight_splits[l] = weight_splits[l].reshape(num_insts * 2, -1, 1, 1, 1) 193 | bias_splits[l] = bias_splits[l].reshape(num_insts * 2) 194 | 195 | return weight_splits, bias_splits 196 | 197 | def heads_forward(self, features, weights, biases, num_insts): 198 | assert features.dim() == 5 199 | n_layers = len(weights) 200 | x = features 201 | for i, (w, b) in enumerate(zip(weights, biases)): 202 | x = F.conv3d( 203 | x, w, bias=b, 204 | stride=1, padding=0, 205 | groups=num_insts 206 | ) 207 | if i < n_layers - 1: 208 | x = F.relu(x) 209 | return x 210 | 211 | def forward(self, inputs_x, task_id): 212 | bs, c, d, w, h = inputs_x.shape 213 | 214 | _ = self.backbone(inputs_x) 215 | layers = self.backbone.get_layers() # [::-1] 216 | 217 | # Transformer body 218 | srcs = [] 219 | masks = [] 220 | pos = [] 221 | for l, feat in enumerate(layers[-self.args.num_feature_levels:]): 222 | src = feat 223 | srcs.append(self.input_proj[l](src)) 224 | masks.append(torch.BoolTensor(src.shape[0], src.shape[2], src.shape[3], src.shape[4]).cuda() * False) 225 | pos.append(self.position_embedding(src).to(src.dtype)) 226 | del feat 227 | 228 | hs, _, _, _, _, memory = self.transformer(srcs, masks, pos, self.query_embed.weight) 229 | params = self.controller(hs[-1].flatten(0, 1)) 230 | 231 | if self.args.add_memory == 0: 232 | x = self.cnn_bottle(layers[-1]) 233 | elif self.args.add_memory == 1: 234 | x = self.memory_conv3_l0(memory[-1]) 235 | elif self.args.add_memory == 2: 236 | x = self.memory_conv3_l0(memory[-1]) + self.cnn_bottle(layers[-1]) 237 | else: 238 | print("Error: no pre-defined add_memory mode!") 239 | 240 | skip3 = self.shortcut_conv3(layers[-2]) 241 | x = self.transposeconv_stage3(x) 242 | x = x + skip3 243 | x = self.stage3_de(x) 244 | 245 | skip2 = self.shortcut_conv2(layers[-3]) 246 | x = self.transposeconv_stage2(x) 247 | x = x + skip2 248 | x = self.stage2_de(x) 249 | 250 | x = self.transposeconv_stage1(x) 251 | skip1 = self.shortcut_conv1(layers[-4]) 252 | x = x + skip1 253 | x = self.stage1_de(x) 254 | 255 | x = self.transposeconv_stage0(x) 256 | skip0 = self.shortcut_conv0(layers[-5]) 257 | x = x + skip0 258 | x = self.stage0_de(x) 259 | 260 | head_inputs = self.precls_conv(x) 261 | head_inputs = _expand(head_inputs, self.args.num_queries) 262 | N, _, D, H, W = head_inputs.size() 263 | head_inputs = head_inputs.reshape(1, -1, D, H, W) 264 | 265 | weight_nums, bias_nums = [], [] 266 | for i in range(self.dyn_head_dep_wid[0]-1): 267 | weight_nums.append(self.dyn_head_dep_wid[1]*self.dyn_head_dep_wid[1]) 268 | bias_nums.append(self.dyn_head_dep_wid[1]) 269 | weight_nums.append(self.dyn_head_dep_wid[1]*2) 270 | bias_nums.append(2) 271 | 272 | weights, biases = self.parse_dynamic_params(params, self.dyn_head_dep_wid[1], weight_nums, bias_nums) 273 | 274 | seg_out = self.heads_forward(head_inputs, weights, biases, N) 275 | seg_out = self.upsample(seg_out) 276 | seg_out = seg_out.view(bs, self.args.num_queries, self.num_classes, seg_out.shape[-3], seg_out.shape[-2], 277 | seg_out.shape[-1]) 278 | 279 | return seg_out 280 | 281 | 282 | class TransDoDNet(nn.Module): 283 | def __init__(self, args, norm_cfg='IN', activation_cfg='relu', num_classes=None, 284 | weight_std=False, deep_supervision=False, res_depth=None, dyn_head_dep_wid=[3,8]): 285 | super().__init__() 286 | self.do_ds = False 287 | self.ResCNN_DeformTR = ResCNN_DeformTR(args, norm_cfg, activation_cfg, num_classes, weight_std, res_depth, dyn_head_dep_wid) 288 | if weight_std == False: 289 | self.conv_op = nn.Conv3d 290 | else: 291 | self.conv_op = Conv3d_wd 292 | if norm_cfg == 'BN': 293 | self.norm_op = nn.BatchNorm3d 294 | if norm_cfg == 'GN': 295 | self.norm_op = nn.GroupNorm 296 | if norm_cfg == 'IN': 297 | self.norm_op = nn.InstanceNorm3d 298 | self.num_classes = num_classes 299 | self._deep_supervision = deep_supervision 300 | self.do_ds = deep_supervision 301 | 302 | def forward(self, x, task_id): 303 | seg_output = self.ResCNN_DeformTR(x, task_id) 304 | return seg_output -------------------------------------------------------------------------------- /TransDoD/models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jianpengz/DoDNet/f06ce1b02988b6894a85f925623c6bcc6012284f/TransDoD/models/__init__.py -------------------------------------------------------------------------------- /TransDoD/models/__pycache__/ResCNN2.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jianpengz/DoDNet/f06ce1b02988b6894a85f925623c6bcc6012284f/TransDoD/models/__pycache__/ResCNN2.cpython-37.pyc -------------------------------------------------------------------------------- /TransDoD/models/__pycache__/TransDoDNet.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jianpengz/DoDNet/f06ce1b02988b6894a85f925623c6bcc6012284f/TransDoD/models/__pycache__/TransDoDNet.cpython-37.pyc -------------------------------------------------------------------------------- /TransDoD/models/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jianpengz/DoDNet/f06ce1b02988b6894a85f925623c6bcc6012284f/TransDoD/models/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /TransDoD/models/__pycache__/deepnorm.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jianpengz/DoDNet/f06ce1b02988b6894a85f925623c6bcc6012284f/TransDoD/models/__pycache__/deepnorm.cpython-37.pyc -------------------------------------------------------------------------------- /TransDoD/models/__pycache__/deformable_transformer.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jianpengz/DoDNet/f06ce1b02988b6894a85f925623c6bcc6012284f/TransDoD/models/__pycache__/deformable_transformer.cpython-37.pyc -------------------------------------------------------------------------------- /TransDoD/models/__pycache__/position_encoding.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jianpengz/DoDNet/f06ce1b02988b6894a85f925623c6bcc6012284f/TransDoD/models/__pycache__/position_encoding.cpython-37.pyc -------------------------------------------------------------------------------- /TransDoD/models/deepnorm.py: -------------------------------------------------------------------------------- 1 | from enum import Enum 2 | from typing import List, Optional, Tuple, Union 3 | from collections import namedtuple 4 | 5 | DeepNormCoefficients = namedtuple("DeepNormCoefficients", ["alpha", "beta"]) 6 | 7 | def get_deepnorm_coefficients( 8 | encoder_layers: int, decoder_layers: int 9 | ) -> Tuple[Optional[DeepNormCoefficients], Optional[DeepNormCoefficients]]: 10 | """ 11 | See DeepNet_. 12 | Returns alpha and beta depending on the number of encoder and decoder layers, 13 | first tuple is for the for the encoder and second for the decoder 14 | .. _DeepNet: https://arxiv.org/pdf/2203.00555v1.pdf 15 | """ 16 | 17 | N = encoder_layers 18 | M = decoder_layers 19 | 20 | if decoder_layers == 0: 21 | # Encoder only 22 | return ( 23 | DeepNormCoefficients(alpha=(2 * N) ** 0.25, beta=(8 * N) ** -0.25), 24 | None, 25 | ) 26 | 27 | elif encoder_layers == 0: 28 | # Decoder only 29 | return None, DeepNormCoefficients(alpha=(2 * M) ** 0.25, beta=(8 * M) ** -0.25) 30 | else: 31 | # Encoder/decoder 32 | encoder_coeffs = DeepNormCoefficients( 33 | alpha=0.81 * ((N ** 4) * M) ** 0.0625, beta=0.87 * ((N ** 4) * M) ** -0.0625 34 | ) 35 | 36 | decoder_coeffs = DeepNormCoefficients( 37 | alpha=(3 * M) ** 0.25, beta=(12 * M) ** -0.25 38 | ) 39 | 40 | return (encoder_coeffs, decoder_coeffs) 41 | 42 | if __name__=="__main__": 43 | coef = get_deepnorm_coefficients(1000,1000) 44 | pass -------------------------------------------------------------------------------- /TransDoD/models/ops/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jianpengz/DoDNet/f06ce1b02988b6894a85f925623c6bcc6012284f/TransDoD/models/ops/.DS_Store -------------------------------------------------------------------------------- /TransDoD/models/ops/MultiScaleDeformableAttention.egg-info/PKG-INFO: -------------------------------------------------------------------------------- 1 | Metadata-Version: 2.1 2 | Name: MultiScaleDeformableAttention 3 | Version: 1.0 4 | Summary: PyTorch Wrapper for CUDA Functions of Multi-Scale Deformable Attention 5 | Home-page: https://github.com/fundamentalvision/Deformable-DETR 6 | Author: Weijie Su 7 | License: UNKNOWN 8 | Platform: UNKNOWN 9 | 10 | UNKNOWN 11 | 12 | -------------------------------------------------------------------------------- /TransDoD/models/ops/MultiScaleDeformableAttention.egg-info/SOURCES.txt: -------------------------------------------------------------------------------- 1 | setup.py 2 | /media/userdisk1/jpzhang/myproject-Seg/MOTS-pro2/f_transformers/models/ops/src/vision.cpp 3 | /media/userdisk1/jpzhang/myproject-Seg/MOTS-pro2/f_transformers/models/ops/src/cpu/ms_deform_attn_cpu.cpp 4 | /media/userdisk1/jpzhang/myproject-Seg/MOTS-pro2/f_transformers/models/ops/src/cuda/ms_deform_attn_cuda.cu 5 | MultiScaleDeformableAttention.egg-info/PKG-INFO 6 | MultiScaleDeformableAttention.egg-info/SOURCES.txt 7 | MultiScaleDeformableAttention.egg-info/dependency_links.txt 8 | MultiScaleDeformableAttention.egg-info/top_level.txt 9 | functions/__init__.py 10 | functions/ms_deform_attn_func.py 11 | modules/__init__.py 12 | modules/ms_deform_attn (33).py 13 | modules/ms_deform_attn.py -------------------------------------------------------------------------------- /TransDoD/models/ops/MultiScaleDeformableAttention.egg-info/dependency_links.txt: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /TransDoD/models/ops/MultiScaleDeformableAttention.egg-info/top_level.txt: -------------------------------------------------------------------------------- 1 | MultiScaleDeformableAttention 2 | functions 3 | modules 4 | -------------------------------------------------------------------------------- /TransDoD/models/ops/build/lib.linux-x86_64-3.7/MultiScaleDeformableAttention.cpython-37m-x86_64-linux-gnu.so: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jianpengz/DoDNet/f06ce1b02988b6894a85f925623c6bcc6012284f/TransDoD/models/ops/build/lib.linux-x86_64-3.7/MultiScaleDeformableAttention.cpython-37m-x86_64-linux-gnu.so -------------------------------------------------------------------------------- /TransDoD/models/ops/build/lib.linux-x86_64-3.7/functions/__init__.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------------------------ 2 | # Deformable DETR 3 | # Copyright (c) 2020 SenseTime. All Rights Reserved. 4 | # Licensed under the Apache License, Version 2.0 [see LICENSE for details] 5 | # ------------------------------------------------------------------------------------------------ 6 | # Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 7 | # ------------------------------------------------------------------------------------------------ 8 | 9 | from .ms_deform_attn_func import MSDeformAttnFunction 10 | 11 | -------------------------------------------------------------------------------- /TransDoD/models/ops/build/lib.linux-x86_64-3.7/functions/ms_deform_attn_func.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------------------------ 2 | # Deformable DETR 3 | # Copyright (c) 2020 SenseTime. All Rights Reserved. 4 | # Licensed under the Apache License, Version 2.0 [see LICENSE for details] 5 | # ------------------------------------------------------------------------------------------------ 6 | # Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 7 | # ------------------------------------------------------------------------------------------------ 8 | 9 | from __future__ import absolute_import 10 | from __future__ import print_function 11 | from __future__ import division 12 | 13 | import torch 14 | import torch.nn.functional as F 15 | from torch.autograd import Function 16 | from torch.autograd.function import once_differentiable 17 | 18 | import MultiScaleDeformableAttention as MSDA 19 | 20 | 21 | class MSDeformAttnFunction(Function): 22 | @staticmethod 23 | def forward(ctx, value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights, im2col_step): 24 | ctx.im2col_step = im2col_step 25 | output = MSDA.ms_deform_attn_forward( 26 | value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights, ctx.im2col_step) 27 | ctx.save_for_backward(value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights) 28 | return output 29 | 30 | @staticmethod 31 | @once_differentiable 32 | def backward(ctx, grad_output): 33 | value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights = ctx.saved_tensors 34 | grad_value, grad_sampling_loc, grad_attn_weight = \ 35 | MSDA.ms_deform_attn_backward( 36 | value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights, grad_output, ctx.im2col_step) 37 | 38 | return grad_value, None, None, grad_sampling_loc, grad_attn_weight, None 39 | 40 | 41 | def ms_deform_attn_core_pytorch(value, value_spatial_shapes, sampling_locations, attention_weights): 42 | # for debug and test only, 43 | # need to use cuda version instead 44 | N_, S_, M_, D_ = value.shape 45 | _, Lq_, M_, L_, P_, _ = sampling_locations.shape 46 | value_list = value.split([H_ * W_ for H_, W_ in value_spatial_shapes], dim=1) 47 | sampling_grids = 2 * sampling_locations - 1 48 | sampling_value_list = [] 49 | for lid_, (H_, W_) in enumerate(value_spatial_shapes): 50 | # N_, H_*W_, M_, D_ -> N_, H_*W_, M_*D_ -> N_, M_*D_, H_*W_ -> N_*M_, D_, H_, W_ 51 | value_l_ = value_list[lid_].flatten(2).transpose(1, 2).reshape(N_*M_, D_, H_, W_) 52 | # N_, Lq_, M_, P_, 2 -> N_, M_, Lq_, P_, 2 -> N_*M_, Lq_, P_, 2 53 | sampling_grid_l_ = sampling_grids[:, :, :, lid_].transpose(1, 2).flatten(0, 1) 54 | # N_*M_, D_, Lq_, P_ 55 | sampling_value_l_ = F.grid_sample(value_l_, sampling_grid_l_, 56 | mode='bilinear', padding_mode='zeros', align_corners=False) 57 | sampling_value_list.append(sampling_value_l_) 58 | # (N_, Lq_, M_, L_, P_) -> (N_, M_, Lq_, L_, P_) -> (N_, M_, 1, Lq_, L_*P_) 59 | attention_weights = attention_weights.transpose(1, 2).reshape(N_*M_, 1, Lq_, L_*P_) 60 | output = (torch.stack(sampling_value_list, dim=-2).flatten(-2) * attention_weights).sum(-1).view(N_, M_*D_, Lq_) 61 | return output.transpose(1, 2).contiguous() 62 | 63 | def ms_deform_attn_core_pytorch_3D(value, value_spatial_shapes, sampling_locations, attention_weights): 64 | # for debug and test only, 65 | # need to use cuda version instead 66 | N_, S_, M_, D_ = value.shape 67 | _, Lq_, M_, L_, P_, _ = sampling_locations.shape 68 | value_list = value.split([T_ * H_ * W_ for T_, H_, W_ in value_spatial_shapes], dim=1) 69 | sampling_grids = 2 * sampling_locations - 1 70 | sampling_value_list = [] 71 | for lid_, (T_, H_, W_) in enumerate(value_spatial_shapes): 72 | # N_, T_*H_*W_, M_, D_ -> N_, T_*H_*W_, M_*D_ -> N_, M_*D_, T_*H_*W_ -> N_*M_, D_, T_, H_, W_ 73 | value_l_ = value_list[lid_].flatten(2).transpose(1, 2).reshape(N_*M_, D_, T_, H_, W_) 74 | # N_, Lq_, M_, P_, 3 -> N_, M_, Lq_, P_, 3 -> N_*M_, Lq_, P_, 3 75 | sampling_grid_l_ = sampling_grids[:, :, :, lid_].transpose(1, 2).flatten(0, 1)[:,None,:,:,:] 76 | # N_*M_, D_, Lq_, P_ 77 | sampling_value_l_ = F.grid_sample(value_l_, sampling_grid_l_, mode='bilinear', padding_mode='zeros', align_corners=False)[:,:,0] 78 | sampling_value_list.append(sampling_value_l_) 79 | # (N_, Lq_, M_, L_, P_) -> (N_, M_, Lq_, L_, P_) -> (N_, M_, 1, Lq_, L_*P_) 80 | attention_weights = attention_weights.transpose(1, 2).reshape(N_*M_, 1, Lq_, L_*P_) 81 | output = (torch.stack(sampling_value_list, dim=-2).flatten(-2) * attention_weights).sum(-1).view(N_, M_*D_, Lq_) 82 | return output.transpose(1, 2).contiguous() 83 | -------------------------------------------------------------------------------- /TransDoD/models/ops/build/lib.linux-x86_64-3.7/modules/__init__.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------------------------ 2 | # Deformable DETR 3 | # Copyright (c) 2020 SenseTime. All Rights Reserved. 4 | # Licensed under the Apache License, Version 2.0 [see LICENSE for details] 5 | # ------------------------------------------------------------------------------------------------ 6 | # Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 7 | # ------------------------------------------------------------------------------------------------ 8 | 9 | from .ms_deform_attn import MSDeformAttn 10 | -------------------------------------------------------------------------------- /TransDoD/models/ops/build/lib.linux-x86_64-3.7/modules/ms_deform_attn (33).py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------------------------ 2 | # Deformable DETR 3 | # Copyright (c) 2020 SenseTime. All Rights Reserved. 4 | # Licensed under the Apache License, Version 2.0 [see LICENSE for details] 5 | # ------------------------------------------------------------------------------------------------ 6 | # Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 7 | # ------------------------------------------------------------------------------------------------ 8 | 9 | from __future__ import absolute_import 10 | from __future__ import print_function 11 | from __future__ import division 12 | 13 | import warnings 14 | import math 15 | 16 | import torch 17 | from torch import nn 18 | import torch.nn.functional as F 19 | from torch.nn.init import xavier_uniform_, constant_ 20 | 21 | from ..functions import MSDeformAttnFunction 22 | from ..functions.ms_deform_attn_func import MSDeformAttnFunction, ms_deform_attn_core_pytorch_3D 23 | 24 | def _is_power_of_2(n): 25 | if (not isinstance(n, int)) or (n < 0): 26 | raise ValueError("invalid input for _is_power_of_2: {} (type: {})".format(n, type(n))) 27 | return (n & (n-1) == 0) and n != 0 28 | 29 | 30 | class MSDeformAttn(nn.Module): 31 | def __init__(self, d_model=256, n_levels=4, n_heads=8, n_points=4): 32 | """ 33 | Multi-Scale Deformable Attention Module 34 | :param d_model hidden dimension 35 | :param n_levels number of feature levels 36 | :param n_heads number of attention heads 37 | :param n_points number of sampling points per attention head per feature level 38 | """ 39 | super().__init__() 40 | if d_model % n_heads != 0: 41 | raise ValueError('d_model must be divisible by n_heads, but got {} and {}'.format(d_model, n_heads)) 42 | _d_per_head = d_model // n_heads 43 | # you'd better set _d_per_head to a power of 2 which is more efficient in our CUDA implementation 44 | if not _is_power_of_2(_d_per_head): 45 | warnings.warn("You'd better set d_model in MSDeformAttn to make the dimension of each attention head a power of 2 " 46 | "which is more efficient in our CUDA implementation.") 47 | 48 | self.im2col_step = 64 49 | 50 | self.d_model = d_model 51 | self.n_levels = n_levels 52 | self.n_heads = n_heads 53 | self.n_points = n_points 54 | 55 | self.sampling_offsets = nn.Linear(d_model, n_heads * n_levels * n_points * 3) 56 | self.attention_weights = nn.Linear(d_model, n_heads * n_levels * n_points) 57 | self.value_proj = nn.Linear(d_model, d_model) 58 | self.output_proj = nn.Linear(d_model, d_model) 59 | 60 | # self.s_norm = nn.LayerNorm(n_heads * n_levels * n_points * 3) 61 | 62 | self._reset_parameters() 63 | 64 | # add for vis 65 | self.samp_location = [] 66 | self.atte_w = [] 67 | 68 | def _reset_parameters(self): 69 | constant_(self.sampling_offsets.weight.data, 0.) 70 | thetas = torch.arange(self.n_heads, dtype=torch.float32) * (2.0 * math.pi / self.n_heads) 71 | # grid_init = torch.stack([thetas.sin()*thetas.cos(), thetas.sin()*thetas.sin(), thetas.cos()], -1) 72 | # grid_init = torch.stack([thetas.sin() * thetas.cos(), thetas.sin(), thetas.cos()], -1) 73 | # grid_init = torch.stack([thetas.sin() * thetas.cos(), thetas.cos(), thetas.sin()], -1) 74 | grid_init = torch.stack([2*thetas.sin()*thetas.cos(), thetas.cos(), thetas.sin()], -1) 75 | 76 | # thetas = torch.arange(self.n_heads, dtype=torch.float32) * (2.0 * math.pi / self.n_heads) 77 | # beta = torch.arange(-self.n_heads//2, self.n_heads//2, dtype=torch.float32) * (1.0 * math.pi / self.n_heads) 78 | # grid_init = torch.stack([beta.sin(), beta.cos() * thetas.cos(), beta.cos() * thetas.sin()], -1) 79 | 80 | # thetas = torch.arange(self.n_heads, dtype=torch.float32) * (2.0 * math.pi / self.n_heads) 81 | # beta = torch.arange(self.n_heads, dtype=torch.float32) * (2.0 * math.pi / self.n_heads) 82 | # grid_init = torch.stack([beta.sin() * thetas.sin(), beta.cos(), beta.sin() * thetas.cos()], -1) 83 | 84 | grid_init = (grid_init / grid_init.abs().max(-1, keepdim=True)[0]).view(self.n_heads, 1, 1, 3).repeat(1, self.n_levels, self.n_points, 1) 85 | for i in range(self.n_points): 86 | grid_init[:, :, i, :] *= i + 1 87 | with torch.no_grad(): 88 | self.sampling_offsets.bias = nn.Parameter(grid_init.view(-1)) 89 | constant_(self.attention_weights.weight.data, 0.) 90 | constant_(self.attention_weights.bias.data, 0.) 91 | xavier_uniform_(self.value_proj.weight.data) 92 | constant_(self.value_proj.bias.data, 0.) 93 | xavier_uniform_(self.output_proj.weight.data) 94 | constant_(self.output_proj.bias.data, 0.) 95 | 96 | def forward(self, query, reference_points, input_flatten, input_spatial_shapes, input_level_start_index, input_padding_mask=None): 97 | """ 98 | :param query (N, Length_{query}, C) 99 | :param reference_points (N, Length_{query}, n_levels, 2), range in [0, 1], top-left (0,0), bottom-right (1, 1), including padding area 100 | or (N, Length_{query}, n_levels, 4), add additional (w, h) to form reference boxes 101 | :param input_flatten (N, \sum_{l=0}^{L-1} H_l \cdot W_l, C) 102 | :param input_spatial_shapes (n_levels, 2), [(H_0, W_0), (H_1, W_1), ..., (H_{L-1}, W_{L-1})] 103 | :param input_level_start_index (n_levels, ), [0, H_0*W_0, H_0*W_0+H_1*W_1, H_0*W_0+H_1*W_1+H_2*W_2, ..., H_0*W_0+H_1*W_1+...+H_{L-1}*W_{L-1}] 104 | :param input_padding_mask (N, \sum_{l=0}^{L-1} H_l \cdot W_l), True for padding elements, False for non-padding elements 105 | 106 | :return output (N, Length_{query}, C) 107 | """ 108 | N, Len_q, _ = query.shape 109 | N, Len_in, _ = input_flatten.shape 110 | assert (input_spatial_shapes[:, 0] * input_spatial_shapes[:, 1] * input_spatial_shapes[:, 2]).sum() == Len_in 111 | 112 | value = self.value_proj(input_flatten) 113 | if input_padding_mask is not None: 114 | value = value.masked_fill(input_padding_mask[..., None], float(0)) 115 | value = value.view(N, Len_in, self.n_heads, self.d_model // self.n_heads) 116 | ## using Tanh 117 | sampling_offsets = self.sampling_offsets(query).view(N, Len_q, self.n_heads, self.n_levels, self.n_points, 3) 118 | # sampling_offsets = F.tanh(sampling_offsets) 119 | 120 | ## or using norm 121 | # sampling_offsets = self.s_norm(self.sampling_offsets(query)).view(N, Len_q, self.n_heads, self.n_levels, self.n_points, 3) 122 | 123 | attention_weights = self.attention_weights(query).view(N, Len_q, self.n_heads, self.n_levels * self.n_points) 124 | attention_weights = F.softmax(attention_weights, -1).view(N, Len_q, self.n_heads, self.n_levels, self.n_points) 125 | # N, Len_q, n_heads, n_levels, n_points, 3 126 | if reference_points.shape[-1] == 3: 127 | offset_normalizer = torch.stack([input_spatial_shapes[..., 2], input_spatial_shapes[..., 2], input_spatial_shapes[..., 1]], -1) 128 | sampling_locations = reference_points[:, :, None, :, None, :] \ 129 | + sampling_offsets / (2*offset_normalizer[None, None, None, :, None, :]) 130 | # sampling_locations = reference_points[:, :, None, :, None, :] + sampling_offsets 131 | elif reference_points.shape[-1] == 4: 132 | sampling_locations = reference_points[:, :, None, :, None, :2] \ 133 | + sampling_offsets / self.n_points * reference_points[:, :, None, :, None, 2:] * 0.5 134 | else: 135 | raise ValueError( 136 | 'Last dim of reference_points must be 2 or 4, but get {} instead.'.format(reference_points.shape[-1])) 137 | # output = MSDeformAttnFunction.apply( 138 | # value, input_spatial_shapes, input_level_start_index, sampling_locations, attention_weights, self.im2col_step) 139 | # debug 140 | output = ms_deform_attn_core_pytorch_3D(value, input_spatial_shapes, sampling_locations, 141 | attention_weights)#.detach().cpu() 142 | # output_cuda = MSDeformAttnFunction.apply(value, input_spatial_shapes, input_level_start_index, 143 | # sampling_locations, attention_weights, self.im2col_step).detach().cpu() 144 | 145 | output = self.output_proj(output) 146 | self.atte_w = attention_weights 147 | self.samp_location = sampling_locations 148 | return output 149 | 150 | 151 | -------------------------------------------------------------------------------- /TransDoD/models/ops/build/lib.linux-x86_64-3.7/modules/ms_deform_attn.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------------------------ 2 | # Deformable DETR 3 | # Copyright (c) 2020 SenseTime. All Rights Reserved. 4 | # Licensed under the Apache License, Version 2.0 [see LICENSE for details] 5 | # ------------------------------------------------------------------------------------------------ 6 | # Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 7 | # ------------------------------------------------------------------------------------------------ 8 | 9 | from __future__ import absolute_import 10 | from __future__ import print_function 11 | from __future__ import division 12 | 13 | import warnings 14 | import math 15 | 16 | import torch 17 | from torch import nn 18 | import torch.nn.functional as F 19 | from torch.nn.init import xavier_uniform_, constant_ 20 | 21 | from ..functions import MSDeformAttnFunction 22 | from ..functions.ms_deform_attn_func import MSDeformAttnFunction, ms_deform_attn_core_pytorch_3D 23 | 24 | def _is_power_of_2(n): 25 | if (not isinstance(n, int)) or (n < 0): 26 | raise ValueError("invalid input for _is_power_of_2: {} (type: {})".format(n, type(n))) 27 | return (n & (n-1) == 0) and n != 0 28 | 29 | 30 | class MSDeformAttn(nn.Module): 31 | def __init__(self, d_model=256, n_levels=4, n_heads=8, n_points=4): 32 | """ 33 | Multi-Scale Deformable Attention Module 34 | :param d_model hidden dimension 35 | :param n_levels number of feature levels 36 | :param n_heads number of attention heads 37 | :param n_points number of sampling points per attention head per feature level 38 | """ 39 | super().__init__() 40 | if d_model % n_heads != 0: 41 | raise ValueError('d_model must be divisible by n_heads, but got {} and {}'.format(d_model, n_heads)) 42 | _d_per_head = d_model // n_heads 43 | # you'd better set _d_per_head to a power of 2 which is more efficient in our CUDA implementation 44 | if not _is_power_of_2(_d_per_head): 45 | warnings.warn("You'd better set d_model in MSDeformAttn to make the dimension of each attention head a power of 2 " 46 | "which is more efficient in our CUDA implementation.") 47 | 48 | self.im2col_step = 64 49 | 50 | self.d_model = d_model 51 | self.n_levels = n_levels 52 | self.n_heads = n_heads 53 | self.n_points = n_points 54 | 55 | self.sampling_offsets = nn.Linear(d_model, n_heads * n_levels * n_points * 3) 56 | self.attention_weights = nn.Linear(d_model, n_heads * n_levels * n_points) 57 | self.value_proj = nn.Linear(d_model, d_model) 58 | self.output_proj = nn.Linear(d_model, d_model) 59 | 60 | # self.s_norm = nn.LayerNorm(n_heads * n_levels * n_points * 3) 61 | 62 | self._reset_parameters() 63 | 64 | # add for vis 65 | self.samp_location = [] 66 | self.atte_w = [] 67 | 68 | def _reset_parameters(self): 69 | constant_(self.sampling_offsets.weight.data, 0.) 70 | thetas = torch.arange(self.n_heads, dtype=torch.float32) * (2.0 * math.pi / self.n_heads) 71 | # grid_init = torch.stack([thetas.sin()*thetas.cos(), thetas.sin()*thetas.sin(), thetas.cos()], -1) 72 | # grid_init = torch.stack([thetas.sin() * thetas.cos(), thetas.sin(), thetas.cos()], -1) 73 | # grid_init = torch.stack([thetas.sin() * thetas.cos(), thetas.cos(), thetas.sin()], -1) 74 | grid_init = torch.stack([2*thetas.sin()*thetas.cos(), thetas.cos(), thetas.sin()], -1) 75 | 76 | # thetas = torch.arange(self.n_heads, dtype=torch.float32) * (2.0 * math.pi / self.n_heads) 77 | # beta = torch.arange(-self.n_heads//2, self.n_heads//2, dtype=torch.float32) * (1.0 * math.pi / self.n_heads) 78 | # grid_init = torch.stack([beta.sin(), beta.cos() * thetas.cos(), beta.cos() * thetas.sin()], -1) 79 | 80 | # thetas = torch.arange(self.n_heads, dtype=torch.float32) * (2.0 * math.pi / self.n_heads) 81 | # beta = torch.arange(self.n_heads, dtype=torch.float32) * (2.0 * math.pi / self.n_heads) 82 | # grid_init = torch.stack([beta.sin() * thetas.sin(), beta.cos(), beta.sin() * thetas.cos()], -1) 83 | 84 | grid_init = (grid_init / grid_init.abs().max(-1, keepdim=True)[0]).view(self.n_heads, 1, 1, 3).repeat(1, self.n_levels, self.n_points, 1) 85 | for i in range(self.n_points): 86 | grid_init[:, :, i, :] *= i + 1 87 | with torch.no_grad(): 88 | self.sampling_offsets.bias = nn.Parameter(grid_init.view(-1)) 89 | constant_(self.attention_weights.weight.data, 0.) 90 | constant_(self.attention_weights.bias.data, 0.) 91 | xavier_uniform_(self.value_proj.weight.data) 92 | constant_(self.value_proj.bias.data, 0.) 93 | xavier_uniform_(self.output_proj.weight.data) 94 | constant_(self.output_proj.bias.data, 0.) 95 | 96 | def forward(self, query, reference_points, input_flatten, input_spatial_shapes, input_level_start_index, input_padding_mask=None): 97 | """ 98 | :param query (N, Length_{query}, C) 99 | :param reference_points (N, Length_{query}, n_levels, 2), range in [0, 1], top-left (0,0), bottom-right (1, 1), including padding area 100 | or (N, Length_{query}, n_levels, 4), add additional (w, h) to form reference boxes 101 | :param input_flatten (N, \sum_{l=0}^{L-1} H_l \cdot W_l, C) 102 | :param input_spatial_shapes (n_levels, 2), [(H_0, W_0), (H_1, W_1), ..., (H_{L-1}, W_{L-1})] 103 | :param input_level_start_index (n_levels, ), [0, H_0*W_0, H_0*W_0+H_1*W_1, H_0*W_0+H_1*W_1+H_2*W_2, ..., H_0*W_0+H_1*W_1+...+H_{L-1}*W_{L-1}] 104 | :param input_padding_mask (N, \sum_{l=0}^{L-1} H_l \cdot W_l), True for padding elements, False for non-padding elements 105 | 106 | :return output (N, Length_{query}, C) 107 | """ 108 | N, Len_q, _ = query.shape 109 | N, Len_in, _ = input_flatten.shape 110 | assert (input_spatial_shapes[:, 0] * input_spatial_shapes[:, 1] * input_spatial_shapes[:, 2]).sum() == Len_in 111 | 112 | value = self.value_proj(input_flatten) 113 | if input_padding_mask is not None: 114 | value = value.masked_fill(input_padding_mask[..., None], float(0)) 115 | value = value.view(N, Len_in, self.n_heads, self.d_model // self.n_heads) 116 | ## using Tanh 117 | sampling_offsets = self.sampling_offsets(query).view(N, Len_q, self.n_heads, self.n_levels, self.n_points, 3) 118 | 119 | ## or using norm 120 | # sampling_offsets = self.s_norm(self.sampling_offsets(query)).view(N, Len_q, self.n_heads, self.n_levels, self.n_points, 3) 121 | # sampling_offsets = F.sigmoid(sampling_offsets)-0.5 122 | 123 | # attention_weights = self.attention_weights(query).view(N, Len_q, self.n_heads, self.n_levels * self.n_points) 124 | attention_weights = self.attention_weights(query).view(N, Len_q, self.n_heads, self.n_levels, self.n_points) 125 | attention_weights = F.softmax(attention_weights, -1).view(N, Len_q, self.n_heads, self.n_levels, self.n_points) 126 | # N, Len_q, n_heads, n_levels, n_points, 3 127 | if reference_points.shape[-1] == 3: 128 | offset_normalizer = torch.stack([input_spatial_shapes[..., 2], input_spatial_shapes[..., 2], input_spatial_shapes[..., 1]], -1) 129 | sampling_locations = reference_points[:, :, None, :, None, :] \ 130 | + sampling_offsets / (offset_normalizer[None, None, None, :, None, :]) 131 | # input_spatial_shapes_ratio = input_spatial_shapes//input_spatial_shapes[-1] 132 | # offset_normalizer = torch.stack([input_spatial_shapes_ratio[..., 0], input_spatial_shapes_ratio[..., 2], input_spatial_shapes_ratio[..., 1]], -1) 133 | # sampling_locations = reference_points[:, :, None, :, None, :] + sampling_offsets/offset_normalizer[None, None, None, :, None, :] 134 | elif reference_points.shape[-1] == 4: 135 | sampling_locations = reference_points[:, :, None, :, None, :2] \ 136 | + sampling_offsets / self.n_points * reference_points[:, :, None, :, None, 2:] * 0.5 137 | else: 138 | raise ValueError( 139 | 'Last dim of reference_points must be 2 or 4, but get {} instead.'.format(reference_points.shape[-1])) 140 | # output = MSDeformAttnFunction.apply( 141 | # value, input_spatial_shapes, input_level_start_index, sampling_locations, attention_weights, self.im2col_step) 142 | # debug 143 | output = ms_deform_attn_core_pytorch_3D(value, input_spatial_shapes, sampling_locations, 144 | attention_weights)#.detach().cpu() 145 | # output_cuda = MSDeformAttnFunction.apply(value, input_spatial_shapes, input_level_start_index, 146 | # sampling_locations, attention_weights, self.im2col_step).detach().cpu() 147 | 148 | output = self.output_proj(output) 149 | self.atte_w = attention_weights 150 | self.samp_location = sampling_locations 151 | return output 152 | 153 | -------------------------------------------------------------------------------- /TransDoD/models/ops/build/temp.linux-x86_64-3.7/.ninja_deps: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jianpengz/DoDNet/f06ce1b02988b6894a85f925623c6bcc6012284f/TransDoD/models/ops/build/temp.linux-x86_64-3.7/.ninja_deps -------------------------------------------------------------------------------- /TransDoD/models/ops/build/temp.linux-x86_64-3.7/.ninja_log: -------------------------------------------------------------------------------- 1 | # ninja log v5 2 | 0 2669 1612246041636417765 /media/userdisk0/myproject-Det/Deformable-DETR/models/ops/build/temp.linux-x86_64-3.7/media/userdisk0/myproject-Det/Deformable-DETR/models/ops/src/cpu/ms_deform_attn_cpu.o fa8eca1803b7fff 3 | 0 14281 1612246053244838432 /media/userdisk0/myproject-Det/Deformable-DETR/models/ops/build/temp.linux-x86_64-3.7/media/userdisk0/myproject-Det/Deformable-DETR/models/ops/src/vision.o 95ace05e02c8713a 4 | 2 2898 1612246810056197491 /media/userdisk0/myproject-Det/Deformable-DETR/models/ops/build/temp.linux-x86_64-3.7/media/userdisk0/myproject-Det/Deformable-DETR/models/ops/src/cpu/ms_deform_attn_cpu.o aa390de0c9cbff6c 5 | 1 15432 1612246822584649734 /media/userdisk0/myproject-Det/Deformable-DETR/models/ops/build/temp.linux-x86_64-3.7/media/userdisk0/myproject-Det/Deformable-DETR/models/ops/src/vision.o b1d913aa26893e0e 6 | 23 2904 1612248891812574389 /media/userdisk0/myproject-Det/Deformable-DETR/models/ops/build/temp.linux-x86_64-3.7/media/userdisk0/myproject-Det/Deformable-DETR/models/ops/src/cpu/ms_deform_attn_cpu.o 57991895174314d7 7 | 23 9280 1612248898188696475 /media/userdisk0/myproject-Det/Deformable-DETR/models/ops/build/temp.linux-x86_64-3.7/media/userdisk0/myproject-Det/Deformable-DETR/models/ops/src/cuda/ms_deform_attn_cuda.o ef4370695c918cc3 8 | 22 14918 1612248903824806762 /media/userdisk0/myproject-Det/Deformable-DETR/models/ops/build/temp.linux-x86_64-3.7/media/userdisk0/myproject-Det/Deformable-DETR/models/ops/src/vision.o 5c6a42b954ea6aab 9 | 0 2911 1612369311440774386 /media/userdisk0/myproject-Seg/MOTS-pro2/f_transformers/models/ops/build/temp.linux-x86_64-3.7/media/userdisk0/myproject-Seg/MOTS-pro2/f_transformers/models/ops/src/cpu/ms_deform_attn_cpu.o aab11968d0c1acd7 10 | 1 9604 1612369318128880821 /media/userdisk0/myproject-Seg/MOTS-pro2/f_transformers/models/ops/build/temp.linux-x86_64-3.7/media/userdisk0/myproject-Seg/MOTS-pro2/f_transformers/models/ops/src/cuda/ms_deform_attn_cuda.o 51082b0e211b1655 11 | 0 15104 1612369323628968524 /media/userdisk0/myproject-Seg/MOTS-pro2/f_transformers/models/ops/build/temp.linux-x86_64-3.7/media/userdisk0/myproject-Seg/MOTS-pro2/f_transformers/models/ops/src/vision.o 8faaab1c9fecd08 12 | 2 2634 1612369362309589583 /media/userdisk0/myproject-Seg/MOTS-pro2/f_transformers/models/ops/build/temp.linux-x86_64-3.7/media/userdisk0/myproject-Seg/MOTS-pro2/f_transformers/models/ops/src/cpu/ms_deform_attn_cpu.o c13b919edaa3adcd 13 | 2 8953 1612369368625691682 /media/userdisk0/myproject-Seg/MOTS-pro2/f_transformers/models/ops/build/temp.linux-x86_64-3.7/media/userdisk0/myproject-Seg/MOTS-pro2/f_transformers/models/ops/src/cuda/ms_deform_attn_cuda.o 748c09fa10cb6ce0 14 | 1 13941 1612369373613772446 /media/userdisk0/myproject-Seg/MOTS-pro2/f_transformers/models/ops/build/temp.linux-x86_64-3.7/media/userdisk0/myproject-Seg/MOTS-pro2/f_transformers/models/ops/src/vision.o 1f9993e49b4dd542 15 | 1 4206 1646716265972410395 /media/userdisk1/jpzhang/myproject-Seg/MOTS-pro2/f_transformers/models/ops/build/temp.linux-x86_64-3.7/media/userdisk1/jpzhang/myproject-Seg/MOTS-pro2/f_transformers/models/ops/src/cpu/ms_deform_attn_cpu.o c77572b2cc994346 16 | 1 13715 1646716275476429471 /media/userdisk1/jpzhang/myproject-Seg/MOTS-pro2/f_transformers/models/ops/build/temp.linux-x86_64-3.7/media/userdisk1/jpzhang/myproject-Seg/MOTS-pro2/f_transformers/models/ops/src/cuda/ms_deform_attn_cuda.o 4379a47d34917b9 17 | 0 21653 1646716283416445500 /media/userdisk1/jpzhang/myproject-Seg/MOTS-pro2/f_transformers/models/ops/build/temp.linux-x86_64-3.7/media/userdisk1/jpzhang/myproject-Seg/MOTS-pro2/f_transformers/models/ops/src/vision.o 2d19fbfed58046ad 18 | 3 3829 1646752319151673811 /media/userdisk1/jpzhang/myproject-Seg/MOTS-pro2/f_transformers/models/ops/build/temp.linux-x86_64-3.7/media/userdisk1/jpzhang/myproject-Seg/MOTS-pro2/f_transformers/models/ops/src/cpu/ms_deform_attn_cpu.o 29f77dba4605cab5 19 | 4 13155 1646752328467758933 /media/userdisk1/jpzhang/myproject-Seg/MOTS-pro2/f_transformers/models/ops/build/temp.linux-x86_64-3.7/media/userdisk1/jpzhang/myproject-Seg/MOTS-pro2/f_transformers/models/ops/src/cuda/ms_deform_attn_cuda.o 9fa8fdf9cadaa81e 20 | 3 20006 1646752335319820952 /media/userdisk1/jpzhang/myproject-Seg/MOTS-pro2/f_transformers/models/ops/build/temp.linux-x86_64-3.7/media/userdisk1/jpzhang/myproject-Seg/MOTS-pro2/f_transformers/models/ops/src/vision.o 6720fc7b8b2a845d 21 | -------------------------------------------------------------------------------- /TransDoD/models/ops/build/temp.linux-x86_64-3.7/build.ninja: -------------------------------------------------------------------------------- 1 | ninja_required_version = 1.3 2 | cxx = c++ 3 | nvcc = /media/userdisk0/softwares/cuda-11.0/bin/nvcc 4 | 5 | cflags = -pthread -B /home/jpzhang/.pyenv/versions/anaconda3-4.4.0/envs/fast/compiler_compat -Wl,--sysroot=/ -Wsign-compare -DNDEBUG -g -fwrapv -O3 -Wall -Wstrict-prototypes -fPIC -DWITH_CUDA -I/media/userdisk1/jpzhang/myproject-Seg/MOTS-pro2/f_transformers/models/ops/src -I/home/jpzhang/.pyenv/versions/anaconda3-4.4.0/envs/fast/lib/python3.7/site-packages/torch/include -I/home/jpzhang/.pyenv/versions/anaconda3-4.4.0/envs/fast/lib/python3.7/site-packages/torch/include/torch/csrc/api/include -I/home/jpzhang/.pyenv/versions/anaconda3-4.4.0/envs/fast/lib/python3.7/site-packages/torch/include/TH -I/home/jpzhang/.pyenv/versions/anaconda3-4.4.0/envs/fast/lib/python3.7/site-packages/torch/include/THC -I/media/userdisk0/softwares/cuda-11.0/include -I/home/jpzhang/.pyenv/versions/anaconda3-4.4.0/envs/fast/include/python3.7m -c 6 | post_cflags = -DTORCH_API_INCLUDE_EXTENSION_H '-DPYBIND11_COMPILER_TYPE="_gcc"' '-DPYBIND11_STDLIB="_libstdcpp"' '-DPYBIND11_BUILD_ABI="_cxxabi1011"' -DTORCH_EXTENSION_NAME=MultiScaleDeformableAttention -D_GLIBCXX_USE_CXX11_ABI=1 -std=c++14 7 | cuda_cflags = -DWITH_CUDA -I/media/userdisk1/jpzhang/myproject-Seg/MOTS-pro2/f_transformers/models/ops/src -I/home/jpzhang/.pyenv/versions/anaconda3-4.4.0/envs/fast/lib/python3.7/site-packages/torch/include -I/home/jpzhang/.pyenv/versions/anaconda3-4.4.0/envs/fast/lib/python3.7/site-packages/torch/include/torch/csrc/api/include -I/home/jpzhang/.pyenv/versions/anaconda3-4.4.0/envs/fast/lib/python3.7/site-packages/torch/include/TH -I/home/jpzhang/.pyenv/versions/anaconda3-4.4.0/envs/fast/lib/python3.7/site-packages/torch/include/THC -I/media/userdisk0/softwares/cuda-11.0/include -I/home/jpzhang/.pyenv/versions/anaconda3-4.4.0/envs/fast/include/python3.7m -c 8 | cuda_post_cflags = -D__CUDA_NO_HALF_OPERATORS__ -D__CUDA_NO_HALF_CONVERSIONS__ -D__CUDA_NO_BFLOAT16_CONVERSIONS__ -D__CUDA_NO_HALF2_OPERATORS__ --expt-relaxed-constexpr --compiler-options ''"'"'-fPIC'"'"'' -DCUDA_HAS_FP16=1 -D__CUDA_NO_HALF_OPERATORS__ -D__CUDA_NO_HALF_CONVERSIONS__ -D__CUDA_NO_HALF2_OPERATORS__ -DTORCH_API_INCLUDE_EXTENSION_H '-DPYBIND11_COMPILER_TYPE="_gcc"' '-DPYBIND11_STDLIB="_libstdcpp"' '-DPYBIND11_BUILD_ABI="_cxxabi1011"' -DTORCH_EXTENSION_NAME=MultiScaleDeformableAttention -D_GLIBCXX_USE_CXX11_ABI=1 -gencode=arch=compute_75,code=compute_75 -gencode=arch=compute_75,code=sm_75 -std=c++14 9 | ldflags = 10 | 11 | rule compile 12 | command = $cxx -MMD -MF $out.d $cflags -c $in -o $out $post_cflags 13 | depfile = $out.d 14 | deps = gcc 15 | 16 | rule cuda_compile 17 | depfile = $out.d 18 | deps = gcc 19 | command = $nvcc $cuda_cflags -c $in -o $out $cuda_post_cflags 20 | 21 | 22 | 23 | build /media/userdisk1/jpzhang/myproject-Seg/MOTS-pro2/f_transformers/models/ops/build/temp.linux-x86_64-3.7/media/userdisk1/jpzhang/myproject-Seg/MOTS-pro2/f_transformers/models/ops/src/vision.o: compile /media/userdisk1/jpzhang/myproject-Seg/MOTS-pro2/f_transformers/models/ops/src/vision.cpp 24 | build /media/userdisk1/jpzhang/myproject-Seg/MOTS-pro2/f_transformers/models/ops/build/temp.linux-x86_64-3.7/media/userdisk1/jpzhang/myproject-Seg/MOTS-pro2/f_transformers/models/ops/src/cpu/ms_deform_attn_cpu.o: compile /media/userdisk1/jpzhang/myproject-Seg/MOTS-pro2/f_transformers/models/ops/src/cpu/ms_deform_attn_cpu.cpp 25 | build /media/userdisk1/jpzhang/myproject-Seg/MOTS-pro2/f_transformers/models/ops/build/temp.linux-x86_64-3.7/media/userdisk1/jpzhang/myproject-Seg/MOTS-pro2/f_transformers/models/ops/src/cuda/ms_deform_attn_cuda.o: cuda_compile /media/userdisk1/jpzhang/myproject-Seg/MOTS-pro2/f_transformers/models/ops/src/cuda/ms_deform_attn_cuda.cu 26 | 27 | 28 | 29 | 30 | 31 | -------------------------------------------------------------------------------- /TransDoD/models/ops/build/temp.linux-x86_64-3.7/media/userdisk0/myproject-Det/Deformable-DETR/models/ops/src/cpu/ms_deform_attn_cpu.o: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jianpengz/DoDNet/f06ce1b02988b6894a85f925623c6bcc6012284f/TransDoD/models/ops/build/temp.linux-x86_64-3.7/media/userdisk0/myproject-Det/Deformable-DETR/models/ops/src/cpu/ms_deform_attn_cpu.o -------------------------------------------------------------------------------- /TransDoD/models/ops/build/temp.linux-x86_64-3.7/media/userdisk0/myproject-Det/Deformable-DETR/models/ops/src/cuda/ms_deform_attn_cuda.o: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jianpengz/DoDNet/f06ce1b02988b6894a85f925623c6bcc6012284f/TransDoD/models/ops/build/temp.linux-x86_64-3.7/media/userdisk0/myproject-Det/Deformable-DETR/models/ops/src/cuda/ms_deform_attn_cuda.o -------------------------------------------------------------------------------- /TransDoD/models/ops/build/temp.linux-x86_64-3.7/media/userdisk0/myproject-Det/Deformable-DETR/models/ops/src/vision.o: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jianpengz/DoDNet/f06ce1b02988b6894a85f925623c6bcc6012284f/TransDoD/models/ops/build/temp.linux-x86_64-3.7/media/userdisk0/myproject-Det/Deformable-DETR/models/ops/src/vision.o -------------------------------------------------------------------------------- /TransDoD/models/ops/build/temp.linux-x86_64-3.7/media/userdisk0/myproject-Seg/MOTS-pro2/f_transformers/models/ops/src/cpu/ms_deform_attn_cpu.o: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jianpengz/DoDNet/f06ce1b02988b6894a85f925623c6bcc6012284f/TransDoD/models/ops/build/temp.linux-x86_64-3.7/media/userdisk0/myproject-Seg/MOTS-pro2/f_transformers/models/ops/src/cpu/ms_deform_attn_cpu.o -------------------------------------------------------------------------------- /TransDoD/models/ops/build/temp.linux-x86_64-3.7/media/userdisk0/myproject-Seg/MOTS-pro2/f_transformers/models/ops/src/cuda/ms_deform_attn_cuda.o: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jianpengz/DoDNet/f06ce1b02988b6894a85f925623c6bcc6012284f/TransDoD/models/ops/build/temp.linux-x86_64-3.7/media/userdisk0/myproject-Seg/MOTS-pro2/f_transformers/models/ops/src/cuda/ms_deform_attn_cuda.o -------------------------------------------------------------------------------- /TransDoD/models/ops/build/temp.linux-x86_64-3.7/media/userdisk0/myproject-Seg/MOTS-pro2/f_transformers/models/ops/src/vision.o: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jianpengz/DoDNet/f06ce1b02988b6894a85f925623c6bcc6012284f/TransDoD/models/ops/build/temp.linux-x86_64-3.7/media/userdisk0/myproject-Seg/MOTS-pro2/f_transformers/models/ops/src/vision.o -------------------------------------------------------------------------------- /TransDoD/models/ops/build/temp.linux-x86_64-3.7/media/userdisk1/jpzhang/myproject-Seg/MOTS-pro2/f_transformers/models/ops/src/cpu/ms_deform_attn_cpu.o: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jianpengz/DoDNet/f06ce1b02988b6894a85f925623c6bcc6012284f/TransDoD/models/ops/build/temp.linux-x86_64-3.7/media/userdisk1/jpzhang/myproject-Seg/MOTS-pro2/f_transformers/models/ops/src/cpu/ms_deform_attn_cpu.o -------------------------------------------------------------------------------- /TransDoD/models/ops/build/temp.linux-x86_64-3.7/media/userdisk1/jpzhang/myproject-Seg/MOTS-pro2/f_transformers/models/ops/src/cuda/ms_deform_attn_cuda.o: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jianpengz/DoDNet/f06ce1b02988b6894a85f925623c6bcc6012284f/TransDoD/models/ops/build/temp.linux-x86_64-3.7/media/userdisk1/jpzhang/myproject-Seg/MOTS-pro2/f_transformers/models/ops/src/cuda/ms_deform_attn_cuda.o -------------------------------------------------------------------------------- /TransDoD/models/ops/build/temp.linux-x86_64-3.7/media/userdisk1/jpzhang/myproject-Seg/MOTS-pro2/f_transformers/models/ops/src/vision.o: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jianpengz/DoDNet/f06ce1b02988b6894a85f925623c6bcc6012284f/TransDoD/models/ops/build/temp.linux-x86_64-3.7/media/userdisk1/jpzhang/myproject-Seg/MOTS-pro2/f_transformers/models/ops/src/vision.o -------------------------------------------------------------------------------- /TransDoD/models/ops/dist/MultiScaleDeformableAttention-1.0-py3.7-linux-x86_64.egg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jianpengz/DoDNet/f06ce1b02988b6894a85f925623c6bcc6012284f/TransDoD/models/ops/dist/MultiScaleDeformableAttention-1.0-py3.7-linux-x86_64.egg -------------------------------------------------------------------------------- /TransDoD/models/ops/functions/__init__.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------------------------ 2 | # Deformable DETR 3 | # Copyright (c) 2020 SenseTime. All Rights Reserved. 4 | # Licensed under the Apache License, Version 2.0 [see LICENSE for details] 5 | # ------------------------------------------------------------------------------------------------ 6 | # Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 7 | # ------------------------------------------------------------------------------------------------ 8 | 9 | from .ms_deform_attn_func import MSDeformAttnFunction 10 | 11 | -------------------------------------------------------------------------------- /TransDoD/models/ops/functions/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jianpengz/DoDNet/f06ce1b02988b6894a85f925623c6bcc6012284f/TransDoD/models/ops/functions/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /TransDoD/models/ops/functions/__pycache__/ms_deform_attn_func.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jianpengz/DoDNet/f06ce1b02988b6894a85f925623c6bcc6012284f/TransDoD/models/ops/functions/__pycache__/ms_deform_attn_func.cpython-37.pyc -------------------------------------------------------------------------------- /TransDoD/models/ops/functions/ms_deform_attn_func.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------------------------ 2 | # Deformable DETR 3 | # Copyright (c) 2020 SenseTime. All Rights Reserved. 4 | # Licensed under the Apache License, Version 2.0 [see LICENSE for details] 5 | # ------------------------------------------------------------------------------------------------ 6 | # Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 7 | # ------------------------------------------------------------------------------------------------ 8 | 9 | from __future__ import absolute_import 10 | from __future__ import print_function 11 | from __future__ import division 12 | 13 | import torch 14 | import torch.nn.functional as F 15 | from torch.autograd import Function 16 | from torch.autograd.function import once_differentiable 17 | 18 | import MultiScaleDeformableAttention as MSDA 19 | 20 | 21 | class MSDeformAttnFunction(Function): 22 | @staticmethod 23 | def forward(ctx, value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights, im2col_step): 24 | ctx.im2col_step = im2col_step 25 | output = MSDA.ms_deform_attn_forward( 26 | value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights, ctx.im2col_step) 27 | ctx.save_for_backward(value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights) 28 | return output 29 | 30 | @staticmethod 31 | @once_differentiable 32 | def backward(ctx, grad_output): 33 | value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights = ctx.saved_tensors 34 | grad_value, grad_sampling_loc, grad_attn_weight = \ 35 | MSDA.ms_deform_attn_backward( 36 | value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights, grad_output, ctx.im2col_step) 37 | 38 | return grad_value, None, None, grad_sampling_loc, grad_attn_weight, None 39 | 40 | 41 | def ms_deform_attn_core_pytorch(value, value_spatial_shapes, sampling_locations, attention_weights): 42 | # for debug and test only, 43 | # need to use cuda version instead 44 | N_, S_, M_, D_ = value.shape 45 | _, Lq_, M_, L_, P_, _ = sampling_locations.shape 46 | value_list = value.split([H_ * W_ for H_, W_ in value_spatial_shapes], dim=1) 47 | sampling_grids = 2 * sampling_locations - 1 48 | sampling_value_list = [] 49 | for lid_, (H_, W_) in enumerate(value_spatial_shapes): 50 | # N_, H_*W_, M_, D_ -> N_, H_*W_, M_*D_ -> N_, M_*D_, H_*W_ -> N_*M_, D_, H_, W_ 51 | value_l_ = value_list[lid_].flatten(2).transpose(1, 2).reshape(N_*M_, D_, H_, W_) 52 | # N_, Lq_, M_, P_, 2 -> N_, M_, Lq_, P_, 2 -> N_*M_, Lq_, P_, 2 53 | sampling_grid_l_ = sampling_grids[:, :, :, lid_].transpose(1, 2).flatten(0, 1) 54 | # N_*M_, D_, Lq_, P_ 55 | sampling_value_l_ = F.grid_sample(value_l_, sampling_grid_l_, 56 | mode='bilinear', padding_mode='zeros', align_corners=False) 57 | sampling_value_list.append(sampling_value_l_) 58 | # (N_, Lq_, M_, L_, P_) -> (N_, M_, Lq_, L_, P_) -> (N_, M_, 1, Lq_, L_*P_) 59 | attention_weights = attention_weights.transpose(1, 2).reshape(N_*M_, 1, Lq_, L_*P_) 60 | output = (torch.stack(sampling_value_list, dim=-2).flatten(-2) * attention_weights).sum(-1).view(N_, M_*D_, Lq_) 61 | return output.transpose(1, 2).contiguous() 62 | 63 | def ms_deform_attn_core_pytorch_3D(value, value_spatial_shapes, sampling_locations, attention_weights): 64 | # for debug and test only, 65 | # need to use cuda version instead 66 | N_, S_, M_, D_ = value.shape 67 | _, Lq_, M_, L_, P_, _ = sampling_locations.shape 68 | value_list = value.split([T_ * H_ * W_ for T_, H_, W_ in value_spatial_shapes], dim=1) 69 | sampling_grids = 2 * sampling_locations - 1 70 | sampling_value_list = [] 71 | for lid_, (T_, H_, W_) in enumerate(value_spatial_shapes): 72 | # N_, T_*H_*W_, M_, D_ -> N_, T_*H_*W_, M_*D_ -> N_, M_*D_, T_*H_*W_ -> N_*M_, D_, T_, H_, W_ 73 | value_l_ = value_list[lid_].flatten(2).transpose(1, 2).reshape(N_*M_, D_, T_, H_, W_) 74 | # N_, Lq_, M_, P_, 3 -> N_, M_, Lq_, P_, 3 -> N_*M_, Lq_, P_, 3 75 | sampling_grid_l_ = sampling_grids[:, :, :, lid_].transpose(1, 2).flatten(0, 1)[:,None,:,:,:] 76 | # N_*M_, D_, Lq_, P_ 77 | sampling_value_l_ = F.grid_sample(value_l_, sampling_grid_l_, mode='bilinear', padding_mode='zeros', align_corners=False)[:,:,0] 78 | sampling_value_list.append(sampling_value_l_) 79 | # (N_, Lq_, M_, L_, P_) -> (N_, M_, Lq_, L_, P_) -> (N_, M_, 1, Lq_, L_*P_) 80 | attention_weights = attention_weights.transpose(1, 2).reshape(N_*M_, 1, Lq_, L_*P_) 81 | output = (torch.stack(sampling_value_list, dim=-2).flatten(-2) * attention_weights).sum(-1).view(N_, M_*D_, Lq_) 82 | return output.transpose(1, 2).contiguous() 83 | -------------------------------------------------------------------------------- /TransDoD/models/ops/make.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | # ------------------------------------------------------------------------------------------------ 3 | # Deformable DETR 4 | # Copyright (c) 2020 SenseTime. All Rights Reserved. 5 | # Licensed under the Apache License, Version 2.0 [see LICENSE for details] 6 | # ------------------------------------------------------------------------------------------------ 7 | # Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 8 | # ------------------------------------------------------------------------------------------------ 9 | 10 | python setup.py build install 11 | -------------------------------------------------------------------------------- /TransDoD/models/ops/modules/__init__.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------------------------ 2 | # Deformable DETR 3 | # Copyright (c) 2020 SenseTime. All Rights Reserved. 4 | # Licensed under the Apache License, Version 2.0 [see LICENSE for details] 5 | # ------------------------------------------------------------------------------------------------ 6 | # Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 7 | # ------------------------------------------------------------------------------------------------ 8 | 9 | from .ms_deform_attn import MSDeformAttn 10 | -------------------------------------------------------------------------------- /TransDoD/models/ops/modules/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jianpengz/DoDNet/f06ce1b02988b6894a85f925623c6bcc6012284f/TransDoD/models/ops/modules/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /TransDoD/models/ops/modules/__pycache__/ms_deform_attn.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jianpengz/DoDNet/f06ce1b02988b6894a85f925623c6bcc6012284f/TransDoD/models/ops/modules/__pycache__/ms_deform_attn.cpython-37.pyc -------------------------------------------------------------------------------- /TransDoD/models/ops/modules/ms_deform_attn (33).py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------------------------ 2 | # Deformable DETR 3 | # Copyright (c) 2020 SenseTime. All Rights Reserved. 4 | # Licensed under the Apache License, Version 2.0 [see LICENSE for details] 5 | # ------------------------------------------------------------------------------------------------ 6 | # Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 7 | # ------------------------------------------------------------------------------------------------ 8 | 9 | from __future__ import absolute_import 10 | from __future__ import print_function 11 | from __future__ import division 12 | 13 | import warnings 14 | import math 15 | 16 | import torch 17 | from torch import nn 18 | import torch.nn.functional as F 19 | from torch.nn.init import xavier_uniform_, constant_ 20 | 21 | from ..functions import MSDeformAttnFunction 22 | from ..functions.ms_deform_attn_func import MSDeformAttnFunction, ms_deform_attn_core_pytorch_3D 23 | 24 | def _is_power_of_2(n): 25 | if (not isinstance(n, int)) or (n < 0): 26 | raise ValueError("invalid input for _is_power_of_2: {} (type: {})".format(n, type(n))) 27 | return (n & (n-1) == 0) and n != 0 28 | 29 | 30 | class MSDeformAttn(nn.Module): 31 | def __init__(self, d_model=256, n_levels=4, n_heads=8, n_points=4): 32 | """ 33 | Multi-Scale Deformable Attention Module 34 | :param d_model hidden dimension 35 | :param n_levels number of feature levels 36 | :param n_heads number of attention heads 37 | :param n_points number of sampling points per attention head per feature level 38 | """ 39 | super().__init__() 40 | if d_model % n_heads != 0: 41 | raise ValueError('d_model must be divisible by n_heads, but got {} and {}'.format(d_model, n_heads)) 42 | _d_per_head = d_model // n_heads 43 | # you'd better set _d_per_head to a power of 2 which is more efficient in our CUDA implementation 44 | if not _is_power_of_2(_d_per_head): 45 | warnings.warn("You'd better set d_model in MSDeformAttn to make the dimension of each attention head a power of 2 " 46 | "which is more efficient in our CUDA implementation.") 47 | 48 | self.im2col_step = 64 49 | 50 | self.d_model = d_model 51 | self.n_levels = n_levels 52 | self.n_heads = n_heads 53 | self.n_points = n_points 54 | 55 | self.sampling_offsets = nn.Linear(d_model, n_heads * n_levels * n_points * 3) 56 | self.attention_weights = nn.Linear(d_model, n_heads * n_levels * n_points) 57 | self.value_proj = nn.Linear(d_model, d_model) 58 | self.output_proj = nn.Linear(d_model, d_model) 59 | 60 | # self.s_norm = nn.LayerNorm(n_heads * n_levels * n_points * 3) 61 | 62 | self._reset_parameters() 63 | 64 | # add for vis 65 | self.samp_location = [] 66 | self.atte_w = [] 67 | 68 | def _reset_parameters(self): 69 | constant_(self.sampling_offsets.weight.data, 0.) 70 | thetas = torch.arange(self.n_heads, dtype=torch.float32) * (2.0 * math.pi / self.n_heads) 71 | # grid_init = torch.stack([thetas.sin()*thetas.cos(), thetas.sin()*thetas.sin(), thetas.cos()], -1) 72 | # grid_init = torch.stack([thetas.sin() * thetas.cos(), thetas.sin(), thetas.cos()], -1) 73 | # grid_init = torch.stack([thetas.sin() * thetas.cos(), thetas.cos(), thetas.sin()], -1) 74 | grid_init = torch.stack([2*thetas.sin()*thetas.cos(), thetas.cos(), thetas.sin()], -1) 75 | 76 | # thetas = torch.arange(self.n_heads, dtype=torch.float32) * (2.0 * math.pi / self.n_heads) 77 | # beta = torch.arange(-self.n_heads//2, self.n_heads//2, dtype=torch.float32) * (1.0 * math.pi / self.n_heads) 78 | # grid_init = torch.stack([beta.sin(), beta.cos() * thetas.cos(), beta.cos() * thetas.sin()], -1) 79 | 80 | # thetas = torch.arange(self.n_heads, dtype=torch.float32) * (2.0 * math.pi / self.n_heads) 81 | # beta = torch.arange(self.n_heads, dtype=torch.float32) * (2.0 * math.pi / self.n_heads) 82 | # grid_init = torch.stack([beta.sin() * thetas.sin(), beta.cos(), beta.sin() * thetas.cos()], -1) 83 | 84 | grid_init = (grid_init / grid_init.abs().max(-1, keepdim=True)[0]).view(self.n_heads, 1, 1, 3).repeat(1, self.n_levels, self.n_points, 1) 85 | for i in range(self.n_points): 86 | grid_init[:, :, i, :] *= i + 1 87 | with torch.no_grad(): 88 | self.sampling_offsets.bias = nn.Parameter(grid_init.view(-1)) 89 | constant_(self.attention_weights.weight.data, 0.) 90 | constant_(self.attention_weights.bias.data, 0.) 91 | xavier_uniform_(self.value_proj.weight.data) 92 | constant_(self.value_proj.bias.data, 0.) 93 | xavier_uniform_(self.output_proj.weight.data) 94 | constant_(self.output_proj.bias.data, 0.) 95 | 96 | def forward(self, query, reference_points, input_flatten, input_spatial_shapes, input_level_start_index, input_padding_mask=None): 97 | """ 98 | :param query (N, Length_{query}, C) 99 | :param reference_points (N, Length_{query}, n_levels, 2), range in [0, 1], top-left (0,0), bottom-right (1, 1), including padding area 100 | or (N, Length_{query}, n_levels, 4), add additional (w, h) to form reference boxes 101 | :param input_flatten (N, \sum_{l=0}^{L-1} H_l \cdot W_l, C) 102 | :param input_spatial_shapes (n_levels, 2), [(H_0, W_0), (H_1, W_1), ..., (H_{L-1}, W_{L-1})] 103 | :param input_level_start_index (n_levels, ), [0, H_0*W_0, H_0*W_0+H_1*W_1, H_0*W_0+H_1*W_1+H_2*W_2, ..., H_0*W_0+H_1*W_1+...+H_{L-1}*W_{L-1}] 104 | :param input_padding_mask (N, \sum_{l=0}^{L-1} H_l \cdot W_l), True for padding elements, False for non-padding elements 105 | 106 | :return output (N, Length_{query}, C) 107 | """ 108 | N, Len_q, _ = query.shape 109 | N, Len_in, _ = input_flatten.shape 110 | assert (input_spatial_shapes[:, 0] * input_spatial_shapes[:, 1] * input_spatial_shapes[:, 2]).sum() == Len_in 111 | 112 | value = self.value_proj(input_flatten) 113 | if input_padding_mask is not None: 114 | value = value.masked_fill(input_padding_mask[..., None], float(0)) 115 | value = value.view(N, Len_in, self.n_heads, self.d_model // self.n_heads) 116 | ## using Tanh 117 | sampling_offsets = self.sampling_offsets(query).view(N, Len_q, self.n_heads, self.n_levels, self.n_points, 3) 118 | # sampling_offsets = F.tanh(sampling_offsets) 119 | 120 | ## or using norm 121 | # sampling_offsets = self.s_norm(self.sampling_offsets(query)).view(N, Len_q, self.n_heads, self.n_levels, self.n_points, 3) 122 | 123 | attention_weights = self.attention_weights(query).view(N, Len_q, self.n_heads, self.n_levels * self.n_points) 124 | attention_weights = F.softmax(attention_weights, -1).view(N, Len_q, self.n_heads, self.n_levels, self.n_points) 125 | # N, Len_q, n_heads, n_levels, n_points, 3 126 | if reference_points.shape[-1] == 3: 127 | offset_normalizer = torch.stack([input_spatial_shapes[..., 2], input_spatial_shapes[..., 2], input_spatial_shapes[..., 1]], -1) 128 | sampling_locations = reference_points[:, :, None, :, None, :] \ 129 | + sampling_offsets / (2*offset_normalizer[None, None, None, :, None, :]) 130 | # sampling_locations = reference_points[:, :, None, :, None, :] + sampling_offsets 131 | elif reference_points.shape[-1] == 4: 132 | sampling_locations = reference_points[:, :, None, :, None, :2] \ 133 | + sampling_offsets / self.n_points * reference_points[:, :, None, :, None, 2:] * 0.5 134 | else: 135 | raise ValueError( 136 | 'Last dim of reference_points must be 2 or 4, but get {} instead.'.format(reference_points.shape[-1])) 137 | # output = MSDeformAttnFunction.apply( 138 | # value, input_spatial_shapes, input_level_start_index, sampling_locations, attention_weights, self.im2col_step) 139 | # debug 140 | output = ms_deform_attn_core_pytorch_3D(value, input_spatial_shapes, sampling_locations, 141 | attention_weights)#.detach().cpu() 142 | # output_cuda = MSDeformAttnFunction.apply(value, input_spatial_shapes, input_level_start_index, 143 | # sampling_locations, attention_weights, self.im2col_step).detach().cpu() 144 | 145 | output = self.output_proj(output) 146 | self.atte_w = attention_weights 147 | self.samp_location = sampling_locations 148 | return output 149 | 150 | 151 | -------------------------------------------------------------------------------- /TransDoD/models/ops/modules/ms_deform_attn.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------------------------ 2 | # Deformable DETR 3 | # Copyright (c) 2020 SenseTime. All Rights Reserved. 4 | # Licensed under the Apache License, Version 2.0 [see LICENSE for details] 5 | # ------------------------------------------------------------------------------------------------ 6 | # Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 7 | # ------------------------------------------------------------------------------------------------ 8 | 9 | from __future__ import absolute_import 10 | from __future__ import print_function 11 | from __future__ import division 12 | 13 | import warnings 14 | import math 15 | 16 | import torch 17 | from torch import nn 18 | import torch.nn.functional as F 19 | from torch.nn.init import xavier_uniform_, constant_ 20 | 21 | from ..functions import MSDeformAttnFunction 22 | from ..functions.ms_deform_attn_func import MSDeformAttnFunction, ms_deform_attn_core_pytorch_3D 23 | 24 | def _is_power_of_2(n): 25 | if (not isinstance(n, int)) or (n < 0): 26 | raise ValueError("invalid input for _is_power_of_2: {} (type: {})".format(n, type(n))) 27 | return (n & (n-1) == 0) and n != 0 28 | 29 | 30 | class MSDeformAttn(nn.Module): 31 | def __init__(self, d_model=256, n_levels=4, n_heads=8, n_points=4): 32 | """ 33 | Multi-Scale Deformable Attention Module 34 | :param d_model hidden dimension 35 | :param n_levels number of feature levels 36 | :param n_heads number of attention heads 37 | :param n_points number of sampling points per attention head per feature level 38 | """ 39 | super().__init__() 40 | if d_model % n_heads != 0: 41 | raise ValueError('d_model must be divisible by n_heads, but got {} and {}'.format(d_model, n_heads)) 42 | _d_per_head = d_model // n_heads 43 | # you'd better set _d_per_head to a power of 2 which is more efficient in our CUDA implementation 44 | if not _is_power_of_2(_d_per_head): 45 | warnings.warn("You'd better set d_model in MSDeformAttn to make the dimension of each attention head a power of 2 " 46 | "which is more efficient in our CUDA implementation.") 47 | 48 | self.im2col_step = 64 49 | 50 | self.d_model = d_model 51 | self.n_levels = n_levels 52 | self.n_heads = n_heads 53 | self.n_points = n_points 54 | 55 | self.sampling_offsets = nn.Linear(d_model, n_heads * n_levels * n_points * 3) 56 | self.attention_weights = nn.Linear(d_model, n_heads * n_levels * n_points) 57 | self.value_proj = nn.Linear(d_model, d_model) 58 | self.output_proj = nn.Linear(d_model, d_model) 59 | 60 | # self.s_norm = nn.LayerNorm(n_heads * n_levels * n_points * 3) 61 | 62 | inx = 16 63 | self.offset_normalizer = torch.tensor([[inx * (2 ** (i - 1)), inx * (2 ** (i - 1)), inx * (2 ** (i - 1))] for i in 64 | range(self.n_levels, 0, -1)]).cuda() 65 | 66 | self._reset_parameters() 67 | 68 | # add for vis 69 | self.samp_location = [] 70 | self.atte_w = [] 71 | 72 | def _reset_parameters(self): 73 | constant_(self.sampling_offsets.weight.data, 0.) 74 | thetas = torch.arange(self.n_heads, dtype=torch.float32) * (2.0 * math.pi / self.n_heads) 75 | grid_init = torch.stack([thetas.sin()*thetas.cos(), thetas.cos(), thetas.sin()], -1) 76 | grid_init = (grid_init / grid_init.abs().max(-1, keepdim=True)[0]).view(self.n_heads, 1, 1, 3).repeat(1, self.n_levels, self.n_points, 1) 77 | for i in range(self.n_points): 78 | grid_init[:, :, i, :] *= i + 1 79 | with torch.no_grad(): 80 | self.sampling_offsets.bias = nn.Parameter(grid_init.view(-1)) 81 | constant_(self.attention_weights.weight.data, 0.) 82 | constant_(self.attention_weights.bias.data, 0.) 83 | xavier_uniform_(self.value_proj.weight.data) 84 | constant_(self.value_proj.bias.data, 0.) 85 | xavier_uniform_(self.output_proj.weight.data) 86 | constant_(self.output_proj.bias.data, 0.) 87 | 88 | def forward(self, query, reference_points, input_flatten, input_spatial_shapes, input_level_start_index, input_padding_mask=None): 89 | N, Len_q, _ = query.shape 90 | N, Len_in, _ = input_flatten.shape 91 | assert (input_spatial_shapes[:, 0] * input_spatial_shapes[:, 1] * input_spatial_shapes[:, 2]).sum() == Len_in 92 | 93 | value = self.value_proj(input_flatten) 94 | if input_padding_mask is not None: 95 | value = value.masked_fill(input_padding_mask[..., None], float(0)) 96 | value = value.view(N, Len_in, self.n_heads, self.d_model // self.n_heads) 97 | sampling_offsets = self.sampling_offsets(query).view(N, Len_q, self.n_heads, self.n_levels, self.n_points, 3) 98 | 99 | attention_weights = self.attention_weights(query).view(N, Len_q, self.n_heads, self.n_levels * self.n_points) # v0 100 | attention_weights = F.softmax(attention_weights, -1).view(N, Len_q, self.n_heads, self.n_levels, self.n_points) 101 | if reference_points.shape[-1] == 3: 102 | sampling_locations = reference_points[:, :, None, :, None, :] \ 103 | + sampling_offsets / (self.offset_normalizer[None, None, None, :, None, :]) 104 | elif reference_points.shape[-1] == 4: 105 | sampling_locations = reference_points[:, :, None, :, None, :2] \ 106 | + sampling_offsets / self.n_points * reference_points[:, :, None, :, None, 2:] * 0.5 107 | else: 108 | raise ValueError( 109 | 'Last dim of reference_points must be 2 or 4, but get {} instead.'.format(reference_points.shape[-1])) 110 | output = ms_deform_attn_core_pytorch_3D(value, input_spatial_shapes, sampling_locations, 111 | attention_weights) 112 | output = self.output_proj(output) 113 | self.atte_w = attention_weights 114 | self.samp_location = sampling_locations 115 | return output -------------------------------------------------------------------------------- /TransDoD/models/ops/setup.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------------------------ 2 | # Deformable DETR 3 | # Copyright (c) 2020 SenseTime. All Rights Reserved. 4 | # Licensed under the Apache License, Version 2.0 [see LICENSE for details] 5 | # ------------------------------------------------------------------------------------------------ 6 | # Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 7 | # ------------------------------------------------------------------------------------------------ 8 | 9 | import os 10 | import glob 11 | 12 | import torch 13 | 14 | from torch.utils.cpp_extension import CUDA_HOME 15 | from torch.utils.cpp_extension import CppExtension 16 | from torch.utils.cpp_extension import CUDAExtension 17 | 18 | from setuptools import find_packages 19 | from setuptools import setup 20 | 21 | requirements = ["torch", "torchvision"] 22 | 23 | def get_extensions(): 24 | this_dir = os.path.dirname(os.path.abspath(__file__)) 25 | extensions_dir = os.path.join(this_dir, "src") 26 | 27 | main_file = glob.glob(os.path.join(extensions_dir, "*.cpp")) 28 | source_cpu = glob.glob(os.path.join(extensions_dir, "cpu", "*.cpp")) 29 | source_cuda = glob.glob(os.path.join(extensions_dir, "cuda", "*.cu")) 30 | 31 | sources = main_file + source_cpu 32 | extension = CppExtension 33 | extra_compile_args = {"cxx": []} 34 | define_macros = [] 35 | 36 | if torch.cuda.is_available() and CUDA_HOME is not None: 37 | extension = CUDAExtension 38 | sources += source_cuda 39 | define_macros += [("WITH_CUDA", None)] 40 | extra_compile_args["nvcc"] = [ 41 | "-DCUDA_HAS_FP16=1", 42 | "-D__CUDA_NO_HALF_OPERATORS__", 43 | "-D__CUDA_NO_HALF_CONVERSIONS__", 44 | "-D__CUDA_NO_HALF2_OPERATORS__", 45 | ] 46 | else: 47 | raise NotImplementedError('Cuda is not availabel') 48 | 49 | sources = [os.path.join(extensions_dir, s) for s in sources] 50 | include_dirs = [extensions_dir] 51 | ext_modules = [ 52 | extension( 53 | "MultiScaleDeformableAttention", 54 | sources, 55 | include_dirs=include_dirs, 56 | define_macros=define_macros, 57 | extra_compile_args=extra_compile_args, 58 | ) 59 | ] 60 | return ext_modules 61 | 62 | setup( 63 | name="MultiScaleDeformableAttention", 64 | version="1.0", 65 | author="Weijie Su", 66 | url="https://github.com/fundamentalvision/Deformable-DETR", 67 | description="PyTorch Wrapper for CUDA Functions of Multi-Scale Deformable Attention", 68 | packages=find_packages(exclude=("configs", "tests",)), 69 | ext_modules=get_extensions(), 70 | cmdclass={"build_ext": torch.utils.cpp_extension.BuildExtension}, 71 | ) 72 | -------------------------------------------------------------------------------- /TransDoD/models/ops/src/cpu/ms_deform_attn_cpu.cpp: -------------------------------------------------------------------------------- 1 | /*! 2 | ************************************************************************************************** 3 | * Deformable DETR 4 | * Copyright (c) 2020 SenseTime. All Rights Reserved. 5 | * Licensed under the Apache License, Version 2.0 [see LICENSE for details] 6 | ************************************************************************************************** 7 | * Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 8 | ************************************************************************************************** 9 | */ 10 | 11 | #include 12 | 13 | #include 14 | #include 15 | 16 | 17 | at::Tensor 18 | ms_deform_attn_cpu_forward( 19 | const at::Tensor &value, 20 | const at::Tensor &spatial_shapes, 21 | const at::Tensor &level_start_index, 22 | const at::Tensor &sampling_loc, 23 | const at::Tensor &attn_weight, 24 | const int im2col_step) 25 | { 26 | AT_ERROR("Not implement on cpu"); 27 | } 28 | 29 | std::vector 30 | ms_deform_attn_cpu_backward( 31 | const at::Tensor &value, 32 | const at::Tensor &spatial_shapes, 33 | const at::Tensor &level_start_index, 34 | const at::Tensor &sampling_loc, 35 | const at::Tensor &attn_weight, 36 | const at::Tensor &grad_output, 37 | const int im2col_step) 38 | { 39 | AT_ERROR("Not implement on cpu"); 40 | } 41 | 42 | -------------------------------------------------------------------------------- /TransDoD/models/ops/src/cpu/ms_deform_attn_cpu.h: -------------------------------------------------------------------------------- 1 | /*! 2 | ************************************************************************************************** 3 | * Deformable DETR 4 | * Copyright (c) 2020 SenseTime. All Rights Reserved. 5 | * Licensed under the Apache License, Version 2.0 [see LICENSE for details] 6 | ************************************************************************************************** 7 | * Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 8 | ************************************************************************************************** 9 | */ 10 | 11 | #pragma once 12 | #include 13 | 14 | at::Tensor 15 | ms_deform_attn_cpu_forward( 16 | const at::Tensor &value, 17 | const at::Tensor &spatial_shapes, 18 | const at::Tensor &level_start_index, 19 | const at::Tensor &sampling_loc, 20 | const at::Tensor &attn_weight, 21 | const int im2col_step); 22 | 23 | std::vector 24 | ms_deform_attn_cpu_backward( 25 | const at::Tensor &value, 26 | const at::Tensor &spatial_shapes, 27 | const at::Tensor &level_start_index, 28 | const at::Tensor &sampling_loc, 29 | const at::Tensor &attn_weight, 30 | const at::Tensor &grad_output, 31 | const int im2col_step); 32 | 33 | 34 | -------------------------------------------------------------------------------- /TransDoD/models/ops/src/cuda/ms_deform_attn_cuda.cu: -------------------------------------------------------------------------------- 1 | /*! 2 | ************************************************************************************************** 3 | * Deformable DETR 4 | * Copyright (c) 2020 SenseTime. All Rights Reserved. 5 | * Licensed under the Apache License, Version 2.0 [see LICENSE for details] 6 | ************************************************************************************************** 7 | * Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 8 | ************************************************************************************************** 9 | */ 10 | 11 | #include 12 | #include "cuda/ms_deform_im2col_cuda.cuh" 13 | 14 | #include 15 | #include 16 | #include 17 | #include 18 | 19 | 20 | at::Tensor ms_deform_attn_cuda_forward( 21 | const at::Tensor &value, 22 | const at::Tensor &spatial_shapes, 23 | const at::Tensor &level_start_index, 24 | const at::Tensor &sampling_loc, 25 | const at::Tensor &attn_weight, 26 | const int im2col_step) 27 | { 28 | AT_ASSERTM(value.is_contiguous(), "value tensor has to be contiguous"); 29 | AT_ASSERTM(spatial_shapes.is_contiguous(), "spatial_shapes tensor has to be contiguous"); 30 | AT_ASSERTM(level_start_index.is_contiguous(), "level_start_index tensor has to be contiguous"); 31 | AT_ASSERTM(sampling_loc.is_contiguous(), "sampling_loc tensor has to be contiguous"); 32 | AT_ASSERTM(attn_weight.is_contiguous(), "attn_weight tensor has to be contiguous"); 33 | 34 | AT_ASSERTM(value.type().is_cuda(), "value must be a CUDA tensor"); 35 | AT_ASSERTM(spatial_shapes.type().is_cuda(), "spatial_shapes must be a CUDA tensor"); 36 | AT_ASSERTM(level_start_index.type().is_cuda(), "level_start_index must be a CUDA tensor"); 37 | AT_ASSERTM(sampling_loc.type().is_cuda(), "sampling_loc must be a CUDA tensor"); 38 | AT_ASSERTM(attn_weight.type().is_cuda(), "attn_weight must be a CUDA tensor"); 39 | 40 | const int batch = value.size(0); 41 | const int spatial_size = value.size(1); 42 | const int num_heads = value.size(2); 43 | const int channels = value.size(3); 44 | 45 | const int num_levels = spatial_shapes.size(0); 46 | 47 | const int num_query = sampling_loc.size(1); 48 | const int num_point = sampling_loc.size(4); 49 | 50 | const int im2col_step_ = std::min(batch, im2col_step); 51 | 52 | AT_ASSERTM(batch % im2col_step_ == 0, "batch(%d) must divide im2col_step(%d)", batch, im2col_step_); 53 | 54 | auto output = at::zeros({batch, num_query, num_heads, channels}, value.options()); 55 | 56 | const int batch_n = im2col_step_; 57 | auto output_n = output.view({batch/im2col_step_, batch_n, num_query, num_heads, channels}); 58 | auto per_value_size = spatial_size * num_heads * channels; 59 | auto per_sample_loc_size = num_query * num_heads * num_levels * num_point * 2; 60 | auto per_attn_weight_size = num_query * num_heads * num_levels * num_point; 61 | for (int n = 0; n < batch/im2col_step_; ++n) 62 | { 63 | auto columns = output_n.select(0, n); 64 | AT_DISPATCH_FLOATING_TYPES(value.type(), "ms_deform_attn_forward_cuda", ([&] { 65 | ms_deformable_im2col_cuda(at::cuda::getCurrentCUDAStream(), 66 | value.data() + n * im2col_step_ * per_value_size, 67 | spatial_shapes.data(), 68 | level_start_index.data(), 69 | sampling_loc.data() + n * im2col_step_ * per_sample_loc_size, 70 | attn_weight.data() + n * im2col_step_ * per_attn_weight_size, 71 | batch_n, spatial_size, num_heads, channels, num_levels, num_query, num_point, 72 | columns.data()); 73 | 74 | })); 75 | } 76 | 77 | output = output.view({batch, num_query, num_heads*channels}); 78 | 79 | return output; 80 | } 81 | 82 | 83 | std::vector ms_deform_attn_cuda_backward( 84 | const at::Tensor &value, 85 | const at::Tensor &spatial_shapes, 86 | const at::Tensor &level_start_index, 87 | const at::Tensor &sampling_loc, 88 | const at::Tensor &attn_weight, 89 | const at::Tensor &grad_output, 90 | const int im2col_step) 91 | { 92 | 93 | AT_ASSERTM(value.is_contiguous(), "value tensor has to be contiguous"); 94 | AT_ASSERTM(spatial_shapes.is_contiguous(), "spatial_shapes tensor has to be contiguous"); 95 | AT_ASSERTM(level_start_index.is_contiguous(), "level_start_index tensor has to be contiguous"); 96 | AT_ASSERTM(sampling_loc.is_contiguous(), "sampling_loc tensor has to be contiguous"); 97 | AT_ASSERTM(attn_weight.is_contiguous(), "attn_weight tensor has to be contiguous"); 98 | AT_ASSERTM(grad_output.is_contiguous(), "grad_output tensor has to be contiguous"); 99 | 100 | AT_ASSERTM(value.type().is_cuda(), "value must be a CUDA tensor"); 101 | AT_ASSERTM(spatial_shapes.type().is_cuda(), "spatial_shapes must be a CUDA tensor"); 102 | AT_ASSERTM(level_start_index.type().is_cuda(), "level_start_index must be a CUDA tensor"); 103 | AT_ASSERTM(sampling_loc.type().is_cuda(), "sampling_loc must be a CUDA tensor"); 104 | AT_ASSERTM(attn_weight.type().is_cuda(), "attn_weight must be a CUDA tensor"); 105 | AT_ASSERTM(grad_output.type().is_cuda(), "grad_output must be a CUDA tensor"); 106 | 107 | const int batch = value.size(0); 108 | const int spatial_size = value.size(1); 109 | const int num_heads = value.size(2); 110 | const int channels = value.size(3); 111 | 112 | const int num_levels = spatial_shapes.size(0); 113 | 114 | const int num_query = sampling_loc.size(1); 115 | const int num_point = sampling_loc.size(4); 116 | 117 | const int im2col_step_ = std::min(batch, im2col_step); 118 | 119 | AT_ASSERTM(batch % im2col_step_ == 0, "batch(%d) must divide im2col_step(%d)", batch, im2col_step_); 120 | 121 | auto grad_value = at::zeros_like(value); 122 | auto grad_sampling_loc = at::zeros_like(sampling_loc); 123 | auto grad_attn_weight = at::zeros_like(attn_weight); 124 | 125 | const int batch_n = im2col_step_; 126 | auto per_value_size = spatial_size * num_heads * channels; 127 | auto per_sample_loc_size = num_query * num_heads * num_levels * num_point * 2; 128 | auto per_attn_weight_size = num_query * num_heads * num_levels * num_point; 129 | auto grad_output_n = grad_output.view({batch/im2col_step_, batch_n, num_query, num_heads, channels}); 130 | 131 | for (int n = 0; n < batch/im2col_step_; ++n) 132 | { 133 | auto grad_output_g = grad_output_n.select(0, n); 134 | AT_DISPATCH_FLOATING_TYPES(value.type(), "ms_deform_attn_backward_cuda", ([&] { 135 | ms_deformable_col2im_cuda(at::cuda::getCurrentCUDAStream(), 136 | grad_output_g.data(), 137 | value.data() + n * im2col_step_ * per_value_size, 138 | spatial_shapes.data(), 139 | level_start_index.data(), 140 | sampling_loc.data() + n * im2col_step_ * per_sample_loc_size, 141 | attn_weight.data() + n * im2col_step_ * per_attn_weight_size, 142 | batch_n, spatial_size, num_heads, channels, num_levels, num_query, num_point, 143 | grad_value.data() + n * im2col_step_ * per_value_size, 144 | grad_sampling_loc.data() + n * im2col_step_ * per_sample_loc_size, 145 | grad_attn_weight.data() + n * im2col_step_ * per_attn_weight_size); 146 | 147 | })); 148 | } 149 | 150 | return { 151 | grad_value, grad_sampling_loc, grad_attn_weight 152 | }; 153 | } -------------------------------------------------------------------------------- /TransDoD/models/ops/src/cuda/ms_deform_attn_cuda.h: -------------------------------------------------------------------------------- 1 | /*! 2 | ************************************************************************************************** 3 | * Deformable DETR 4 | * Copyright (c) 2020 SenseTime. All Rights Reserved. 5 | * Licensed under the Apache License, Version 2.0 [see LICENSE for details] 6 | ************************************************************************************************** 7 | * Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 8 | ************************************************************************************************** 9 | */ 10 | 11 | #pragma once 12 | #include 13 | 14 | at::Tensor ms_deform_attn_cuda_forward( 15 | const at::Tensor &value, 16 | const at::Tensor &spatial_shapes, 17 | const at::Tensor &level_start_index, 18 | const at::Tensor &sampling_loc, 19 | const at::Tensor &attn_weight, 20 | const int im2col_step); 21 | 22 | std::vector ms_deform_attn_cuda_backward( 23 | const at::Tensor &value, 24 | const at::Tensor &spatial_shapes, 25 | const at::Tensor &level_start_index, 26 | const at::Tensor &sampling_loc, 27 | const at::Tensor &attn_weight, 28 | const at::Tensor &grad_output, 29 | const int im2col_step); 30 | 31 | -------------------------------------------------------------------------------- /TransDoD/models/ops/src/ms_deform_attn.h: -------------------------------------------------------------------------------- 1 | /*! 2 | ************************************************************************************************** 3 | * Deformable DETR 4 | * Copyright (c) 2020 SenseTime. All Rights Reserved. 5 | * Licensed under the Apache License, Version 2.0 [see LICENSE for details] 6 | ************************************************************************************************** 7 | * Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 8 | ************************************************************************************************** 9 | */ 10 | 11 | #pragma once 12 | 13 | #include "cpu/ms_deform_attn_cpu.h" 14 | 15 | #ifdef WITH_CUDA 16 | #include "cuda/ms_deform_attn_cuda.h" 17 | #endif 18 | 19 | 20 | at::Tensor 21 | ms_deform_attn_forward( 22 | const at::Tensor &value, 23 | const at::Tensor &spatial_shapes, 24 | const at::Tensor &level_start_index, 25 | const at::Tensor &sampling_loc, 26 | const at::Tensor &attn_weight, 27 | const int im2col_step) 28 | { 29 | if (value.type().is_cuda()) 30 | { 31 | #ifdef WITH_CUDA 32 | return ms_deform_attn_cuda_forward( 33 | value, spatial_shapes, level_start_index, sampling_loc, attn_weight, im2col_step); 34 | #else 35 | AT_ERROR("Not compiled with GPU support"); 36 | #endif 37 | } 38 | AT_ERROR("Not implemented on the CPU"); 39 | } 40 | 41 | std::vector 42 | ms_deform_attn_backward( 43 | const at::Tensor &value, 44 | const at::Tensor &spatial_shapes, 45 | const at::Tensor &level_start_index, 46 | const at::Tensor &sampling_loc, 47 | const at::Tensor &attn_weight, 48 | const at::Tensor &grad_output, 49 | const int im2col_step) 50 | { 51 | if (value.type().is_cuda()) 52 | { 53 | #ifdef WITH_CUDA 54 | return ms_deform_attn_cuda_backward( 55 | value, spatial_shapes, level_start_index, sampling_loc, attn_weight, grad_output, im2col_step); 56 | #else 57 | AT_ERROR("Not compiled with GPU support"); 58 | #endif 59 | } 60 | AT_ERROR("Not implemented on the CPU"); 61 | } 62 | 63 | -------------------------------------------------------------------------------- /TransDoD/models/ops/src/vision.cpp: -------------------------------------------------------------------------------- 1 | /*! 2 | ************************************************************************************************** 3 | * Deformable DETR 4 | * Copyright (c) 2020 SenseTime. All Rights Reserved. 5 | * Licensed under the Apache License, Version 2.0 [see LICENSE for details] 6 | ************************************************************************************************** 7 | * Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 8 | ************************************************************************************************** 9 | */ 10 | 11 | #include "ms_deform_attn.h" 12 | 13 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 14 | m.def("ms_deform_attn_forward", &ms_deform_attn_forward, "ms_deform_attn_forward"); 15 | m.def("ms_deform_attn_backward", &ms_deform_attn_backward, "ms_deform_attn_backward"); 16 | } 17 | -------------------------------------------------------------------------------- /TransDoD/models/ops/test.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------------------------ 2 | # Deformable DETR 3 | # Copyright (c) 2020 SenseTime. All Rights Reserved. 4 | # Licensed under the Apache License, Version 2.0 [see LICENSE for details] 5 | # ------------------------------------------------------------------------------------------------ 6 | # Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 7 | # ------------------------------------------------------------------------------------------------ 8 | 9 | from __future__ import absolute_import 10 | from __future__ import print_function 11 | from __future__ import division 12 | 13 | import time 14 | import torch 15 | import torch.nn as nn 16 | from torch.autograd import gradcheck 17 | 18 | from functions.ms_deform_attn_func import MSDeformAttnFunction, ms_deform_attn_core_pytorch 19 | 20 | 21 | N, M, D = 1, 2, 2 22 | Lq, L, P = 2, 2, 2 23 | shapes = torch.as_tensor([(6, 4), (3, 2)], dtype=torch.long).cuda() 24 | level_start_index = torch.cat((shapes.new_zeros((1, )), shapes.prod(1).cumsum(0)[:-1])) 25 | S = sum([(H*W).item() for H, W in shapes]) 26 | 27 | 28 | torch.manual_seed(3) 29 | 30 | 31 | @torch.no_grad() 32 | def check_forward_equal_with_pytorch_double(): 33 | value = torch.rand(N, S, M, D).cuda() * 0.01 34 | sampling_locations = torch.rand(N, Lq, M, L, P, 2).cuda() 35 | attention_weights = torch.rand(N, Lq, M, L, P).cuda() + 1e-5 36 | attention_weights /= attention_weights.sum(-1, keepdim=True).sum(-2, keepdim=True) 37 | im2col_step = 2 38 | output_pytorch = ms_deform_attn_core_pytorch(value.double(), shapes, sampling_locations.double(), attention_weights.double()).detach().cpu() 39 | output_cuda = MSDeformAttnFunction.apply(value.double(), shapes, level_start_index, sampling_locations.double(), attention_weights.double(), im2col_step).detach().cpu() 40 | fwdok = torch.allclose(output_cuda, output_pytorch) 41 | max_abs_err = (output_cuda - output_pytorch).abs().max() 42 | max_rel_err = ((output_cuda - output_pytorch).abs() / output_pytorch.abs()).max() 43 | 44 | print(f'* {fwdok} check_forward_equal_with_pytorch_double: max_abs_err {max_abs_err:.2e} max_rel_err {max_rel_err:.2e}') 45 | 46 | 47 | @torch.no_grad() 48 | def check_forward_equal_with_pytorch_float(): 49 | value = torch.rand(N, S, M, D).cuda() * 0.01 50 | sampling_locations = torch.rand(N, Lq, M, L, P, 2).cuda() 51 | attention_weights = torch.rand(N, Lq, M, L, P).cuda() + 1e-5 52 | attention_weights /= attention_weights.sum(-1, keepdim=True).sum(-2, keepdim=True) 53 | im2col_step = 2 54 | output_pytorch = ms_deform_attn_core_pytorch(value, shapes, sampling_locations, attention_weights).detach().cpu() 55 | output_cuda = MSDeformAttnFunction.apply(value, shapes, level_start_index, sampling_locations, attention_weights, im2col_step).detach().cpu() 56 | fwdok = torch.allclose(output_cuda, output_pytorch, rtol=1e-2, atol=1e-3) 57 | max_abs_err = (output_cuda - output_pytorch).abs().max() 58 | max_rel_err = ((output_cuda - output_pytorch).abs() / output_pytorch.abs()).max() 59 | 60 | print(f'* {fwdok} check_forward_equal_with_pytorch_float: max_abs_err {max_abs_err:.2e} max_rel_err {max_rel_err:.2e}') 61 | 62 | 63 | def check_gradient_numerical(channels=4, grad_value=True, grad_sampling_loc=True, grad_attn_weight=True): 64 | 65 | value = torch.rand(N, S, M, channels).cuda() * 0.01 66 | sampling_locations = torch.rand(N, Lq, M, L, P, 2).cuda() 67 | attention_weights = torch.rand(N, Lq, M, L, P).cuda() + 1e-5 68 | attention_weights /= attention_weights.sum(-1, keepdim=True).sum(-2, keepdim=True) 69 | im2col_step = 2 70 | func = MSDeformAttnFunction.apply 71 | 72 | value.requires_grad = grad_value 73 | sampling_locations.requires_grad = grad_sampling_loc 74 | attention_weights.requires_grad = grad_attn_weight 75 | 76 | gradok = gradcheck(func, (value.double(), shapes, level_start_index, sampling_locations.double(), attention_weights.double(), im2col_step)) 77 | 78 | print(f'* {gradok} check_gradient_numerical(D={channels})') 79 | 80 | 81 | if __name__ == '__main__': 82 | check_forward_equal_with_pytorch_double() 83 | check_forward_equal_with_pytorch_float() 84 | 85 | for channels in [30, 32, 64, 71, 1025, 2048, 3096]: 86 | check_gradient_numerical(channels, True, True, True) 87 | 88 | 89 | 90 | -------------------------------------------------------------------------------- /TransDoD/models/position_encoding.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | """ 3 | Various positional encodings for the transformer. 4 | """ 5 | import sys 6 | sys.path.append("../../utils/") 7 | 8 | import math 9 | import torch 10 | from torch import nn 11 | from utils.misc import NestedTensor 12 | 13 | class PositionEmbeddingSine(nn.Module): 14 | """ 15 | This is a more standard version of the position embedding, very similar to the one 16 | used by the Attention is all you need paper, generalized to work on images. 17 | """ 18 | def __init__(self, num_pos_feats=64, temperature=10000, normalize=False, scale=None): 19 | super().__init__() 20 | self.num_pos_feats = num_pos_feats 21 | self.temperature = temperature 22 | self.normalize = normalize 23 | if scale is not None and normalize is False: 24 | raise ValueError("normalize should be True if scale is passed") 25 | if scale is None: 26 | scale = 2 * math.pi 27 | self.scale = scale 28 | 29 | def forward(self, tensor): 30 | bs, c, d, h, w = tensor.shape 31 | mask = torch.zeros((bs, d, h, w), dtype=torch.bool).cuda() 32 | assert mask is not None 33 | not_mask = ~mask 34 | d_embed = not_mask.cumsum(1, dtype=torch.float32) 35 | y_embed = not_mask.cumsum(2, dtype=torch.float32) 36 | x_embed = not_mask.cumsum(3, dtype=torch.float32) 37 | if self.normalize: 38 | eps = 1e-6 39 | d_embed = d_embed / (d_embed[:, -1:, :, :] + eps) * self.scale 40 | y_embed = y_embed / (y_embed[:, :, -1:, :] + eps) * self.scale 41 | x_embed = x_embed / (x_embed[:, :, :, -1:] + eps) * self.scale 42 | 43 | dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=tensor.device) 44 | dim_t = self.temperature ** (3 * (dim_t // 3) / self.num_pos_feats) 45 | 46 | pos_x = x_embed[:, :, :, :, None] / dim_t 47 | pos_y = y_embed[:, :, :, :, None] / dim_t 48 | pos_d = d_embed[:, :, :, :, None] / dim_t 49 | pos_x = torch.stack((pos_x[:, :, :, :, 0::2].sin(), pos_x[:, :, :, :, 1::2].cos()), dim=5).flatten(4) 50 | pos_y = torch.stack((pos_y[:, :, :, :, 0::2].sin(), pos_y[:, :, :, :, 1::2].cos()), dim=5).flatten(4) 51 | pos_d = torch.stack((pos_d[:, :, :, :, 0::2].sin(), pos_d[:, :, :, :, 1::2].cos()), dim=5).flatten(4) 52 | pos = torch.cat((pos_d, pos_y, pos_x), dim=4).permute(0, 4, 1, 2, 3) 53 | return pos 54 | 55 | 56 | 57 | class PositionEmbeddingLearned(nn.Module): 58 | """ 59 | Absolute pos embedding, learned. 60 | """ 61 | def __init__(self, num_pos_feats=256): 62 | super().__init__() 63 | self.row_embed = nn.Embedding(50, num_pos_feats) 64 | self.col_embed = nn.Embedding(50, num_pos_feats) 65 | self.reset_parameters() 66 | 67 | def reset_parameters(self): 68 | nn.init.uniform_(self.row_embed.weight) 69 | nn.init.uniform_(self.col_embed.weight) 70 | 71 | def forward(self, tensor_list: NestedTensor): 72 | x = tensor_list.tensors 73 | h, w = x.shape[-2:] 74 | i = torch.arange(w, device=x.device) 75 | j = torch.arange(h, device=x.device) 76 | x_emb = self.col_embed(i) 77 | y_emb = self.row_embed(j) 78 | pos = torch.cat([ 79 | x_emb.unsqueeze(0).repeat(h, 1, 1), 80 | y_emb.unsqueeze(1).repeat(1, w, 1), 81 | ], dim=-1).permute(2, 0, 1).unsqueeze(0).repeat(x.shape[0], 1, 1, 1) 82 | return pos 83 | 84 | 85 | def build_position_encoding(args): 86 | N_steps = args.hidden_dim // 3 87 | if args.position_embedding in ('v2', 'sine'): 88 | # TODO find a better way of exposing other arguments 89 | position_embedding = PositionEmbeddingSine(N_steps, normalize=True) 90 | elif args.position_embedding in ('v3', 'learned'): 91 | position_embedding = PositionEmbeddingLearned(N_steps) 92 | else: 93 | raise ValueError(f"not supported {args.position_embedding}") 94 | 95 | return position_embedding 96 | -------------------------------------------------------------------------------- /__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jianpengz/DoDNet/f06ce1b02988b6894a85f925623c6bcc6012284f/__init__.py -------------------------------------------------------------------------------- /__pycache__/engine.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jianpengz/DoDNet/f06ce1b02988b6894a85f925623c6bcc6012284f/__pycache__/engine.cpython-37.pyc -------------------------------------------------------------------------------- /a_DynConv/.train.py.swp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jianpengz/DoDNet/f06ce1b02988b6894a85f925623c6bcc6012284f/a_DynConv/.train.py.swp -------------------------------------------------------------------------------- /a_DynConv/dodnet.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jianpengz/DoDNet/f06ce1b02988b6894a85f925623c6bcc6012284f/a_DynConv/dodnet.png -------------------------------------------------------------------------------- /a_DynConv/evaluate.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os, sys 3 | 4 | sys.path.append("..") 5 | 6 | import torch 7 | import torch.nn as nn 8 | from torch.utils import data 9 | import numpy as np 10 | import pickle 11 | import cv2 12 | import torch.optim as optim 13 | import scipy.misc 14 | import torch.backends.cudnn as cudnn 15 | import torch.nn.functional as F 16 | import matplotlib.pyplot as plt 17 | from scipy.ndimage.filters import gaussian_filter 18 | 19 | # from tqdm import tqdm 20 | import os.path as osp 21 | 22 | from unet3D_DynConv882 import UNet3D 23 | from MOTSDataset import MOTSValDataSet 24 | 25 | import random 26 | import timeit 27 | from tensorboardX import SummaryWriter 28 | from loss_functions import loss 29 | 30 | from sklearn import metrics 31 | import nibabel as nib 32 | from math import ceil 33 | 34 | from engine import Engine 35 | from apex import amp 36 | from apex.parallel import convert_syncbn_model 37 | 38 | start = timeit.default_timer() 39 | 40 | 41 | def str2bool(v): 42 | if v.lower() in ('yes', 'true', 't', 'y', '1'): 43 | return True 44 | elif v.lower() in ('no', 'false', 'f', 'n', '0'): 45 | return False 46 | else: 47 | raise argparse.ArgumentTypeError('Boolean value expected.') 48 | 49 | 50 | def get_arguments(): 51 | """Parse all the arguments provided from the CLI. 52 | 53 | Returns: 54 | A list of parsed arguments. 55 | """ 56 | parser = argparse.ArgumentParser(description="MOTS: DynConv solution!") 57 | 58 | parser.add_argument("--data_dir", type=str, default='../dataset/') 59 | parser.add_argument("--val_list", type=str, default='list/MOTS/tt.txt') 60 | parser.add_argument("--reload_path", type=str, default='snapshots/fold1/MOTS_DynConv_fold1_final_e999.pth') 61 | parser.add_argument("--reload_from_checkpoint", type=str2bool, default=True) 62 | parser.add_argument("--save_path", type=str, default='outputs/') 63 | 64 | parser.add_argument("--input_size", type=str, default='64,192,192') 65 | parser.add_argument("--batch_size", type=int, default=1) 66 | parser.add_argument("--num_gpus", type=int, default=1) 67 | parser.add_argument('--local_rank', type=int, default=0) 68 | parser.add_argument("--FP16", type=str2bool, default=False) 69 | parser.add_argument("--num_epochs", type=int, default=500) 70 | parser.add_argument("--patience", type=int, default=3) 71 | parser.add_argument("--start_epoch", type=int, default=0) 72 | parser.add_argument("--val_pred_every", type=int, default=10) 73 | parser.add_argument("--learning_rate", type=float, default=1e-3) 74 | parser.add_argument("--num_classes", type=int, default=2) 75 | parser.add_argument("--num_workers", type=int, default=1) 76 | 77 | parser.add_argument("--weight_std", type=str2bool, default=True) 78 | parser.add_argument("--momentum", type=float, default=0.9) 79 | parser.add_argument("--power", type=float, default=0.9) 80 | parser.add_argument("--weight_decay", type=float, default=0.0005) 81 | 82 | return parser 83 | 84 | 85 | 86 | def dice_score(preds, labels): # on GPU 87 | assert preds.shape[0] == labels.shape[0], "predict & target batch size don't match" 88 | predict = preds.contiguous().view(preds.shape[0], -1) 89 | target = labels.contiguous().view(labels.shape[0], -1) 90 | 91 | num = torch.sum(torch.mul(predict, target), dim=1) 92 | den = torch.sum(predict, dim=1) + torch.sum(target, dim=1) + 1 93 | 94 | dice = 2 * num / den 95 | 96 | return dice.mean() 97 | 98 | 99 | 100 | def _get_gaussian(patch_size, sigma_scale=1. / 8) -> np.ndarray: 101 | tmp = np.zeros(patch_size) 102 | center_coords = [i // 2 for i in patch_size] 103 | sigmas = [i * sigma_scale for i in patch_size] 104 | tmp[tuple(center_coords)] = 1 105 | gaussian_importance_map = gaussian_filter(tmp, sigmas, 0, mode='constant', cval=0) 106 | gaussian_importance_map = gaussian_importance_map / np.max(gaussian_importance_map) * 1 107 | gaussian_importance_map = gaussian_importance_map.astype(np.float32) 108 | 109 | # gaussian_importance_map cannot be 0, otherwise we may end up with nans! 110 | gaussian_importance_map[gaussian_importance_map == 0] = np.min( 111 | gaussian_importance_map[gaussian_importance_map != 0]) 112 | 113 | return gaussian_importance_map 114 | 115 | def multi_net(net_list, img, task_id): 116 | # img = torch.from_numpy(img).cuda() 117 | 118 | padded_prediction = net_list[0](img, task_id) 119 | padded_prediction = F.sigmoid(padded_prediction) 120 | for i in range(1, len(net_list)): 121 | padded_prediction_i = net_list[i](img, task_id) 122 | padded_prediction_i = F.sigmoid(padded_prediction_i) 123 | padded_prediction += padded_prediction_i 124 | padded_prediction /= len(net_list) 125 | return padded_prediction#.cpu().data.numpy() 126 | 127 | def predict_sliding(args, net_list, image, tile_size, classes, task_id): # tile_size:32x256x256 128 | gaussian_importance_map = _get_gaussian(tile_size, sigma_scale=1. / 8) 129 | 130 | image_size = image.shape 131 | overlap = 1 / 2 132 | 133 | strideHW = ceil(tile_size[1] * (1 - overlap)) 134 | strideD = ceil(tile_size[0] * (1 - overlap)) 135 | tile_deps = int(ceil((image_size[2] - tile_size[0]) / strideD) + 1) 136 | tile_rows = int(ceil((image_size[3] - tile_size[1]) / strideHW) + 1) # strided convolution formula 137 | tile_cols = int(ceil((image_size[4] - tile_size[2]) / strideHW) + 1) 138 | # print("Need %i x %i x %i prediction tiles @ stride %i x %i px" % (tile_deps, tile_cols, tile_rows, strideD, strideHW)) 139 | full_probs = np.zeros((image_size[0], classes, image_size[2], image_size[3], image_size[4]))#.astype(np.float32) # 1x4x155x240x240 140 | count_predictions = np.zeros((image_size[0], classes, image_size[2], image_size[3], image_size[4]))#.astype(np.float32) 141 | full_probs = torch.from_numpy(full_probs) 142 | count_predictions = torch.from_numpy(count_predictions) 143 | 144 | for dep in range(tile_deps): 145 | for row in range(tile_rows): 146 | for col in range(tile_cols): 147 | d1 = int(dep * strideD) 148 | x1 = int(col * strideHW) 149 | y1 = int(row * strideHW) 150 | d2 = min(d1 + tile_size[0], image_size[2]) 151 | x2 = min(x1 + tile_size[2], image_size[4]) 152 | y2 = min(y1 + tile_size[1], image_size[3]) 153 | d1 = max(int(d2 - tile_size[0]), 0) 154 | x1 = max(int(x2 - tile_size[2]), 0) # for portrait images the x1 underflows sometimes 155 | y1 = max(int(y2 - tile_size[1]), 0) # for very few rows y1 underflows 156 | 157 | img = image[:, :, d1:d2, y1:y2, x1:x2] 158 | img = torch.from_numpy(img).cuda() 159 | 160 | prediction1 = multi_net(net_list, img, task_id) 161 | prediction2 = torch.flip(multi_net(net_list, torch.flip(img, [2]), task_id), [2]) 162 | prediction3 = torch.flip(multi_net(net_list, torch.flip(img, [3]), task_id), [3]) 163 | prediction4 = torch.flip(multi_net(net_list, torch.flip(img, [4]), task_id), [4]) 164 | prediction5 = torch.flip(multi_net(net_list, torch.flip(img, [2,3]), task_id), [2,3]) 165 | prediction6 = torch.flip(multi_net(net_list, torch.flip(img, [2,4]), task_id), [2,4]) 166 | prediction7 = torch.flip(multi_net(net_list, torch.flip(img, [3,4]), task_id), [3,4]) 167 | prediction8 = torch.flip(multi_net(net_list, torch.flip(img, [2,3,4]), task_id), [2,3,4]) 168 | prediction = (prediction1 + prediction2 + prediction3 + prediction4 + prediction5 + prediction6 + prediction7 + prediction8) / 8. 169 | prediction = prediction.cpu() 170 | 171 | prediction[:,:] *= gaussian_importance_map 172 | 173 | if isinstance(prediction, list): 174 | shape = np.array(prediction[0].shape) 175 | shape[0] = prediction[0].shape[0] * len(prediction) 176 | shape = tuple(shape) 177 | preds = torch.zeros(shape).cuda() 178 | bs_singlegpu = prediction[0].shape[0] 179 | for i in range(len(prediction)): 180 | preds[i * bs_singlegpu: (i + 1) * bs_singlegpu] = prediction[i] 181 | count_predictions[:, :, d1:d2, y1:y2, x1:x2] += 1 182 | full_probs[:, :, d1:d2, y1:y2, x1:x2] += preds 183 | 184 | else: 185 | count_predictions[:, :, d1:d2, y1:y2, x1:x2] += gaussian_importance_map 186 | full_probs[:, :, d1:d2, y1:y2, x1:x2] += prediction 187 | 188 | full_probs /= count_predictions 189 | return full_probs 190 | 191 | def save_nii(args, pred, label, name, affine): # bs, c, WHD 192 | seg_pred_2class = np.asarray(np.around(pred), dtype=np.uint8) 193 | seg_pred_0 = seg_pred_2class[:, 0, :, :, :] 194 | seg_pred_1 = seg_pred_2class[:, 1, :, :, :] 195 | seg_pred = np.zeros_like(seg_pred_0) 196 | if name[0][0:3]!='spl': 197 | seg_pred = np.where(seg_pred_0 == 1, 1, seg_pred) 198 | seg_pred = np.where(seg_pred_1 == 1, 2, seg_pred) 199 | else:# spleen only organ 200 | seg_pred = seg_pred_0 201 | 202 | label_0 = label[:, 0, :, :, :] 203 | label_1 = label[:, 1, :, :, :] 204 | seg_label = np.zeros_like(label_0) 205 | seg_label = np.where(label_0 == 1, 1, seg_label) 206 | seg_label = np.where(label_1 == 1, 2, seg_label) 207 | 208 | if name[0][0:3]!='cas': 209 | seg_pred = seg_pred.transpose((0, 2, 3, 1)) 210 | seg_label = seg_label.transpose((0, 2, 3, 1)) 211 | 212 | # save 213 | for tt in range(seg_pred.shape[0]): 214 | seg_pred_tt = seg_pred[tt] 215 | seg_label_tt = seg_label[tt] 216 | seg_pred_tt = nib.Nifti1Image(seg_pred_tt, affine=affine[tt]) 217 | seg_label_tt = nib.Nifti1Image(seg_label_tt, affine=affine[tt]) 218 | if not os.path.exists(args.save_path): 219 | os.makedirs(args.save_path) 220 | seg_label_save_p = os.path.join(args.save_path + '/%s_label.nii.gz' % (name[tt])) 221 | seg_pred_save_p = os.path.join(args.save_path + '/%s_pred.nii.gz' % (name[tt])) 222 | nib.save(seg_label_tt, seg_label_save_p) 223 | nib.save(seg_pred_tt, seg_pred_save_p) 224 | return None 225 | 226 | def validate(args, input_size, model, ValLoader, num_classes, engine): 227 | 228 | val_loss = torch.zeros(size=(7, 1)).cuda() # np.zeros(shape=(7, 1)) 229 | val_Dice = torch.zeros(size=(7, 2)).cuda() # np.zeros(shape=(7, 2)) 230 | count = torch.zeros(size=(7, 2)).cuda() # np.zeros(shape=(7, 2)) 231 | 232 | for index, batch in enumerate(ValLoader): 233 | # print('%d processd' % (index)) 234 | image, label, name, task_id, affine = batch 235 | 236 | with torch.no_grad(): 237 | 238 | pred_sigmoid = predict_sliding(args, model, image.numpy(), input_size, num_classes, task_id) 239 | 240 | # loss = loss_seg_DICE.forward(pred, label) + loss_seg_CE.forward(pred, label) 241 | loss = torch.tensor(1).cuda() 242 | val_loss[task_id[0], 0] += loss 243 | 244 | if label[0, 0, 0, 0, 0] == -1: 245 | dice_c1 = torch.from_numpy(np.array([-999])) 246 | else: 247 | dice_c1 = dice_score(pred_sigmoid[:, 0, :, :, :], label[:, 0, :, :, :]) 248 | val_Dice[task_id[0], 0] += dice_c1 249 | count[task_id[0], 0] += 1 250 | if label[0, 1, 0, 0, 0] == -1: 251 | dice_c2 = torch.from_numpy(np.array([-999])) 252 | else: 253 | dice_c2 = dice_score(pred_sigmoid[:, 1, :, :, :], label[:, 1, :, :, :]) 254 | val_Dice[task_id[0], 1] += dice_c2 255 | count[task_id[0], 1] += 1 256 | 257 | print('Task%d-%s loss:%.4f Organ:%.4f Tumor:%.4f' % (task_id, name, loss.item(), dice_c1.item(), dice_c2.item())) 258 | 259 | # save 260 | save_nii(args, pred_sigmoid, label, name, affine) 261 | 262 | count[count == 0] = 1 263 | val_Dice = val_Dice / count 264 | val_loss = val_loss / count.max(axis=1)[0].unsqueeze(1) 265 | 266 | reduce_val_loss = torch.zeros_like(val_loss).cuda() 267 | reduce_val_Dice = torch.zeros_like(val_Dice).cuda() 268 | for i in range(val_loss.shape[0]): 269 | reduce_val_loss[i] = engine.all_reduce_tensor(val_loss[i]) 270 | reduce_val_Dice[i] = engine.all_reduce_tensor(val_Dice[i]) 271 | 272 | if args.local_rank == 0: 273 | print("Sum results") 274 | for t in range(7): 275 | print('Sum: Task%d- loss:%.4f Organ:%.4f Tumor:%.4f' % (t, reduce_val_loss[t, 0], reduce_val_Dice[t, 0], reduce_val_Dice[t, 1])) 276 | 277 | return reduce_val_loss.mean(), reduce_val_Dice 278 | 279 | 280 | 281 | 282 | def main(): 283 | """Create the model and start the training.""" 284 | parser = get_arguments() 285 | print(parser) 286 | 287 | with Engine(custom_parser=parser) as engine: 288 | args = parser.parse_args() 289 | if args.num_gpus > 1: 290 | torch.cuda.set_device(args.local_rank) 291 | 292 | d, h, w = map(int, args.input_size.split(',')) 293 | input_size = (d, h, w) 294 | 295 | cudnn.benchmark = True 296 | seed = 1234 297 | if engine.distributed: 298 | seed = args.local_rank 299 | torch.manual_seed(seed) 300 | if torch.cuda.is_available(): 301 | torch.cuda.manual_seed(seed) 302 | 303 | # Create network. 304 | model = UNet3D(num_classes=args.num_classes, weight_std=args.weight_std) 305 | 306 | model = nn.DataParallel(model) 307 | 308 | model.eval() 309 | 310 | device = torch.device('cuda:{}'.format(args.local_rank)) 311 | model.to(device) 312 | 313 | if args.num_gpus > 1: 314 | model = engine.data_parallel(model) 315 | 316 | # load checkpoint... 317 | if args.reload_from_checkpoint: 318 | print('loading from checkpoint: {}'.format(args.reload_path)) 319 | if os.path.exists(args.reload_path): 320 | if args.FP16: 321 | checkpoint = torch.load(args.reload_path, map_location=torch.device('cpu')) 322 | model.load_state_dict(checkpoint['model']) 323 | # optimizer.load_state_dict(checkpoint['optimizer']) 324 | # amp.load_state_dict(checkpoint['amp']) 325 | else: 326 | model.load_state_dict(torch.load(args.reload_path, map_location=torch.device('cpu'))) 327 | else: 328 | print('File not exists in the reload path: {}'.format(args.reload_path)) 329 | 330 | 331 | valloader, val_sampler = engine.get_test_loader( 332 | MOTSValDataSet(args.data_dir, args.val_list)) 333 | 334 | print('validate ...') 335 | val_loss, val_Dice = validate(args, input_size, [model], valloader, args.num_classes, engine) 336 | 337 | print('Validate \n 0Liver={:.4} 0LiverT={:.4} \n 1Kidney={:.4} 1KidneyT={:.4} \n' 338 | ' 2Hepa={:.4} 2HepaT={:.4} \n 3Panc={:.4} 3PancT={:.4} \n 4ColonT={:.4} \n 5LungT={:.4} \n 6Spleen={:.4}' 339 | .format(val_Dice[0, 0].item(), val_Dice[0, 1].item(), 340 | val_Dice[1, 0].item(), val_Dice[1, 1].item(), val_Dice[2, 0].item(), 341 | val_Dice[2, 1].item(), 342 | val_Dice[3, 0].item(), val_Dice[3, 1].item(), val_Dice[4, 1].item(), 343 | val_Dice[5, 1].item(), val_Dice[6, 0].item())) 344 | 345 | end = timeit.default_timer() 346 | print(end - start, 'seconds') 347 | 348 | 349 | if __name__ == '__main__': 350 | main() 351 | -------------------------------------------------------------------------------- /a_DynConv/postp.py: -------------------------------------------------------------------------------- 1 | import os 2 | import os.path as osp 3 | import numpy as np 4 | import random 5 | import collections 6 | import torch 7 | import torchvision 8 | import cv2 9 | from torch.utils import data 10 | import matplotlib.pyplot as plt 11 | import nibabel as nib 12 | from skimage.measure import label as LAB 13 | from skimage.transform import resize 14 | import SimpleITK as sitk 15 | import argparse 16 | 17 | from medpy.metric import hd95 18 | 19 | def get_arguments(): 20 | """Parse all the arguments provided from the CLI. 21 | 22 | Returns: 23 | A list of parsed arguments. 24 | """ 25 | parser = argparse.ArgumentParser(description="Dynconv post processing!") 26 | 27 | parser.add_argument("--img_folder_path", type=str, default='outputs/dodnet/') 28 | 29 | return parser.parse_args() 30 | 31 | args = get_arguments() 32 | 33 | def continues_region_extract_organ(label, keep_region_nums): # keep_region_nums=1 34 | mask = False*np.zeros_like(label) 35 | regions = np.where(label>=1, np.ones_like(label), np.zeros_like(label)) 36 | L, n = LAB(regions, neighbors=4, background=0, connectivity=2, return_num=True) 37 | 38 | # 39 | ary_num = np.zeros(shape=(n+1,1)) 40 | for i in range(0, n+1): 41 | ary_num[i] = np.sum(L==i) 42 | max_index = np.argsort(-ary_num, axis=0) 43 | count=1 44 | for i in range(1, n+1): 45 | if count<=keep_region_nums: # keep 46 | mask = np.where(L == max_index[i][0], label, mask) 47 | count+=1 48 | label = np.where(mask==True, label, np.zeros_like(label)) 49 | return label 50 | 51 | def continues_region_extract_tumor(label): # 52 | 53 | regions = np.where(label>=1, np.ones_like(label), np.zeros_like(label)) 54 | L, n = LAB(regions, neighbors=4, background=0, connectivity=2, return_num=True) 55 | 56 | for i in range(1, n+1): 57 | if np.sum(L==i)<=50 and n>1: # remove default 50 58 | label = np.where(L == i, np.zeros_like(label), label) 59 | 60 | return label 61 | 62 | 63 | 64 | def dice_score(preds, labels): 65 | preds = preds[np.newaxis, :] 66 | labels = labels[np.newaxis, :] 67 | assert preds.shape[0] == labels.shape[0], "predict & target batch size don't match" 68 | predict = preds.view().reshape(preds.shape[0], -1) 69 | target = labels.view().reshape(labels.shape[0], -1) 70 | 71 | num = np.sum(np.multiply(predict, target), axis=1) 72 | den = np.sum(predict, axis=1) + np.sum(target, axis=1) + 1 73 | 74 | dice = 2 * num / den 75 | 76 | return dice.mean() 77 | 78 | 79 | def task_index(name): 80 | if "liver" in name: 81 | return 0 82 | if "case" in name: 83 | return 1 84 | if "hepa" in name: 85 | return 2 86 | if "pancreas" in name: 87 | return 3 88 | if "colon" in name: 89 | return 4 90 | if "lung" in name: 91 | return 5 92 | if "spleen" in name: 93 | return 6 94 | 95 | def compute_HD95(ref, pred): 96 | num_ref = np.sum(ref) 97 | num_pred = np.sum(pred) 98 | 99 | if num_ref == 0: 100 | if num_pred == 0: 101 | return 0 102 | else: 103 | return 373.12866 104 | elif num_pred == 0 and num_ref != 0: 105 | return 373.12866 106 | else: 107 | return hd95(pred, ref, (1, 1, 1)) 108 | 109 | val_Dice = np.zeros(shape=(7, 2)) 110 | val_HD = np.zeros(shape=(7, 2)) 111 | count = np.zeros(shape=(7, 2)) 112 | 113 | for root, dirs, files in os.walk(args.img_folder_path): 114 | for i in sorted(files): 115 | if i[-12:-7] != 'label': 116 | continue 117 | i2 = i[:-12]+'pred'+i[-7:] 118 | i_file = root + i 119 | i2_file = root + i2 120 | predNII = nib.load(i2_file) 121 | labelNII = nib.load(i_file) 122 | pred = predNII.get_data() 123 | label = labelNII.get_data() 124 | 125 | # post-processing 126 | 127 | task_id = task_index(i) 128 | if task_id == 0 or task_id == 1 or task_id == 3: 129 | pred_organ = (pred >= 1) 130 | pred_tumor = (pred == 2) 131 | label_organ = (label >= 1) 132 | label_tumor = (label == 2) 133 | 134 | elif task_id == 2: 135 | pred_organ = (pred == 1) 136 | pred_tumor = (pred == 2) 137 | label_organ = (label == 1) 138 | label_tumor = (label == 2) 139 | 140 | elif task_id == 4 or task_id == 5: 141 | pred_organ = None 142 | pred_tumor = (pred == 2) 143 | label_organ = None 144 | label_tumor = (label == 2) 145 | elif task_id == 6: 146 | pred_organ = (pred == 1) 147 | pred_tumor = None 148 | label_organ = (label == 1) 149 | label_tumor = None 150 | else: 151 | print("No such a task!") 152 | 153 | if task_id == 0: 154 | pred_organ = continues_region_extract_organ(pred_organ, 1) 155 | pred_tumor = np.where(pred_organ == True, pred_tumor, np.zeros_like(pred_tumor)) 156 | pred_tumor = continues_region_extract_tumor(pred_tumor) 157 | elif task_id == 1: 158 | pred_organ = continues_region_extract_organ(pred_organ, 2) 159 | pred_tumor = np.where(pred_organ == True, pred_tumor, np.zeros_like(pred_tumor)) 160 | pred_tumor = continues_region_extract_organ(pred_tumor, 1) 161 | elif task_id == 2: 162 | pred_tumor = continues_region_extract_tumor(pred_tumor) 163 | elif task_id == 3: 164 | pred_organ = continues_region_extract_organ(pred_organ, 1) 165 | pred_tumor = np.where(pred_organ == True, pred_tumor, np.zeros_like(pred_tumor)) 166 | pred_tumor = continues_region_extract_tumor(pred_tumor) 167 | elif task_id == 4: 168 | pred_tumor = continues_region_extract_organ(pred_tumor, 1) 169 | elif task_id == 5: 170 | pred_tumor = continues_region_extract_organ(pred_tumor, 1) 171 | elif task_id == 6: 172 | pred_organ = continues_region_extract_organ(pred_organ, 1) 173 | else: 174 | print("No such a task index!!!") 175 | 176 | if label_organ is not None: 177 | dice_c1 = dice_score(pred_organ, label_organ) 178 | HD_c1 = compute_HD95(label_organ, pred_organ) 179 | val_Dice[task_id, 0] += dice_c1 180 | val_HD[task_id, 0] += HD_c1 181 | count[task_id, 0] += 1 182 | else: 183 | dice_c1=-999 184 | HD_c1=999 185 | if label_tumor is not None: 186 | dice_c2 = dice_score(pred_tumor, label_tumor) 187 | HD_c2 = compute_HD95(label_tumor, pred_tumor) 188 | val_Dice[task_id, 1] += dice_c2 189 | val_HD[task_id, 1] += HD_c2 190 | count[task_id, 1] += 1 191 | else: 192 | dice_c2=-999 193 | HD_c2=999 194 | print("%s: Organ_Dice %f, tumor_Dice %f | Organ_HD %f, tumor_HD %f" % (i[:-13], dice_c1, dice_c2, HD_c1, HD_c2)) 195 | 196 | count[count == 0] = 1 197 | val_Dice = val_Dice / count 198 | val_HD = val_HD / count 199 | 200 | print("Sum results") 201 | for t in range(7): 202 | print('Sum: Task%d- Organ_Dice:%.4f Tumor_Dice:%.4f | Organ_HD:%.4f Tumor_HD:%.4f' % (t, val_Dice[t, 0], val_Dice[t, 1], val_HD[t,0], val_HD[t,1])) 203 | -------------------------------------------------------------------------------- /a_DynConv/train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os, sys 3 | 4 | sys.path.append("..") 5 | 6 | import torch 7 | import torch.nn as nn 8 | from torch.utils import data 9 | import numpy as np 10 | import pickle 11 | import cv2 12 | import torch.optim as optim 13 | import scipy.misc 14 | import torch.backends.cudnn as cudnn 15 | import torch.nn.functional as F 16 | import matplotlib.pyplot as plt 17 | 18 | import os.path as osp 19 | from unet3D_DynConv882 import UNet3D 20 | from MOTSDataset import MOTSDataSet, my_collate 21 | 22 | import random 23 | import timeit 24 | from tensorboardX import SummaryWriter 25 | from loss_functions import loss 26 | 27 | from sklearn import metrics 28 | from math import ceil 29 | 30 | from engine import Engine 31 | from apex import amp 32 | from apex.parallel import convert_syncbn_model 33 | 34 | start = timeit.default_timer() 35 | 36 | 37 | def str2bool(v): 38 | if v.lower() in ('yes', 'true', 't', 'y', '1'): 39 | return True 40 | elif v.lower() in ('no', 'false', 'f', 'n', '0'): 41 | return False 42 | else: 43 | raise argparse.ArgumentTypeError('Boolean value expected.') 44 | 45 | 46 | def get_arguments(): 47 | 48 | parser = argparse.ArgumentParser(description="unet3D_DynConv882") 49 | 50 | parser.add_argument("--data_dir", type=str, default='../dataset/') 51 | parser.add_argument("--train_list", type=str, default='list/MOTS/MOTS_train.txt') 52 | parser.add_argument("--val_list", type=str, default='list/MOTS/xx.txt') 53 | parser.add_argument("--snapshot_dir", type=str, default='snapshots/fold1/') 54 | parser.add_argument("--reload_path", type=str, default='snapshots/fold1/xx.pth') 55 | parser.add_argument("--reload_from_checkpoint", type=str2bool, default=False) 56 | parser.add_argument("--input_size", type=str, default='64,64,64') 57 | parser.add_argument("--batch_size", type=int, default=2) 58 | parser.add_argument("--num_gpus", type=int, default=1) 59 | parser.add_argument('--local_rank', type=int, default=0) 60 | parser.add_argument("--FP16", type=str2bool, default=False) 61 | parser.add_argument("--num_epochs", type=int, default=500) 62 | parser.add_argument("--itrs_each_epoch", type=int, default=250) 63 | parser.add_argument("--patience", type=int, default=3) 64 | parser.add_argument("--start_epoch", type=int, default=0) 65 | parser.add_argument("--val_pred_every", type=int, default=10) 66 | parser.add_argument("--learning_rate", type=float, default=1e-3) 67 | parser.add_argument("--num_classes", type=int, default=2) 68 | parser.add_argument("--num_workers", type=int, default=1) 69 | parser.add_argument("--weight_std", type=str2bool, default=True) 70 | parser.add_argument("--momentum", type=float, default=0.9) 71 | parser.add_argument("--power", type=float, default=0.9) 72 | parser.add_argument("--weight_decay", type=float, default=0.0005) 73 | parser.add_argument("--ignore_label", type=int, default=255) 74 | parser.add_argument("--is_training", action="store_true") 75 | parser.add_argument("--random_mirror", type=str2bool, default=True) 76 | parser.add_argument("--random_scale", type=str2bool, default=True) 77 | parser.add_argument("--random_seed", type=int, default=1234) 78 | parser.add_argument("--gpu", type=str, default='None') 79 | return parser 80 | 81 | 82 | def lr_poly(base_lr, iter, max_iter, power): 83 | return base_lr * ((1 - float(iter) / max_iter) ** (power)) 84 | 85 | 86 | def adjust_learning_rate(optimizer, i_iter, lr, num_stemps, power): 87 | """Sets the learning rate to the initial LR divided by 5 at 60th, 120th and 160th epochs""" 88 | lr = lr_poly(lr, i_iter, num_stemps, power) 89 | optimizer.param_groups[0]['lr'] = lr 90 | return lr 91 | 92 | 93 | def main(): 94 | """Create the model and start the training.""" 95 | parser = get_arguments() 96 | print(parser) 97 | 98 | with Engine(custom_parser=parser) as engine: 99 | args = parser.parse_args() 100 | if args.num_gpus > 1: 101 | torch.cuda.set_device(args.local_rank) 102 | 103 | writer = SummaryWriter(args.snapshot_dir) 104 | 105 | if not args.gpu == 'None': 106 | os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu 107 | 108 | d, h, w = map(int, args.input_size.split(',')) 109 | input_size = (d, h, w) 110 | 111 | cudnn.benchmark = True 112 | seed = args.random_seed 113 | if engine.distributed: 114 | seed = args.local_rank 115 | torch.manual_seed(seed) 116 | if torch.cuda.is_available(): 117 | torch.cuda.manual_seed(seed) 118 | 119 | # Create model 120 | model = UNet3D(num_classes=args.num_classes, weight_std=args.weight_std) 121 | 122 | model.train() 123 | 124 | device = torch.device('cuda:{}'.format(args.local_rank)) 125 | model.to(device) 126 | 127 | optimizer = torch.optim.SGD(model.parameters(), args.learning_rate, momentum=0.99, nesterov=True) 128 | 129 | if args.FP16: 130 | print("Note: Using FP16 during training************") 131 | model, optimizer = amp.initialize(model, optimizer, opt_level="O1") 132 | 133 | if args.num_gpus > 1: 134 | model = engine.data_parallel(model) 135 | 136 | # load checkpoint... 137 | if args.reload_from_checkpoint: 138 | print('loading from checkpoint: {}'.format(args.reload_path)) 139 | if os.path.exists(args.reload_path): 140 | if args.FP16: 141 | checkpoint = torch.load(args.reload_path, map_location=torch.device('cpu')) 142 | model.load_state_dict(checkpoint['model']) 143 | optimizer.load_state_dict(checkpoint['optimizer']) 144 | amp.load_state_dict(checkpoint['amp']) 145 | else: 146 | model.load_state_dict(torch.load(args.reload_path, map_location=torch.device('cpu'))) 147 | else: 148 | print('File not exists in the reload path: {}'.format(args.reload_path)) 149 | 150 | loss_seg_DICE = loss.DiceLoss4MOTS(num_classes=args.num_classes).to(device) 151 | loss_seg_CE = loss.CELoss4MOTS(num_classes=args.num_classes, ignore_index=255).to(device) 152 | 153 | if not os.path.exists(args.snapshot_dir): 154 | os.makedirs(args.snapshot_dir) 155 | 156 | trainloader, train_sampler = engine.get_train_loader( 157 | MOTSDataSet(args.data_dir, args.train_list, max_iters=args.itrs_each_epoch * args.batch_size, crop_size=input_size, scale=args.random_scale, mirror=args.random_mirror), 158 | collate_fn=my_collate) 159 | 160 | all_tr_loss = [] 161 | all_va_loss = [] 162 | train_loss_MA = None 163 | val_loss_MA = None 164 | 165 | val_best_loss = 999999 166 | 167 | for epoch in range(args.num_epochs): 168 | if epoch < args.start_epoch: 169 | continue 170 | 171 | if engine.distributed: 172 | train_sampler.set_epoch(epoch) 173 | 174 | epoch_loss = [] 175 | adjust_learning_rate(optimizer, epoch, args.learning_rate, args.num_epochs, args.power) 176 | 177 | for iter, batch in enumerate(trainloader): 178 | 179 | images = torch.from_numpy(batch['image']).cuda() 180 | labels = torch.from_numpy(batch['label']).cuda() 181 | volumeName = batch['name'] 182 | task_ids = batch['task_id'] 183 | 184 | optimizer.zero_grad() 185 | preds = model(images, task_ids) 186 | 187 | term_seg_Dice = loss_seg_DICE.forward(preds, labels) 188 | term_seg_BCE = loss_seg_CE.forward(preds, labels) 189 | term_all = term_seg_Dice + term_seg_BCE 190 | 191 | reduce_Dice = engine.all_reduce_tensor(term_seg_Dice) 192 | reduce_BCE = engine.all_reduce_tensor(term_seg_BCE) 193 | reduce_all = engine.all_reduce_tensor(term_all) 194 | 195 | if args.FP16: 196 | with amp.scale_loss(term_all, optimizer) as scaled_loss: 197 | scaled_loss.backward() 198 | else: 199 | term_all.backward() 200 | optimizer.step() 201 | 202 | epoch_loss.append(float(reduce_all)) 203 | 204 | if (args.local_rank == 0): 205 | print( 206 | 'Epoch {}: {}/{}, lr = {:.4}, loss_seg_Dice = {:.4}, loss_seg_BCE = {:.4}, loss_Sum = {:.4}'.format( \ 207 | epoch, iter, len(trainloader), optimizer.param_groups[0]['lr'], reduce_Dice.item(), 208 | reduce_BCE.item(), reduce_all.item())) 209 | 210 | epoch_loss = np.mean(epoch_loss) 211 | 212 | all_tr_loss.append(epoch_loss) 213 | 214 | if (args.local_rank == 0): 215 | print('Epoch_sum {}: lr = {:.4}, loss_Sum = {:.4}'.format(epoch, optimizer.param_groups[0]['lr'], 216 | epoch_loss.item())) 217 | writer.add_scalar('learning_rate', optimizer.param_groups[0]['lr'], epoch) 218 | writer.add_scalar('Train_loss', epoch_loss.item(), epoch) 219 | 220 | if (epoch >= 0) and (args.local_rank == 0) and (((epoch % 50 == 0) and (epoch >= 800)) or (epoch % 50 == 0)): 221 | print('save model ...') 222 | if args.FP16: 223 | checkpoint = { 224 | 'model': model.state_dict(), 225 | 'optimizer': optimizer.state_dict(), 226 | 'amp': amp.state_dict() 227 | } 228 | torch.save(checkpoint, osp.join(args.snapshot_dir, 'MOTS_DynConv_' + args.snapshot_dir.split('/')[-2] + '_e' + str(epoch) + '.pth')) 229 | else: 230 | torch.save(model.state_dict(),osp.join(args.snapshot_dir, 'MOTS_DynConv_' + args.snapshot_dir.split('/')[-2] + '_e' + str(epoch) + '.pth')) 231 | 232 | if (epoch >= args.num_epochs - 1) and (args.local_rank == 0): 233 | print('save model ...') 234 | if args.FP16: 235 | checkpoint = { 236 | 'model': model.state_dict(), 237 | 'optimizer': optimizer.state_dict(), 238 | 'amp': amp.state_dict() 239 | } 240 | torch.save(checkpoint, osp.join(args.snapshot_dir, 'MOTS_DynConv_' + args.snapshot_dir.split('/')[-2] + '_final_e' + str(epoch) + '.pth')) 241 | else: 242 | torch.save(model.state_dict(),osp.join(args.snapshot_dir, 'MOTS_DynConv_' + args.snapshot_dir.split('/')[-2] + '_final_e' + str(epoch) + '.pth')) 243 | break 244 | 245 | end = timeit.default_timer() 246 | print(end - start, 'seconds') 247 | 248 | 249 | if __name__ == '__main__': 250 | main() 251 | -------------------------------------------------------------------------------- /a_DynConv/unet3D_DynConv882.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from torch.nn import functional as F 3 | import math 4 | import torch.utils.model_zoo as model_zoo 5 | import torch 6 | import numpy as np 7 | from torch.autograd import Variable 8 | import matplotlib.pyplot as plt 9 | affine_par = True 10 | import functools 11 | 12 | import sys, os 13 | 14 | in_place = True 15 | 16 | class Conv3d(nn.Conv3d): 17 | 18 | def __init__(self, in_channels, out_channels, kernel_size, stride=(1,1,1), padding=(0,0,0), dilation=(1,1,1), groups=1, bias=False): 19 | super(Conv3d, self).__init__(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias) 20 | 21 | def forward(self, x): 22 | weight = self.weight 23 | weight_mean = weight.mean(dim=1, keepdim=True).mean(dim=2, keepdim=True).mean(dim=3, keepdim=True).mean(dim=4, keepdim=True) 24 | weight = weight - weight_mean 25 | std = torch.sqrt(torch.var(weight.view(weight.size(0), -1), dim=1) + 1e-12).view(-1, 1, 1, 1, 1) 26 | weight = weight / std.expand_as(weight) 27 | return F.conv3d(x, weight, self.bias, self.stride, self.padding, self.dilation, self.groups) 28 | 29 | 30 | def conv3x3x3(in_planes, out_planes, kernel_size=(3,3,3), stride=(1,1,1), padding=1, dilation=1, bias=False, weight_std=False): 31 | "3x3x3 convolution with padding" 32 | if weight_std: 33 | return Conv3d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, bias=bias) 34 | else: 35 | return nn.Conv3d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, bias=bias) 36 | 37 | 38 | 39 | 40 | class NoBottleneck(nn.Module): 41 | def __init__(self, inplanes, planes, stride=1, dilation=1, downsample=None, fist_dilation=1, multi_grid=1, weight_std=False): 42 | super(NoBottleneck, self).__init__() 43 | self.weight_std = weight_std 44 | self.gn1 = nn.GroupNorm(16, inplanes) 45 | self.conv1 = conv3x3x3(inplanes, planes, kernel_size=(3, 3, 3), stride=stride, padding=(1,1,1), 46 | dilation=dilation * multi_grid, bias=False, weight_std=self.weight_std) 47 | self.relu = nn.ReLU(inplace=in_place) 48 | 49 | self.gn2 = nn.GroupNorm(16, planes) 50 | self.conv2 = conv3x3x3(planes, planes, kernel_size=(3, 3, 3), stride=1, padding=(1,1,1), 51 | dilation=dilation * multi_grid, bias=False, weight_std=self.weight_std) 52 | self.downsample = downsample 53 | self.dilation = dilation 54 | self.stride = stride 55 | 56 | def forward(self, x): 57 | residual = x 58 | 59 | out = self.gn1(x) 60 | out = self.relu(out) 61 | out = self.conv1(out) 62 | 63 | 64 | out = self.gn2(out) 65 | out = self.relu(out) 66 | out = self.conv2(out) 67 | 68 | if self.downsample is not None: 69 | residual = self.downsample(x) 70 | 71 | out = out + residual 72 | 73 | return out 74 | 75 | 76 | class unet3D(nn.Module): 77 | def __init__(self, layers, num_classes=3, weight_std = False): 78 | self.inplanes = 128 79 | self.weight_std = weight_std 80 | super(unet3D, self).__init__() 81 | 82 | self.conv1 = conv3x3x3(1, 32, stride=[1, 1, 1], weight_std=self.weight_std) 83 | 84 | self.layer0 = self._make_layer(NoBottleneck, 32, 32, layers[0], stride=(1, 1, 1)) 85 | self.layer1 = self._make_layer(NoBottleneck, 32, 64, layers[1], stride=(2, 2, 2)) 86 | self.layer2 = self._make_layer(NoBottleneck, 64, 128, layers[2], stride=(2, 2, 2)) 87 | self.layer3 = self._make_layer(NoBottleneck, 128, 256, layers[3], stride=(2, 2, 2)) 88 | self.layer4 = self._make_layer(NoBottleneck, 256, 256, layers[4], stride=(2, 2, 2)) 89 | 90 | self.fusionConv = nn.Sequential( 91 | nn.GroupNorm(16, 256), 92 | nn.ReLU(inplace=in_place), 93 | conv3x3x3(256, 256, kernel_size=(1, 1, 1), padding=(0, 0, 0), weight_std=self.weight_std) 94 | ) 95 | 96 | self.upsamplex2 = nn.Upsample(scale_factor=2, mode='trilinear') 97 | 98 | self.x8_resb = self._make_layer(NoBottleneck, 256, 128, 1, stride=(1, 1, 1)) 99 | self.x4_resb = self._make_layer(NoBottleneck, 128, 64, 1, stride=(1, 1, 1)) 100 | self.x2_resb = self._make_layer(NoBottleneck, 64, 32, 1, stride=(1, 1, 1)) 101 | self.x1_resb = self._make_layer(NoBottleneck, 32, 32, 1, stride=(1, 1, 1)) 102 | 103 | self.precls_conv = nn.Sequential( 104 | nn.GroupNorm(16, 32), 105 | nn.ReLU(inplace=in_place), 106 | nn.Conv3d(32, 8, kernel_size=1) 107 | ) 108 | 109 | self.GAP = nn.Sequential( 110 | nn.GroupNorm(16, 256), 111 | nn.ReLU(inplace=in_place), 112 | torch.nn.AdaptiveAvgPool3d((1,1,1)) 113 | ) 114 | self.controller = nn.Conv3d(256+7, 162, kernel_size=1, stride=1, padding=0) 115 | 116 | def _make_layer(self, block, inplanes, planes, blocks, stride=(1, 1, 1), dilation=1, multi_grid=1): 117 | downsample = None 118 | if stride[0] != 1 or stride[1] != 1 or stride[2] != 1 or inplanes != planes: 119 | downsample = nn.Sequential( 120 | nn.GroupNorm(16, inplanes), 121 | nn.ReLU(inplace=in_place), 122 | conv3x3x3(inplanes, planes, kernel_size=(1, 1, 1), stride=stride, padding=0, 123 | weight_std=self.weight_std), 124 | ) 125 | 126 | layers = [] 127 | generate_multi_grid = lambda index, grids: grids[index % len(grids)] if isinstance(grids, tuple) else 1 128 | layers.append(block(inplanes, planes, stride, dilation=dilation, downsample=downsample, 129 | multi_grid=generate_multi_grid(0, multi_grid), weight_std=self.weight_std)) 130 | # self.inplanes = planes 131 | for i in range(1, blocks): 132 | layers.append( 133 | block(planes, planes, dilation=dilation, multi_grid=generate_multi_grid(i, multi_grid), 134 | weight_std=self.weight_std)) 135 | 136 | return nn.Sequential(*layers) 137 | 138 | def encoding_task(self, task_id): 139 | N = task_id.shape[0] 140 | task_encoding = torch.zeros(size=(N, 7)) 141 | for i in range(N): 142 | task_encoding[i, task_id[i]]=1 143 | return task_encoding.cuda() 144 | 145 | def parse_dynamic_params(self, params, channels, weight_nums, bias_nums): 146 | assert params.dim() == 2 147 | assert len(weight_nums) == len(bias_nums) 148 | assert params.size(1) == sum(weight_nums) + sum(bias_nums) 149 | 150 | num_insts = params.size(0) 151 | num_layers = len(weight_nums) 152 | 153 | params_splits = list(torch.split_with_sizes( 154 | params, weight_nums + bias_nums, dim=1 155 | )) 156 | 157 | weight_splits = params_splits[:num_layers] 158 | bias_splits = params_splits[num_layers:] 159 | 160 | for l in range(num_layers): 161 | if l < num_layers - 1: 162 | weight_splits[l] = weight_splits[l].reshape(num_insts * channels, -1, 1, 1, 1) 163 | bias_splits[l] = bias_splits[l].reshape(num_insts * channels) 164 | else: 165 | weight_splits[l] = weight_splits[l].reshape(num_insts * 2, -1, 1, 1, 1) 166 | bias_splits[l] = bias_splits[l].reshape(num_insts * 2) 167 | 168 | return weight_splits, bias_splits 169 | 170 | def heads_forward(self, features, weights, biases, num_insts): 171 | assert features.dim() == 5 172 | n_layers = len(weights) 173 | x = features 174 | for i, (w, b) in enumerate(zip(weights, biases)): 175 | x = F.conv3d( 176 | x, w, bias=b, 177 | stride=1, padding=0, 178 | groups=num_insts 179 | ) 180 | if i < n_layers - 1: 181 | x = F.relu(x) 182 | return x 183 | 184 | def forward(self, input, task_id): 185 | 186 | x = self.conv1(input) 187 | x = self.layer0(x) 188 | skip0 = x 189 | 190 | x = self.layer1(x) 191 | skip1 = x 192 | 193 | x = self.layer2(x) 194 | skip2 = x 195 | 196 | x = self.layer3(x) 197 | skip3 = x 198 | 199 | x = self.layer4(x) 200 | 201 | x = self.fusionConv(x) 202 | 203 | # generate conv filters for classification layer 204 | task_encoding = self.encoding_task(task_id) 205 | task_encoding.unsqueeze_(2).unsqueeze_(2).unsqueeze_(2) 206 | x_feat = self.GAP(x) 207 | x_cond = torch.cat([x_feat, task_encoding], 1) 208 | params = self.controller(x_cond) 209 | params.squeeze_(-1).squeeze_(-1).squeeze_(-1) 210 | 211 | 212 | # x8 213 | x = self.upsamplex2(x) 214 | x = x + skip3 215 | x = self.x8_resb(x) 216 | 217 | # x4 218 | x = self.upsamplex2(x) 219 | x = x + skip2 220 | x = self.x4_resb(x) 221 | 222 | # x2 223 | x = self.upsamplex2(x) 224 | x = x + skip1 225 | x = self.x2_resb(x) 226 | 227 | # x1 228 | x = self.upsamplex2(x) 229 | x = x + skip0 230 | x = self.x1_resb(x) 231 | 232 | head_inputs = self.precls_conv(x) 233 | 234 | N, _, D, H, W = head_inputs.size() 235 | head_inputs = head_inputs.reshape(1, -1, D, H, W) 236 | 237 | weight_nums, bias_nums = [], [] 238 | weight_nums.append(8*8) 239 | weight_nums.append(8*8) 240 | weight_nums.append(8*2) 241 | bias_nums.append(8) 242 | bias_nums.append(8) 243 | bias_nums.append(2) 244 | weights, biases = self.parse_dynamic_params(params, 8, weight_nums, bias_nums) 245 | 246 | logits = self.heads_forward(head_inputs, weights, biases, N) 247 | 248 | logits = logits.reshape(-1, 2, D, H, W) 249 | 250 | return logits 251 | 252 | def UNet3D(num_classes=1, weight_std=False): 253 | print("Using DynConv 8,8,2") 254 | model = unet3D([1, 2, 2, 2, 2], num_classes, weight_std) 255 | return model -------------------------------------------------------------------------------- /data_list/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jianpengz/DoDNet/f06ce1b02988b6894a85f925623c6bcc6012284f/data_list/.DS_Store -------------------------------------------------------------------------------- /data_list/MOTS/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jianpengz/DoDNet/f06ce1b02988b6894a85f925623c6bcc6012284f/data_list/MOTS/.DS_Store -------------------------------------------------------------------------------- /data_list/MOTS/MOTS_test.txt: -------------------------------------------------------------------------------- 1 | liver_0 2 | liver_102 3 | liver_107 4 | liver_111 5 | liver_116 6 | liver_120 7 | liver_125 8 | liver_13 9 | liver_17 10 | liver_21 11 | liver_26 12 | liver_30 13 | liver_35 14 | liver_4 15 | liver_44 16 | liver_49 17 | liver_53 18 | liver_58 19 | liver_62 20 | liver_67 21 | liver_71 22 | liver_76 23 | liver_80 24 | liver_85 25 | liver_9 26 | liver_94 27 | liver_99 28 | kidney_00000 29 | kidney_00005 30 | kidney_00010 31 | kidney_00015 32 | kidney_00020 33 | kidney_00025 34 | kidney_00030 35 | kidney_00035 36 | kidney_00040 37 | kidney_00045 38 | kidney_00050 39 | kidney_00055 40 | kidney_00060 41 | kidney_00065 42 | kidney_00070 43 | kidney_00075 44 | kidney_00080 45 | kidney_00085 46 | kidney_00090 47 | kidney_00095 48 | kidney_00100 49 | kidney_00105 50 | kidney_00110 51 | kidney_00115 52 | kidney_00120 53 | kidney_00125 54 | kidney_00130 55 | kidney_00135 56 | kidney_00140 57 | kidney_00145 58 | kidney_00150 59 | kidney_00155 60 | kidney_00160 61 | kidney_00165 62 | kidney_00170 63 | kidney_00175 64 | kidney_00180 65 | kidney_00185 66 | kidney_00190 67 | kidney_00195 68 | kidney_00200 69 | kidney_00205 70 | hepaticvessel_001 71 | hepaticvessel_008 72 | hepaticvessel_018 73 | hepaticvessel_025 74 | hepaticvessel_030 75 | hepaticvessel_040 76 | hepaticvessel_050 77 | hepaticvessel_058 78 | hepaticvessel_066 79 | hepaticvessel_072 80 | hepaticvessel_080 81 | hepaticvessel_085 82 | hepaticvessel_090 83 | hepaticvessel_096 84 | hepaticvessel_103 85 | hepaticvessel_112 86 | hepaticvessel_119 87 | hepaticvessel_127 88 | hepaticvessel_133 89 | hepaticvessel_140 90 | hepaticvessel_146 91 | hepaticvessel_154 92 | hepaticvessel_161 93 | hepaticvessel_167 94 | hepaticvessel_175 95 | hepaticvessel_183 96 | hepaticvessel_192 97 | hepaticvessel_197 98 | hepaticvessel_203 99 | hepaticvessel_210 100 | hepaticvessel_217 101 | hepaticvessel_223 102 | hepaticvessel_230 103 | hepaticvessel_236 104 | hepaticvessel_244 105 | hepaticvessel_256 106 | hepaticvessel_265 107 | hepaticvessel_271 108 | hepaticvessel_279 109 | hepaticvessel_285 110 | hepaticvessel_291 111 | hepaticvessel_299 112 | hepaticvessel_308 113 | hepaticvessel_320 114 | hepaticvessel_325 115 | hepaticvessel_333 116 | hepaticvessel_341 117 | hepaticvessel_350 118 | hepaticvessel_361 119 | hepaticvessel_369 120 | hepaticvessel_375 121 | hepaticvessel_384 122 | hepaticvessel_391 123 | hepaticvessel_400 124 | hepaticvessel_407 125 | hepaticvessel_416 126 | hepaticvessel_424 127 | hepaticvessel_432 128 | hepaticvessel_440 129 | hepaticvessel_445 130 | hepaticvessel_455 131 | pancreas_001 132 | pancreas_012 133 | pancreas_021 134 | pancreas_032 135 | pancreas_042 136 | pancreas_049 137 | pancreas_056 138 | pancreas_067 139 | pancreas_075 140 | pancreas_083 141 | pancreas_089 142 | pancreas_095 143 | pancreas_101 144 | pancreas_106 145 | pancreas_113 146 | pancreas_122 147 | pancreas_129 148 | pancreas_138 149 | pancreas_149 150 | pancreas_160 151 | pancreas_170 152 | pancreas_179 153 | pancreas_186 154 | pancreas_196 155 | pancreas_201 156 | pancreas_210 157 | pancreas_215 158 | pancreas_224 159 | pancreas_229 160 | pancreas_236 161 | pancreas_244 162 | pancreas_254 163 | pancreas_261 164 | pancreas_267 165 | pancreas_275 166 | pancreas_280 167 | pancreas_287 168 | pancreas_293 169 | pancreas_298 170 | pancreas_303 171 | pancreas_310 172 | pancreas_316 173 | pancreas_325 174 | pancreas_330 175 | pancreas_339 176 | pancreas_346 177 | pancreas_354 178 | pancreas_360 179 | pancreas_366 180 | pancreas_374 181 | pancreas_379 182 | pancreas_387 183 | pancreas_393 184 | pancreas_401 185 | pancreas_409 186 | pancreas_414 187 | pancreas_421 188 | colon_001 189 | colon_009 190 | colon_024 191 | colon_029 192 | colon_036 193 | colon_042 194 | colon_053 195 | colon_065 196 | colon_075 197 | colon_088 198 | colon_096 199 | colon_103 200 | colon_111 201 | colon_118 202 | colon_126 203 | colon_134 204 | colon_140 205 | colon_145 206 | colon_157 207 | colon_164 208 | colon_171 209 | colon_185 210 | colon_196 211 | colon_206 212 | colon_214 213 | colon_219 214 | lung_001 215 | lung_009 216 | lung_018 217 | lung_026 218 | lung_033 219 | lung_041 220 | lung_046 221 | lung_053 222 | lung_059 223 | lung_066 224 | lung_074 225 | lung_081 226 | lung_093 227 | spleen_10 228 | spleen_17 229 | spleen_21 230 | spleen_27 231 | spleen_32 232 | spleen_44 233 | spleen_52 234 | spleen_60 235 | spleen_9 236 | -------------------------------------------------------------------------------- /dataset/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jianpengz/DoDNet/f06ce1b02988b6894a85f925623c6bcc6012284f/dataset/.DS_Store -------------------------------------------------------------------------------- /dataset/list/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jianpengz/DoDNet/f06ce1b02988b6894a85f925623c6bcc6012284f/dataset/list/.DS_Store -------------------------------------------------------------------------------- /dataset/re_spacing.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | import nibabel as nib 4 | from skimage.transform import resize 5 | from tqdm import tqdm 6 | import matplotlib.pyplot as plt 7 | import SimpleITK as sitk 8 | 9 | spacing = { 10 | 0: [1.5, 0.8, 0.8], 11 | 1: [1.5, 0.8, 0.8], 12 | 2: [1.5, 0.8, 0.8], 13 | 3: [1.5, 0.8, 0.8], 14 | 4: [1.5, 0.8, 0.8], 15 | 5: [1.5, 0.8, 0.8], 16 | 6: [1.5, 0.8, 0.8], 17 | } 18 | 19 | ori_path = './0123456' 20 | new_path = './0123456_spacing_same' 21 | 22 | count = -1 23 | for root1, dirs1, _ in os.walk(ori_path): 24 | for i_dirs1 in tqdm(sorted(dirs1)): # 0Liver 25 | # if i_dirs1 != '0Liver': 26 | # continue 27 | ########################################################################### 28 | if i_dirs1 == '1Kidney': 29 | for root2, dirs2, files2 in os.walk(os.path.join(root1, i_dirs1)): 30 | 31 | for root3, dirs3, files3 in os.walk(os.path.join(root2, 'origin')): 32 | 33 | for i_dirs3 in sorted(dirs3): # case_00000 34 | # if int(i_dirs3[-2:])!=4: 35 | # continue 36 | 37 | for root4, dirs4, files4 in os.walk(os.path.join(root3, i_dirs3)): 38 | for i_files4 in sorted(files4): 39 | # read img 40 | print("Processing %s" % (i_files4)) 41 | img_path = os.path.join(root4, i_files4) 42 | imageITK = sitk.ReadImage(img_path) 43 | image = sitk.GetArrayFromImage(imageITK) 44 | ori_size = np.array(imageITK.GetSize())[[2, 1, 0]] 45 | ori_spacing = np.array(imageITK.GetSpacing())[[2, 1, 0]] 46 | ori_origin = imageITK.GetOrigin() 47 | ori_direction = imageITK.GetDirection() 48 | 49 | task_id = int(i_dirs1[0]) 50 | target_spacing = np.array(spacing[task_id]) 51 | 52 | if ori_spacing[0] < 0 or ori_spacing[1] < 0 or ori_spacing[2] < 0: 53 | print("error") 54 | spc_ratio = ori_spacing / target_spacing 55 | 56 | data_type = image.dtype 57 | if i_files4 != 'segmentation.nii.gz': 58 | data_type = np.int32 59 | 60 | if i_files4 == 'segmentation.nii.gz': 61 | order = 0 62 | mode_ = 'edge' 63 | else: 64 | order = 3 65 | mode_ = 'constant' 66 | 67 | image = image.astype(np.float) 68 | 69 | image_resize = resize(image, ( 70 | int(ori_size[0] * spc_ratio[0]), int(ori_size[1] * spc_ratio[1]), 71 | int(ori_size[2] * spc_ratio[2])), order=order, cval=0, clip=True, 72 | preserve_range=True) 73 | 74 | image_resize = np.round(image_resize).astype(data_type) 75 | 76 | # save 77 | save_path = os.path.join(new_path, i_dirs1, 'origin', i_dirs3) 78 | if not os.path.exists(save_path): 79 | os.makedirs(save_path) 80 | saveITK = sitk.GetImageFromArray(image_resize) 81 | saveITK.SetSpacing(target_spacing[[2, 1, 0]]) 82 | saveITK.SetOrigin(ori_origin) 83 | saveITK.SetDirection(ori_direction) 84 | sitk.WriteImage(saveITK, os.path.join(save_path, i_files4)) 85 | 86 | ############################################################################# 87 | for root2, dirs2, files2 in os.walk(os.path.join(root1, i_dirs1)): 88 | for i_dirs2 in sorted(dirs2): # imagesTr 89 | 90 | for root3, dirs3, files3 in os.walk(os.path.join(root2, i_dirs2)): 91 | for i_files3 in sorted(files3): 92 | if i_files3[0] == '.': 93 | continue 94 | # read img 95 | print("Processing %s" % (i_files3)) 96 | img_path = os.path.join(root3, i_files3) 97 | imageITK = sitk.ReadImage(img_path) 98 | image = sitk.GetArrayFromImage(imageITK) 99 | ori_size = np.array(imageITK.GetSize())[[2, 1, 0]] 100 | ori_spacing = np.array(imageITK.GetSpacing())[[2, 1, 0]] 101 | ori_origin = imageITK.GetOrigin() 102 | ori_direction = imageITK.GetDirection() 103 | 104 | task_id = int(i_dirs1[0]) 105 | target_spacing = np.array(spacing[task_id]) 106 | spc_ratio = ori_spacing / target_spacing 107 | 108 | data_type = image.dtype 109 | if i_dirs2 != 'labelsTr': 110 | data_type = np.int32 111 | 112 | if i_dirs2 == 'labelsTr': 113 | order = 0 114 | mode_ = 'edge' 115 | else: 116 | order = 3 117 | mode_ = 'constant' 118 | 119 | image = image.astype(np.float) 120 | 121 | image_resize = resize(image, (int(ori_size[0] * spc_ratio[0]), int(ori_size[1] * spc_ratio[1]), 122 | int(ori_size[2] * spc_ratio[2])), 123 | order=order, mode=mode_, cval=0, clip=True, preserve_range=True) 124 | image_resize = np.round(image_resize).astype(data_type) 125 | 126 | # save 127 | save_path = os.path.join(new_path, i_dirs1, i_dirs2) 128 | if not os.path.exists(save_path): 129 | os.makedirs(save_path) 130 | saveITK = sitk.GetImageFromArray(image_resize) 131 | saveITK.SetSpacing(target_spacing[[2, 1, 0]]) 132 | saveITK.SetOrigin(ori_origin) 133 | saveITK.SetDirection(ori_direction) 134 | sitk.WriteImage(saveITK, os.path.join(save_path, i_files3)) 135 | -------------------------------------------------------------------------------- /engine.py: -------------------------------------------------------------------------------- 1 | import os 2 | import os.path as osp 3 | import time 4 | import argparse 5 | 6 | import torch 7 | import torch.distributed as dist 8 | 9 | from utils.logger import get_logger 10 | from utils.pyt_utils import all_reduce_tensor, extant_file 11 | 12 | try: 13 | from apex.parallel import DistributedDataParallel, SyncBatchNorm 14 | except ImportError: 15 | raise ImportError( 16 | "Please install apex from https://www.github.com/nvidia/apex .") 17 | 18 | 19 | logger = get_logger() 20 | 21 | 22 | class Engine(object): 23 | def __init__(self, custom_parser=None): 24 | logger.info( 25 | "PyTorch Version {}".format(torch.__version__)) 26 | self.devices = None 27 | self.distributed = False 28 | 29 | if custom_parser is None: 30 | self.parser = argparse.ArgumentParser() 31 | else: 32 | assert isinstance(custom_parser, argparse.ArgumentParser) 33 | self.parser = custom_parser 34 | 35 | self.inject_default_parser() 36 | self.args = self.parser.parse_args() 37 | 38 | self.continue_state_object = self.args.continue_fpath 39 | 40 | # if not self.args.gpu == 'None': 41 | # os.environ["CUDA_VISIBLE_DEVICES"]=self.args.gpu 42 | 43 | if 'WORLD_SIZE' in os.environ: 44 | self.distributed = int(os.environ['WORLD_SIZE']) > 1 45 | print("WORLD_SIZE is %d" % (int(os.environ['WORLD_SIZE']))) 46 | if self.distributed: 47 | self.local_rank = self.args.local_rank 48 | self.world_size = int(os.environ['WORLD_SIZE']) 49 | torch.cuda.set_device(self.local_rank) 50 | dist.init_process_group(backend="nccl", init_method='env://') 51 | self.devices = [i for i in range(self.world_size)] 52 | else: 53 | gpus = os.environ["CUDA_VISIBLE_DEVICES"] 54 | self.devices = [i for i in range(len(gpus.split(',')))] 55 | 56 | def inject_default_parser(self): 57 | p = self.parser 58 | p.add_argument('-d', '--devices', default='', 59 | help='set data parallel training') 60 | p.add_argument('-c', '--continue', type=extant_file, 61 | metavar="FILE", 62 | dest="continue_fpath", 63 | help='continue from one certain checkpoint') 64 | # p.add_argument('--local_rank', default=0, type=int, 65 | # help='process rank on node') 66 | 67 | def data_parallel(self, model): 68 | if self.distributed: 69 | model = DistributedDataParallel(model) 70 | else: 71 | model = torch.nn.DataParallel(model) 72 | return model 73 | 74 | def get_train_loader(self, train_dataset, collate_fn=None): 75 | train_sampler = None 76 | is_shuffle = True 77 | batch_size = self.args.batch_size 78 | 79 | if self.distributed: 80 | train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset) 81 | batch_size = self.args.batch_size // self.world_size 82 | is_shuffle = False 83 | 84 | train_loader = torch.utils.data.DataLoader(train_dataset, 85 | batch_size=batch_size, 86 | num_workers=self.args.num_workers, 87 | drop_last=False, 88 | shuffle=is_shuffle, 89 | pin_memory=True, 90 | sampler=train_sampler, 91 | collate_fn=collate_fn) 92 | 93 | return train_loader, train_sampler 94 | 95 | def get_test_loader(self, test_dataset): 96 | test_sampler = None 97 | is_shuffle = False 98 | batch_size = self.args.batch_size 99 | 100 | if self.distributed: 101 | test_sampler = torch.utils.data.distributed.DistributedSampler( 102 | test_dataset) 103 | batch_size = self.args.batch_size // self.world_size 104 | 105 | test_loader = torch.utils.data.DataLoader(test_dataset, 106 | batch_size=1, 107 | num_workers=self.args.num_workers, 108 | drop_last=False, 109 | shuffle=is_shuffle, 110 | pin_memory=True, 111 | sampler=test_sampler) 112 | 113 | return test_loader, test_sampler 114 | 115 | 116 | def all_reduce_tensor(self, tensor, norm=True): 117 | if self.distributed: 118 | return all_reduce_tensor(tensor, world_size=self.world_size, norm=norm) 119 | else: 120 | return torch.mean(tensor) 121 | 122 | 123 | def __enter__(self): 124 | return self 125 | 126 | def __exit__(self, type, value, tb): 127 | torch.cuda.empty_cache() 128 | if type is not None: 129 | logger.warning( 130 | "A exception occurred during Engine initialization, " 131 | "give up running process") 132 | return False 133 | -------------------------------------------------------------------------------- /loss_functions/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /loss_functions/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jianpengz/DoDNet/f06ce1b02988b6894a85f925623c6bcc6012284f/loss_functions/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /loss_functions/__pycache__/loss.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jianpengz/DoDNet/f06ce1b02988b6894a85f925623c6bcc6012284f/loss_functions/__pycache__/loss.cpython-37.pyc -------------------------------------------------------------------------------- /loss_functions/loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import torch.nn as nn 4 | from torch.autograd import Variable 5 | import numpy as np 6 | import scipy.ndimage as nd 7 | from matplotlib import pyplot as plt 8 | from torch import Tensor, einsum 9 | 10 | 11 | class BinaryDiceLoss(nn.Module): 12 | def __init__(self, smooth=1, p=2, reduction='mean'): 13 | super(BinaryDiceLoss, self).__init__() 14 | self.smooth = smooth 15 | self.p = p 16 | self.reduction = reduction 17 | 18 | def forward(self, predict, target): 19 | assert predict.shape[0] == target.shape[0], "predict & target batch size don't match" 20 | predict = predict.contiguous().view(predict.shape[0], -1) 21 | target = target.contiguous().view(target.shape[0], -1) 22 | 23 | num = torch.sum(torch.mul(predict, target), dim=1) 24 | den = torch.sum(predict, dim=1) + torch.sum(target, dim=1) + self.smooth 25 | 26 | dice_score = 2*num / den 27 | dice_loss = 1 - dice_score 28 | 29 | dice_loss_avg = dice_loss[target[:,0]!=-1].sum() / dice_loss[target[:,0]!=-1].shape[0] 30 | 31 | return dice_loss_avg 32 | 33 | class DiceLoss4MOTS(nn.Module): 34 | def __init__(self, weight=None, ignore_index=None, num_classes=3, **kwargs): 35 | super(DiceLoss4MOTS, self).__init__() 36 | self.kwargs = kwargs 37 | self.weight = weight 38 | self.ignore_index = ignore_index 39 | self.num_classes = num_classes 40 | self.dice = BinaryDiceLoss(**self.kwargs) 41 | 42 | def forward(self, predict, target): 43 | 44 | total_loss = [] 45 | predict = F.sigmoid(predict) 46 | 47 | for i in range(self.num_classes): 48 | if i != self.ignore_index: 49 | dice_loss = self.dice(predict[:, i], target[:, i]) 50 | if self.weight is not None: 51 | assert self.weight.shape[0] == self.num_classes, \ 52 | 'Expect weight shape [{}], get[{}]'.format(self.num_classes, self.weight.shape[0]) 53 | dice_loss *= self.weights[i] 54 | total_loss.append(dice_loss) 55 | 56 | total_loss = torch.stack(total_loss) 57 | total_loss = total_loss[total_loss==total_loss] 58 | 59 | return total_loss.sum()/total_loss.shape[0] 60 | 61 | class CELoss4MOTS(nn.Module): 62 | def __init__(self, ignore_index=None,num_classes=3, **kwargs): 63 | super(CELoss4MOTS, self).__init__() 64 | self.kwargs = kwargs 65 | self.num_classes = num_classes 66 | self.ignore_index = ignore_index 67 | self.criterion = nn.BCEWithLogitsLoss(reduction='none') 68 | 69 | def weight_function(self, mask): 70 | weights = torch.ones_like(mask).float() 71 | voxels_sum = mask.shape[0] * mask.shape[1] * mask.shape[2] 72 | for i in range(2): 73 | voxels_i = [mask == i][0].sum().cpu().numpy() 74 | w_i = np.log(voxels_sum / voxels_i).astype(np.float32) 75 | weights = torch.where(mask == i, w_i * torch.ones_like(weights).float(), weights) 76 | 77 | return weights 78 | 79 | def forward(self, predict, target): 80 | assert predict.shape == target.shape, 'predict & target shape do not match' 81 | 82 | total_loss = [] 83 | for i in range(self.num_classes): 84 | if i != self.ignore_index: 85 | ce_loss = self.criterion(predict[:, i], target[:, i]) 86 | ce_loss = torch.mean(ce_loss, dim=[1,2,3]) 87 | 88 | ce_loss_avg = ce_loss[target[:, i, 0, 0, 0] != -1].sum() / ce_loss[target[:, i, 0, 0, 0] != -1].shape[0] 89 | 90 | total_loss.append(ce_loss_avg) 91 | 92 | total_loss = torch.stack(total_loss) 93 | total_loss = total_loss[total_loss == total_loss] 94 | 95 | return total_loss.sum()/total_loss.shape[0] 96 | -------------------------------------------------------------------------------- /run_script.sh: -------------------------------------------------------------------------------- 1 | cd TransDoD/ 2 | 3 | # Training 4 | CUDA_VISIBLE_DEVICES=0,1 python -m torch.distributed.launch --nproc_per_node=2 train.py \ 5 | --train_list='MOTS/MOTS_train.txt' \ 6 | --snapshot_dir='snapshots/TransDoDNet/' \ 7 | --nnUNet_preprocessed='/path/to/nnUNet_preprocessed' \ 8 | --input_size='64,192,192' \ 9 | --learning_rate=2e-4 \ 10 | --batch_size=2 \ 11 | --num_gpus=2 \ 12 | --num_epochs=1000 13 | 14 | # Testing 15 | CUDA_VISIBLE_DEVICES=0 python test.py \ 16 | --val_list='MOTS/MOTS_test.txt' \ 17 | --nnUNet_preprocessed='/path/to/nnUNet_preprocessed' \ 18 | --reload_path='/path/to/checkpoint.pth' \ 19 | --reload_from_checkpoint=True \ 20 | --save_path='outputs/TransDoDNet' \ 21 | --input_size='64,192,192' -------------------------------------------------------------------------------- /utils/ParaFlop.py: -------------------------------------------------------------------------------- 1 | # coding:utf8 2 | import torch 3 | import torchvision 4 | 5 | import torch.nn as nn 6 | from torch.autograd import Variable 7 | import torchvision.models as models 8 | 9 | import numpy as np 10 | from collections import OrderedDict 11 | import pandas as pd 12 | import torch.nn.functional as F 13 | 14 | ##usage: add to train.py or test.py: misc.print_model_parm_nums(model) 15 | ## misc.print_model_parm_flops(model,inputs) 16 | def print_model_parm_nums(model): 17 | total = sum([param.nelement() for param in model.parameters()]) 18 | print(' + Number of params: %.2f(e6)' % (total / 1e6)) 19 | 20 | 21 | def print_model_parm_flops(model): 22 | # prods = {} 23 | # def save_prods(self, input, output): 24 | # print 'flops:{}'.format(self.__class__.__name__) 25 | # print 'input:{}'.format(input) 26 | # print '_dim:{}'.format(input[0].dim()) 27 | # print 'input_shape:{}'.format(np.prod(input[0].shape)) 28 | # grads.append(np.prod(input[0].shape)) 29 | 30 | prods = {} 31 | 32 | def save_hook(name): 33 | def hook_per(self, input, output): 34 | # print 'flops:{}'.format(self.__class__.__name__) 35 | # print 'input:{}'.format(input) 36 | # print '_dim:{}'.format(input[0].dim()) 37 | # print 'input_shape:{}'.format(np.prod(input[0].shape)) 38 | # prods.append(np.prod(input[0].shape)) 39 | prods[name] = np.prod(input[0].shape) 40 | # prods.append(np.prod(input[0].shape)) 41 | 42 | return hook_per 43 | 44 | list_1 = [] 45 | 46 | def simple_hook(self, input, output): 47 | list_1.append(np.prod(input[0].shape)) 48 | 49 | list_2 = {} 50 | 51 | def simple_hook2(self, input, output): 52 | list_2['names'] = np.prod(input[0].shape) 53 | 54 | multiply_adds = False 55 | list_conv = [] 56 | 57 | def conv_hook(self, input, output): 58 | batch_size, input_channels, input_time, input_height, input_width = input[0].size() 59 | output_channels, output_time, output_height, output_width = output[0].size() 60 | 61 | kernel_ops = self.kernel_size[0] * self.kernel_size[1] * self.kernel_size[2] * ( 62 | self.in_channels / self.groups) * (2 if multiply_adds else 1) 63 | bias_ops = 1 if self.bias is not None else 0 64 | 65 | params = output_channels * (kernel_ops + bias_ops) 66 | flops = batch_size * params * output_time * output_height * output_width 67 | list_conv.append(flops) 68 | 69 | list_linear = [] 70 | 71 | def linear_hook(self, input, output): 72 | batch_size = input[0].size(0) if input[0].dim() == 2 else 1 73 | 74 | weight_ops = self.weight.nelement() * (2 if multiply_adds else 1) 75 | bias_ops = self.bias.nelement() 76 | 77 | flops = batch_size * (weight_ops + bias_ops) 78 | list_linear.append(flops) 79 | 80 | list_bn = [] 81 | 82 | def bn_hook(self, input, output): 83 | list_bn.append(input[0].nelement()) 84 | 85 | list_fc = [] 86 | 87 | def fc_hook(self, input, output): 88 | list_bn.append(input[0].nelement()) 89 | 90 | list_relu = [] 91 | 92 | def relu_hook(self, input, output): 93 | list_relu.append(input[0].nelement()) 94 | 95 | list_pooling = [] 96 | 97 | # def pooling_hook(self, input, output): 98 | # batch_size, input_channels, input_time,input_height, input_width = input[0].size() 99 | # output_channels, output_time, output_height, output_width = output[0].size() 100 | 101 | # kernel_ops = self.kernel_size * self.kernel_size*self.kernel_size 102 | # bias_ops = 0 103 | # params = output_channels * (kernel_ops + bias_ops) 104 | # flops = batch_size * params * output_height * output_width * output_time 105 | 106 | # list_pooling.append(flops) 107 | 108 | def foo(net): 109 | childrens = list(net.children()) 110 | if not childrens: 111 | if isinstance(net, torch.nn.Conv3d): 112 | # net.register_forward_hook(save_hook(net.__class__.__name__)) 113 | # net.register_forward_hook(simple_hook) 114 | # net.register_forward_hook(simple_hook2) 115 | net.register_forward_hook(conv_hook) 116 | if isinstance(net, torch.nn.Linear): 117 | net.register_forward_hook(linear_hook) 118 | if isinstance(net, torch.nn.BatchNorm3d): 119 | net.register_forward_hook(bn_hook) 120 | if isinstance(net, torch.nn.ReLU): 121 | net.register_forward_hook(relu_hook) 122 | if isinstance(net, torch.nn.ReLU): 123 | net.register_forward_hook(relu_hook) 124 | # if isinstance(net, torch.nn.MaxPool3d) or isinstance(net, torch.nn.AvgPool2d): 125 | # net.register_forward_hook(pooling_hook) 126 | return 127 | for c in childrens: 128 | foo(c) 129 | 130 | foo(model) 131 | model.cuda() 132 | # input = Variable(torch.rand(1,16,256,256).unsqueeze(0), requires_grad = True) 133 | # output = model(input) 134 | input = torch.rand(1, 64, 192, 192).unsqueeze(0).cuda() 135 | # input_res = torch.rand(1, 80, 160, 160).unsqueeze(0).cuda() 136 | # output = model([input, input_res]) 137 | # output = model(input, torch.tensor([0]).type(torch.long)) 138 | 139 | N = 1 140 | task_encoding = torch.zeros(size=(N, 7, 7)).cuda() 141 | for b in range(N): 142 | for i in range(7): 143 | for j in range(7): 144 | if i == j: 145 | task_encoding[b, i, j] = 1 146 | task_encoding.unsqueeze_(-1).unsqueeze_(-1).unsqueeze_(-1) 147 | output = model(input, [], task_encoding) 148 | 149 | total_flops = (sum(list_conv) + sum(list_linear) + sum(list_bn)+sum(list_relu)) 150 | print(' + Number of FLOPs: %.5f(e9)' % (total_flops / 1e9)) 151 | 152 | 153 | 154 | def get_names_dict(model): 155 | """ 156 | Recursive walk to get names including path 157 | """ 158 | names = {} 159 | 160 | def _get_names(module, parent_name=''): 161 | for key, module in module.named_children(): 162 | name = parent_name + '.' + key if parent_name else key 163 | names[name] = module 164 | if isinstance(module, torch.nn.Module): 165 | _get_names(module, parent_name=name) 166 | 167 | _get_names(model) 168 | return names 169 | 170 | def torch_summarize_df(input_size, model, weights=False, input_shape=True, nb_trainable=False): 171 | """ 172 | Summarizes torch model by showing trainable parameters and weights. 173 | 174 | author: wassname 175 | url: https://gist.github.com/wassname/0fb8f95e4272e6bdd27bd7df386716b7 176 | license: MIT 177 | 178 | Modified from: 179 | - https://github.com/pytorch/pytorch/issues/2001#issuecomment-313735757 180 | - https://gist.github.com/wassname/0fb8f95e4272e6bdd27bd7df386716b7/ 181 | 182 | Usage: 183 | import torchvision.models as models 184 | model = models.alexnet() 185 | df = torch_summarize_df(input_size=(3, 224,224), model=model) 186 | print(df) 187 | 188 | # name class_name input_shape output_shape nb_params 189 | # 1 features=>0 Conv2d (-1, 3, 224, 224) (-1, 64, 55, 55) 23296#(3*11*11+1)*64 190 | # 2 features=>1 ReLU (-1, 64, 55, 55) (-1, 64, 55, 55) 0 191 | # ... 192 | """ 193 | 194 | def register_hook(module): 195 | def hook(module, input, output): 196 | name = '' 197 | for key, item in names.items(): 198 | if item == module: 199 | name = key 200 | # 201 | class_name = str(module.__class__).split('.')[-1].split("'")[0] 202 | module_idx = len(summary) 203 | 204 | m_key = module_idx + 1 205 | 206 | summary[m_key] = OrderedDict() 207 | summary[m_key]['name'] = name 208 | summary[m_key]['class_name'] = class_name 209 | if input_shape: 210 | summary[m_key][ 211 | 'input_shape'] = (-1,) + tuple(input[0].size())[1:] 212 | summary[m_key]['output_shape'] = (-1,) + tuple(output.size())[1:] 213 | if weights: 214 | summary[m_key]['weights'] = list( 215 | [tuple(p.size()) for p in module.parameters()]) 216 | 217 | # summary[m_key]['trainable'] = any([p.requires_grad for p in module.parameters()]) 218 | if nb_trainable: 219 | params_trainable = sum( 220 | [torch.LongTensor(list(p.size())).prod() for p in module.parameters() if p.requires_grad]) 221 | summary[m_key]['nb_trainable'] = params_trainable 222 | params = sum([torch.LongTensor(list(p.size())).prod() for p in module.parameters()]) 223 | summary[m_key]['nb_params'] = params 224 | 225 | if not isinstance(module, nn.Sequential) and \ 226 | not isinstance(module, nn.ModuleList) and \ 227 | not (module == model): 228 | hooks.append(module.register_forward_hook(hook)) 229 | 230 | # Names are stored in parent and path+name is unique not the name 231 | names = get_names_dict(model) 232 | 233 | # check if there are multiple inputs to the network 234 | if isinstance(input_size[0], (list, tuple)): 235 | x = [Variable(torch.rand(1, *in_size)) for in_size in input_size] 236 | else: 237 | x = Variable(torch.rand(1, *input_size)) 238 | 239 | if next(model.parameters()).is_cuda: 240 | x = x.cuda() 241 | 242 | # create properties 243 | summary = OrderedDict() 244 | hooks = [] 245 | 246 | # register hook 247 | model.apply(register_hook) 248 | 249 | # make a forward pass 250 | model(x) 251 | 252 | # remove these hooks 253 | for h in hooks: 254 | h.remove() 255 | 256 | # make dataframe 257 | df_summary = pd.DataFrame.from_dict(summary, orient='index') 258 | 259 | print(df_summary) 260 | 261 | return df_summary 262 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jianpengz/DoDNet/f06ce1b02988b6894a85f925623c6bcc6012284f/utils/__init__.py -------------------------------------------------------------------------------- /utils/__pycache__/ParaFlop.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jianpengz/DoDNet/f06ce1b02988b6894a85f925623c6bcc6012284f/utils/__pycache__/ParaFlop.cpython-37.pyc -------------------------------------------------------------------------------- /utils/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jianpengz/DoDNet/f06ce1b02988b6894a85f925623c6bcc6012284f/utils/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /utils/__pycache__/logger.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jianpengz/DoDNet/f06ce1b02988b6894a85f925623c6bcc6012284f/utils/__pycache__/logger.cpython-37.pyc -------------------------------------------------------------------------------- /utils/__pycache__/misc.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jianpengz/DoDNet/f06ce1b02988b6894a85f925623c6bcc6012284f/utils/__pycache__/misc.cpython-37.pyc -------------------------------------------------------------------------------- /utils/__pycache__/my_utils.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jianpengz/DoDNet/f06ce1b02988b6894a85f925623c6bcc6012284f/utils/__pycache__/my_utils.cpython-37.pyc -------------------------------------------------------------------------------- /utils/__pycache__/pyt_utils.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jianpengz/DoDNet/f06ce1b02988b6894a85f925623c6bcc6012284f/utils/__pycache__/pyt_utils.cpython-37.pyc -------------------------------------------------------------------------------- /utils/logger.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import logging 4 | 5 | _default_level_name = os.getenv('ENGINE_LOGGING_LEVEL', 'INFO') 6 | _default_level = logging.getLevelName(_default_level_name.upper()) 7 | 8 | 9 | class LogFormatter(logging.Formatter): 10 | log_fout = None 11 | date_full = '[%(asctime)s %(lineno)d@%(filename)s:%(name)s] ' 12 | date = '%(asctime)s ' 13 | msg = '%(message)s' 14 | 15 | def format(self, record): 16 | if record.levelno == logging.DEBUG: 17 | mcl, mtxt = self._color_dbg, 'DBG' 18 | elif record.levelno == logging.WARNING: 19 | mcl, mtxt = self._color_warn, 'WRN' 20 | elif record.levelno == logging.ERROR: 21 | mcl, mtxt = self._color_err, 'ERR' 22 | else: 23 | mcl, mtxt = self._color_normal, '' 24 | 25 | if mtxt: 26 | mtxt += ' ' 27 | 28 | if self.log_fout: 29 | self.__set_fmt(self.date_full + mtxt + self.msg) 30 | formatted = super(LogFormatter, self).format(record) 31 | return formatted 32 | 33 | self.__set_fmt(self._color_date(self.date) + mcl(mtxt + self.msg)) 34 | formatted = super(LogFormatter, self).format(record) 35 | 36 | return formatted 37 | 38 | if sys.version_info.major < 3: 39 | def __set_fmt(self, fmt): 40 | self._fmt = fmt 41 | else: 42 | def __set_fmt(self, fmt): 43 | self._style._fmt = fmt 44 | 45 | @staticmethod 46 | def _color_dbg(msg): 47 | return '\x1b[36m{}\x1b[0m'.format(msg) 48 | 49 | @staticmethod 50 | def _color_warn(msg): 51 | return '\x1b[1;31m{}\x1b[0m'.format(msg) 52 | 53 | @staticmethod 54 | def _color_err(msg): 55 | return '\x1b[1;4;31m{}\x1b[0m'.format(msg) 56 | 57 | @staticmethod 58 | def _color_omitted(msg): 59 | return '\x1b[35m{}\x1b[0m'.format(msg) 60 | 61 | @staticmethod 62 | def _color_normal(msg): 63 | return msg 64 | 65 | @staticmethod 66 | def _color_date(msg): 67 | return '\x1b[32m{}\x1b[0m'.format(msg) 68 | 69 | 70 | def get_logger(log_dir=None, log_file=None, formatter=LogFormatter): 71 | logger = logging.getLogger() 72 | logger.setLevel(_default_level) 73 | del logger.handlers[:] 74 | 75 | if log_dir and log_file: 76 | if not os.path.isdir(log_dir): 77 | os.makedirs(log_dir) 78 | LogFormatter.log_fout = True 79 | file_handler = logging.FileHandler(log_file, mode='a') 80 | file_handler.setLevel(logging.INFO) 81 | file_handler.setFormatter(formatter) 82 | logger.addHandler(file_handler) 83 | 84 | stream_handler = logging.StreamHandler() 85 | stream_handler.setFormatter(formatter(datefmt='%d %H:%M:%S')) 86 | stream_handler.setLevel(0) 87 | logger.addHandler(stream_handler) 88 | return logger 89 | -------------------------------------------------------------------------------- /utils/my_utils.py: -------------------------------------------------------------------------------- 1 | # Author: Zylo117 2 | 3 | import os 4 | 5 | import cv2 6 | import numpy as np 7 | import torch 8 | from glob import glob 9 | from torch import nn 10 | # from torchvision.ops import nms 11 | from typing import Union 12 | import uuid 13 | 14 | # from utils.sync_batchnorm import SynchronizedBatchNorm2d 15 | 16 | 17 | def invert_affine(metas: Union[float, list, tuple], preds): 18 | for i in range(len(preds)): 19 | if len(preds[i]['rois']) == 0: 20 | continue 21 | else: 22 | if metas is float: 23 | preds[i]['rois'][:, [0, 2]] = preds[i]['rois'][:, [0, 2]] / metas 24 | preds[i]['rois'][:, [1, 3]] = preds[i]['rois'][:, [1, 3]] / metas 25 | else: 26 | new_w, new_h, old_w, old_h, padding_w, padding_h = metas[i] 27 | preds[i]['rois'][:, [0, 2]] = preds[i]['rois'][:, [0, 2]] / (new_w / old_w) 28 | preds[i]['rois'][:, [1, 3]] = preds[i]['rois'][:, [1, 3]] / (new_h / old_h) 29 | return preds 30 | 31 | 32 | def aspectaware_resize_padding(image, width, height, interpolation=None, means=None): 33 | old_h, old_w, c = image.shape 34 | if old_w > old_h: 35 | new_w = width 36 | new_h = int(width / old_w * old_h) 37 | else: 38 | new_w = int(height / old_h * old_w) 39 | new_h = height 40 | 41 | canvas = np.zeros((height, height, c), np.float32) 42 | if means is not None: 43 | canvas[...] = means 44 | 45 | if new_w != old_w or new_h != old_h: 46 | if interpolation is None: 47 | image = cv2.resize(image, (new_w, new_h)) 48 | else: 49 | image = cv2.resize(image, (new_w, new_h), interpolation=interpolation) 50 | 51 | padding_h = height - new_h 52 | padding_w = width - new_w 53 | 54 | if c > 1: 55 | canvas[:new_h, :new_w] = image 56 | else: 57 | if len(image.shape) == 2: 58 | canvas[:new_h, :new_w, 0] = image 59 | else: 60 | canvas[:new_h, :new_w] = image 61 | 62 | return canvas, new_w, new_h, old_w, old_h, padding_w, padding_h, 63 | 64 | 65 | def preprocess(*image_path, max_size=512, mean=(0.406, 0.456, 0.485), std=(0.225, 0.224, 0.229)): 66 | ori_imgs = [cv2.imread(img_path) for img_path in image_path] 67 | normalized_imgs = [(img / 255 - mean) / std for img in ori_imgs] 68 | imgs_meta = [aspectaware_resize_padding(img[..., ::-1], max_size, max_size, 69 | means=None) for img in normalized_imgs] 70 | framed_imgs = [img_meta[0] for img_meta in imgs_meta] 71 | framed_metas = [img_meta[1:] for img_meta in imgs_meta] 72 | 73 | return ori_imgs, framed_imgs, framed_metas 74 | 75 | def preprocess_video(*frame_from_video, max_size=512, mean=(0.406, 0.456, 0.485), std=(0.225, 0.224, 0.229)): 76 | ori_imgs = frame_from_video 77 | normalized_imgs = [(img / 255 - mean) / std for img in ori_imgs] 78 | imgs_meta = [aspectaware_resize_padding(img[..., ::-1], max_size, max_size, 79 | means=None) for img in normalized_imgs] 80 | framed_imgs = [img_meta[0] for img_meta in imgs_meta] 81 | framed_metas = [img_meta[1:] for img_meta in imgs_meta] 82 | 83 | return ori_imgs, framed_imgs, framed_metas 84 | 85 | def postprocess(x, anchors, regression, classification, regressBoxes, clipBoxes, threshold, iou_threshold): 86 | transformed_anchors = regressBoxes(anchors, regression) 87 | transformed_anchors = clipBoxes(transformed_anchors, x) 88 | scores = torch.max(classification, dim=2, keepdim=True)[0] 89 | scores_over_thresh = (scores > threshold)[:, :, 0] 90 | out = [] 91 | for i in range(x.shape[0]): 92 | if scores_over_thresh.sum() == 0: 93 | out.append({ 94 | 'rois': np.array(()), 95 | 'class_ids': np.array(()), 96 | 'scores': np.array(()), 97 | }) 98 | 99 | classification_per = classification[i, scores_over_thresh[i, :], ...].permute(1, 0) 100 | transformed_anchors_per = transformed_anchors[i, scores_over_thresh[i, :], ...] 101 | scores_per = scores[i, scores_over_thresh[i, :], ...] 102 | anchors_nms_idx = nms(transformed_anchors_per, scores_per[:, 0], iou_threshold=iou_threshold) 103 | 104 | if anchors_nms_idx.shape[0] != 0: 105 | scores_, classes_ = classification_per[:, anchors_nms_idx].max(dim=0) 106 | boxes_ = transformed_anchors_per[anchors_nms_idx, :] 107 | 108 | out.append({ 109 | 'rois': boxes_.cpu().numpy(), 110 | 'class_ids': classes_.cpu().numpy(), 111 | 'scores': scores_.cpu().numpy(), 112 | }) 113 | else: 114 | out.append({ 115 | 'rois': np.array(()), 116 | 'class_ids': np.array(()), 117 | 'scores': np.array(()), 118 | }) 119 | 120 | return out 121 | 122 | 123 | def display(preds, imgs, obj_list, imshow=True, imwrite=False): 124 | for i in range(len(imgs)): 125 | if len(preds[i]['rois']) == 0: 126 | continue 127 | 128 | for j in range(len(preds[i]['rois'])): 129 | (x1, y1, x2, y2) = preds[i]['rois'][j].astype(np.int) 130 | cv2.rectangle(imgs[i], (x1, y1), (x2, y2), (255, 255, 0), 2) 131 | obj = obj_list[preds[i]['class_ids'][j]] 132 | score = float(preds[i]['scores'][j]) 133 | 134 | cv2.putText(imgs[i], '{}, {:.3f}'.format(obj, score), 135 | (x1, y1 + 10), cv2.FONT_HERSHEY_SIMPLEX, 0.5, 136 | (255, 255, 0), 1) 137 | if imshow: 138 | cv2.imshow('img', imgs[i]) 139 | cv2.waitKey(0) 140 | 141 | if imwrite: 142 | os.makedirs('test/', exist_ok=True) 143 | cv2.imwrite(f'test/{uuid.uuid4().hex}.jpg', imgs[i]) 144 | 145 | 146 | 147 | class CustomDataParallel(nn.DataParallel): 148 | """ 149 | force splitting data to all gpus instead of sending all data to cuda:0 and then moving around. 150 | """ 151 | 152 | def __init__(self, module, num_gpus): 153 | super().__init__(module) 154 | self.num_gpus = num_gpus 155 | 156 | def scatter(self, inputs, kwargs, device_ids): 157 | # More like scatter and data prep at the same time. The point is we prep the data in such a way 158 | # that no scatter is necessary, and there's no need to shuffle stuff around different GPUs. 159 | devices = ['cuda:' + str(x) for x in range(self.num_gpus)] 160 | splits = inputs[0].shape[0] // self.num_gpus 161 | 162 | return [(inputs[0][splits * device_idx: splits * (device_idx + 1)].to(f'cuda:{device_idx}', non_blocking=True), 163 | inputs[1][splits * device_idx: splits * (device_idx + 1)].to(f'cuda:{device_idx}', non_blocking=True)) 164 | for device_idx in range(len(devices))], \ 165 | [kwargs] * len(devices) 166 | 167 | 168 | def get_last_weights(weights_path): 169 | weights_path = glob(weights_path + f'/*.pth') 170 | weights_path = sorted(weights_path, 171 | key=lambda x: int(x.rsplit('_')[-1].rsplit('.')[0]), 172 | reverse=True)[0] 173 | print(f'using weights {weights_path}') 174 | return weights_path 175 | 176 | 177 | def init_weights(model): 178 | for name, module in model.named_modules(): 179 | is_conv_layer = isinstance(module, nn.Conv2d) 180 | 181 | if is_conv_layer: 182 | nn.init.kaiming_uniform_(module.weight.data) 183 | 184 | if module.bias is not None: 185 | module.bias.data.zero_() 186 | 187 | 188 | 189 | def restart_from_checkpoint(ckp_path, run_variables=None, **kwargs): 190 | """ 191 | Re-start from checkpoint 192 | """ 193 | if not os.path.isfile(ckp_path): 194 | return 195 | print("Found checkpoint at {}".format(ckp_path)) 196 | 197 | # open checkpoint file 198 | checkpoint = torch.load(ckp_path, map_location="cpu") 199 | 200 | # key is what to look for in the checkpoint file 201 | # value is the object to load 202 | # example: {'state_dict': model} 203 | for key, value in kwargs.items(): 204 | if key in checkpoint and value is not None: 205 | try: 206 | msg = value.load_state_dict(checkpoint[key], strict=False) 207 | print("=> loaded {} from checkpoint '{}' with msg {}".format(key, ckp_path, msg)) 208 | except TypeError: 209 | try: 210 | msg = value.load_state_dict(checkpoint[key]) 211 | print("=> loaded {} from checkpoint '{}'".format(key, ckp_path)) 212 | except ValueError: 213 | print("=> failed to load {} from checkpoint '{}'".format(key, ckp_path)) 214 | else: 215 | print("=> failed to load {} from checkpoint '{}'".format(key, ckp_path)) 216 | 217 | # re load variable important for the run 218 | if run_variables is not None: 219 | for var_name in run_variables: 220 | if var_name in checkpoint: 221 | run_variables[var_name] = checkpoint[var_name] -------------------------------------------------------------------------------- /utils/pyt_utils.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | import os 3 | import sys 4 | import time 5 | import argparse 6 | from collections import OrderedDict, defaultdict 7 | 8 | import torch 9 | import torch.utils.model_zoo as model_zoo 10 | import torch.distributed as dist 11 | 12 | from .logger import get_logger 13 | 14 | logger = get_logger() 15 | 16 | 17 | def reduce_tensor(tensor, dst=0, op=dist.ReduceOp.SUM, world_size=1): 18 | tensor = tensor.clone() 19 | dist.reduce(tensor, dst, op) 20 | if dist.get_rank() == dst: 21 | tensor.div_(world_size) 22 | 23 | return tensor 24 | 25 | 26 | def all_reduce_tensor(tensor, op=dist.ReduceOp.SUM, world_size=1, norm=True): 27 | tensor = tensor.clone() 28 | dist.all_reduce(tensor, op) 29 | if norm: 30 | tensor.div_(world_size) 31 | 32 | return tensor 33 | 34 | 35 | def extant_file(x): 36 | if not os.path.exists(x): 37 | raise argparse.ArgumentTypeError("{0} does not exist".format(x)) 38 | return x 39 | 40 | --------------------------------------------------------------------------------