├── .DS_Store
├── .gitignore
├── LICENSE
├── README.md
├── TransDoD
├── .DS_Store
├── MOTSDataset.py
├── TranDoD.png
├── __init__.py
├── __pycache__
│ └── MOTSDataset.cpython-37.pyc
├── models
│ ├── .DS_Store
│ ├── ResCNN2.py
│ ├── TransDoDNet.py
│ ├── __init__.py
│ ├── __pycache__
│ │ ├── ResCNN2.cpython-37.pyc
│ │ ├── TransDoDNet.cpython-37.pyc
│ │ ├── __init__.cpython-37.pyc
│ │ ├── deepnorm.cpython-37.pyc
│ │ ├── deformable_transformer.cpython-37.pyc
│ │ └── position_encoding.cpython-37.pyc
│ ├── deepnorm.py
│ ├── deformable_transformer.py
│ ├── ops
│ │ ├── .DS_Store
│ │ ├── MultiScaleDeformableAttention.egg-info
│ │ │ ├── PKG-INFO
│ │ │ ├── SOURCES.txt
│ │ │ ├── dependency_links.txt
│ │ │ └── top_level.txt
│ │ ├── build
│ │ │ ├── lib.linux-x86_64-3.7
│ │ │ │ ├── MultiScaleDeformableAttention.cpython-37m-x86_64-linux-gnu.so
│ │ │ │ ├── functions
│ │ │ │ │ ├── __init__.py
│ │ │ │ │ └── ms_deform_attn_func.py
│ │ │ │ └── modules
│ │ │ │ │ ├── __init__.py
│ │ │ │ │ ├── ms_deform_attn (33).py
│ │ │ │ │ └── ms_deform_attn.py
│ │ │ └── temp.linux-x86_64-3.7
│ │ │ │ ├── .ninja_deps
│ │ │ │ ├── .ninja_log
│ │ │ │ ├── build.ninja
│ │ │ │ └── media
│ │ │ │ ├── userdisk0
│ │ │ │ ├── myproject-Det
│ │ │ │ │ └── Deformable-DETR
│ │ │ │ │ │ └── models
│ │ │ │ │ │ └── ops
│ │ │ │ │ │ └── src
│ │ │ │ │ │ ├── cpu
│ │ │ │ │ │ └── ms_deform_attn_cpu.o
│ │ │ │ │ │ ├── cuda
│ │ │ │ │ │ └── ms_deform_attn_cuda.o
│ │ │ │ │ │ └── vision.o
│ │ │ │ └── myproject-Seg
│ │ │ │ │ └── MOTS-pro2
│ │ │ │ │ └── f_transformers
│ │ │ │ │ └── models
│ │ │ │ │ └── ops
│ │ │ │ │ └── src
│ │ │ │ │ ├── cpu
│ │ │ │ │ └── ms_deform_attn_cpu.o
│ │ │ │ │ ├── cuda
│ │ │ │ │ └── ms_deform_attn_cuda.o
│ │ │ │ │ └── vision.o
│ │ │ │ └── userdisk1
│ │ │ │ └── jpzhang
│ │ │ │ └── myproject-Seg
│ │ │ │ └── MOTS-pro2
│ │ │ │ └── f_transformers
│ │ │ │ └── models
│ │ │ │ └── ops
│ │ │ │ └── src
│ │ │ │ ├── cpu
│ │ │ │ └── ms_deform_attn_cpu.o
│ │ │ │ ├── cuda
│ │ │ │ └── ms_deform_attn_cuda.o
│ │ │ │ └── vision.o
│ │ ├── dist
│ │ │ └── MultiScaleDeformableAttention-1.0-py3.7-linux-x86_64.egg
│ │ ├── functions
│ │ │ ├── __init__.py
│ │ │ ├── __pycache__
│ │ │ │ ├── __init__.cpython-37.pyc
│ │ │ │ └── ms_deform_attn_func.cpython-37.pyc
│ │ │ └── ms_deform_attn_func.py
│ │ ├── make.sh
│ │ ├── modules
│ │ │ ├── __init__.py
│ │ │ ├── __pycache__
│ │ │ │ ├── __init__.cpython-37.pyc
│ │ │ │ └── ms_deform_attn.cpython-37.pyc
│ │ │ ├── ms_deform_attn (33).py
│ │ │ └── ms_deform_attn.py
│ │ ├── setup.py
│ │ ├── src
│ │ │ ├── cpu
│ │ │ │ ├── ms_deform_attn_cpu.cpp
│ │ │ │ └── ms_deform_attn_cpu.h
│ │ │ ├── cuda
│ │ │ │ ├── ms_deform_attn_cuda.cu
│ │ │ │ ├── ms_deform_attn_cuda.h
│ │ │ │ └── ms_deform_im2col_cuda.cuh
│ │ │ ├── ms_deform_attn.h
│ │ │ └── vision.cpp
│ │ └── test.py
│ └── position_encoding.py
├── test.py
└── train.py
├── __init__.py
├── __pycache__
└── engine.cpython-37.pyc
├── a_DynConv
├── .train.py.swp
├── MOTSDataset.py
├── dodnet.png
├── evaluate.py
├── postp.py
├── train.py
└── unet3D_DynConv882.py
├── data_list
├── .DS_Store
└── MOTS
│ ├── .DS_Store
│ ├── MOTS_test.txt
│ └── MOTS_train.txt
├── dataset
├── .DS_Store
├── list
│ ├── .DS_Store
│ └── MOTS
│ │ ├── MOTS_test.txt
│ │ └── MOTS_train.txt
└── re_spacing.py
├── engine.py
├── loss_functions
├── __init__.py
├── __pycache__
│ ├── __init__.cpython-37.pyc
│ └── loss.cpython-37.pyc
└── loss.py
├── run_script.sh
└── utils
├── ParaFlop.py
├── __init__.py
├── __pycache__
├── ParaFlop.cpython-37.pyc
├── __init__.cpython-37.pyc
├── logger.cpython-37.pyc
├── misc.cpython-37.pyc
├── my_utils.cpython-37.pyc
└── pyt_utils.cpython-37.pyc
├── logger.py
├── misc.py
├── my_utils.py
└── pyt_utils.py
/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jianpengz/DoDNet/f06ce1b02988b6894a85f925623c6bcc6012284f/.DS_Store
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # DoDNet
2 |
3 |
4 |
5 |
6 |
7 | This repo holds the pytorch implementation of DoDNet and TransDoDNet:
8 |
9 | **DoDNet: Learning to segment multi-organ and tumors from multiple partially labeled datasets**
10 | (https://arxiv.org/pdf/2011.10217.pdf) \
11 | **Learning from partially labeled data for multi-organ and tumor segmentation**
12 | (https://arxiv.org/pdf/2211.06894.pdf)
13 |
14 |
22 |
23 | ## Usage
24 |
25 |
26 |
32 |
33 | ### 1. MOTS Dataset Preparation
34 | Before starting, MOTS should be re-built from the serveral medical organ and tumor segmentation datasets
35 |
36 | Partial-label task | Data source
37 | --- | :---:
38 | Liver | [data](https://competitions.codalab.org/competitions/17094)
39 | Kidney | [data](https://kits19.grand-challenge.org/data/)
40 | Hepatic Vessel | [data](http://medicaldecathlon.com/)
41 | Pancreas | [data](http://medicaldecathlon.com/)
42 | Colon | [data](http://medicaldecathlon.com/)
43 | Lung | [data](http://medicaldecathlon.com/)
44 | Spleen | [data](http://medicaldecathlon.com/)
45 |
46 |
63 | * Preprocessed data will be available soon.
64 |
65 | ### 2. Training/Testing/Evaluation
66 | sh run_script.sh
67 |
68 |
69 |
109 |
110 |
111 | ### 3. Citation
112 | If this code is helpful for your study, please cite:
113 | ```
114 | @inproceedings{zhang2021dodnet,
115 | title={DoDNet: Learning to segment multi-organ and tumors from multiple partially labeled datasets},
116 | author={Zhang, Jianpeng and Xie, Yutong and Xia, Yong and Shen, Chunhua},
117 | booktitle={Proceedings of the IEEE conference on computer vision and pattern recognition},
118 | pages={},
119 | year={2021}
120 | }
121 | @article{xie2023learning,
122 | title={Learning from partially labeled data for multi-organ and tumor segmentation},
123 | author={Xie, Yutong and Zhang, Jianpeng and Xia, Yong and Shen, Chunhua},
124 | journal={IEEE Transactions on Pattern Analysis and Machine Intelligence},
125 | year={2023}
126 | }
127 | ```
--------------------------------------------------------------------------------
/TransDoD/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jianpengz/DoDNet/f06ce1b02988b6894a85f925623c6bcc6012284f/TransDoD/.DS_Store
--------------------------------------------------------------------------------
/TransDoD/TranDoD.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jianpengz/DoDNet/f06ce1b02988b6894a85f925623c6bcc6012284f/TransDoD/TranDoD.png
--------------------------------------------------------------------------------
/TransDoD/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jianpengz/DoDNet/f06ce1b02988b6894a85f925623c6bcc6012284f/TransDoD/__init__.py
--------------------------------------------------------------------------------
/TransDoD/__pycache__/MOTSDataset.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jianpengz/DoDNet/f06ce1b02988b6894a85f925623c6bcc6012284f/TransDoD/__pycache__/MOTSDataset.cpython-37.pyc
--------------------------------------------------------------------------------
/TransDoD/models/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jianpengz/DoDNet/f06ce1b02988b6894a85f925623c6bcc6012284f/TransDoD/models/.DS_Store
--------------------------------------------------------------------------------
/TransDoD/models/ResCNN2.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | from torch.autograd import Variable
5 | from functools import partial
6 |
7 |
8 | class Conv3d_wd(nn.Conv3d):
9 |
10 | def __init__(self, in_channels, out_channels, kernel_size, stride=(1,1,1), padding=(0,0,0), dilation=(1,1,1), groups=1, bias=False):
11 | super(Conv3d_wd, self).__init__(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias)
12 |
13 | def forward(self, x):
14 | weight = self.weight
15 | weight_mean = weight.mean(dim=1, keepdim=True).mean(dim=2, keepdim=True).mean(dim=3, keepdim=True).mean(dim=4, keepdim=True)
16 | weight = weight - weight_mean
17 | # std = weight.view(weight.size(0), -1).std(dim=1).view(-1, 1, 1, 1, 1) + 1e-5
18 | std = torch.sqrt(torch.var(weight.view(weight.size(0), -1), dim=1) + 1e-12).view(-1, 1, 1, 1, 1)
19 | weight = weight / std.expand_as(weight)
20 | return F.conv3d(x, weight, self.bias, self.stride, self.padding, self.dilation, self.groups)
21 |
22 |
23 | def conv3x3x3(in_planes, out_planes, kernel_size, stride=(1, 1, 1), padding=(0, 0, 0), dilation=(1, 1, 1), bias=False, weight_std=False):
24 | "3x3x3 convolution with padding"
25 | if weight_std:
26 | return Conv3d_wd(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, bias=bias)
27 | else:
28 | return nn.Conv3d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, bias=bias)
29 |
30 |
31 | def downsample_basic_block(x, planes, stride):
32 | out = F.avg_pool3d(x, kernel_size=1, stride=stride)
33 | zero_pads = torch.Tensor(out.size(0), planes - out.size(1), out.size(2), out.size(3),out.size(4)).zero_()
34 | if isinstance(out.data, torch.cuda.FloatTensor):
35 | zero_pads = zero_pads.cuda()
36 |
37 | out = Variable(torch.cat([out.data, zero_pads.cuda()], dim=1))
38 |
39 | return out
40 |
41 |
42 | def Norm_layer(norm_cfg, inplanes):
43 |
44 | if norm_cfg == 'BN':
45 | out = nn.BatchNorm3d(inplanes)
46 | elif norm_cfg == 'SyncBN':
47 | out = nn.SyncBatchNorm(inplanes)
48 | elif norm_cfg == 'GN':
49 | out = nn.GroupNorm(16, inplanes)
50 | elif norm_cfg == 'IN':
51 | out = nn.InstanceNorm3d(inplanes,affine=True)
52 |
53 | return out
54 |
55 |
56 | def Activation_layer(activation_cfg, inplace=True):
57 |
58 | if activation_cfg == 'relu':
59 | out = nn.ReLU(inplace=inplace)
60 | elif activation_cfg == 'LeakyReLU':
61 | out = nn.LeakyReLU(negative_slope=1e-2, inplace=inplace)
62 |
63 | return out
64 |
65 |
66 | class BasicBlock(nn.Module):
67 | expansion = 1
68 |
69 | def __init__(self, inplanes, planes, norm_cfg, activation_cfg, stride=(1, 1, 1), downsample=None, weight_std=False):
70 | super(BasicBlock, self).__init__()
71 | self.conv1 = conv3x3x3(inplanes, planes, kernel_size=3, stride=stride, padding=1, bias=False, weight_std=weight_std)
72 | self.norm1 = Norm_layer(norm_cfg, planes)
73 | self.nonlin = Activation_layer(activation_cfg, inplace=True)
74 | self.conv2 = conv3x3x3(planes, planes, kernel_size=3, stride=(1, 1, 1), padding=1, bias=False, weight_std=weight_std)
75 | self.norm2 = Norm_layer(norm_cfg, planes)
76 | self.downsample = downsample
77 | self.stride = stride
78 |
79 | def forward(self, x):
80 | residual = x
81 |
82 | out = self.conv1(x)
83 | out = self.norm1(out)
84 | out = self.nonlin(out)
85 |
86 | out = self.conv2(out)
87 | out = self.norm2(out)
88 |
89 | if self.downsample is not None:
90 | residual = self.downsample(x)
91 |
92 | out += residual
93 | out = self.nonlin(out)
94 |
95 | return out
96 |
97 |
98 | class Bottleneck(nn.Module):
99 | expansion = 4
100 |
101 | def __init__(self, inplanes, planes, norm_cfg, activation_cfg, stride=(1, 1, 1), downsample=None, weight_std=False):
102 | super(Bottleneck, self).__init__()
103 | self.conv1 = conv3x3x3(inplanes, planes, kernel_size=1, bias=False, weight_std=weight_std)
104 | self.norm1 = Norm_layer(norm_cfg, planes)
105 | self.conv2 = conv3x3x3(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False, weight_std=weight_std)
106 | self.norm2 = Norm_layer(norm_cfg, planes)
107 | self.conv3 = conv3x3x3(planes, planes * 4, kernel_size=1, bias=False, weight_std=weight_std)
108 | self.norm3 = Norm_layer(norm_cfg, planes * 4)
109 | self.nonlin = Activation_layer(activation_cfg, inplace=True)
110 | self.downsample = downsample
111 | self.stride = stride
112 |
113 | def forward(self, x):
114 | residual = x
115 |
116 | out = self.conv1(x)
117 | out = self.norm1(out)
118 | out = self.nonlin(out)
119 |
120 | out = self.conv2(out)
121 | out = self.norm2(out)
122 | out = self.nonlin(out)
123 |
124 | out = self.conv3(out)
125 | out = self.norm3(out)
126 |
127 | if self.downsample is not None:
128 | residual = self.downsample(x)
129 |
130 | out += residual
131 | out = self.nonlin(out)
132 |
133 | return out
134 |
135 |
136 | class ResNet(nn.Module):
137 |
138 | arch_settings = {
139 | 10: (BasicBlock, (1, 1, 1, 1)),
140 | 18: (BasicBlock, (2, 2, 2, 2)),
141 | 34: (BasicBlock, (3, 4, 6, 3)),
142 | 50: (Bottleneck, (3, 4, 6, 3)),
143 | 101: (Bottleneck, (3, 4, 23, 3)),
144 | 152: (Bottleneck, (3, 8, 36, 3)),
145 | 200: (Bottleneck, (3, 24, 36, 3))
146 | }
147 |
148 | def __init__(self,
149 | depth,
150 | in_channels=1,
151 | shortcut_type='B',
152 | norm_cfg='IN',
153 | activation_cfg='relu',
154 | weight_std=False):
155 | super(ResNet, self).__init__()
156 |
157 | if depth not in self.arch_settings:
158 | raise KeyError('invalid depth {} for resnet'.format(depth))
159 | self.depth = depth
160 | block, layers = self.arch_settings[depth]
161 | self.inplanes = 32
162 | self.conv1 = conv3x3x3(in_channels, 32, kernel_size=3, stride=(1, 2, 2), padding=1, bias=False, weight_std=weight_std)
163 | self.norm1 = Norm_layer(norm_cfg, 32)
164 | self.nonlin1 = Activation_layer(activation_cfg, inplace=True)
165 | self.conv2 = conv3x3x3(32, 32, kernel_size=3, stride=(1, 1, 1), padding=1, bias=False, weight_std=weight_std)
166 | self.norm2 = Norm_layer(norm_cfg, 32)
167 | self.nonlin2 = Activation_layer(activation_cfg, inplace=True)
168 | self.maxpool = nn.MaxPool3d(kernel_size=(3, 3, 3), stride=(2, 2, 2), padding=1)
169 | self.layer1 = self._make_layer(block, 64, layers[0], shortcut_type, stride=(2, 2, 2), norm_cfg=norm_cfg, activation_cfg=activation_cfg, weight_std=weight_std)
170 | self.layer2 = self._make_layer(block, 128, layers[1], shortcut_type, stride=(2, 2, 2), norm_cfg=norm_cfg, activation_cfg=activation_cfg, weight_std=weight_std)
171 | self.layer3 = self._make_layer(block, 256, layers[2], shortcut_type, stride=(2, 2, 2), norm_cfg=norm_cfg, activation_cfg=activation_cfg, weight_std=weight_std)
172 | self.layer4 = self._make_layer(block, 320, layers[3], shortcut_type, stride=(2, 2, 2), norm_cfg=norm_cfg, activation_cfg=activation_cfg, weight_std=weight_std)
173 | self.layers = []
174 |
175 | for m in self.modules():
176 | if isinstance(m, (nn.Conv3d, Conv3d_wd)):
177 | m.weight = nn.init.kaiming_normal(m.weight, mode='fan_out')
178 | elif isinstance(m, (nn.BatchNorm3d, nn.GroupNorm, nn.InstanceNorm3d, nn.SyncBatchNorm)):
179 | m.weight.data.fill_(1)
180 | m.bias.data.zero_()
181 |
182 | def _make_layer(self, block, planes, blocks, shortcut_type, stride=(1, 1, 1), norm_cfg='BN', activation_cfg='relu', weight_std=False):
183 | downsample = None
184 | if stride != 1 or self.inplanes != planes * block.expansion:
185 | if shortcut_type == 'A':
186 | downsample = partial(
187 | downsample_basic_block,
188 | planes=planes * block.expansion,
189 | stride=stride)
190 | else:
191 | downsample = nn.Sequential(
192 | conv3x3x3(
193 | self.inplanes,
194 | planes * block.expansion,
195 | kernel_size=1,
196 | stride=stride,
197 | bias=False, weight_std=weight_std),
198 | Norm_layer(norm_cfg, planes * block.expansion))
199 |
200 | layers = []
201 | layers.append(block(self.inplanes, planes, norm_cfg, activation_cfg, stride=stride, downsample=downsample, weight_std=weight_std))
202 | self.inplanes = planes * block.expansion
203 | for i in range(1, blocks):
204 | layers.append(block(self.inplanes, planes, norm_cfg, activation_cfg, weight_std=weight_std))
205 |
206 | return nn.Sequential(*layers)
207 |
208 | def init_weights(self):
209 | for m in self.modules():
210 | if isinstance(m, (nn.Conv3d, Conv3d_wd)):
211 | m.weight = nn.init.kaiming_normal_(m.weight, mode='fan_out')
212 | elif isinstance(m, (nn.BatchNorm3d, nn.GroupNorm, nn.InstanceNorm3d, nn.SyncBatchNorm)):
213 | if m.weight is not None:
214 | nn.init.constant_(m.weight, 1)
215 | if m.bias is not None:
216 | nn.init.constant_(m.bias, 0)
217 |
218 | def forward(self, x):
219 | self.layers = []
220 | x = self.nonlin1(self.norm1(self.conv1(x)))
221 | x = self.nonlin2(self.norm2(self.conv2(x)))
222 | self.layers.append(x)
223 |
224 | x = self.layer1(x)
225 | self.layers.append(x)
226 | x = self.layer2(x)
227 | self.layers.append(x)
228 | x = self.layer3(x)
229 | self.layers.append(x)
230 | x = self.layer4(x)
231 | self.layers.append(x)
232 |
233 | return x
234 |
235 | def get_layers(self):
236 | return self.layers
237 |
--------------------------------------------------------------------------------
/TransDoD/models/TransDoDNet.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | from models import ResCNN2
5 | from models.deformable_transformer import build_deformable_transformer
6 | from .position_encoding import build_position_encoding
7 |
8 |
9 | def _expand(tensor, length: int):
10 | return tensor.unsqueeze(1).repeat(1, int(length), 1, 1, 1, 1).flatten(0, 1)
11 |
12 |
13 | class Conv3d_wd(nn.Conv3d):
14 |
15 | def __init__(self, in_channels, out_channels, kernel_size, stride=(1, 1, 1), padding=(0, 0, 0), dilation=(1, 1, 1),
16 | groups=1, bias=True):
17 | super(Conv3d_wd, self).__init__(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias)
18 |
19 | def forward(self, x):
20 | weight = self.weight
21 | weight_mean = weight.mean(dim=1, keepdim=True).mean(dim=2, keepdim=True).mean(dim=3, keepdim=True).mean(dim=4,
22 | keepdim=True)
23 | weight = weight - weight_mean
24 | std = torch.sqrt(torch.var(weight.view(weight.size(0), -1), dim=1) + 1e-12).view(-1, 1, 1, 1, 1)
25 | weight = weight / std.expand_as(weight)
26 | return F.conv3d(x, weight, self.bias, self.stride, self.padding, self.dilation, self.groups)
27 |
28 |
29 | def conv3x3x3(in_planes, out_planes, kernel_size, stride=(1, 1, 1), padding=(0, 0, 0), dilation=(1, 1, 1), groups=1,
30 | bias=True, weight_std=False):
31 | "3x3x3 convolution with padding"
32 | if weight_std:
33 | return Conv3d_wd(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=padding,
34 | dilation=dilation, groups=groups, bias=bias)
35 | else:
36 | return nn.Conv3d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=padding,
37 | dilation=dilation, groups=groups, bias=bias)
38 |
39 |
40 | def Norm_layer(norm_cfg, inplanes):
41 | if norm_cfg == 'BN':
42 | out = nn.BatchNorm3d(inplanes)
43 | elif norm_cfg == 'SyncBN':
44 | out = nn.SyncBatchNorm(inplanes)
45 | elif norm_cfg == 'GN':
46 | out = nn.GroupNorm(16, inplanes)
47 | elif norm_cfg == 'IN':
48 | out = nn.InstanceNorm3d(inplanes, affine=True)
49 |
50 | return out
51 |
52 |
53 | def Activation_layer(activation_cfg, inplace=True):
54 | if activation_cfg == 'relu':
55 | out = nn.ReLU(inplace=inplace)
56 | elif activation_cfg == 'LeakyReLU':
57 | out = nn.LeakyReLU(negative_slope=1e-2, inplace=inplace)
58 |
59 | return out
60 |
61 |
62 | class ResCNN_DeformTR(nn.Module):
63 | def __init__(self, args, norm_cfg='IN', activation_cfg='relu', num_classes=None, weight_std=False, res_depth=None, dyn_head_dep_wid=[3,8]):
64 | super(ResCNN_DeformTR, self).__init__()
65 |
66 | self.args = args
67 | self.args.activation = activation_cfg
68 | self.num_classes = num_classes
69 | self.dyn_head_dep_wid = dyn_head_dep_wid
70 | if res_depth >= 50:
71 | expansion = 4
72 | else:
73 | expansion = 1
74 |
75 | # num dyn params
76 | num_dyn_params = (dyn_head_dep_wid[1]*dyn_head_dep_wid[1]+dyn_head_dep_wid[1])*(dyn_head_dep_wid[0]-1) + (dyn_head_dep_wid[1]*2+2)*1
77 | print(f"###Total dyn params {num_dyn_params}###")
78 |
79 | self.upsample = nn.Upsample(scale_factor=(1, 2, 2), mode='trilinear')
80 |
81 | if self.args.add_memory >= 1:
82 | if self.args.num_feature_levels >= 1:
83 | self.memory_conv3_l0 = nn.Sequential(
84 | conv3x3x3(args.hidden_dim, 256, kernel_size=1, bias=False, weight_std=weight_std),
85 | Norm_layer(norm_cfg, 256),
86 | Activation_layer(activation_cfg, inplace=True),
87 | )
88 |
89 | self.cnn_bottle = nn.Sequential(
90 | conv3x3x3(320 * expansion, 256, kernel_size=1, bias=False, weight_std=weight_std),
91 | Norm_layer(norm_cfg, 256),
92 | Activation_layer(activation_cfg, inplace=True),
93 | )
94 |
95 | self.shortcut_conv3 = nn.Sequential(
96 | conv3x3x3(256 * expansion, 256, kernel_size=1, bias=False, weight_std=weight_std),
97 | Norm_layer(norm_cfg, 256),
98 | Activation_layer(activation_cfg, inplace=True),
99 | )
100 |
101 | self.shortcut_conv2 = nn.Sequential(
102 | conv3x3x3(128 * expansion, 128, kernel_size=1, bias=False, weight_std=weight_std),
103 | Norm_layer(norm_cfg, 128),
104 | Activation_layer(activation_cfg, inplace=True),
105 | )
106 |
107 | self.shortcut_conv1 = nn.Sequential(
108 | conv3x3x3(64 * expansion, 64, kernel_size=1, bias=False, weight_std=weight_std),
109 | Norm_layer(norm_cfg, 64),
110 | Activation_layer(activation_cfg, inplace=True),
111 | )
112 |
113 | self.shortcut_conv0 = nn.Sequential(
114 | conv3x3x3(32, 32, kernel_size=1, bias=False, weight_std=weight_std),
115 | Norm_layer(norm_cfg, 32),
116 | Activation_layer(activation_cfg, inplace=True),
117 | )
118 |
119 | self.transposeconv_stage3 = nn.ConvTranspose3d(256, 256, kernel_size=(2, 2, 2), stride=(2, 2, 2), bias=False)
120 | self.transposeconv_stage2 = nn.ConvTranspose3d(256, 128, kernel_size=(2, 2, 2), stride=(2, 2, 2), bias=False)
121 | self.transposeconv_stage1 = nn.ConvTranspose3d(128, 64, kernel_size=(2, 2, 2), stride=(2, 2, 2), bias=False)
122 | self.transposeconv_stage0 = nn.ConvTranspose3d(64, 32, kernel_size=(2, 2, 2), stride=(2, 2, 2), bias=False)
123 |
124 | self.stage3_de = ResCNN2.BasicBlock(256, 256, norm_cfg, activation_cfg, weight_std=weight_std)
125 | self.stage2_de = ResCNN2.BasicBlock(128, 128, norm_cfg, activation_cfg, weight_std=weight_std)
126 | self.stage1_de = ResCNN2.BasicBlock(64, 64, norm_cfg, activation_cfg, weight_std=weight_std)
127 | self.stage0_de = ResCNN2.BasicBlock(32, 32, norm_cfg, activation_cfg, weight_std=weight_std)
128 |
129 | self.precls_conv = nn.Sequential(
130 | conv3x3x3(32, dyn_head_dep_wid[1], kernel_size=1, bias=False, weight_std=weight_std),
131 | Norm_layer(norm_cfg, dyn_head_dep_wid[1]),
132 | Activation_layer(activation_cfg, inplace=True),
133 | )
134 |
135 | self.backbone = ResCNN2.ResNet(depth=res_depth, shortcut_type='B', norm_cfg=norm_cfg,
136 | activation_cfg=activation_cfg,
137 | weight_std=weight_std)
138 |
139 | self.backbone_layers = self.backbone.get_layers()
140 |
141 | #
142 | if self.args.using_transformer:
143 | self.transformer = build_deformable_transformer(args)
144 | self.position_embedding = build_position_encoding(args)
145 | self.controller = nn.Sequential(
146 | nn.Linear(args.hidden_dim, args.hidden_dim),
147 | Activation_layer(activation_cfg, inplace=True),
148 | nn.Linear(args.hidden_dim, num_dyn_params)
149 | )
150 |
151 | input_proj_list = []
152 | backbone_num_channels = [64 * expansion, 128 * expansion, 256 * expansion, 320 * expansion]
153 | backbone_num_channels = backbone_num_channels[-args.num_feature_levels:]
154 | for _ in backbone_num_channels:
155 | in_channels = _
156 | input_proj_list.append(nn.Sequential(
157 | conv3x3x3(in_channels, args.hidden_dim, kernel_size=1, bias=True, weight_std=weight_std),
158 | Norm_layer(norm_cfg, args.hidden_dim),
159 | ))
160 | self.input_proj = nn.ModuleList(input_proj_list)
161 | self.query_embed = nn.Embedding(self.args.num_queries, args.hidden_dim * 2)
162 |
163 | for m in self.modules():
164 | if isinstance(m, (nn.Conv3d, Conv3d_wd, nn.ConvTranspose3d)):
165 | m.weight = nn.init.kaiming_normal_(m.weight, mode='fan_out')
166 | elif isinstance(m, (nn.BatchNorm3d, nn.SyncBatchNorm, nn.InstanceNorm3d, nn.GroupNorm)):
167 | if m.weight is not None:
168 | nn.init.constant_(m.weight, 1)
169 | if m.bias is not None:
170 | nn.init.constant_(m.bias, 0)
171 |
172 | def parse_dynamic_params(self, params, channels, weight_nums, bias_nums):
173 | assert params.dim() == 2
174 | assert len(weight_nums) == len(bias_nums)
175 | assert params.size(1) == sum(weight_nums) + sum(bias_nums)
176 |
177 | num_insts = params.size(0)
178 | num_layers = len(weight_nums)
179 |
180 | params_splits = list(torch.split_with_sizes(
181 | params, weight_nums + bias_nums, dim=1
182 | ))
183 |
184 | weight_splits = params_splits[:num_layers]
185 | bias_splits = params_splits[num_layers:]
186 |
187 | for l in range(num_layers):
188 | if l < num_layers - 1:
189 | weight_splits[l] = weight_splits[l].reshape(num_insts * channels, -1, 1, 1, 1)
190 | bias_splits[l] = bias_splits[l].reshape(num_insts * channels)
191 | else:
192 | weight_splits[l] = weight_splits[l].reshape(num_insts * 2, -1, 1, 1, 1)
193 | bias_splits[l] = bias_splits[l].reshape(num_insts * 2)
194 |
195 | return weight_splits, bias_splits
196 |
197 | def heads_forward(self, features, weights, biases, num_insts):
198 | assert features.dim() == 5
199 | n_layers = len(weights)
200 | x = features
201 | for i, (w, b) in enumerate(zip(weights, biases)):
202 | x = F.conv3d(
203 | x, w, bias=b,
204 | stride=1, padding=0,
205 | groups=num_insts
206 | )
207 | if i < n_layers - 1:
208 | x = F.relu(x)
209 | return x
210 |
211 | def forward(self, inputs_x, task_id):
212 | bs, c, d, w, h = inputs_x.shape
213 |
214 | _ = self.backbone(inputs_x)
215 | layers = self.backbone.get_layers() # [::-1]
216 |
217 | # Transformer body
218 | srcs = []
219 | masks = []
220 | pos = []
221 | for l, feat in enumerate(layers[-self.args.num_feature_levels:]):
222 | src = feat
223 | srcs.append(self.input_proj[l](src))
224 | masks.append(torch.BoolTensor(src.shape[0], src.shape[2], src.shape[3], src.shape[4]).cuda() * False)
225 | pos.append(self.position_embedding(src).to(src.dtype))
226 | del feat
227 |
228 | hs, _, _, _, _, memory = self.transformer(srcs, masks, pos, self.query_embed.weight)
229 | params = self.controller(hs[-1].flatten(0, 1))
230 |
231 | if self.args.add_memory == 0:
232 | x = self.cnn_bottle(layers[-1])
233 | elif self.args.add_memory == 1:
234 | x = self.memory_conv3_l0(memory[-1])
235 | elif self.args.add_memory == 2:
236 | x = self.memory_conv3_l0(memory[-1]) + self.cnn_bottle(layers[-1])
237 | else:
238 | print("Error: no pre-defined add_memory mode!")
239 |
240 | skip3 = self.shortcut_conv3(layers[-2])
241 | x = self.transposeconv_stage3(x)
242 | x = x + skip3
243 | x = self.stage3_de(x)
244 |
245 | skip2 = self.shortcut_conv2(layers[-3])
246 | x = self.transposeconv_stage2(x)
247 | x = x + skip2
248 | x = self.stage2_de(x)
249 |
250 | x = self.transposeconv_stage1(x)
251 | skip1 = self.shortcut_conv1(layers[-4])
252 | x = x + skip1
253 | x = self.stage1_de(x)
254 |
255 | x = self.transposeconv_stage0(x)
256 | skip0 = self.shortcut_conv0(layers[-5])
257 | x = x + skip0
258 | x = self.stage0_de(x)
259 |
260 | head_inputs = self.precls_conv(x)
261 | head_inputs = _expand(head_inputs, self.args.num_queries)
262 | N, _, D, H, W = head_inputs.size()
263 | head_inputs = head_inputs.reshape(1, -1, D, H, W)
264 |
265 | weight_nums, bias_nums = [], []
266 | for i in range(self.dyn_head_dep_wid[0]-1):
267 | weight_nums.append(self.dyn_head_dep_wid[1]*self.dyn_head_dep_wid[1])
268 | bias_nums.append(self.dyn_head_dep_wid[1])
269 | weight_nums.append(self.dyn_head_dep_wid[1]*2)
270 | bias_nums.append(2)
271 |
272 | weights, biases = self.parse_dynamic_params(params, self.dyn_head_dep_wid[1], weight_nums, bias_nums)
273 |
274 | seg_out = self.heads_forward(head_inputs, weights, biases, N)
275 | seg_out = self.upsample(seg_out)
276 | seg_out = seg_out.view(bs, self.args.num_queries, self.num_classes, seg_out.shape[-3], seg_out.shape[-2],
277 | seg_out.shape[-1])
278 |
279 | return seg_out
280 |
281 |
282 | class TransDoDNet(nn.Module):
283 | def __init__(self, args, norm_cfg='IN', activation_cfg='relu', num_classes=None,
284 | weight_std=False, deep_supervision=False, res_depth=None, dyn_head_dep_wid=[3,8]):
285 | super().__init__()
286 | self.do_ds = False
287 | self.ResCNN_DeformTR = ResCNN_DeformTR(args, norm_cfg, activation_cfg, num_classes, weight_std, res_depth, dyn_head_dep_wid)
288 | if weight_std == False:
289 | self.conv_op = nn.Conv3d
290 | else:
291 | self.conv_op = Conv3d_wd
292 | if norm_cfg == 'BN':
293 | self.norm_op = nn.BatchNorm3d
294 | if norm_cfg == 'GN':
295 | self.norm_op = nn.GroupNorm
296 | if norm_cfg == 'IN':
297 | self.norm_op = nn.InstanceNorm3d
298 | self.num_classes = num_classes
299 | self._deep_supervision = deep_supervision
300 | self.do_ds = deep_supervision
301 |
302 | def forward(self, x, task_id):
303 | seg_output = self.ResCNN_DeformTR(x, task_id)
304 | return seg_output
--------------------------------------------------------------------------------
/TransDoD/models/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jianpengz/DoDNet/f06ce1b02988b6894a85f925623c6bcc6012284f/TransDoD/models/__init__.py
--------------------------------------------------------------------------------
/TransDoD/models/__pycache__/ResCNN2.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jianpengz/DoDNet/f06ce1b02988b6894a85f925623c6bcc6012284f/TransDoD/models/__pycache__/ResCNN2.cpython-37.pyc
--------------------------------------------------------------------------------
/TransDoD/models/__pycache__/TransDoDNet.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jianpengz/DoDNet/f06ce1b02988b6894a85f925623c6bcc6012284f/TransDoD/models/__pycache__/TransDoDNet.cpython-37.pyc
--------------------------------------------------------------------------------
/TransDoD/models/__pycache__/__init__.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jianpengz/DoDNet/f06ce1b02988b6894a85f925623c6bcc6012284f/TransDoD/models/__pycache__/__init__.cpython-37.pyc
--------------------------------------------------------------------------------
/TransDoD/models/__pycache__/deepnorm.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jianpengz/DoDNet/f06ce1b02988b6894a85f925623c6bcc6012284f/TransDoD/models/__pycache__/deepnorm.cpython-37.pyc
--------------------------------------------------------------------------------
/TransDoD/models/__pycache__/deformable_transformer.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jianpengz/DoDNet/f06ce1b02988b6894a85f925623c6bcc6012284f/TransDoD/models/__pycache__/deformable_transformer.cpython-37.pyc
--------------------------------------------------------------------------------
/TransDoD/models/__pycache__/position_encoding.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jianpengz/DoDNet/f06ce1b02988b6894a85f925623c6bcc6012284f/TransDoD/models/__pycache__/position_encoding.cpython-37.pyc
--------------------------------------------------------------------------------
/TransDoD/models/deepnorm.py:
--------------------------------------------------------------------------------
1 | from enum import Enum
2 | from typing import List, Optional, Tuple, Union
3 | from collections import namedtuple
4 |
5 | DeepNormCoefficients = namedtuple("DeepNormCoefficients", ["alpha", "beta"])
6 |
7 | def get_deepnorm_coefficients(
8 | encoder_layers: int, decoder_layers: int
9 | ) -> Tuple[Optional[DeepNormCoefficients], Optional[DeepNormCoefficients]]:
10 | """
11 | See DeepNet_.
12 | Returns alpha and beta depending on the number of encoder and decoder layers,
13 | first tuple is for the for the encoder and second for the decoder
14 | .. _DeepNet: https://arxiv.org/pdf/2203.00555v1.pdf
15 | """
16 |
17 | N = encoder_layers
18 | M = decoder_layers
19 |
20 | if decoder_layers == 0:
21 | # Encoder only
22 | return (
23 | DeepNormCoefficients(alpha=(2 * N) ** 0.25, beta=(8 * N) ** -0.25),
24 | None,
25 | )
26 |
27 | elif encoder_layers == 0:
28 | # Decoder only
29 | return None, DeepNormCoefficients(alpha=(2 * M) ** 0.25, beta=(8 * M) ** -0.25)
30 | else:
31 | # Encoder/decoder
32 | encoder_coeffs = DeepNormCoefficients(
33 | alpha=0.81 * ((N ** 4) * M) ** 0.0625, beta=0.87 * ((N ** 4) * M) ** -0.0625
34 | )
35 |
36 | decoder_coeffs = DeepNormCoefficients(
37 | alpha=(3 * M) ** 0.25, beta=(12 * M) ** -0.25
38 | )
39 |
40 | return (encoder_coeffs, decoder_coeffs)
41 |
42 | if __name__=="__main__":
43 | coef = get_deepnorm_coefficients(1000,1000)
44 | pass
--------------------------------------------------------------------------------
/TransDoD/models/ops/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jianpengz/DoDNet/f06ce1b02988b6894a85f925623c6bcc6012284f/TransDoD/models/ops/.DS_Store
--------------------------------------------------------------------------------
/TransDoD/models/ops/MultiScaleDeformableAttention.egg-info/PKG-INFO:
--------------------------------------------------------------------------------
1 | Metadata-Version: 2.1
2 | Name: MultiScaleDeformableAttention
3 | Version: 1.0
4 | Summary: PyTorch Wrapper for CUDA Functions of Multi-Scale Deformable Attention
5 | Home-page: https://github.com/fundamentalvision/Deformable-DETR
6 | Author: Weijie Su
7 | License: UNKNOWN
8 | Platform: UNKNOWN
9 |
10 | UNKNOWN
11 |
12 |
--------------------------------------------------------------------------------
/TransDoD/models/ops/MultiScaleDeformableAttention.egg-info/SOURCES.txt:
--------------------------------------------------------------------------------
1 | setup.py
2 | /media/userdisk1/jpzhang/myproject-Seg/MOTS-pro2/f_transformers/models/ops/src/vision.cpp
3 | /media/userdisk1/jpzhang/myproject-Seg/MOTS-pro2/f_transformers/models/ops/src/cpu/ms_deform_attn_cpu.cpp
4 | /media/userdisk1/jpzhang/myproject-Seg/MOTS-pro2/f_transformers/models/ops/src/cuda/ms_deform_attn_cuda.cu
5 | MultiScaleDeformableAttention.egg-info/PKG-INFO
6 | MultiScaleDeformableAttention.egg-info/SOURCES.txt
7 | MultiScaleDeformableAttention.egg-info/dependency_links.txt
8 | MultiScaleDeformableAttention.egg-info/top_level.txt
9 | functions/__init__.py
10 | functions/ms_deform_attn_func.py
11 | modules/__init__.py
12 | modules/ms_deform_attn (33).py
13 | modules/ms_deform_attn.py
--------------------------------------------------------------------------------
/TransDoD/models/ops/MultiScaleDeformableAttention.egg-info/dependency_links.txt:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/TransDoD/models/ops/MultiScaleDeformableAttention.egg-info/top_level.txt:
--------------------------------------------------------------------------------
1 | MultiScaleDeformableAttention
2 | functions
3 | modules
4 |
--------------------------------------------------------------------------------
/TransDoD/models/ops/build/lib.linux-x86_64-3.7/MultiScaleDeformableAttention.cpython-37m-x86_64-linux-gnu.so:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jianpengz/DoDNet/f06ce1b02988b6894a85f925623c6bcc6012284f/TransDoD/models/ops/build/lib.linux-x86_64-3.7/MultiScaleDeformableAttention.cpython-37m-x86_64-linux-gnu.so
--------------------------------------------------------------------------------
/TransDoD/models/ops/build/lib.linux-x86_64-3.7/functions/__init__.py:
--------------------------------------------------------------------------------
1 | # ------------------------------------------------------------------------------------------------
2 | # Deformable DETR
3 | # Copyright (c) 2020 SenseTime. All Rights Reserved.
4 | # Licensed under the Apache License, Version 2.0 [see LICENSE for details]
5 | # ------------------------------------------------------------------------------------------------
6 | # Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
7 | # ------------------------------------------------------------------------------------------------
8 |
9 | from .ms_deform_attn_func import MSDeformAttnFunction
10 |
11 |
--------------------------------------------------------------------------------
/TransDoD/models/ops/build/lib.linux-x86_64-3.7/functions/ms_deform_attn_func.py:
--------------------------------------------------------------------------------
1 | # ------------------------------------------------------------------------------------------------
2 | # Deformable DETR
3 | # Copyright (c) 2020 SenseTime. All Rights Reserved.
4 | # Licensed under the Apache License, Version 2.0 [see LICENSE for details]
5 | # ------------------------------------------------------------------------------------------------
6 | # Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
7 | # ------------------------------------------------------------------------------------------------
8 |
9 | from __future__ import absolute_import
10 | from __future__ import print_function
11 | from __future__ import division
12 |
13 | import torch
14 | import torch.nn.functional as F
15 | from torch.autograd import Function
16 | from torch.autograd.function import once_differentiable
17 |
18 | import MultiScaleDeformableAttention as MSDA
19 |
20 |
21 | class MSDeformAttnFunction(Function):
22 | @staticmethod
23 | def forward(ctx, value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights, im2col_step):
24 | ctx.im2col_step = im2col_step
25 | output = MSDA.ms_deform_attn_forward(
26 | value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights, ctx.im2col_step)
27 | ctx.save_for_backward(value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights)
28 | return output
29 |
30 | @staticmethod
31 | @once_differentiable
32 | def backward(ctx, grad_output):
33 | value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights = ctx.saved_tensors
34 | grad_value, grad_sampling_loc, grad_attn_weight = \
35 | MSDA.ms_deform_attn_backward(
36 | value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights, grad_output, ctx.im2col_step)
37 |
38 | return grad_value, None, None, grad_sampling_loc, grad_attn_weight, None
39 |
40 |
41 | def ms_deform_attn_core_pytorch(value, value_spatial_shapes, sampling_locations, attention_weights):
42 | # for debug and test only,
43 | # need to use cuda version instead
44 | N_, S_, M_, D_ = value.shape
45 | _, Lq_, M_, L_, P_, _ = sampling_locations.shape
46 | value_list = value.split([H_ * W_ for H_, W_ in value_spatial_shapes], dim=1)
47 | sampling_grids = 2 * sampling_locations - 1
48 | sampling_value_list = []
49 | for lid_, (H_, W_) in enumerate(value_spatial_shapes):
50 | # N_, H_*W_, M_, D_ -> N_, H_*W_, M_*D_ -> N_, M_*D_, H_*W_ -> N_*M_, D_, H_, W_
51 | value_l_ = value_list[lid_].flatten(2).transpose(1, 2).reshape(N_*M_, D_, H_, W_)
52 | # N_, Lq_, M_, P_, 2 -> N_, M_, Lq_, P_, 2 -> N_*M_, Lq_, P_, 2
53 | sampling_grid_l_ = sampling_grids[:, :, :, lid_].transpose(1, 2).flatten(0, 1)
54 | # N_*M_, D_, Lq_, P_
55 | sampling_value_l_ = F.grid_sample(value_l_, sampling_grid_l_,
56 | mode='bilinear', padding_mode='zeros', align_corners=False)
57 | sampling_value_list.append(sampling_value_l_)
58 | # (N_, Lq_, M_, L_, P_) -> (N_, M_, Lq_, L_, P_) -> (N_, M_, 1, Lq_, L_*P_)
59 | attention_weights = attention_weights.transpose(1, 2).reshape(N_*M_, 1, Lq_, L_*P_)
60 | output = (torch.stack(sampling_value_list, dim=-2).flatten(-2) * attention_weights).sum(-1).view(N_, M_*D_, Lq_)
61 | return output.transpose(1, 2).contiguous()
62 |
63 | def ms_deform_attn_core_pytorch_3D(value, value_spatial_shapes, sampling_locations, attention_weights):
64 | # for debug and test only,
65 | # need to use cuda version instead
66 | N_, S_, M_, D_ = value.shape
67 | _, Lq_, M_, L_, P_, _ = sampling_locations.shape
68 | value_list = value.split([T_ * H_ * W_ for T_, H_, W_ in value_spatial_shapes], dim=1)
69 | sampling_grids = 2 * sampling_locations - 1
70 | sampling_value_list = []
71 | for lid_, (T_, H_, W_) in enumerate(value_spatial_shapes):
72 | # N_, T_*H_*W_, M_, D_ -> N_, T_*H_*W_, M_*D_ -> N_, M_*D_, T_*H_*W_ -> N_*M_, D_, T_, H_, W_
73 | value_l_ = value_list[lid_].flatten(2).transpose(1, 2).reshape(N_*M_, D_, T_, H_, W_)
74 | # N_, Lq_, M_, P_, 3 -> N_, M_, Lq_, P_, 3 -> N_*M_, Lq_, P_, 3
75 | sampling_grid_l_ = sampling_grids[:, :, :, lid_].transpose(1, 2).flatten(0, 1)[:,None,:,:,:]
76 | # N_*M_, D_, Lq_, P_
77 | sampling_value_l_ = F.grid_sample(value_l_, sampling_grid_l_, mode='bilinear', padding_mode='zeros', align_corners=False)[:,:,0]
78 | sampling_value_list.append(sampling_value_l_)
79 | # (N_, Lq_, M_, L_, P_) -> (N_, M_, Lq_, L_, P_) -> (N_, M_, 1, Lq_, L_*P_)
80 | attention_weights = attention_weights.transpose(1, 2).reshape(N_*M_, 1, Lq_, L_*P_)
81 | output = (torch.stack(sampling_value_list, dim=-2).flatten(-2) * attention_weights).sum(-1).view(N_, M_*D_, Lq_)
82 | return output.transpose(1, 2).contiguous()
83 |
--------------------------------------------------------------------------------
/TransDoD/models/ops/build/lib.linux-x86_64-3.7/modules/__init__.py:
--------------------------------------------------------------------------------
1 | # ------------------------------------------------------------------------------------------------
2 | # Deformable DETR
3 | # Copyright (c) 2020 SenseTime. All Rights Reserved.
4 | # Licensed under the Apache License, Version 2.0 [see LICENSE for details]
5 | # ------------------------------------------------------------------------------------------------
6 | # Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
7 | # ------------------------------------------------------------------------------------------------
8 |
9 | from .ms_deform_attn import MSDeformAttn
10 |
--------------------------------------------------------------------------------
/TransDoD/models/ops/build/lib.linux-x86_64-3.7/modules/ms_deform_attn (33).py:
--------------------------------------------------------------------------------
1 | # ------------------------------------------------------------------------------------------------
2 | # Deformable DETR
3 | # Copyright (c) 2020 SenseTime. All Rights Reserved.
4 | # Licensed under the Apache License, Version 2.0 [see LICENSE for details]
5 | # ------------------------------------------------------------------------------------------------
6 | # Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
7 | # ------------------------------------------------------------------------------------------------
8 |
9 | from __future__ import absolute_import
10 | from __future__ import print_function
11 | from __future__ import division
12 |
13 | import warnings
14 | import math
15 |
16 | import torch
17 | from torch import nn
18 | import torch.nn.functional as F
19 | from torch.nn.init import xavier_uniform_, constant_
20 |
21 | from ..functions import MSDeformAttnFunction
22 | from ..functions.ms_deform_attn_func import MSDeformAttnFunction, ms_deform_attn_core_pytorch_3D
23 |
24 | def _is_power_of_2(n):
25 | if (not isinstance(n, int)) or (n < 0):
26 | raise ValueError("invalid input for _is_power_of_2: {} (type: {})".format(n, type(n)))
27 | return (n & (n-1) == 0) and n != 0
28 |
29 |
30 | class MSDeformAttn(nn.Module):
31 | def __init__(self, d_model=256, n_levels=4, n_heads=8, n_points=4):
32 | """
33 | Multi-Scale Deformable Attention Module
34 | :param d_model hidden dimension
35 | :param n_levels number of feature levels
36 | :param n_heads number of attention heads
37 | :param n_points number of sampling points per attention head per feature level
38 | """
39 | super().__init__()
40 | if d_model % n_heads != 0:
41 | raise ValueError('d_model must be divisible by n_heads, but got {} and {}'.format(d_model, n_heads))
42 | _d_per_head = d_model // n_heads
43 | # you'd better set _d_per_head to a power of 2 which is more efficient in our CUDA implementation
44 | if not _is_power_of_2(_d_per_head):
45 | warnings.warn("You'd better set d_model in MSDeformAttn to make the dimension of each attention head a power of 2 "
46 | "which is more efficient in our CUDA implementation.")
47 |
48 | self.im2col_step = 64
49 |
50 | self.d_model = d_model
51 | self.n_levels = n_levels
52 | self.n_heads = n_heads
53 | self.n_points = n_points
54 |
55 | self.sampling_offsets = nn.Linear(d_model, n_heads * n_levels * n_points * 3)
56 | self.attention_weights = nn.Linear(d_model, n_heads * n_levels * n_points)
57 | self.value_proj = nn.Linear(d_model, d_model)
58 | self.output_proj = nn.Linear(d_model, d_model)
59 |
60 | # self.s_norm = nn.LayerNorm(n_heads * n_levels * n_points * 3)
61 |
62 | self._reset_parameters()
63 |
64 | # add for vis
65 | self.samp_location = []
66 | self.atte_w = []
67 |
68 | def _reset_parameters(self):
69 | constant_(self.sampling_offsets.weight.data, 0.)
70 | thetas = torch.arange(self.n_heads, dtype=torch.float32) * (2.0 * math.pi / self.n_heads)
71 | # grid_init = torch.stack([thetas.sin()*thetas.cos(), thetas.sin()*thetas.sin(), thetas.cos()], -1)
72 | # grid_init = torch.stack([thetas.sin() * thetas.cos(), thetas.sin(), thetas.cos()], -1)
73 | # grid_init = torch.stack([thetas.sin() * thetas.cos(), thetas.cos(), thetas.sin()], -1)
74 | grid_init = torch.stack([2*thetas.sin()*thetas.cos(), thetas.cos(), thetas.sin()], -1)
75 |
76 | # thetas = torch.arange(self.n_heads, dtype=torch.float32) * (2.0 * math.pi / self.n_heads)
77 | # beta = torch.arange(-self.n_heads//2, self.n_heads//2, dtype=torch.float32) * (1.0 * math.pi / self.n_heads)
78 | # grid_init = torch.stack([beta.sin(), beta.cos() * thetas.cos(), beta.cos() * thetas.sin()], -1)
79 |
80 | # thetas = torch.arange(self.n_heads, dtype=torch.float32) * (2.0 * math.pi / self.n_heads)
81 | # beta = torch.arange(self.n_heads, dtype=torch.float32) * (2.0 * math.pi / self.n_heads)
82 | # grid_init = torch.stack([beta.sin() * thetas.sin(), beta.cos(), beta.sin() * thetas.cos()], -1)
83 |
84 | grid_init = (grid_init / grid_init.abs().max(-1, keepdim=True)[0]).view(self.n_heads, 1, 1, 3).repeat(1, self.n_levels, self.n_points, 1)
85 | for i in range(self.n_points):
86 | grid_init[:, :, i, :] *= i + 1
87 | with torch.no_grad():
88 | self.sampling_offsets.bias = nn.Parameter(grid_init.view(-1))
89 | constant_(self.attention_weights.weight.data, 0.)
90 | constant_(self.attention_weights.bias.data, 0.)
91 | xavier_uniform_(self.value_proj.weight.data)
92 | constant_(self.value_proj.bias.data, 0.)
93 | xavier_uniform_(self.output_proj.weight.data)
94 | constant_(self.output_proj.bias.data, 0.)
95 |
96 | def forward(self, query, reference_points, input_flatten, input_spatial_shapes, input_level_start_index, input_padding_mask=None):
97 | """
98 | :param query (N, Length_{query}, C)
99 | :param reference_points (N, Length_{query}, n_levels, 2), range in [0, 1], top-left (0,0), bottom-right (1, 1), including padding area
100 | or (N, Length_{query}, n_levels, 4), add additional (w, h) to form reference boxes
101 | :param input_flatten (N, \sum_{l=0}^{L-1} H_l \cdot W_l, C)
102 | :param input_spatial_shapes (n_levels, 2), [(H_0, W_0), (H_1, W_1), ..., (H_{L-1}, W_{L-1})]
103 | :param input_level_start_index (n_levels, ), [0, H_0*W_0, H_0*W_0+H_1*W_1, H_0*W_0+H_1*W_1+H_2*W_2, ..., H_0*W_0+H_1*W_1+...+H_{L-1}*W_{L-1}]
104 | :param input_padding_mask (N, \sum_{l=0}^{L-1} H_l \cdot W_l), True for padding elements, False for non-padding elements
105 |
106 | :return output (N, Length_{query}, C)
107 | """
108 | N, Len_q, _ = query.shape
109 | N, Len_in, _ = input_flatten.shape
110 | assert (input_spatial_shapes[:, 0] * input_spatial_shapes[:, 1] * input_spatial_shapes[:, 2]).sum() == Len_in
111 |
112 | value = self.value_proj(input_flatten)
113 | if input_padding_mask is not None:
114 | value = value.masked_fill(input_padding_mask[..., None], float(0))
115 | value = value.view(N, Len_in, self.n_heads, self.d_model // self.n_heads)
116 | ## using Tanh
117 | sampling_offsets = self.sampling_offsets(query).view(N, Len_q, self.n_heads, self.n_levels, self.n_points, 3)
118 | # sampling_offsets = F.tanh(sampling_offsets)
119 |
120 | ## or using norm
121 | # sampling_offsets = self.s_norm(self.sampling_offsets(query)).view(N, Len_q, self.n_heads, self.n_levels, self.n_points, 3)
122 |
123 | attention_weights = self.attention_weights(query).view(N, Len_q, self.n_heads, self.n_levels * self.n_points)
124 | attention_weights = F.softmax(attention_weights, -1).view(N, Len_q, self.n_heads, self.n_levels, self.n_points)
125 | # N, Len_q, n_heads, n_levels, n_points, 3
126 | if reference_points.shape[-1] == 3:
127 | offset_normalizer = torch.stack([input_spatial_shapes[..., 2], input_spatial_shapes[..., 2], input_spatial_shapes[..., 1]], -1)
128 | sampling_locations = reference_points[:, :, None, :, None, :] \
129 | + sampling_offsets / (2*offset_normalizer[None, None, None, :, None, :])
130 | # sampling_locations = reference_points[:, :, None, :, None, :] + sampling_offsets
131 | elif reference_points.shape[-1] == 4:
132 | sampling_locations = reference_points[:, :, None, :, None, :2] \
133 | + sampling_offsets / self.n_points * reference_points[:, :, None, :, None, 2:] * 0.5
134 | else:
135 | raise ValueError(
136 | 'Last dim of reference_points must be 2 or 4, but get {} instead.'.format(reference_points.shape[-1]))
137 | # output = MSDeformAttnFunction.apply(
138 | # value, input_spatial_shapes, input_level_start_index, sampling_locations, attention_weights, self.im2col_step)
139 | # debug
140 | output = ms_deform_attn_core_pytorch_3D(value, input_spatial_shapes, sampling_locations,
141 | attention_weights)#.detach().cpu()
142 | # output_cuda = MSDeformAttnFunction.apply(value, input_spatial_shapes, input_level_start_index,
143 | # sampling_locations, attention_weights, self.im2col_step).detach().cpu()
144 |
145 | output = self.output_proj(output)
146 | self.atte_w = attention_weights
147 | self.samp_location = sampling_locations
148 | return output
149 |
150 |
151 |
--------------------------------------------------------------------------------
/TransDoD/models/ops/build/lib.linux-x86_64-3.7/modules/ms_deform_attn.py:
--------------------------------------------------------------------------------
1 | # ------------------------------------------------------------------------------------------------
2 | # Deformable DETR
3 | # Copyright (c) 2020 SenseTime. All Rights Reserved.
4 | # Licensed under the Apache License, Version 2.0 [see LICENSE for details]
5 | # ------------------------------------------------------------------------------------------------
6 | # Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
7 | # ------------------------------------------------------------------------------------------------
8 |
9 | from __future__ import absolute_import
10 | from __future__ import print_function
11 | from __future__ import division
12 |
13 | import warnings
14 | import math
15 |
16 | import torch
17 | from torch import nn
18 | import torch.nn.functional as F
19 | from torch.nn.init import xavier_uniform_, constant_
20 |
21 | from ..functions import MSDeformAttnFunction
22 | from ..functions.ms_deform_attn_func import MSDeformAttnFunction, ms_deform_attn_core_pytorch_3D
23 |
24 | def _is_power_of_2(n):
25 | if (not isinstance(n, int)) or (n < 0):
26 | raise ValueError("invalid input for _is_power_of_2: {} (type: {})".format(n, type(n)))
27 | return (n & (n-1) == 0) and n != 0
28 |
29 |
30 | class MSDeformAttn(nn.Module):
31 | def __init__(self, d_model=256, n_levels=4, n_heads=8, n_points=4):
32 | """
33 | Multi-Scale Deformable Attention Module
34 | :param d_model hidden dimension
35 | :param n_levels number of feature levels
36 | :param n_heads number of attention heads
37 | :param n_points number of sampling points per attention head per feature level
38 | """
39 | super().__init__()
40 | if d_model % n_heads != 0:
41 | raise ValueError('d_model must be divisible by n_heads, but got {} and {}'.format(d_model, n_heads))
42 | _d_per_head = d_model // n_heads
43 | # you'd better set _d_per_head to a power of 2 which is more efficient in our CUDA implementation
44 | if not _is_power_of_2(_d_per_head):
45 | warnings.warn("You'd better set d_model in MSDeformAttn to make the dimension of each attention head a power of 2 "
46 | "which is more efficient in our CUDA implementation.")
47 |
48 | self.im2col_step = 64
49 |
50 | self.d_model = d_model
51 | self.n_levels = n_levels
52 | self.n_heads = n_heads
53 | self.n_points = n_points
54 |
55 | self.sampling_offsets = nn.Linear(d_model, n_heads * n_levels * n_points * 3)
56 | self.attention_weights = nn.Linear(d_model, n_heads * n_levels * n_points)
57 | self.value_proj = nn.Linear(d_model, d_model)
58 | self.output_proj = nn.Linear(d_model, d_model)
59 |
60 | # self.s_norm = nn.LayerNorm(n_heads * n_levels * n_points * 3)
61 |
62 | self._reset_parameters()
63 |
64 | # add for vis
65 | self.samp_location = []
66 | self.atte_w = []
67 |
68 | def _reset_parameters(self):
69 | constant_(self.sampling_offsets.weight.data, 0.)
70 | thetas = torch.arange(self.n_heads, dtype=torch.float32) * (2.0 * math.pi / self.n_heads)
71 | # grid_init = torch.stack([thetas.sin()*thetas.cos(), thetas.sin()*thetas.sin(), thetas.cos()], -1)
72 | # grid_init = torch.stack([thetas.sin() * thetas.cos(), thetas.sin(), thetas.cos()], -1)
73 | # grid_init = torch.stack([thetas.sin() * thetas.cos(), thetas.cos(), thetas.sin()], -1)
74 | grid_init = torch.stack([2*thetas.sin()*thetas.cos(), thetas.cos(), thetas.sin()], -1)
75 |
76 | # thetas = torch.arange(self.n_heads, dtype=torch.float32) * (2.0 * math.pi / self.n_heads)
77 | # beta = torch.arange(-self.n_heads//2, self.n_heads//2, dtype=torch.float32) * (1.0 * math.pi / self.n_heads)
78 | # grid_init = torch.stack([beta.sin(), beta.cos() * thetas.cos(), beta.cos() * thetas.sin()], -1)
79 |
80 | # thetas = torch.arange(self.n_heads, dtype=torch.float32) * (2.0 * math.pi / self.n_heads)
81 | # beta = torch.arange(self.n_heads, dtype=torch.float32) * (2.0 * math.pi / self.n_heads)
82 | # grid_init = torch.stack([beta.sin() * thetas.sin(), beta.cos(), beta.sin() * thetas.cos()], -1)
83 |
84 | grid_init = (grid_init / grid_init.abs().max(-1, keepdim=True)[0]).view(self.n_heads, 1, 1, 3).repeat(1, self.n_levels, self.n_points, 1)
85 | for i in range(self.n_points):
86 | grid_init[:, :, i, :] *= i + 1
87 | with torch.no_grad():
88 | self.sampling_offsets.bias = nn.Parameter(grid_init.view(-1))
89 | constant_(self.attention_weights.weight.data, 0.)
90 | constant_(self.attention_weights.bias.data, 0.)
91 | xavier_uniform_(self.value_proj.weight.data)
92 | constant_(self.value_proj.bias.data, 0.)
93 | xavier_uniform_(self.output_proj.weight.data)
94 | constant_(self.output_proj.bias.data, 0.)
95 |
96 | def forward(self, query, reference_points, input_flatten, input_spatial_shapes, input_level_start_index, input_padding_mask=None):
97 | """
98 | :param query (N, Length_{query}, C)
99 | :param reference_points (N, Length_{query}, n_levels, 2), range in [0, 1], top-left (0,0), bottom-right (1, 1), including padding area
100 | or (N, Length_{query}, n_levels, 4), add additional (w, h) to form reference boxes
101 | :param input_flatten (N, \sum_{l=0}^{L-1} H_l \cdot W_l, C)
102 | :param input_spatial_shapes (n_levels, 2), [(H_0, W_0), (H_1, W_1), ..., (H_{L-1}, W_{L-1})]
103 | :param input_level_start_index (n_levels, ), [0, H_0*W_0, H_0*W_0+H_1*W_1, H_0*W_0+H_1*W_1+H_2*W_2, ..., H_0*W_0+H_1*W_1+...+H_{L-1}*W_{L-1}]
104 | :param input_padding_mask (N, \sum_{l=0}^{L-1} H_l \cdot W_l), True for padding elements, False for non-padding elements
105 |
106 | :return output (N, Length_{query}, C)
107 | """
108 | N, Len_q, _ = query.shape
109 | N, Len_in, _ = input_flatten.shape
110 | assert (input_spatial_shapes[:, 0] * input_spatial_shapes[:, 1] * input_spatial_shapes[:, 2]).sum() == Len_in
111 |
112 | value = self.value_proj(input_flatten)
113 | if input_padding_mask is not None:
114 | value = value.masked_fill(input_padding_mask[..., None], float(0))
115 | value = value.view(N, Len_in, self.n_heads, self.d_model // self.n_heads)
116 | ## using Tanh
117 | sampling_offsets = self.sampling_offsets(query).view(N, Len_q, self.n_heads, self.n_levels, self.n_points, 3)
118 |
119 | ## or using norm
120 | # sampling_offsets = self.s_norm(self.sampling_offsets(query)).view(N, Len_q, self.n_heads, self.n_levels, self.n_points, 3)
121 | # sampling_offsets = F.sigmoid(sampling_offsets)-0.5
122 |
123 | # attention_weights = self.attention_weights(query).view(N, Len_q, self.n_heads, self.n_levels * self.n_points)
124 | attention_weights = self.attention_weights(query).view(N, Len_q, self.n_heads, self.n_levels, self.n_points)
125 | attention_weights = F.softmax(attention_weights, -1).view(N, Len_q, self.n_heads, self.n_levels, self.n_points)
126 | # N, Len_q, n_heads, n_levels, n_points, 3
127 | if reference_points.shape[-1] == 3:
128 | offset_normalizer = torch.stack([input_spatial_shapes[..., 2], input_spatial_shapes[..., 2], input_spatial_shapes[..., 1]], -1)
129 | sampling_locations = reference_points[:, :, None, :, None, :] \
130 | + sampling_offsets / (offset_normalizer[None, None, None, :, None, :])
131 | # input_spatial_shapes_ratio = input_spatial_shapes//input_spatial_shapes[-1]
132 | # offset_normalizer = torch.stack([input_spatial_shapes_ratio[..., 0], input_spatial_shapes_ratio[..., 2], input_spatial_shapes_ratio[..., 1]], -1)
133 | # sampling_locations = reference_points[:, :, None, :, None, :] + sampling_offsets/offset_normalizer[None, None, None, :, None, :]
134 | elif reference_points.shape[-1] == 4:
135 | sampling_locations = reference_points[:, :, None, :, None, :2] \
136 | + sampling_offsets / self.n_points * reference_points[:, :, None, :, None, 2:] * 0.5
137 | else:
138 | raise ValueError(
139 | 'Last dim of reference_points must be 2 or 4, but get {} instead.'.format(reference_points.shape[-1]))
140 | # output = MSDeformAttnFunction.apply(
141 | # value, input_spatial_shapes, input_level_start_index, sampling_locations, attention_weights, self.im2col_step)
142 | # debug
143 | output = ms_deform_attn_core_pytorch_3D(value, input_spatial_shapes, sampling_locations,
144 | attention_weights)#.detach().cpu()
145 | # output_cuda = MSDeformAttnFunction.apply(value, input_spatial_shapes, input_level_start_index,
146 | # sampling_locations, attention_weights, self.im2col_step).detach().cpu()
147 |
148 | output = self.output_proj(output)
149 | self.atte_w = attention_weights
150 | self.samp_location = sampling_locations
151 | return output
152 |
153 |
--------------------------------------------------------------------------------
/TransDoD/models/ops/build/temp.linux-x86_64-3.7/.ninja_deps:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jianpengz/DoDNet/f06ce1b02988b6894a85f925623c6bcc6012284f/TransDoD/models/ops/build/temp.linux-x86_64-3.7/.ninja_deps
--------------------------------------------------------------------------------
/TransDoD/models/ops/build/temp.linux-x86_64-3.7/.ninja_log:
--------------------------------------------------------------------------------
1 | # ninja log v5
2 | 0 2669 1612246041636417765 /media/userdisk0/myproject-Det/Deformable-DETR/models/ops/build/temp.linux-x86_64-3.7/media/userdisk0/myproject-Det/Deformable-DETR/models/ops/src/cpu/ms_deform_attn_cpu.o fa8eca1803b7fff
3 | 0 14281 1612246053244838432 /media/userdisk0/myproject-Det/Deformable-DETR/models/ops/build/temp.linux-x86_64-3.7/media/userdisk0/myproject-Det/Deformable-DETR/models/ops/src/vision.o 95ace05e02c8713a
4 | 2 2898 1612246810056197491 /media/userdisk0/myproject-Det/Deformable-DETR/models/ops/build/temp.linux-x86_64-3.7/media/userdisk0/myproject-Det/Deformable-DETR/models/ops/src/cpu/ms_deform_attn_cpu.o aa390de0c9cbff6c
5 | 1 15432 1612246822584649734 /media/userdisk0/myproject-Det/Deformable-DETR/models/ops/build/temp.linux-x86_64-3.7/media/userdisk0/myproject-Det/Deformable-DETR/models/ops/src/vision.o b1d913aa26893e0e
6 | 23 2904 1612248891812574389 /media/userdisk0/myproject-Det/Deformable-DETR/models/ops/build/temp.linux-x86_64-3.7/media/userdisk0/myproject-Det/Deformable-DETR/models/ops/src/cpu/ms_deform_attn_cpu.o 57991895174314d7
7 | 23 9280 1612248898188696475 /media/userdisk0/myproject-Det/Deformable-DETR/models/ops/build/temp.linux-x86_64-3.7/media/userdisk0/myproject-Det/Deformable-DETR/models/ops/src/cuda/ms_deform_attn_cuda.o ef4370695c918cc3
8 | 22 14918 1612248903824806762 /media/userdisk0/myproject-Det/Deformable-DETR/models/ops/build/temp.linux-x86_64-3.7/media/userdisk0/myproject-Det/Deformable-DETR/models/ops/src/vision.o 5c6a42b954ea6aab
9 | 0 2911 1612369311440774386 /media/userdisk0/myproject-Seg/MOTS-pro2/f_transformers/models/ops/build/temp.linux-x86_64-3.7/media/userdisk0/myproject-Seg/MOTS-pro2/f_transformers/models/ops/src/cpu/ms_deform_attn_cpu.o aab11968d0c1acd7
10 | 1 9604 1612369318128880821 /media/userdisk0/myproject-Seg/MOTS-pro2/f_transformers/models/ops/build/temp.linux-x86_64-3.7/media/userdisk0/myproject-Seg/MOTS-pro2/f_transformers/models/ops/src/cuda/ms_deform_attn_cuda.o 51082b0e211b1655
11 | 0 15104 1612369323628968524 /media/userdisk0/myproject-Seg/MOTS-pro2/f_transformers/models/ops/build/temp.linux-x86_64-3.7/media/userdisk0/myproject-Seg/MOTS-pro2/f_transformers/models/ops/src/vision.o 8faaab1c9fecd08
12 | 2 2634 1612369362309589583 /media/userdisk0/myproject-Seg/MOTS-pro2/f_transformers/models/ops/build/temp.linux-x86_64-3.7/media/userdisk0/myproject-Seg/MOTS-pro2/f_transformers/models/ops/src/cpu/ms_deform_attn_cpu.o c13b919edaa3adcd
13 | 2 8953 1612369368625691682 /media/userdisk0/myproject-Seg/MOTS-pro2/f_transformers/models/ops/build/temp.linux-x86_64-3.7/media/userdisk0/myproject-Seg/MOTS-pro2/f_transformers/models/ops/src/cuda/ms_deform_attn_cuda.o 748c09fa10cb6ce0
14 | 1 13941 1612369373613772446 /media/userdisk0/myproject-Seg/MOTS-pro2/f_transformers/models/ops/build/temp.linux-x86_64-3.7/media/userdisk0/myproject-Seg/MOTS-pro2/f_transformers/models/ops/src/vision.o 1f9993e49b4dd542
15 | 1 4206 1646716265972410395 /media/userdisk1/jpzhang/myproject-Seg/MOTS-pro2/f_transformers/models/ops/build/temp.linux-x86_64-3.7/media/userdisk1/jpzhang/myproject-Seg/MOTS-pro2/f_transformers/models/ops/src/cpu/ms_deform_attn_cpu.o c77572b2cc994346
16 | 1 13715 1646716275476429471 /media/userdisk1/jpzhang/myproject-Seg/MOTS-pro2/f_transformers/models/ops/build/temp.linux-x86_64-3.7/media/userdisk1/jpzhang/myproject-Seg/MOTS-pro2/f_transformers/models/ops/src/cuda/ms_deform_attn_cuda.o 4379a47d34917b9
17 | 0 21653 1646716283416445500 /media/userdisk1/jpzhang/myproject-Seg/MOTS-pro2/f_transformers/models/ops/build/temp.linux-x86_64-3.7/media/userdisk1/jpzhang/myproject-Seg/MOTS-pro2/f_transformers/models/ops/src/vision.o 2d19fbfed58046ad
18 | 3 3829 1646752319151673811 /media/userdisk1/jpzhang/myproject-Seg/MOTS-pro2/f_transformers/models/ops/build/temp.linux-x86_64-3.7/media/userdisk1/jpzhang/myproject-Seg/MOTS-pro2/f_transformers/models/ops/src/cpu/ms_deform_attn_cpu.o 29f77dba4605cab5
19 | 4 13155 1646752328467758933 /media/userdisk1/jpzhang/myproject-Seg/MOTS-pro2/f_transformers/models/ops/build/temp.linux-x86_64-3.7/media/userdisk1/jpzhang/myproject-Seg/MOTS-pro2/f_transformers/models/ops/src/cuda/ms_deform_attn_cuda.o 9fa8fdf9cadaa81e
20 | 3 20006 1646752335319820952 /media/userdisk1/jpzhang/myproject-Seg/MOTS-pro2/f_transformers/models/ops/build/temp.linux-x86_64-3.7/media/userdisk1/jpzhang/myproject-Seg/MOTS-pro2/f_transformers/models/ops/src/vision.o 6720fc7b8b2a845d
21 |
--------------------------------------------------------------------------------
/TransDoD/models/ops/build/temp.linux-x86_64-3.7/build.ninja:
--------------------------------------------------------------------------------
1 | ninja_required_version = 1.3
2 | cxx = c++
3 | nvcc = /media/userdisk0/softwares/cuda-11.0/bin/nvcc
4 |
5 | cflags = -pthread -B /home/jpzhang/.pyenv/versions/anaconda3-4.4.0/envs/fast/compiler_compat -Wl,--sysroot=/ -Wsign-compare -DNDEBUG -g -fwrapv -O3 -Wall -Wstrict-prototypes -fPIC -DWITH_CUDA -I/media/userdisk1/jpzhang/myproject-Seg/MOTS-pro2/f_transformers/models/ops/src -I/home/jpzhang/.pyenv/versions/anaconda3-4.4.0/envs/fast/lib/python3.7/site-packages/torch/include -I/home/jpzhang/.pyenv/versions/anaconda3-4.4.0/envs/fast/lib/python3.7/site-packages/torch/include/torch/csrc/api/include -I/home/jpzhang/.pyenv/versions/anaconda3-4.4.0/envs/fast/lib/python3.7/site-packages/torch/include/TH -I/home/jpzhang/.pyenv/versions/anaconda3-4.4.0/envs/fast/lib/python3.7/site-packages/torch/include/THC -I/media/userdisk0/softwares/cuda-11.0/include -I/home/jpzhang/.pyenv/versions/anaconda3-4.4.0/envs/fast/include/python3.7m -c
6 | post_cflags = -DTORCH_API_INCLUDE_EXTENSION_H '-DPYBIND11_COMPILER_TYPE="_gcc"' '-DPYBIND11_STDLIB="_libstdcpp"' '-DPYBIND11_BUILD_ABI="_cxxabi1011"' -DTORCH_EXTENSION_NAME=MultiScaleDeformableAttention -D_GLIBCXX_USE_CXX11_ABI=1 -std=c++14
7 | cuda_cflags = -DWITH_CUDA -I/media/userdisk1/jpzhang/myproject-Seg/MOTS-pro2/f_transformers/models/ops/src -I/home/jpzhang/.pyenv/versions/anaconda3-4.4.0/envs/fast/lib/python3.7/site-packages/torch/include -I/home/jpzhang/.pyenv/versions/anaconda3-4.4.0/envs/fast/lib/python3.7/site-packages/torch/include/torch/csrc/api/include -I/home/jpzhang/.pyenv/versions/anaconda3-4.4.0/envs/fast/lib/python3.7/site-packages/torch/include/TH -I/home/jpzhang/.pyenv/versions/anaconda3-4.4.0/envs/fast/lib/python3.7/site-packages/torch/include/THC -I/media/userdisk0/softwares/cuda-11.0/include -I/home/jpzhang/.pyenv/versions/anaconda3-4.4.0/envs/fast/include/python3.7m -c
8 | cuda_post_cflags = -D__CUDA_NO_HALF_OPERATORS__ -D__CUDA_NO_HALF_CONVERSIONS__ -D__CUDA_NO_BFLOAT16_CONVERSIONS__ -D__CUDA_NO_HALF2_OPERATORS__ --expt-relaxed-constexpr --compiler-options ''"'"'-fPIC'"'"'' -DCUDA_HAS_FP16=1 -D__CUDA_NO_HALF_OPERATORS__ -D__CUDA_NO_HALF_CONVERSIONS__ -D__CUDA_NO_HALF2_OPERATORS__ -DTORCH_API_INCLUDE_EXTENSION_H '-DPYBIND11_COMPILER_TYPE="_gcc"' '-DPYBIND11_STDLIB="_libstdcpp"' '-DPYBIND11_BUILD_ABI="_cxxabi1011"' -DTORCH_EXTENSION_NAME=MultiScaleDeformableAttention -D_GLIBCXX_USE_CXX11_ABI=1 -gencode=arch=compute_75,code=compute_75 -gencode=arch=compute_75,code=sm_75 -std=c++14
9 | ldflags =
10 |
11 | rule compile
12 | command = $cxx -MMD -MF $out.d $cflags -c $in -o $out $post_cflags
13 | depfile = $out.d
14 | deps = gcc
15 |
16 | rule cuda_compile
17 | depfile = $out.d
18 | deps = gcc
19 | command = $nvcc $cuda_cflags -c $in -o $out $cuda_post_cflags
20 |
21 |
22 |
23 | build /media/userdisk1/jpzhang/myproject-Seg/MOTS-pro2/f_transformers/models/ops/build/temp.linux-x86_64-3.7/media/userdisk1/jpzhang/myproject-Seg/MOTS-pro2/f_transformers/models/ops/src/vision.o: compile /media/userdisk1/jpzhang/myproject-Seg/MOTS-pro2/f_transformers/models/ops/src/vision.cpp
24 | build /media/userdisk1/jpzhang/myproject-Seg/MOTS-pro2/f_transformers/models/ops/build/temp.linux-x86_64-3.7/media/userdisk1/jpzhang/myproject-Seg/MOTS-pro2/f_transformers/models/ops/src/cpu/ms_deform_attn_cpu.o: compile /media/userdisk1/jpzhang/myproject-Seg/MOTS-pro2/f_transformers/models/ops/src/cpu/ms_deform_attn_cpu.cpp
25 | build /media/userdisk1/jpzhang/myproject-Seg/MOTS-pro2/f_transformers/models/ops/build/temp.linux-x86_64-3.7/media/userdisk1/jpzhang/myproject-Seg/MOTS-pro2/f_transformers/models/ops/src/cuda/ms_deform_attn_cuda.o: cuda_compile /media/userdisk1/jpzhang/myproject-Seg/MOTS-pro2/f_transformers/models/ops/src/cuda/ms_deform_attn_cuda.cu
26 |
27 |
28 |
29 |
30 |
31 |
--------------------------------------------------------------------------------
/TransDoD/models/ops/build/temp.linux-x86_64-3.7/media/userdisk0/myproject-Det/Deformable-DETR/models/ops/src/cpu/ms_deform_attn_cpu.o:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jianpengz/DoDNet/f06ce1b02988b6894a85f925623c6bcc6012284f/TransDoD/models/ops/build/temp.linux-x86_64-3.7/media/userdisk0/myproject-Det/Deformable-DETR/models/ops/src/cpu/ms_deform_attn_cpu.o
--------------------------------------------------------------------------------
/TransDoD/models/ops/build/temp.linux-x86_64-3.7/media/userdisk0/myproject-Det/Deformable-DETR/models/ops/src/cuda/ms_deform_attn_cuda.o:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jianpengz/DoDNet/f06ce1b02988b6894a85f925623c6bcc6012284f/TransDoD/models/ops/build/temp.linux-x86_64-3.7/media/userdisk0/myproject-Det/Deformable-DETR/models/ops/src/cuda/ms_deform_attn_cuda.o
--------------------------------------------------------------------------------
/TransDoD/models/ops/build/temp.linux-x86_64-3.7/media/userdisk0/myproject-Det/Deformable-DETR/models/ops/src/vision.o:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jianpengz/DoDNet/f06ce1b02988b6894a85f925623c6bcc6012284f/TransDoD/models/ops/build/temp.linux-x86_64-3.7/media/userdisk0/myproject-Det/Deformable-DETR/models/ops/src/vision.o
--------------------------------------------------------------------------------
/TransDoD/models/ops/build/temp.linux-x86_64-3.7/media/userdisk0/myproject-Seg/MOTS-pro2/f_transformers/models/ops/src/cpu/ms_deform_attn_cpu.o:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jianpengz/DoDNet/f06ce1b02988b6894a85f925623c6bcc6012284f/TransDoD/models/ops/build/temp.linux-x86_64-3.7/media/userdisk0/myproject-Seg/MOTS-pro2/f_transformers/models/ops/src/cpu/ms_deform_attn_cpu.o
--------------------------------------------------------------------------------
/TransDoD/models/ops/build/temp.linux-x86_64-3.7/media/userdisk0/myproject-Seg/MOTS-pro2/f_transformers/models/ops/src/cuda/ms_deform_attn_cuda.o:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jianpengz/DoDNet/f06ce1b02988b6894a85f925623c6bcc6012284f/TransDoD/models/ops/build/temp.linux-x86_64-3.7/media/userdisk0/myproject-Seg/MOTS-pro2/f_transformers/models/ops/src/cuda/ms_deform_attn_cuda.o
--------------------------------------------------------------------------------
/TransDoD/models/ops/build/temp.linux-x86_64-3.7/media/userdisk0/myproject-Seg/MOTS-pro2/f_transformers/models/ops/src/vision.o:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jianpengz/DoDNet/f06ce1b02988b6894a85f925623c6bcc6012284f/TransDoD/models/ops/build/temp.linux-x86_64-3.7/media/userdisk0/myproject-Seg/MOTS-pro2/f_transformers/models/ops/src/vision.o
--------------------------------------------------------------------------------
/TransDoD/models/ops/build/temp.linux-x86_64-3.7/media/userdisk1/jpzhang/myproject-Seg/MOTS-pro2/f_transformers/models/ops/src/cpu/ms_deform_attn_cpu.o:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jianpengz/DoDNet/f06ce1b02988b6894a85f925623c6bcc6012284f/TransDoD/models/ops/build/temp.linux-x86_64-3.7/media/userdisk1/jpzhang/myproject-Seg/MOTS-pro2/f_transformers/models/ops/src/cpu/ms_deform_attn_cpu.o
--------------------------------------------------------------------------------
/TransDoD/models/ops/build/temp.linux-x86_64-3.7/media/userdisk1/jpzhang/myproject-Seg/MOTS-pro2/f_transformers/models/ops/src/cuda/ms_deform_attn_cuda.o:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jianpengz/DoDNet/f06ce1b02988b6894a85f925623c6bcc6012284f/TransDoD/models/ops/build/temp.linux-x86_64-3.7/media/userdisk1/jpzhang/myproject-Seg/MOTS-pro2/f_transformers/models/ops/src/cuda/ms_deform_attn_cuda.o
--------------------------------------------------------------------------------
/TransDoD/models/ops/build/temp.linux-x86_64-3.7/media/userdisk1/jpzhang/myproject-Seg/MOTS-pro2/f_transformers/models/ops/src/vision.o:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jianpengz/DoDNet/f06ce1b02988b6894a85f925623c6bcc6012284f/TransDoD/models/ops/build/temp.linux-x86_64-3.7/media/userdisk1/jpzhang/myproject-Seg/MOTS-pro2/f_transformers/models/ops/src/vision.o
--------------------------------------------------------------------------------
/TransDoD/models/ops/dist/MultiScaleDeformableAttention-1.0-py3.7-linux-x86_64.egg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jianpengz/DoDNet/f06ce1b02988b6894a85f925623c6bcc6012284f/TransDoD/models/ops/dist/MultiScaleDeformableAttention-1.0-py3.7-linux-x86_64.egg
--------------------------------------------------------------------------------
/TransDoD/models/ops/functions/__init__.py:
--------------------------------------------------------------------------------
1 | # ------------------------------------------------------------------------------------------------
2 | # Deformable DETR
3 | # Copyright (c) 2020 SenseTime. All Rights Reserved.
4 | # Licensed under the Apache License, Version 2.0 [see LICENSE for details]
5 | # ------------------------------------------------------------------------------------------------
6 | # Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
7 | # ------------------------------------------------------------------------------------------------
8 |
9 | from .ms_deform_attn_func import MSDeformAttnFunction
10 |
11 |
--------------------------------------------------------------------------------
/TransDoD/models/ops/functions/__pycache__/__init__.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jianpengz/DoDNet/f06ce1b02988b6894a85f925623c6bcc6012284f/TransDoD/models/ops/functions/__pycache__/__init__.cpython-37.pyc
--------------------------------------------------------------------------------
/TransDoD/models/ops/functions/__pycache__/ms_deform_attn_func.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jianpengz/DoDNet/f06ce1b02988b6894a85f925623c6bcc6012284f/TransDoD/models/ops/functions/__pycache__/ms_deform_attn_func.cpython-37.pyc
--------------------------------------------------------------------------------
/TransDoD/models/ops/functions/ms_deform_attn_func.py:
--------------------------------------------------------------------------------
1 | # ------------------------------------------------------------------------------------------------
2 | # Deformable DETR
3 | # Copyright (c) 2020 SenseTime. All Rights Reserved.
4 | # Licensed under the Apache License, Version 2.0 [see LICENSE for details]
5 | # ------------------------------------------------------------------------------------------------
6 | # Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
7 | # ------------------------------------------------------------------------------------------------
8 |
9 | from __future__ import absolute_import
10 | from __future__ import print_function
11 | from __future__ import division
12 |
13 | import torch
14 | import torch.nn.functional as F
15 | from torch.autograd import Function
16 | from torch.autograd.function import once_differentiable
17 |
18 | import MultiScaleDeformableAttention as MSDA
19 |
20 |
21 | class MSDeformAttnFunction(Function):
22 | @staticmethod
23 | def forward(ctx, value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights, im2col_step):
24 | ctx.im2col_step = im2col_step
25 | output = MSDA.ms_deform_attn_forward(
26 | value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights, ctx.im2col_step)
27 | ctx.save_for_backward(value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights)
28 | return output
29 |
30 | @staticmethod
31 | @once_differentiable
32 | def backward(ctx, grad_output):
33 | value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights = ctx.saved_tensors
34 | grad_value, grad_sampling_loc, grad_attn_weight = \
35 | MSDA.ms_deform_attn_backward(
36 | value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights, grad_output, ctx.im2col_step)
37 |
38 | return grad_value, None, None, grad_sampling_loc, grad_attn_weight, None
39 |
40 |
41 | def ms_deform_attn_core_pytorch(value, value_spatial_shapes, sampling_locations, attention_weights):
42 | # for debug and test only,
43 | # need to use cuda version instead
44 | N_, S_, M_, D_ = value.shape
45 | _, Lq_, M_, L_, P_, _ = sampling_locations.shape
46 | value_list = value.split([H_ * W_ for H_, W_ in value_spatial_shapes], dim=1)
47 | sampling_grids = 2 * sampling_locations - 1
48 | sampling_value_list = []
49 | for lid_, (H_, W_) in enumerate(value_spatial_shapes):
50 | # N_, H_*W_, M_, D_ -> N_, H_*W_, M_*D_ -> N_, M_*D_, H_*W_ -> N_*M_, D_, H_, W_
51 | value_l_ = value_list[lid_].flatten(2).transpose(1, 2).reshape(N_*M_, D_, H_, W_)
52 | # N_, Lq_, M_, P_, 2 -> N_, M_, Lq_, P_, 2 -> N_*M_, Lq_, P_, 2
53 | sampling_grid_l_ = sampling_grids[:, :, :, lid_].transpose(1, 2).flatten(0, 1)
54 | # N_*M_, D_, Lq_, P_
55 | sampling_value_l_ = F.grid_sample(value_l_, sampling_grid_l_,
56 | mode='bilinear', padding_mode='zeros', align_corners=False)
57 | sampling_value_list.append(sampling_value_l_)
58 | # (N_, Lq_, M_, L_, P_) -> (N_, M_, Lq_, L_, P_) -> (N_, M_, 1, Lq_, L_*P_)
59 | attention_weights = attention_weights.transpose(1, 2).reshape(N_*M_, 1, Lq_, L_*P_)
60 | output = (torch.stack(sampling_value_list, dim=-2).flatten(-2) * attention_weights).sum(-1).view(N_, M_*D_, Lq_)
61 | return output.transpose(1, 2).contiguous()
62 |
63 | def ms_deform_attn_core_pytorch_3D(value, value_spatial_shapes, sampling_locations, attention_weights):
64 | # for debug and test only,
65 | # need to use cuda version instead
66 | N_, S_, M_, D_ = value.shape
67 | _, Lq_, M_, L_, P_, _ = sampling_locations.shape
68 | value_list = value.split([T_ * H_ * W_ for T_, H_, W_ in value_spatial_shapes], dim=1)
69 | sampling_grids = 2 * sampling_locations - 1
70 | sampling_value_list = []
71 | for lid_, (T_, H_, W_) in enumerate(value_spatial_shapes):
72 | # N_, T_*H_*W_, M_, D_ -> N_, T_*H_*W_, M_*D_ -> N_, M_*D_, T_*H_*W_ -> N_*M_, D_, T_, H_, W_
73 | value_l_ = value_list[lid_].flatten(2).transpose(1, 2).reshape(N_*M_, D_, T_, H_, W_)
74 | # N_, Lq_, M_, P_, 3 -> N_, M_, Lq_, P_, 3 -> N_*M_, Lq_, P_, 3
75 | sampling_grid_l_ = sampling_grids[:, :, :, lid_].transpose(1, 2).flatten(0, 1)[:,None,:,:,:]
76 | # N_*M_, D_, Lq_, P_
77 | sampling_value_l_ = F.grid_sample(value_l_, sampling_grid_l_, mode='bilinear', padding_mode='zeros', align_corners=False)[:,:,0]
78 | sampling_value_list.append(sampling_value_l_)
79 | # (N_, Lq_, M_, L_, P_) -> (N_, M_, Lq_, L_, P_) -> (N_, M_, 1, Lq_, L_*P_)
80 | attention_weights = attention_weights.transpose(1, 2).reshape(N_*M_, 1, Lq_, L_*P_)
81 | output = (torch.stack(sampling_value_list, dim=-2).flatten(-2) * attention_weights).sum(-1).view(N_, M_*D_, Lq_)
82 | return output.transpose(1, 2).contiguous()
83 |
--------------------------------------------------------------------------------
/TransDoD/models/ops/make.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 | # ------------------------------------------------------------------------------------------------
3 | # Deformable DETR
4 | # Copyright (c) 2020 SenseTime. All Rights Reserved.
5 | # Licensed under the Apache License, Version 2.0 [see LICENSE for details]
6 | # ------------------------------------------------------------------------------------------------
7 | # Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
8 | # ------------------------------------------------------------------------------------------------
9 |
10 | python setup.py build install
11 |
--------------------------------------------------------------------------------
/TransDoD/models/ops/modules/__init__.py:
--------------------------------------------------------------------------------
1 | # ------------------------------------------------------------------------------------------------
2 | # Deformable DETR
3 | # Copyright (c) 2020 SenseTime. All Rights Reserved.
4 | # Licensed under the Apache License, Version 2.0 [see LICENSE for details]
5 | # ------------------------------------------------------------------------------------------------
6 | # Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
7 | # ------------------------------------------------------------------------------------------------
8 |
9 | from .ms_deform_attn import MSDeformAttn
10 |
--------------------------------------------------------------------------------
/TransDoD/models/ops/modules/__pycache__/__init__.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jianpengz/DoDNet/f06ce1b02988b6894a85f925623c6bcc6012284f/TransDoD/models/ops/modules/__pycache__/__init__.cpython-37.pyc
--------------------------------------------------------------------------------
/TransDoD/models/ops/modules/__pycache__/ms_deform_attn.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jianpengz/DoDNet/f06ce1b02988b6894a85f925623c6bcc6012284f/TransDoD/models/ops/modules/__pycache__/ms_deform_attn.cpython-37.pyc
--------------------------------------------------------------------------------
/TransDoD/models/ops/modules/ms_deform_attn (33).py:
--------------------------------------------------------------------------------
1 | # ------------------------------------------------------------------------------------------------
2 | # Deformable DETR
3 | # Copyright (c) 2020 SenseTime. All Rights Reserved.
4 | # Licensed under the Apache License, Version 2.0 [see LICENSE for details]
5 | # ------------------------------------------------------------------------------------------------
6 | # Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
7 | # ------------------------------------------------------------------------------------------------
8 |
9 | from __future__ import absolute_import
10 | from __future__ import print_function
11 | from __future__ import division
12 |
13 | import warnings
14 | import math
15 |
16 | import torch
17 | from torch import nn
18 | import torch.nn.functional as F
19 | from torch.nn.init import xavier_uniform_, constant_
20 |
21 | from ..functions import MSDeformAttnFunction
22 | from ..functions.ms_deform_attn_func import MSDeformAttnFunction, ms_deform_attn_core_pytorch_3D
23 |
24 | def _is_power_of_2(n):
25 | if (not isinstance(n, int)) or (n < 0):
26 | raise ValueError("invalid input for _is_power_of_2: {} (type: {})".format(n, type(n)))
27 | return (n & (n-1) == 0) and n != 0
28 |
29 |
30 | class MSDeformAttn(nn.Module):
31 | def __init__(self, d_model=256, n_levels=4, n_heads=8, n_points=4):
32 | """
33 | Multi-Scale Deformable Attention Module
34 | :param d_model hidden dimension
35 | :param n_levels number of feature levels
36 | :param n_heads number of attention heads
37 | :param n_points number of sampling points per attention head per feature level
38 | """
39 | super().__init__()
40 | if d_model % n_heads != 0:
41 | raise ValueError('d_model must be divisible by n_heads, but got {} and {}'.format(d_model, n_heads))
42 | _d_per_head = d_model // n_heads
43 | # you'd better set _d_per_head to a power of 2 which is more efficient in our CUDA implementation
44 | if not _is_power_of_2(_d_per_head):
45 | warnings.warn("You'd better set d_model in MSDeformAttn to make the dimension of each attention head a power of 2 "
46 | "which is more efficient in our CUDA implementation.")
47 |
48 | self.im2col_step = 64
49 |
50 | self.d_model = d_model
51 | self.n_levels = n_levels
52 | self.n_heads = n_heads
53 | self.n_points = n_points
54 |
55 | self.sampling_offsets = nn.Linear(d_model, n_heads * n_levels * n_points * 3)
56 | self.attention_weights = nn.Linear(d_model, n_heads * n_levels * n_points)
57 | self.value_proj = nn.Linear(d_model, d_model)
58 | self.output_proj = nn.Linear(d_model, d_model)
59 |
60 | # self.s_norm = nn.LayerNorm(n_heads * n_levels * n_points * 3)
61 |
62 | self._reset_parameters()
63 |
64 | # add for vis
65 | self.samp_location = []
66 | self.atte_w = []
67 |
68 | def _reset_parameters(self):
69 | constant_(self.sampling_offsets.weight.data, 0.)
70 | thetas = torch.arange(self.n_heads, dtype=torch.float32) * (2.0 * math.pi / self.n_heads)
71 | # grid_init = torch.stack([thetas.sin()*thetas.cos(), thetas.sin()*thetas.sin(), thetas.cos()], -1)
72 | # grid_init = torch.stack([thetas.sin() * thetas.cos(), thetas.sin(), thetas.cos()], -1)
73 | # grid_init = torch.stack([thetas.sin() * thetas.cos(), thetas.cos(), thetas.sin()], -1)
74 | grid_init = torch.stack([2*thetas.sin()*thetas.cos(), thetas.cos(), thetas.sin()], -1)
75 |
76 | # thetas = torch.arange(self.n_heads, dtype=torch.float32) * (2.0 * math.pi / self.n_heads)
77 | # beta = torch.arange(-self.n_heads//2, self.n_heads//2, dtype=torch.float32) * (1.0 * math.pi / self.n_heads)
78 | # grid_init = torch.stack([beta.sin(), beta.cos() * thetas.cos(), beta.cos() * thetas.sin()], -1)
79 |
80 | # thetas = torch.arange(self.n_heads, dtype=torch.float32) * (2.0 * math.pi / self.n_heads)
81 | # beta = torch.arange(self.n_heads, dtype=torch.float32) * (2.0 * math.pi / self.n_heads)
82 | # grid_init = torch.stack([beta.sin() * thetas.sin(), beta.cos(), beta.sin() * thetas.cos()], -1)
83 |
84 | grid_init = (grid_init / grid_init.abs().max(-1, keepdim=True)[0]).view(self.n_heads, 1, 1, 3).repeat(1, self.n_levels, self.n_points, 1)
85 | for i in range(self.n_points):
86 | grid_init[:, :, i, :] *= i + 1
87 | with torch.no_grad():
88 | self.sampling_offsets.bias = nn.Parameter(grid_init.view(-1))
89 | constant_(self.attention_weights.weight.data, 0.)
90 | constant_(self.attention_weights.bias.data, 0.)
91 | xavier_uniform_(self.value_proj.weight.data)
92 | constant_(self.value_proj.bias.data, 0.)
93 | xavier_uniform_(self.output_proj.weight.data)
94 | constant_(self.output_proj.bias.data, 0.)
95 |
96 | def forward(self, query, reference_points, input_flatten, input_spatial_shapes, input_level_start_index, input_padding_mask=None):
97 | """
98 | :param query (N, Length_{query}, C)
99 | :param reference_points (N, Length_{query}, n_levels, 2), range in [0, 1], top-left (0,0), bottom-right (1, 1), including padding area
100 | or (N, Length_{query}, n_levels, 4), add additional (w, h) to form reference boxes
101 | :param input_flatten (N, \sum_{l=0}^{L-1} H_l \cdot W_l, C)
102 | :param input_spatial_shapes (n_levels, 2), [(H_0, W_0), (H_1, W_1), ..., (H_{L-1}, W_{L-1})]
103 | :param input_level_start_index (n_levels, ), [0, H_0*W_0, H_0*W_0+H_1*W_1, H_0*W_0+H_1*W_1+H_2*W_2, ..., H_0*W_0+H_1*W_1+...+H_{L-1}*W_{L-1}]
104 | :param input_padding_mask (N, \sum_{l=0}^{L-1} H_l \cdot W_l), True for padding elements, False for non-padding elements
105 |
106 | :return output (N, Length_{query}, C)
107 | """
108 | N, Len_q, _ = query.shape
109 | N, Len_in, _ = input_flatten.shape
110 | assert (input_spatial_shapes[:, 0] * input_spatial_shapes[:, 1] * input_spatial_shapes[:, 2]).sum() == Len_in
111 |
112 | value = self.value_proj(input_flatten)
113 | if input_padding_mask is not None:
114 | value = value.masked_fill(input_padding_mask[..., None], float(0))
115 | value = value.view(N, Len_in, self.n_heads, self.d_model // self.n_heads)
116 | ## using Tanh
117 | sampling_offsets = self.sampling_offsets(query).view(N, Len_q, self.n_heads, self.n_levels, self.n_points, 3)
118 | # sampling_offsets = F.tanh(sampling_offsets)
119 |
120 | ## or using norm
121 | # sampling_offsets = self.s_norm(self.sampling_offsets(query)).view(N, Len_q, self.n_heads, self.n_levels, self.n_points, 3)
122 |
123 | attention_weights = self.attention_weights(query).view(N, Len_q, self.n_heads, self.n_levels * self.n_points)
124 | attention_weights = F.softmax(attention_weights, -1).view(N, Len_q, self.n_heads, self.n_levels, self.n_points)
125 | # N, Len_q, n_heads, n_levels, n_points, 3
126 | if reference_points.shape[-1] == 3:
127 | offset_normalizer = torch.stack([input_spatial_shapes[..., 2], input_spatial_shapes[..., 2], input_spatial_shapes[..., 1]], -1)
128 | sampling_locations = reference_points[:, :, None, :, None, :] \
129 | + sampling_offsets / (2*offset_normalizer[None, None, None, :, None, :])
130 | # sampling_locations = reference_points[:, :, None, :, None, :] + sampling_offsets
131 | elif reference_points.shape[-1] == 4:
132 | sampling_locations = reference_points[:, :, None, :, None, :2] \
133 | + sampling_offsets / self.n_points * reference_points[:, :, None, :, None, 2:] * 0.5
134 | else:
135 | raise ValueError(
136 | 'Last dim of reference_points must be 2 or 4, but get {} instead.'.format(reference_points.shape[-1]))
137 | # output = MSDeformAttnFunction.apply(
138 | # value, input_spatial_shapes, input_level_start_index, sampling_locations, attention_weights, self.im2col_step)
139 | # debug
140 | output = ms_deform_attn_core_pytorch_3D(value, input_spatial_shapes, sampling_locations,
141 | attention_weights)#.detach().cpu()
142 | # output_cuda = MSDeformAttnFunction.apply(value, input_spatial_shapes, input_level_start_index,
143 | # sampling_locations, attention_weights, self.im2col_step).detach().cpu()
144 |
145 | output = self.output_proj(output)
146 | self.atte_w = attention_weights
147 | self.samp_location = sampling_locations
148 | return output
149 |
150 |
151 |
--------------------------------------------------------------------------------
/TransDoD/models/ops/modules/ms_deform_attn.py:
--------------------------------------------------------------------------------
1 | # ------------------------------------------------------------------------------------------------
2 | # Deformable DETR
3 | # Copyright (c) 2020 SenseTime. All Rights Reserved.
4 | # Licensed under the Apache License, Version 2.0 [see LICENSE for details]
5 | # ------------------------------------------------------------------------------------------------
6 | # Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
7 | # ------------------------------------------------------------------------------------------------
8 |
9 | from __future__ import absolute_import
10 | from __future__ import print_function
11 | from __future__ import division
12 |
13 | import warnings
14 | import math
15 |
16 | import torch
17 | from torch import nn
18 | import torch.nn.functional as F
19 | from torch.nn.init import xavier_uniform_, constant_
20 |
21 | from ..functions import MSDeformAttnFunction
22 | from ..functions.ms_deform_attn_func import MSDeformAttnFunction, ms_deform_attn_core_pytorch_3D
23 |
24 | def _is_power_of_2(n):
25 | if (not isinstance(n, int)) or (n < 0):
26 | raise ValueError("invalid input for _is_power_of_2: {} (type: {})".format(n, type(n)))
27 | return (n & (n-1) == 0) and n != 0
28 |
29 |
30 | class MSDeformAttn(nn.Module):
31 | def __init__(self, d_model=256, n_levels=4, n_heads=8, n_points=4):
32 | """
33 | Multi-Scale Deformable Attention Module
34 | :param d_model hidden dimension
35 | :param n_levels number of feature levels
36 | :param n_heads number of attention heads
37 | :param n_points number of sampling points per attention head per feature level
38 | """
39 | super().__init__()
40 | if d_model % n_heads != 0:
41 | raise ValueError('d_model must be divisible by n_heads, but got {} and {}'.format(d_model, n_heads))
42 | _d_per_head = d_model // n_heads
43 | # you'd better set _d_per_head to a power of 2 which is more efficient in our CUDA implementation
44 | if not _is_power_of_2(_d_per_head):
45 | warnings.warn("You'd better set d_model in MSDeformAttn to make the dimension of each attention head a power of 2 "
46 | "which is more efficient in our CUDA implementation.")
47 |
48 | self.im2col_step = 64
49 |
50 | self.d_model = d_model
51 | self.n_levels = n_levels
52 | self.n_heads = n_heads
53 | self.n_points = n_points
54 |
55 | self.sampling_offsets = nn.Linear(d_model, n_heads * n_levels * n_points * 3)
56 | self.attention_weights = nn.Linear(d_model, n_heads * n_levels * n_points)
57 | self.value_proj = nn.Linear(d_model, d_model)
58 | self.output_proj = nn.Linear(d_model, d_model)
59 |
60 | # self.s_norm = nn.LayerNorm(n_heads * n_levels * n_points * 3)
61 |
62 | inx = 16
63 | self.offset_normalizer = torch.tensor([[inx * (2 ** (i - 1)), inx * (2 ** (i - 1)), inx * (2 ** (i - 1))] for i in
64 | range(self.n_levels, 0, -1)]).cuda()
65 |
66 | self._reset_parameters()
67 |
68 | # add for vis
69 | self.samp_location = []
70 | self.atte_w = []
71 |
72 | def _reset_parameters(self):
73 | constant_(self.sampling_offsets.weight.data, 0.)
74 | thetas = torch.arange(self.n_heads, dtype=torch.float32) * (2.0 * math.pi / self.n_heads)
75 | grid_init = torch.stack([thetas.sin()*thetas.cos(), thetas.cos(), thetas.sin()], -1)
76 | grid_init = (grid_init / grid_init.abs().max(-1, keepdim=True)[0]).view(self.n_heads, 1, 1, 3).repeat(1, self.n_levels, self.n_points, 1)
77 | for i in range(self.n_points):
78 | grid_init[:, :, i, :] *= i + 1
79 | with torch.no_grad():
80 | self.sampling_offsets.bias = nn.Parameter(grid_init.view(-1))
81 | constant_(self.attention_weights.weight.data, 0.)
82 | constant_(self.attention_weights.bias.data, 0.)
83 | xavier_uniform_(self.value_proj.weight.data)
84 | constant_(self.value_proj.bias.data, 0.)
85 | xavier_uniform_(self.output_proj.weight.data)
86 | constant_(self.output_proj.bias.data, 0.)
87 |
88 | def forward(self, query, reference_points, input_flatten, input_spatial_shapes, input_level_start_index, input_padding_mask=None):
89 | N, Len_q, _ = query.shape
90 | N, Len_in, _ = input_flatten.shape
91 | assert (input_spatial_shapes[:, 0] * input_spatial_shapes[:, 1] * input_spatial_shapes[:, 2]).sum() == Len_in
92 |
93 | value = self.value_proj(input_flatten)
94 | if input_padding_mask is not None:
95 | value = value.masked_fill(input_padding_mask[..., None], float(0))
96 | value = value.view(N, Len_in, self.n_heads, self.d_model // self.n_heads)
97 | sampling_offsets = self.sampling_offsets(query).view(N, Len_q, self.n_heads, self.n_levels, self.n_points, 3)
98 |
99 | attention_weights = self.attention_weights(query).view(N, Len_q, self.n_heads, self.n_levels * self.n_points) # v0
100 | attention_weights = F.softmax(attention_weights, -1).view(N, Len_q, self.n_heads, self.n_levels, self.n_points)
101 | if reference_points.shape[-1] == 3:
102 | sampling_locations = reference_points[:, :, None, :, None, :] \
103 | + sampling_offsets / (self.offset_normalizer[None, None, None, :, None, :])
104 | elif reference_points.shape[-1] == 4:
105 | sampling_locations = reference_points[:, :, None, :, None, :2] \
106 | + sampling_offsets / self.n_points * reference_points[:, :, None, :, None, 2:] * 0.5
107 | else:
108 | raise ValueError(
109 | 'Last dim of reference_points must be 2 or 4, but get {} instead.'.format(reference_points.shape[-1]))
110 | output = ms_deform_attn_core_pytorch_3D(value, input_spatial_shapes, sampling_locations,
111 | attention_weights)
112 | output = self.output_proj(output)
113 | self.atte_w = attention_weights
114 | self.samp_location = sampling_locations
115 | return output
--------------------------------------------------------------------------------
/TransDoD/models/ops/setup.py:
--------------------------------------------------------------------------------
1 | # ------------------------------------------------------------------------------------------------
2 | # Deformable DETR
3 | # Copyright (c) 2020 SenseTime. All Rights Reserved.
4 | # Licensed under the Apache License, Version 2.0 [see LICENSE for details]
5 | # ------------------------------------------------------------------------------------------------
6 | # Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
7 | # ------------------------------------------------------------------------------------------------
8 |
9 | import os
10 | import glob
11 |
12 | import torch
13 |
14 | from torch.utils.cpp_extension import CUDA_HOME
15 | from torch.utils.cpp_extension import CppExtension
16 | from torch.utils.cpp_extension import CUDAExtension
17 |
18 | from setuptools import find_packages
19 | from setuptools import setup
20 |
21 | requirements = ["torch", "torchvision"]
22 |
23 | def get_extensions():
24 | this_dir = os.path.dirname(os.path.abspath(__file__))
25 | extensions_dir = os.path.join(this_dir, "src")
26 |
27 | main_file = glob.glob(os.path.join(extensions_dir, "*.cpp"))
28 | source_cpu = glob.glob(os.path.join(extensions_dir, "cpu", "*.cpp"))
29 | source_cuda = glob.glob(os.path.join(extensions_dir, "cuda", "*.cu"))
30 |
31 | sources = main_file + source_cpu
32 | extension = CppExtension
33 | extra_compile_args = {"cxx": []}
34 | define_macros = []
35 |
36 | if torch.cuda.is_available() and CUDA_HOME is not None:
37 | extension = CUDAExtension
38 | sources += source_cuda
39 | define_macros += [("WITH_CUDA", None)]
40 | extra_compile_args["nvcc"] = [
41 | "-DCUDA_HAS_FP16=1",
42 | "-D__CUDA_NO_HALF_OPERATORS__",
43 | "-D__CUDA_NO_HALF_CONVERSIONS__",
44 | "-D__CUDA_NO_HALF2_OPERATORS__",
45 | ]
46 | else:
47 | raise NotImplementedError('Cuda is not availabel')
48 |
49 | sources = [os.path.join(extensions_dir, s) for s in sources]
50 | include_dirs = [extensions_dir]
51 | ext_modules = [
52 | extension(
53 | "MultiScaleDeformableAttention",
54 | sources,
55 | include_dirs=include_dirs,
56 | define_macros=define_macros,
57 | extra_compile_args=extra_compile_args,
58 | )
59 | ]
60 | return ext_modules
61 |
62 | setup(
63 | name="MultiScaleDeformableAttention",
64 | version="1.0",
65 | author="Weijie Su",
66 | url="https://github.com/fundamentalvision/Deformable-DETR",
67 | description="PyTorch Wrapper for CUDA Functions of Multi-Scale Deformable Attention",
68 | packages=find_packages(exclude=("configs", "tests",)),
69 | ext_modules=get_extensions(),
70 | cmdclass={"build_ext": torch.utils.cpp_extension.BuildExtension},
71 | )
72 |
--------------------------------------------------------------------------------
/TransDoD/models/ops/src/cpu/ms_deform_attn_cpu.cpp:
--------------------------------------------------------------------------------
1 | /*!
2 | **************************************************************************************************
3 | * Deformable DETR
4 | * Copyright (c) 2020 SenseTime. All Rights Reserved.
5 | * Licensed under the Apache License, Version 2.0 [see LICENSE for details]
6 | **************************************************************************************************
7 | * Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
8 | **************************************************************************************************
9 | */
10 |
11 | #include
12 |
13 | #include
14 | #include
15 |
16 |
17 | at::Tensor
18 | ms_deform_attn_cpu_forward(
19 | const at::Tensor &value,
20 | const at::Tensor &spatial_shapes,
21 | const at::Tensor &level_start_index,
22 | const at::Tensor &sampling_loc,
23 | const at::Tensor &attn_weight,
24 | const int im2col_step)
25 | {
26 | AT_ERROR("Not implement on cpu");
27 | }
28 |
29 | std::vector
30 | ms_deform_attn_cpu_backward(
31 | const at::Tensor &value,
32 | const at::Tensor &spatial_shapes,
33 | const at::Tensor &level_start_index,
34 | const at::Tensor &sampling_loc,
35 | const at::Tensor &attn_weight,
36 | const at::Tensor &grad_output,
37 | const int im2col_step)
38 | {
39 | AT_ERROR("Not implement on cpu");
40 | }
41 |
42 |
--------------------------------------------------------------------------------
/TransDoD/models/ops/src/cpu/ms_deform_attn_cpu.h:
--------------------------------------------------------------------------------
1 | /*!
2 | **************************************************************************************************
3 | * Deformable DETR
4 | * Copyright (c) 2020 SenseTime. All Rights Reserved.
5 | * Licensed under the Apache License, Version 2.0 [see LICENSE for details]
6 | **************************************************************************************************
7 | * Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
8 | **************************************************************************************************
9 | */
10 |
11 | #pragma once
12 | #include
13 |
14 | at::Tensor
15 | ms_deform_attn_cpu_forward(
16 | const at::Tensor &value,
17 | const at::Tensor &spatial_shapes,
18 | const at::Tensor &level_start_index,
19 | const at::Tensor &sampling_loc,
20 | const at::Tensor &attn_weight,
21 | const int im2col_step);
22 |
23 | std::vector
24 | ms_deform_attn_cpu_backward(
25 | const at::Tensor &value,
26 | const at::Tensor &spatial_shapes,
27 | const at::Tensor &level_start_index,
28 | const at::Tensor &sampling_loc,
29 | const at::Tensor &attn_weight,
30 | const at::Tensor &grad_output,
31 | const int im2col_step);
32 |
33 |
34 |
--------------------------------------------------------------------------------
/TransDoD/models/ops/src/cuda/ms_deform_attn_cuda.cu:
--------------------------------------------------------------------------------
1 | /*!
2 | **************************************************************************************************
3 | * Deformable DETR
4 | * Copyright (c) 2020 SenseTime. All Rights Reserved.
5 | * Licensed under the Apache License, Version 2.0 [see LICENSE for details]
6 | **************************************************************************************************
7 | * Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
8 | **************************************************************************************************
9 | */
10 |
11 | #include
12 | #include "cuda/ms_deform_im2col_cuda.cuh"
13 |
14 | #include
15 | #include
16 | #include
17 | #include
18 |
19 |
20 | at::Tensor ms_deform_attn_cuda_forward(
21 | const at::Tensor &value,
22 | const at::Tensor &spatial_shapes,
23 | const at::Tensor &level_start_index,
24 | const at::Tensor &sampling_loc,
25 | const at::Tensor &attn_weight,
26 | const int im2col_step)
27 | {
28 | AT_ASSERTM(value.is_contiguous(), "value tensor has to be contiguous");
29 | AT_ASSERTM(spatial_shapes.is_contiguous(), "spatial_shapes tensor has to be contiguous");
30 | AT_ASSERTM(level_start_index.is_contiguous(), "level_start_index tensor has to be contiguous");
31 | AT_ASSERTM(sampling_loc.is_contiguous(), "sampling_loc tensor has to be contiguous");
32 | AT_ASSERTM(attn_weight.is_contiguous(), "attn_weight tensor has to be contiguous");
33 |
34 | AT_ASSERTM(value.type().is_cuda(), "value must be a CUDA tensor");
35 | AT_ASSERTM(spatial_shapes.type().is_cuda(), "spatial_shapes must be a CUDA tensor");
36 | AT_ASSERTM(level_start_index.type().is_cuda(), "level_start_index must be a CUDA tensor");
37 | AT_ASSERTM(sampling_loc.type().is_cuda(), "sampling_loc must be a CUDA tensor");
38 | AT_ASSERTM(attn_weight.type().is_cuda(), "attn_weight must be a CUDA tensor");
39 |
40 | const int batch = value.size(0);
41 | const int spatial_size = value.size(1);
42 | const int num_heads = value.size(2);
43 | const int channels = value.size(3);
44 |
45 | const int num_levels = spatial_shapes.size(0);
46 |
47 | const int num_query = sampling_loc.size(1);
48 | const int num_point = sampling_loc.size(4);
49 |
50 | const int im2col_step_ = std::min(batch, im2col_step);
51 |
52 | AT_ASSERTM(batch % im2col_step_ == 0, "batch(%d) must divide im2col_step(%d)", batch, im2col_step_);
53 |
54 | auto output = at::zeros({batch, num_query, num_heads, channels}, value.options());
55 |
56 | const int batch_n = im2col_step_;
57 | auto output_n = output.view({batch/im2col_step_, batch_n, num_query, num_heads, channels});
58 | auto per_value_size = spatial_size * num_heads * channels;
59 | auto per_sample_loc_size = num_query * num_heads * num_levels * num_point * 2;
60 | auto per_attn_weight_size = num_query * num_heads * num_levels * num_point;
61 | for (int n = 0; n < batch/im2col_step_; ++n)
62 | {
63 | auto columns = output_n.select(0, n);
64 | AT_DISPATCH_FLOATING_TYPES(value.type(), "ms_deform_attn_forward_cuda", ([&] {
65 | ms_deformable_im2col_cuda(at::cuda::getCurrentCUDAStream(),
66 | value.data() + n * im2col_step_ * per_value_size,
67 | spatial_shapes.data(),
68 | level_start_index.data(),
69 | sampling_loc.data() + n * im2col_step_ * per_sample_loc_size,
70 | attn_weight.data() + n * im2col_step_ * per_attn_weight_size,
71 | batch_n, spatial_size, num_heads, channels, num_levels, num_query, num_point,
72 | columns.data());
73 |
74 | }));
75 | }
76 |
77 | output = output.view({batch, num_query, num_heads*channels});
78 |
79 | return output;
80 | }
81 |
82 |
83 | std::vector ms_deform_attn_cuda_backward(
84 | const at::Tensor &value,
85 | const at::Tensor &spatial_shapes,
86 | const at::Tensor &level_start_index,
87 | const at::Tensor &sampling_loc,
88 | const at::Tensor &attn_weight,
89 | const at::Tensor &grad_output,
90 | const int im2col_step)
91 | {
92 |
93 | AT_ASSERTM(value.is_contiguous(), "value tensor has to be contiguous");
94 | AT_ASSERTM(spatial_shapes.is_contiguous(), "spatial_shapes tensor has to be contiguous");
95 | AT_ASSERTM(level_start_index.is_contiguous(), "level_start_index tensor has to be contiguous");
96 | AT_ASSERTM(sampling_loc.is_contiguous(), "sampling_loc tensor has to be contiguous");
97 | AT_ASSERTM(attn_weight.is_contiguous(), "attn_weight tensor has to be contiguous");
98 | AT_ASSERTM(grad_output.is_contiguous(), "grad_output tensor has to be contiguous");
99 |
100 | AT_ASSERTM(value.type().is_cuda(), "value must be a CUDA tensor");
101 | AT_ASSERTM(spatial_shapes.type().is_cuda(), "spatial_shapes must be a CUDA tensor");
102 | AT_ASSERTM(level_start_index.type().is_cuda(), "level_start_index must be a CUDA tensor");
103 | AT_ASSERTM(sampling_loc.type().is_cuda(), "sampling_loc must be a CUDA tensor");
104 | AT_ASSERTM(attn_weight.type().is_cuda(), "attn_weight must be a CUDA tensor");
105 | AT_ASSERTM(grad_output.type().is_cuda(), "grad_output must be a CUDA tensor");
106 |
107 | const int batch = value.size(0);
108 | const int spatial_size = value.size(1);
109 | const int num_heads = value.size(2);
110 | const int channels = value.size(3);
111 |
112 | const int num_levels = spatial_shapes.size(0);
113 |
114 | const int num_query = sampling_loc.size(1);
115 | const int num_point = sampling_loc.size(4);
116 |
117 | const int im2col_step_ = std::min(batch, im2col_step);
118 |
119 | AT_ASSERTM(batch % im2col_step_ == 0, "batch(%d) must divide im2col_step(%d)", batch, im2col_step_);
120 |
121 | auto grad_value = at::zeros_like(value);
122 | auto grad_sampling_loc = at::zeros_like(sampling_loc);
123 | auto grad_attn_weight = at::zeros_like(attn_weight);
124 |
125 | const int batch_n = im2col_step_;
126 | auto per_value_size = spatial_size * num_heads * channels;
127 | auto per_sample_loc_size = num_query * num_heads * num_levels * num_point * 2;
128 | auto per_attn_weight_size = num_query * num_heads * num_levels * num_point;
129 | auto grad_output_n = grad_output.view({batch/im2col_step_, batch_n, num_query, num_heads, channels});
130 |
131 | for (int n = 0; n < batch/im2col_step_; ++n)
132 | {
133 | auto grad_output_g = grad_output_n.select(0, n);
134 | AT_DISPATCH_FLOATING_TYPES(value.type(), "ms_deform_attn_backward_cuda", ([&] {
135 | ms_deformable_col2im_cuda(at::cuda::getCurrentCUDAStream(),
136 | grad_output_g.data(),
137 | value.data() + n * im2col_step_ * per_value_size,
138 | spatial_shapes.data(),
139 | level_start_index.data(),
140 | sampling_loc.data() + n * im2col_step_ * per_sample_loc_size,
141 | attn_weight.data() + n * im2col_step_ * per_attn_weight_size,
142 | batch_n, spatial_size, num_heads, channels, num_levels, num_query, num_point,
143 | grad_value.data() + n * im2col_step_ * per_value_size,
144 | grad_sampling_loc.data() + n * im2col_step_ * per_sample_loc_size,
145 | grad_attn_weight.data() + n * im2col_step_ * per_attn_weight_size);
146 |
147 | }));
148 | }
149 |
150 | return {
151 | grad_value, grad_sampling_loc, grad_attn_weight
152 | };
153 | }
--------------------------------------------------------------------------------
/TransDoD/models/ops/src/cuda/ms_deform_attn_cuda.h:
--------------------------------------------------------------------------------
1 | /*!
2 | **************************************************************************************************
3 | * Deformable DETR
4 | * Copyright (c) 2020 SenseTime. All Rights Reserved.
5 | * Licensed under the Apache License, Version 2.0 [see LICENSE for details]
6 | **************************************************************************************************
7 | * Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
8 | **************************************************************************************************
9 | */
10 |
11 | #pragma once
12 | #include
13 |
14 | at::Tensor ms_deform_attn_cuda_forward(
15 | const at::Tensor &value,
16 | const at::Tensor &spatial_shapes,
17 | const at::Tensor &level_start_index,
18 | const at::Tensor &sampling_loc,
19 | const at::Tensor &attn_weight,
20 | const int im2col_step);
21 |
22 | std::vector ms_deform_attn_cuda_backward(
23 | const at::Tensor &value,
24 | const at::Tensor &spatial_shapes,
25 | const at::Tensor &level_start_index,
26 | const at::Tensor &sampling_loc,
27 | const at::Tensor &attn_weight,
28 | const at::Tensor &grad_output,
29 | const int im2col_step);
30 |
31 |
--------------------------------------------------------------------------------
/TransDoD/models/ops/src/ms_deform_attn.h:
--------------------------------------------------------------------------------
1 | /*!
2 | **************************************************************************************************
3 | * Deformable DETR
4 | * Copyright (c) 2020 SenseTime. All Rights Reserved.
5 | * Licensed under the Apache License, Version 2.0 [see LICENSE for details]
6 | **************************************************************************************************
7 | * Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
8 | **************************************************************************************************
9 | */
10 |
11 | #pragma once
12 |
13 | #include "cpu/ms_deform_attn_cpu.h"
14 |
15 | #ifdef WITH_CUDA
16 | #include "cuda/ms_deform_attn_cuda.h"
17 | #endif
18 |
19 |
20 | at::Tensor
21 | ms_deform_attn_forward(
22 | const at::Tensor &value,
23 | const at::Tensor &spatial_shapes,
24 | const at::Tensor &level_start_index,
25 | const at::Tensor &sampling_loc,
26 | const at::Tensor &attn_weight,
27 | const int im2col_step)
28 | {
29 | if (value.type().is_cuda())
30 | {
31 | #ifdef WITH_CUDA
32 | return ms_deform_attn_cuda_forward(
33 | value, spatial_shapes, level_start_index, sampling_loc, attn_weight, im2col_step);
34 | #else
35 | AT_ERROR("Not compiled with GPU support");
36 | #endif
37 | }
38 | AT_ERROR("Not implemented on the CPU");
39 | }
40 |
41 | std::vector
42 | ms_deform_attn_backward(
43 | const at::Tensor &value,
44 | const at::Tensor &spatial_shapes,
45 | const at::Tensor &level_start_index,
46 | const at::Tensor &sampling_loc,
47 | const at::Tensor &attn_weight,
48 | const at::Tensor &grad_output,
49 | const int im2col_step)
50 | {
51 | if (value.type().is_cuda())
52 | {
53 | #ifdef WITH_CUDA
54 | return ms_deform_attn_cuda_backward(
55 | value, spatial_shapes, level_start_index, sampling_loc, attn_weight, grad_output, im2col_step);
56 | #else
57 | AT_ERROR("Not compiled with GPU support");
58 | #endif
59 | }
60 | AT_ERROR("Not implemented on the CPU");
61 | }
62 |
63 |
--------------------------------------------------------------------------------
/TransDoD/models/ops/src/vision.cpp:
--------------------------------------------------------------------------------
1 | /*!
2 | **************************************************************************************************
3 | * Deformable DETR
4 | * Copyright (c) 2020 SenseTime. All Rights Reserved.
5 | * Licensed under the Apache License, Version 2.0 [see LICENSE for details]
6 | **************************************************************************************************
7 | * Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
8 | **************************************************************************************************
9 | */
10 |
11 | #include "ms_deform_attn.h"
12 |
13 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
14 | m.def("ms_deform_attn_forward", &ms_deform_attn_forward, "ms_deform_attn_forward");
15 | m.def("ms_deform_attn_backward", &ms_deform_attn_backward, "ms_deform_attn_backward");
16 | }
17 |
--------------------------------------------------------------------------------
/TransDoD/models/ops/test.py:
--------------------------------------------------------------------------------
1 | # ------------------------------------------------------------------------------------------------
2 | # Deformable DETR
3 | # Copyright (c) 2020 SenseTime. All Rights Reserved.
4 | # Licensed under the Apache License, Version 2.0 [see LICENSE for details]
5 | # ------------------------------------------------------------------------------------------------
6 | # Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
7 | # ------------------------------------------------------------------------------------------------
8 |
9 | from __future__ import absolute_import
10 | from __future__ import print_function
11 | from __future__ import division
12 |
13 | import time
14 | import torch
15 | import torch.nn as nn
16 | from torch.autograd import gradcheck
17 |
18 | from functions.ms_deform_attn_func import MSDeformAttnFunction, ms_deform_attn_core_pytorch
19 |
20 |
21 | N, M, D = 1, 2, 2
22 | Lq, L, P = 2, 2, 2
23 | shapes = torch.as_tensor([(6, 4), (3, 2)], dtype=torch.long).cuda()
24 | level_start_index = torch.cat((shapes.new_zeros((1, )), shapes.prod(1).cumsum(0)[:-1]))
25 | S = sum([(H*W).item() for H, W in shapes])
26 |
27 |
28 | torch.manual_seed(3)
29 |
30 |
31 | @torch.no_grad()
32 | def check_forward_equal_with_pytorch_double():
33 | value = torch.rand(N, S, M, D).cuda() * 0.01
34 | sampling_locations = torch.rand(N, Lq, M, L, P, 2).cuda()
35 | attention_weights = torch.rand(N, Lq, M, L, P).cuda() + 1e-5
36 | attention_weights /= attention_weights.sum(-1, keepdim=True).sum(-2, keepdim=True)
37 | im2col_step = 2
38 | output_pytorch = ms_deform_attn_core_pytorch(value.double(), shapes, sampling_locations.double(), attention_weights.double()).detach().cpu()
39 | output_cuda = MSDeformAttnFunction.apply(value.double(), shapes, level_start_index, sampling_locations.double(), attention_weights.double(), im2col_step).detach().cpu()
40 | fwdok = torch.allclose(output_cuda, output_pytorch)
41 | max_abs_err = (output_cuda - output_pytorch).abs().max()
42 | max_rel_err = ((output_cuda - output_pytorch).abs() / output_pytorch.abs()).max()
43 |
44 | print(f'* {fwdok} check_forward_equal_with_pytorch_double: max_abs_err {max_abs_err:.2e} max_rel_err {max_rel_err:.2e}')
45 |
46 |
47 | @torch.no_grad()
48 | def check_forward_equal_with_pytorch_float():
49 | value = torch.rand(N, S, M, D).cuda() * 0.01
50 | sampling_locations = torch.rand(N, Lq, M, L, P, 2).cuda()
51 | attention_weights = torch.rand(N, Lq, M, L, P).cuda() + 1e-5
52 | attention_weights /= attention_weights.sum(-1, keepdim=True).sum(-2, keepdim=True)
53 | im2col_step = 2
54 | output_pytorch = ms_deform_attn_core_pytorch(value, shapes, sampling_locations, attention_weights).detach().cpu()
55 | output_cuda = MSDeformAttnFunction.apply(value, shapes, level_start_index, sampling_locations, attention_weights, im2col_step).detach().cpu()
56 | fwdok = torch.allclose(output_cuda, output_pytorch, rtol=1e-2, atol=1e-3)
57 | max_abs_err = (output_cuda - output_pytorch).abs().max()
58 | max_rel_err = ((output_cuda - output_pytorch).abs() / output_pytorch.abs()).max()
59 |
60 | print(f'* {fwdok} check_forward_equal_with_pytorch_float: max_abs_err {max_abs_err:.2e} max_rel_err {max_rel_err:.2e}')
61 |
62 |
63 | def check_gradient_numerical(channels=4, grad_value=True, grad_sampling_loc=True, grad_attn_weight=True):
64 |
65 | value = torch.rand(N, S, M, channels).cuda() * 0.01
66 | sampling_locations = torch.rand(N, Lq, M, L, P, 2).cuda()
67 | attention_weights = torch.rand(N, Lq, M, L, P).cuda() + 1e-5
68 | attention_weights /= attention_weights.sum(-1, keepdim=True).sum(-2, keepdim=True)
69 | im2col_step = 2
70 | func = MSDeformAttnFunction.apply
71 |
72 | value.requires_grad = grad_value
73 | sampling_locations.requires_grad = grad_sampling_loc
74 | attention_weights.requires_grad = grad_attn_weight
75 |
76 | gradok = gradcheck(func, (value.double(), shapes, level_start_index, sampling_locations.double(), attention_weights.double(), im2col_step))
77 |
78 | print(f'* {gradok} check_gradient_numerical(D={channels})')
79 |
80 |
81 | if __name__ == '__main__':
82 | check_forward_equal_with_pytorch_double()
83 | check_forward_equal_with_pytorch_float()
84 |
85 | for channels in [30, 32, 64, 71, 1025, 2048, 3096]:
86 | check_gradient_numerical(channels, True, True, True)
87 |
88 |
89 |
90 |
--------------------------------------------------------------------------------
/TransDoD/models/position_encoding.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
2 | """
3 | Various positional encodings for the transformer.
4 | """
5 | import sys
6 | sys.path.append("../../utils/")
7 |
8 | import math
9 | import torch
10 | from torch import nn
11 | from utils.misc import NestedTensor
12 |
13 | class PositionEmbeddingSine(nn.Module):
14 | """
15 | This is a more standard version of the position embedding, very similar to the one
16 | used by the Attention is all you need paper, generalized to work on images.
17 | """
18 | def __init__(self, num_pos_feats=64, temperature=10000, normalize=False, scale=None):
19 | super().__init__()
20 | self.num_pos_feats = num_pos_feats
21 | self.temperature = temperature
22 | self.normalize = normalize
23 | if scale is not None and normalize is False:
24 | raise ValueError("normalize should be True if scale is passed")
25 | if scale is None:
26 | scale = 2 * math.pi
27 | self.scale = scale
28 |
29 | def forward(self, tensor):
30 | bs, c, d, h, w = tensor.shape
31 | mask = torch.zeros((bs, d, h, w), dtype=torch.bool).cuda()
32 | assert mask is not None
33 | not_mask = ~mask
34 | d_embed = not_mask.cumsum(1, dtype=torch.float32)
35 | y_embed = not_mask.cumsum(2, dtype=torch.float32)
36 | x_embed = not_mask.cumsum(3, dtype=torch.float32)
37 | if self.normalize:
38 | eps = 1e-6
39 | d_embed = d_embed / (d_embed[:, -1:, :, :] + eps) * self.scale
40 | y_embed = y_embed / (y_embed[:, :, -1:, :] + eps) * self.scale
41 | x_embed = x_embed / (x_embed[:, :, :, -1:] + eps) * self.scale
42 |
43 | dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=tensor.device)
44 | dim_t = self.temperature ** (3 * (dim_t // 3) / self.num_pos_feats)
45 |
46 | pos_x = x_embed[:, :, :, :, None] / dim_t
47 | pos_y = y_embed[:, :, :, :, None] / dim_t
48 | pos_d = d_embed[:, :, :, :, None] / dim_t
49 | pos_x = torch.stack((pos_x[:, :, :, :, 0::2].sin(), pos_x[:, :, :, :, 1::2].cos()), dim=5).flatten(4)
50 | pos_y = torch.stack((pos_y[:, :, :, :, 0::2].sin(), pos_y[:, :, :, :, 1::2].cos()), dim=5).flatten(4)
51 | pos_d = torch.stack((pos_d[:, :, :, :, 0::2].sin(), pos_d[:, :, :, :, 1::2].cos()), dim=5).flatten(4)
52 | pos = torch.cat((pos_d, pos_y, pos_x), dim=4).permute(0, 4, 1, 2, 3)
53 | return pos
54 |
55 |
56 |
57 | class PositionEmbeddingLearned(nn.Module):
58 | """
59 | Absolute pos embedding, learned.
60 | """
61 | def __init__(self, num_pos_feats=256):
62 | super().__init__()
63 | self.row_embed = nn.Embedding(50, num_pos_feats)
64 | self.col_embed = nn.Embedding(50, num_pos_feats)
65 | self.reset_parameters()
66 |
67 | def reset_parameters(self):
68 | nn.init.uniform_(self.row_embed.weight)
69 | nn.init.uniform_(self.col_embed.weight)
70 |
71 | def forward(self, tensor_list: NestedTensor):
72 | x = tensor_list.tensors
73 | h, w = x.shape[-2:]
74 | i = torch.arange(w, device=x.device)
75 | j = torch.arange(h, device=x.device)
76 | x_emb = self.col_embed(i)
77 | y_emb = self.row_embed(j)
78 | pos = torch.cat([
79 | x_emb.unsqueeze(0).repeat(h, 1, 1),
80 | y_emb.unsqueeze(1).repeat(1, w, 1),
81 | ], dim=-1).permute(2, 0, 1).unsqueeze(0).repeat(x.shape[0], 1, 1, 1)
82 | return pos
83 |
84 |
85 | def build_position_encoding(args):
86 | N_steps = args.hidden_dim // 3
87 | if args.position_embedding in ('v2', 'sine'):
88 | # TODO find a better way of exposing other arguments
89 | position_embedding = PositionEmbeddingSine(N_steps, normalize=True)
90 | elif args.position_embedding in ('v3', 'learned'):
91 | position_embedding = PositionEmbeddingLearned(N_steps)
92 | else:
93 | raise ValueError(f"not supported {args.position_embedding}")
94 |
95 | return position_embedding
96 |
--------------------------------------------------------------------------------
/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jianpengz/DoDNet/f06ce1b02988b6894a85f925623c6bcc6012284f/__init__.py
--------------------------------------------------------------------------------
/__pycache__/engine.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jianpengz/DoDNet/f06ce1b02988b6894a85f925623c6bcc6012284f/__pycache__/engine.cpython-37.pyc
--------------------------------------------------------------------------------
/a_DynConv/.train.py.swp:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jianpengz/DoDNet/f06ce1b02988b6894a85f925623c6bcc6012284f/a_DynConv/.train.py.swp
--------------------------------------------------------------------------------
/a_DynConv/dodnet.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jianpengz/DoDNet/f06ce1b02988b6894a85f925623c6bcc6012284f/a_DynConv/dodnet.png
--------------------------------------------------------------------------------
/a_DynConv/evaluate.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os, sys
3 |
4 | sys.path.append("..")
5 |
6 | import torch
7 | import torch.nn as nn
8 | from torch.utils import data
9 | import numpy as np
10 | import pickle
11 | import cv2
12 | import torch.optim as optim
13 | import scipy.misc
14 | import torch.backends.cudnn as cudnn
15 | import torch.nn.functional as F
16 | import matplotlib.pyplot as plt
17 | from scipy.ndimage.filters import gaussian_filter
18 |
19 | # from tqdm import tqdm
20 | import os.path as osp
21 |
22 | from unet3D_DynConv882 import UNet3D
23 | from MOTSDataset import MOTSValDataSet
24 |
25 | import random
26 | import timeit
27 | from tensorboardX import SummaryWriter
28 | from loss_functions import loss
29 |
30 | from sklearn import metrics
31 | import nibabel as nib
32 | from math import ceil
33 |
34 | from engine import Engine
35 | from apex import amp
36 | from apex.parallel import convert_syncbn_model
37 |
38 | start = timeit.default_timer()
39 |
40 |
41 | def str2bool(v):
42 | if v.lower() in ('yes', 'true', 't', 'y', '1'):
43 | return True
44 | elif v.lower() in ('no', 'false', 'f', 'n', '0'):
45 | return False
46 | else:
47 | raise argparse.ArgumentTypeError('Boolean value expected.')
48 |
49 |
50 | def get_arguments():
51 | """Parse all the arguments provided from the CLI.
52 |
53 | Returns:
54 | A list of parsed arguments.
55 | """
56 | parser = argparse.ArgumentParser(description="MOTS: DynConv solution!")
57 |
58 | parser.add_argument("--data_dir", type=str, default='../dataset/')
59 | parser.add_argument("--val_list", type=str, default='list/MOTS/tt.txt')
60 | parser.add_argument("--reload_path", type=str, default='snapshots/fold1/MOTS_DynConv_fold1_final_e999.pth')
61 | parser.add_argument("--reload_from_checkpoint", type=str2bool, default=True)
62 | parser.add_argument("--save_path", type=str, default='outputs/')
63 |
64 | parser.add_argument("--input_size", type=str, default='64,192,192')
65 | parser.add_argument("--batch_size", type=int, default=1)
66 | parser.add_argument("--num_gpus", type=int, default=1)
67 | parser.add_argument('--local_rank', type=int, default=0)
68 | parser.add_argument("--FP16", type=str2bool, default=False)
69 | parser.add_argument("--num_epochs", type=int, default=500)
70 | parser.add_argument("--patience", type=int, default=3)
71 | parser.add_argument("--start_epoch", type=int, default=0)
72 | parser.add_argument("--val_pred_every", type=int, default=10)
73 | parser.add_argument("--learning_rate", type=float, default=1e-3)
74 | parser.add_argument("--num_classes", type=int, default=2)
75 | parser.add_argument("--num_workers", type=int, default=1)
76 |
77 | parser.add_argument("--weight_std", type=str2bool, default=True)
78 | parser.add_argument("--momentum", type=float, default=0.9)
79 | parser.add_argument("--power", type=float, default=0.9)
80 | parser.add_argument("--weight_decay", type=float, default=0.0005)
81 |
82 | return parser
83 |
84 |
85 |
86 | def dice_score(preds, labels): # on GPU
87 | assert preds.shape[0] == labels.shape[0], "predict & target batch size don't match"
88 | predict = preds.contiguous().view(preds.shape[0], -1)
89 | target = labels.contiguous().view(labels.shape[0], -1)
90 |
91 | num = torch.sum(torch.mul(predict, target), dim=1)
92 | den = torch.sum(predict, dim=1) + torch.sum(target, dim=1) + 1
93 |
94 | dice = 2 * num / den
95 |
96 | return dice.mean()
97 |
98 |
99 |
100 | def _get_gaussian(patch_size, sigma_scale=1. / 8) -> np.ndarray:
101 | tmp = np.zeros(patch_size)
102 | center_coords = [i // 2 for i in patch_size]
103 | sigmas = [i * sigma_scale for i in patch_size]
104 | tmp[tuple(center_coords)] = 1
105 | gaussian_importance_map = gaussian_filter(tmp, sigmas, 0, mode='constant', cval=0)
106 | gaussian_importance_map = gaussian_importance_map / np.max(gaussian_importance_map) * 1
107 | gaussian_importance_map = gaussian_importance_map.astype(np.float32)
108 |
109 | # gaussian_importance_map cannot be 0, otherwise we may end up with nans!
110 | gaussian_importance_map[gaussian_importance_map == 0] = np.min(
111 | gaussian_importance_map[gaussian_importance_map != 0])
112 |
113 | return gaussian_importance_map
114 |
115 | def multi_net(net_list, img, task_id):
116 | # img = torch.from_numpy(img).cuda()
117 |
118 | padded_prediction = net_list[0](img, task_id)
119 | padded_prediction = F.sigmoid(padded_prediction)
120 | for i in range(1, len(net_list)):
121 | padded_prediction_i = net_list[i](img, task_id)
122 | padded_prediction_i = F.sigmoid(padded_prediction_i)
123 | padded_prediction += padded_prediction_i
124 | padded_prediction /= len(net_list)
125 | return padded_prediction#.cpu().data.numpy()
126 |
127 | def predict_sliding(args, net_list, image, tile_size, classes, task_id): # tile_size:32x256x256
128 | gaussian_importance_map = _get_gaussian(tile_size, sigma_scale=1. / 8)
129 |
130 | image_size = image.shape
131 | overlap = 1 / 2
132 |
133 | strideHW = ceil(tile_size[1] * (1 - overlap))
134 | strideD = ceil(tile_size[0] * (1 - overlap))
135 | tile_deps = int(ceil((image_size[2] - tile_size[0]) / strideD) + 1)
136 | tile_rows = int(ceil((image_size[3] - tile_size[1]) / strideHW) + 1) # strided convolution formula
137 | tile_cols = int(ceil((image_size[4] - tile_size[2]) / strideHW) + 1)
138 | # print("Need %i x %i x %i prediction tiles @ stride %i x %i px" % (tile_deps, tile_cols, tile_rows, strideD, strideHW))
139 | full_probs = np.zeros((image_size[0], classes, image_size[2], image_size[3], image_size[4]))#.astype(np.float32) # 1x4x155x240x240
140 | count_predictions = np.zeros((image_size[0], classes, image_size[2], image_size[3], image_size[4]))#.astype(np.float32)
141 | full_probs = torch.from_numpy(full_probs)
142 | count_predictions = torch.from_numpy(count_predictions)
143 |
144 | for dep in range(tile_deps):
145 | for row in range(tile_rows):
146 | for col in range(tile_cols):
147 | d1 = int(dep * strideD)
148 | x1 = int(col * strideHW)
149 | y1 = int(row * strideHW)
150 | d2 = min(d1 + tile_size[0], image_size[2])
151 | x2 = min(x1 + tile_size[2], image_size[4])
152 | y2 = min(y1 + tile_size[1], image_size[3])
153 | d1 = max(int(d2 - tile_size[0]), 0)
154 | x1 = max(int(x2 - tile_size[2]), 0) # for portrait images the x1 underflows sometimes
155 | y1 = max(int(y2 - tile_size[1]), 0) # for very few rows y1 underflows
156 |
157 | img = image[:, :, d1:d2, y1:y2, x1:x2]
158 | img = torch.from_numpy(img).cuda()
159 |
160 | prediction1 = multi_net(net_list, img, task_id)
161 | prediction2 = torch.flip(multi_net(net_list, torch.flip(img, [2]), task_id), [2])
162 | prediction3 = torch.flip(multi_net(net_list, torch.flip(img, [3]), task_id), [3])
163 | prediction4 = torch.flip(multi_net(net_list, torch.flip(img, [4]), task_id), [4])
164 | prediction5 = torch.flip(multi_net(net_list, torch.flip(img, [2,3]), task_id), [2,3])
165 | prediction6 = torch.flip(multi_net(net_list, torch.flip(img, [2,4]), task_id), [2,4])
166 | prediction7 = torch.flip(multi_net(net_list, torch.flip(img, [3,4]), task_id), [3,4])
167 | prediction8 = torch.flip(multi_net(net_list, torch.flip(img, [2,3,4]), task_id), [2,3,4])
168 | prediction = (prediction1 + prediction2 + prediction3 + prediction4 + prediction5 + prediction6 + prediction7 + prediction8) / 8.
169 | prediction = prediction.cpu()
170 |
171 | prediction[:,:] *= gaussian_importance_map
172 |
173 | if isinstance(prediction, list):
174 | shape = np.array(prediction[0].shape)
175 | shape[0] = prediction[0].shape[0] * len(prediction)
176 | shape = tuple(shape)
177 | preds = torch.zeros(shape).cuda()
178 | bs_singlegpu = prediction[0].shape[0]
179 | for i in range(len(prediction)):
180 | preds[i * bs_singlegpu: (i + 1) * bs_singlegpu] = prediction[i]
181 | count_predictions[:, :, d1:d2, y1:y2, x1:x2] += 1
182 | full_probs[:, :, d1:d2, y1:y2, x1:x2] += preds
183 |
184 | else:
185 | count_predictions[:, :, d1:d2, y1:y2, x1:x2] += gaussian_importance_map
186 | full_probs[:, :, d1:d2, y1:y2, x1:x2] += prediction
187 |
188 | full_probs /= count_predictions
189 | return full_probs
190 |
191 | def save_nii(args, pred, label, name, affine): # bs, c, WHD
192 | seg_pred_2class = np.asarray(np.around(pred), dtype=np.uint8)
193 | seg_pred_0 = seg_pred_2class[:, 0, :, :, :]
194 | seg_pred_1 = seg_pred_2class[:, 1, :, :, :]
195 | seg_pred = np.zeros_like(seg_pred_0)
196 | if name[0][0:3]!='spl':
197 | seg_pred = np.where(seg_pred_0 == 1, 1, seg_pred)
198 | seg_pred = np.where(seg_pred_1 == 1, 2, seg_pred)
199 | else:# spleen only organ
200 | seg_pred = seg_pred_0
201 |
202 | label_0 = label[:, 0, :, :, :]
203 | label_1 = label[:, 1, :, :, :]
204 | seg_label = np.zeros_like(label_0)
205 | seg_label = np.where(label_0 == 1, 1, seg_label)
206 | seg_label = np.where(label_1 == 1, 2, seg_label)
207 |
208 | if name[0][0:3]!='cas':
209 | seg_pred = seg_pred.transpose((0, 2, 3, 1))
210 | seg_label = seg_label.transpose((0, 2, 3, 1))
211 |
212 | # save
213 | for tt in range(seg_pred.shape[0]):
214 | seg_pred_tt = seg_pred[tt]
215 | seg_label_tt = seg_label[tt]
216 | seg_pred_tt = nib.Nifti1Image(seg_pred_tt, affine=affine[tt])
217 | seg_label_tt = nib.Nifti1Image(seg_label_tt, affine=affine[tt])
218 | if not os.path.exists(args.save_path):
219 | os.makedirs(args.save_path)
220 | seg_label_save_p = os.path.join(args.save_path + '/%s_label.nii.gz' % (name[tt]))
221 | seg_pred_save_p = os.path.join(args.save_path + '/%s_pred.nii.gz' % (name[tt]))
222 | nib.save(seg_label_tt, seg_label_save_p)
223 | nib.save(seg_pred_tt, seg_pred_save_p)
224 | return None
225 |
226 | def validate(args, input_size, model, ValLoader, num_classes, engine):
227 |
228 | val_loss = torch.zeros(size=(7, 1)).cuda() # np.zeros(shape=(7, 1))
229 | val_Dice = torch.zeros(size=(7, 2)).cuda() # np.zeros(shape=(7, 2))
230 | count = torch.zeros(size=(7, 2)).cuda() # np.zeros(shape=(7, 2))
231 |
232 | for index, batch in enumerate(ValLoader):
233 | # print('%d processd' % (index))
234 | image, label, name, task_id, affine = batch
235 |
236 | with torch.no_grad():
237 |
238 | pred_sigmoid = predict_sliding(args, model, image.numpy(), input_size, num_classes, task_id)
239 |
240 | # loss = loss_seg_DICE.forward(pred, label) + loss_seg_CE.forward(pred, label)
241 | loss = torch.tensor(1).cuda()
242 | val_loss[task_id[0], 0] += loss
243 |
244 | if label[0, 0, 0, 0, 0] == -1:
245 | dice_c1 = torch.from_numpy(np.array([-999]))
246 | else:
247 | dice_c1 = dice_score(pred_sigmoid[:, 0, :, :, :], label[:, 0, :, :, :])
248 | val_Dice[task_id[0], 0] += dice_c1
249 | count[task_id[0], 0] += 1
250 | if label[0, 1, 0, 0, 0] == -1:
251 | dice_c2 = torch.from_numpy(np.array([-999]))
252 | else:
253 | dice_c2 = dice_score(pred_sigmoid[:, 1, :, :, :], label[:, 1, :, :, :])
254 | val_Dice[task_id[0], 1] += dice_c2
255 | count[task_id[0], 1] += 1
256 |
257 | print('Task%d-%s loss:%.4f Organ:%.4f Tumor:%.4f' % (task_id, name, loss.item(), dice_c1.item(), dice_c2.item()))
258 |
259 | # save
260 | save_nii(args, pred_sigmoid, label, name, affine)
261 |
262 | count[count == 0] = 1
263 | val_Dice = val_Dice / count
264 | val_loss = val_loss / count.max(axis=1)[0].unsqueeze(1)
265 |
266 | reduce_val_loss = torch.zeros_like(val_loss).cuda()
267 | reduce_val_Dice = torch.zeros_like(val_Dice).cuda()
268 | for i in range(val_loss.shape[0]):
269 | reduce_val_loss[i] = engine.all_reduce_tensor(val_loss[i])
270 | reduce_val_Dice[i] = engine.all_reduce_tensor(val_Dice[i])
271 |
272 | if args.local_rank == 0:
273 | print("Sum results")
274 | for t in range(7):
275 | print('Sum: Task%d- loss:%.4f Organ:%.4f Tumor:%.4f' % (t, reduce_val_loss[t, 0], reduce_val_Dice[t, 0], reduce_val_Dice[t, 1]))
276 |
277 | return reduce_val_loss.mean(), reduce_val_Dice
278 |
279 |
280 |
281 |
282 | def main():
283 | """Create the model and start the training."""
284 | parser = get_arguments()
285 | print(parser)
286 |
287 | with Engine(custom_parser=parser) as engine:
288 | args = parser.parse_args()
289 | if args.num_gpus > 1:
290 | torch.cuda.set_device(args.local_rank)
291 |
292 | d, h, w = map(int, args.input_size.split(','))
293 | input_size = (d, h, w)
294 |
295 | cudnn.benchmark = True
296 | seed = 1234
297 | if engine.distributed:
298 | seed = args.local_rank
299 | torch.manual_seed(seed)
300 | if torch.cuda.is_available():
301 | torch.cuda.manual_seed(seed)
302 |
303 | # Create network.
304 | model = UNet3D(num_classes=args.num_classes, weight_std=args.weight_std)
305 |
306 | model = nn.DataParallel(model)
307 |
308 | model.eval()
309 |
310 | device = torch.device('cuda:{}'.format(args.local_rank))
311 | model.to(device)
312 |
313 | if args.num_gpus > 1:
314 | model = engine.data_parallel(model)
315 |
316 | # load checkpoint...
317 | if args.reload_from_checkpoint:
318 | print('loading from checkpoint: {}'.format(args.reload_path))
319 | if os.path.exists(args.reload_path):
320 | if args.FP16:
321 | checkpoint = torch.load(args.reload_path, map_location=torch.device('cpu'))
322 | model.load_state_dict(checkpoint['model'])
323 | # optimizer.load_state_dict(checkpoint['optimizer'])
324 | # amp.load_state_dict(checkpoint['amp'])
325 | else:
326 | model.load_state_dict(torch.load(args.reload_path, map_location=torch.device('cpu')))
327 | else:
328 | print('File not exists in the reload path: {}'.format(args.reload_path))
329 |
330 |
331 | valloader, val_sampler = engine.get_test_loader(
332 | MOTSValDataSet(args.data_dir, args.val_list))
333 |
334 | print('validate ...')
335 | val_loss, val_Dice = validate(args, input_size, [model], valloader, args.num_classes, engine)
336 |
337 | print('Validate \n 0Liver={:.4} 0LiverT={:.4} \n 1Kidney={:.4} 1KidneyT={:.4} \n'
338 | ' 2Hepa={:.4} 2HepaT={:.4} \n 3Panc={:.4} 3PancT={:.4} \n 4ColonT={:.4} \n 5LungT={:.4} \n 6Spleen={:.4}'
339 | .format(val_Dice[0, 0].item(), val_Dice[0, 1].item(),
340 | val_Dice[1, 0].item(), val_Dice[1, 1].item(), val_Dice[2, 0].item(),
341 | val_Dice[2, 1].item(),
342 | val_Dice[3, 0].item(), val_Dice[3, 1].item(), val_Dice[4, 1].item(),
343 | val_Dice[5, 1].item(), val_Dice[6, 0].item()))
344 |
345 | end = timeit.default_timer()
346 | print(end - start, 'seconds')
347 |
348 |
349 | if __name__ == '__main__':
350 | main()
351 |
--------------------------------------------------------------------------------
/a_DynConv/postp.py:
--------------------------------------------------------------------------------
1 | import os
2 | import os.path as osp
3 | import numpy as np
4 | import random
5 | import collections
6 | import torch
7 | import torchvision
8 | import cv2
9 | from torch.utils import data
10 | import matplotlib.pyplot as plt
11 | import nibabel as nib
12 | from skimage.measure import label as LAB
13 | from skimage.transform import resize
14 | import SimpleITK as sitk
15 | import argparse
16 |
17 | from medpy.metric import hd95
18 |
19 | def get_arguments():
20 | """Parse all the arguments provided from the CLI.
21 |
22 | Returns:
23 | A list of parsed arguments.
24 | """
25 | parser = argparse.ArgumentParser(description="Dynconv post processing!")
26 |
27 | parser.add_argument("--img_folder_path", type=str, default='outputs/dodnet/')
28 |
29 | return parser.parse_args()
30 |
31 | args = get_arguments()
32 |
33 | def continues_region_extract_organ(label, keep_region_nums): # keep_region_nums=1
34 | mask = False*np.zeros_like(label)
35 | regions = np.where(label>=1, np.ones_like(label), np.zeros_like(label))
36 | L, n = LAB(regions, neighbors=4, background=0, connectivity=2, return_num=True)
37 |
38 | #
39 | ary_num = np.zeros(shape=(n+1,1))
40 | for i in range(0, n+1):
41 | ary_num[i] = np.sum(L==i)
42 | max_index = np.argsort(-ary_num, axis=0)
43 | count=1
44 | for i in range(1, n+1):
45 | if count<=keep_region_nums: # keep
46 | mask = np.where(L == max_index[i][0], label, mask)
47 | count+=1
48 | label = np.where(mask==True, label, np.zeros_like(label))
49 | return label
50 |
51 | def continues_region_extract_tumor(label): #
52 |
53 | regions = np.where(label>=1, np.ones_like(label), np.zeros_like(label))
54 | L, n = LAB(regions, neighbors=4, background=0, connectivity=2, return_num=True)
55 |
56 | for i in range(1, n+1):
57 | if np.sum(L==i)<=50 and n>1: # remove default 50
58 | label = np.where(L == i, np.zeros_like(label), label)
59 |
60 | return label
61 |
62 |
63 |
64 | def dice_score(preds, labels):
65 | preds = preds[np.newaxis, :]
66 | labels = labels[np.newaxis, :]
67 | assert preds.shape[0] == labels.shape[0], "predict & target batch size don't match"
68 | predict = preds.view().reshape(preds.shape[0], -1)
69 | target = labels.view().reshape(labels.shape[0], -1)
70 |
71 | num = np.sum(np.multiply(predict, target), axis=1)
72 | den = np.sum(predict, axis=1) + np.sum(target, axis=1) + 1
73 |
74 | dice = 2 * num / den
75 |
76 | return dice.mean()
77 |
78 |
79 | def task_index(name):
80 | if "liver" in name:
81 | return 0
82 | if "case" in name:
83 | return 1
84 | if "hepa" in name:
85 | return 2
86 | if "pancreas" in name:
87 | return 3
88 | if "colon" in name:
89 | return 4
90 | if "lung" in name:
91 | return 5
92 | if "spleen" in name:
93 | return 6
94 |
95 | def compute_HD95(ref, pred):
96 | num_ref = np.sum(ref)
97 | num_pred = np.sum(pred)
98 |
99 | if num_ref == 0:
100 | if num_pred == 0:
101 | return 0
102 | else:
103 | return 373.12866
104 | elif num_pred == 0 and num_ref != 0:
105 | return 373.12866
106 | else:
107 | return hd95(pred, ref, (1, 1, 1))
108 |
109 | val_Dice = np.zeros(shape=(7, 2))
110 | val_HD = np.zeros(shape=(7, 2))
111 | count = np.zeros(shape=(7, 2))
112 |
113 | for root, dirs, files in os.walk(args.img_folder_path):
114 | for i in sorted(files):
115 | if i[-12:-7] != 'label':
116 | continue
117 | i2 = i[:-12]+'pred'+i[-7:]
118 | i_file = root + i
119 | i2_file = root + i2
120 | predNII = nib.load(i2_file)
121 | labelNII = nib.load(i_file)
122 | pred = predNII.get_data()
123 | label = labelNII.get_data()
124 |
125 | # post-processing
126 |
127 | task_id = task_index(i)
128 | if task_id == 0 or task_id == 1 or task_id == 3:
129 | pred_organ = (pred >= 1)
130 | pred_tumor = (pred == 2)
131 | label_organ = (label >= 1)
132 | label_tumor = (label == 2)
133 |
134 | elif task_id == 2:
135 | pred_organ = (pred == 1)
136 | pred_tumor = (pred == 2)
137 | label_organ = (label == 1)
138 | label_tumor = (label == 2)
139 |
140 | elif task_id == 4 or task_id == 5:
141 | pred_organ = None
142 | pred_tumor = (pred == 2)
143 | label_organ = None
144 | label_tumor = (label == 2)
145 | elif task_id == 6:
146 | pred_organ = (pred == 1)
147 | pred_tumor = None
148 | label_organ = (label == 1)
149 | label_tumor = None
150 | else:
151 | print("No such a task!")
152 |
153 | if task_id == 0:
154 | pred_organ = continues_region_extract_organ(pred_organ, 1)
155 | pred_tumor = np.where(pred_organ == True, pred_tumor, np.zeros_like(pred_tumor))
156 | pred_tumor = continues_region_extract_tumor(pred_tumor)
157 | elif task_id == 1:
158 | pred_organ = continues_region_extract_organ(pred_organ, 2)
159 | pred_tumor = np.where(pred_organ == True, pred_tumor, np.zeros_like(pred_tumor))
160 | pred_tumor = continues_region_extract_organ(pred_tumor, 1)
161 | elif task_id == 2:
162 | pred_tumor = continues_region_extract_tumor(pred_tumor)
163 | elif task_id == 3:
164 | pred_organ = continues_region_extract_organ(pred_organ, 1)
165 | pred_tumor = np.where(pred_organ == True, pred_tumor, np.zeros_like(pred_tumor))
166 | pred_tumor = continues_region_extract_tumor(pred_tumor)
167 | elif task_id == 4:
168 | pred_tumor = continues_region_extract_organ(pred_tumor, 1)
169 | elif task_id == 5:
170 | pred_tumor = continues_region_extract_organ(pred_tumor, 1)
171 | elif task_id == 6:
172 | pred_organ = continues_region_extract_organ(pred_organ, 1)
173 | else:
174 | print("No such a task index!!!")
175 |
176 | if label_organ is not None:
177 | dice_c1 = dice_score(pred_organ, label_organ)
178 | HD_c1 = compute_HD95(label_organ, pred_organ)
179 | val_Dice[task_id, 0] += dice_c1
180 | val_HD[task_id, 0] += HD_c1
181 | count[task_id, 0] += 1
182 | else:
183 | dice_c1=-999
184 | HD_c1=999
185 | if label_tumor is not None:
186 | dice_c2 = dice_score(pred_tumor, label_tumor)
187 | HD_c2 = compute_HD95(label_tumor, pred_tumor)
188 | val_Dice[task_id, 1] += dice_c2
189 | val_HD[task_id, 1] += HD_c2
190 | count[task_id, 1] += 1
191 | else:
192 | dice_c2=-999
193 | HD_c2=999
194 | print("%s: Organ_Dice %f, tumor_Dice %f | Organ_HD %f, tumor_HD %f" % (i[:-13], dice_c1, dice_c2, HD_c1, HD_c2))
195 |
196 | count[count == 0] = 1
197 | val_Dice = val_Dice / count
198 | val_HD = val_HD / count
199 |
200 | print("Sum results")
201 | for t in range(7):
202 | print('Sum: Task%d- Organ_Dice:%.4f Tumor_Dice:%.4f | Organ_HD:%.4f Tumor_HD:%.4f' % (t, val_Dice[t, 0], val_Dice[t, 1], val_HD[t,0], val_HD[t,1]))
203 |
--------------------------------------------------------------------------------
/a_DynConv/train.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os, sys
3 |
4 | sys.path.append("..")
5 |
6 | import torch
7 | import torch.nn as nn
8 | from torch.utils import data
9 | import numpy as np
10 | import pickle
11 | import cv2
12 | import torch.optim as optim
13 | import scipy.misc
14 | import torch.backends.cudnn as cudnn
15 | import torch.nn.functional as F
16 | import matplotlib.pyplot as plt
17 |
18 | import os.path as osp
19 | from unet3D_DynConv882 import UNet3D
20 | from MOTSDataset import MOTSDataSet, my_collate
21 |
22 | import random
23 | import timeit
24 | from tensorboardX import SummaryWriter
25 | from loss_functions import loss
26 |
27 | from sklearn import metrics
28 | from math import ceil
29 |
30 | from engine import Engine
31 | from apex import amp
32 | from apex.parallel import convert_syncbn_model
33 |
34 | start = timeit.default_timer()
35 |
36 |
37 | def str2bool(v):
38 | if v.lower() in ('yes', 'true', 't', 'y', '1'):
39 | return True
40 | elif v.lower() in ('no', 'false', 'f', 'n', '0'):
41 | return False
42 | else:
43 | raise argparse.ArgumentTypeError('Boolean value expected.')
44 |
45 |
46 | def get_arguments():
47 |
48 | parser = argparse.ArgumentParser(description="unet3D_DynConv882")
49 |
50 | parser.add_argument("--data_dir", type=str, default='../dataset/')
51 | parser.add_argument("--train_list", type=str, default='list/MOTS/MOTS_train.txt')
52 | parser.add_argument("--val_list", type=str, default='list/MOTS/xx.txt')
53 | parser.add_argument("--snapshot_dir", type=str, default='snapshots/fold1/')
54 | parser.add_argument("--reload_path", type=str, default='snapshots/fold1/xx.pth')
55 | parser.add_argument("--reload_from_checkpoint", type=str2bool, default=False)
56 | parser.add_argument("--input_size", type=str, default='64,64,64')
57 | parser.add_argument("--batch_size", type=int, default=2)
58 | parser.add_argument("--num_gpus", type=int, default=1)
59 | parser.add_argument('--local_rank', type=int, default=0)
60 | parser.add_argument("--FP16", type=str2bool, default=False)
61 | parser.add_argument("--num_epochs", type=int, default=500)
62 | parser.add_argument("--itrs_each_epoch", type=int, default=250)
63 | parser.add_argument("--patience", type=int, default=3)
64 | parser.add_argument("--start_epoch", type=int, default=0)
65 | parser.add_argument("--val_pred_every", type=int, default=10)
66 | parser.add_argument("--learning_rate", type=float, default=1e-3)
67 | parser.add_argument("--num_classes", type=int, default=2)
68 | parser.add_argument("--num_workers", type=int, default=1)
69 | parser.add_argument("--weight_std", type=str2bool, default=True)
70 | parser.add_argument("--momentum", type=float, default=0.9)
71 | parser.add_argument("--power", type=float, default=0.9)
72 | parser.add_argument("--weight_decay", type=float, default=0.0005)
73 | parser.add_argument("--ignore_label", type=int, default=255)
74 | parser.add_argument("--is_training", action="store_true")
75 | parser.add_argument("--random_mirror", type=str2bool, default=True)
76 | parser.add_argument("--random_scale", type=str2bool, default=True)
77 | parser.add_argument("--random_seed", type=int, default=1234)
78 | parser.add_argument("--gpu", type=str, default='None')
79 | return parser
80 |
81 |
82 | def lr_poly(base_lr, iter, max_iter, power):
83 | return base_lr * ((1 - float(iter) / max_iter) ** (power))
84 |
85 |
86 | def adjust_learning_rate(optimizer, i_iter, lr, num_stemps, power):
87 | """Sets the learning rate to the initial LR divided by 5 at 60th, 120th and 160th epochs"""
88 | lr = lr_poly(lr, i_iter, num_stemps, power)
89 | optimizer.param_groups[0]['lr'] = lr
90 | return lr
91 |
92 |
93 | def main():
94 | """Create the model and start the training."""
95 | parser = get_arguments()
96 | print(parser)
97 |
98 | with Engine(custom_parser=parser) as engine:
99 | args = parser.parse_args()
100 | if args.num_gpus > 1:
101 | torch.cuda.set_device(args.local_rank)
102 |
103 | writer = SummaryWriter(args.snapshot_dir)
104 |
105 | if not args.gpu == 'None':
106 | os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu
107 |
108 | d, h, w = map(int, args.input_size.split(','))
109 | input_size = (d, h, w)
110 |
111 | cudnn.benchmark = True
112 | seed = args.random_seed
113 | if engine.distributed:
114 | seed = args.local_rank
115 | torch.manual_seed(seed)
116 | if torch.cuda.is_available():
117 | torch.cuda.manual_seed(seed)
118 |
119 | # Create model
120 | model = UNet3D(num_classes=args.num_classes, weight_std=args.weight_std)
121 |
122 | model.train()
123 |
124 | device = torch.device('cuda:{}'.format(args.local_rank))
125 | model.to(device)
126 |
127 | optimizer = torch.optim.SGD(model.parameters(), args.learning_rate, momentum=0.99, nesterov=True)
128 |
129 | if args.FP16:
130 | print("Note: Using FP16 during training************")
131 | model, optimizer = amp.initialize(model, optimizer, opt_level="O1")
132 |
133 | if args.num_gpus > 1:
134 | model = engine.data_parallel(model)
135 |
136 | # load checkpoint...
137 | if args.reload_from_checkpoint:
138 | print('loading from checkpoint: {}'.format(args.reload_path))
139 | if os.path.exists(args.reload_path):
140 | if args.FP16:
141 | checkpoint = torch.load(args.reload_path, map_location=torch.device('cpu'))
142 | model.load_state_dict(checkpoint['model'])
143 | optimizer.load_state_dict(checkpoint['optimizer'])
144 | amp.load_state_dict(checkpoint['amp'])
145 | else:
146 | model.load_state_dict(torch.load(args.reload_path, map_location=torch.device('cpu')))
147 | else:
148 | print('File not exists in the reload path: {}'.format(args.reload_path))
149 |
150 | loss_seg_DICE = loss.DiceLoss4MOTS(num_classes=args.num_classes).to(device)
151 | loss_seg_CE = loss.CELoss4MOTS(num_classes=args.num_classes, ignore_index=255).to(device)
152 |
153 | if not os.path.exists(args.snapshot_dir):
154 | os.makedirs(args.snapshot_dir)
155 |
156 | trainloader, train_sampler = engine.get_train_loader(
157 | MOTSDataSet(args.data_dir, args.train_list, max_iters=args.itrs_each_epoch * args.batch_size, crop_size=input_size, scale=args.random_scale, mirror=args.random_mirror),
158 | collate_fn=my_collate)
159 |
160 | all_tr_loss = []
161 | all_va_loss = []
162 | train_loss_MA = None
163 | val_loss_MA = None
164 |
165 | val_best_loss = 999999
166 |
167 | for epoch in range(args.num_epochs):
168 | if epoch < args.start_epoch:
169 | continue
170 |
171 | if engine.distributed:
172 | train_sampler.set_epoch(epoch)
173 |
174 | epoch_loss = []
175 | adjust_learning_rate(optimizer, epoch, args.learning_rate, args.num_epochs, args.power)
176 |
177 | for iter, batch in enumerate(trainloader):
178 |
179 | images = torch.from_numpy(batch['image']).cuda()
180 | labels = torch.from_numpy(batch['label']).cuda()
181 | volumeName = batch['name']
182 | task_ids = batch['task_id']
183 |
184 | optimizer.zero_grad()
185 | preds = model(images, task_ids)
186 |
187 | term_seg_Dice = loss_seg_DICE.forward(preds, labels)
188 | term_seg_BCE = loss_seg_CE.forward(preds, labels)
189 | term_all = term_seg_Dice + term_seg_BCE
190 |
191 | reduce_Dice = engine.all_reduce_tensor(term_seg_Dice)
192 | reduce_BCE = engine.all_reduce_tensor(term_seg_BCE)
193 | reduce_all = engine.all_reduce_tensor(term_all)
194 |
195 | if args.FP16:
196 | with amp.scale_loss(term_all, optimizer) as scaled_loss:
197 | scaled_loss.backward()
198 | else:
199 | term_all.backward()
200 | optimizer.step()
201 |
202 | epoch_loss.append(float(reduce_all))
203 |
204 | if (args.local_rank == 0):
205 | print(
206 | 'Epoch {}: {}/{}, lr = {:.4}, loss_seg_Dice = {:.4}, loss_seg_BCE = {:.4}, loss_Sum = {:.4}'.format( \
207 | epoch, iter, len(trainloader), optimizer.param_groups[0]['lr'], reduce_Dice.item(),
208 | reduce_BCE.item(), reduce_all.item()))
209 |
210 | epoch_loss = np.mean(epoch_loss)
211 |
212 | all_tr_loss.append(epoch_loss)
213 |
214 | if (args.local_rank == 0):
215 | print('Epoch_sum {}: lr = {:.4}, loss_Sum = {:.4}'.format(epoch, optimizer.param_groups[0]['lr'],
216 | epoch_loss.item()))
217 | writer.add_scalar('learning_rate', optimizer.param_groups[0]['lr'], epoch)
218 | writer.add_scalar('Train_loss', epoch_loss.item(), epoch)
219 |
220 | if (epoch >= 0) and (args.local_rank == 0) and (((epoch % 50 == 0) and (epoch >= 800)) or (epoch % 50 == 0)):
221 | print('save model ...')
222 | if args.FP16:
223 | checkpoint = {
224 | 'model': model.state_dict(),
225 | 'optimizer': optimizer.state_dict(),
226 | 'amp': amp.state_dict()
227 | }
228 | torch.save(checkpoint, osp.join(args.snapshot_dir, 'MOTS_DynConv_' + args.snapshot_dir.split('/')[-2] + '_e' + str(epoch) + '.pth'))
229 | else:
230 | torch.save(model.state_dict(),osp.join(args.snapshot_dir, 'MOTS_DynConv_' + args.snapshot_dir.split('/')[-2] + '_e' + str(epoch) + '.pth'))
231 |
232 | if (epoch >= args.num_epochs - 1) and (args.local_rank == 0):
233 | print('save model ...')
234 | if args.FP16:
235 | checkpoint = {
236 | 'model': model.state_dict(),
237 | 'optimizer': optimizer.state_dict(),
238 | 'amp': amp.state_dict()
239 | }
240 | torch.save(checkpoint, osp.join(args.snapshot_dir, 'MOTS_DynConv_' + args.snapshot_dir.split('/')[-2] + '_final_e' + str(epoch) + '.pth'))
241 | else:
242 | torch.save(model.state_dict(),osp.join(args.snapshot_dir, 'MOTS_DynConv_' + args.snapshot_dir.split('/')[-2] + '_final_e' + str(epoch) + '.pth'))
243 | break
244 |
245 | end = timeit.default_timer()
246 | print(end - start, 'seconds')
247 |
248 |
249 | if __name__ == '__main__':
250 | main()
251 |
--------------------------------------------------------------------------------
/a_DynConv/unet3D_DynConv882.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | from torch.nn import functional as F
3 | import math
4 | import torch.utils.model_zoo as model_zoo
5 | import torch
6 | import numpy as np
7 | from torch.autograd import Variable
8 | import matplotlib.pyplot as plt
9 | affine_par = True
10 | import functools
11 |
12 | import sys, os
13 |
14 | in_place = True
15 |
16 | class Conv3d(nn.Conv3d):
17 |
18 | def __init__(self, in_channels, out_channels, kernel_size, stride=(1,1,1), padding=(0,0,0), dilation=(1,1,1), groups=1, bias=False):
19 | super(Conv3d, self).__init__(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias)
20 |
21 | def forward(self, x):
22 | weight = self.weight
23 | weight_mean = weight.mean(dim=1, keepdim=True).mean(dim=2, keepdim=True).mean(dim=3, keepdim=True).mean(dim=4, keepdim=True)
24 | weight = weight - weight_mean
25 | std = torch.sqrt(torch.var(weight.view(weight.size(0), -1), dim=1) + 1e-12).view(-1, 1, 1, 1, 1)
26 | weight = weight / std.expand_as(weight)
27 | return F.conv3d(x, weight, self.bias, self.stride, self.padding, self.dilation, self.groups)
28 |
29 |
30 | def conv3x3x3(in_planes, out_planes, kernel_size=(3,3,3), stride=(1,1,1), padding=1, dilation=1, bias=False, weight_std=False):
31 | "3x3x3 convolution with padding"
32 | if weight_std:
33 | return Conv3d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, bias=bias)
34 | else:
35 | return nn.Conv3d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, bias=bias)
36 |
37 |
38 |
39 |
40 | class NoBottleneck(nn.Module):
41 | def __init__(self, inplanes, planes, stride=1, dilation=1, downsample=None, fist_dilation=1, multi_grid=1, weight_std=False):
42 | super(NoBottleneck, self).__init__()
43 | self.weight_std = weight_std
44 | self.gn1 = nn.GroupNorm(16, inplanes)
45 | self.conv1 = conv3x3x3(inplanes, planes, kernel_size=(3, 3, 3), stride=stride, padding=(1,1,1),
46 | dilation=dilation * multi_grid, bias=False, weight_std=self.weight_std)
47 | self.relu = nn.ReLU(inplace=in_place)
48 |
49 | self.gn2 = nn.GroupNorm(16, planes)
50 | self.conv2 = conv3x3x3(planes, planes, kernel_size=(3, 3, 3), stride=1, padding=(1,1,1),
51 | dilation=dilation * multi_grid, bias=False, weight_std=self.weight_std)
52 | self.downsample = downsample
53 | self.dilation = dilation
54 | self.stride = stride
55 |
56 | def forward(self, x):
57 | residual = x
58 |
59 | out = self.gn1(x)
60 | out = self.relu(out)
61 | out = self.conv1(out)
62 |
63 |
64 | out = self.gn2(out)
65 | out = self.relu(out)
66 | out = self.conv2(out)
67 |
68 | if self.downsample is not None:
69 | residual = self.downsample(x)
70 |
71 | out = out + residual
72 |
73 | return out
74 |
75 |
76 | class unet3D(nn.Module):
77 | def __init__(self, layers, num_classes=3, weight_std = False):
78 | self.inplanes = 128
79 | self.weight_std = weight_std
80 | super(unet3D, self).__init__()
81 |
82 | self.conv1 = conv3x3x3(1, 32, stride=[1, 1, 1], weight_std=self.weight_std)
83 |
84 | self.layer0 = self._make_layer(NoBottleneck, 32, 32, layers[0], stride=(1, 1, 1))
85 | self.layer1 = self._make_layer(NoBottleneck, 32, 64, layers[1], stride=(2, 2, 2))
86 | self.layer2 = self._make_layer(NoBottleneck, 64, 128, layers[2], stride=(2, 2, 2))
87 | self.layer3 = self._make_layer(NoBottleneck, 128, 256, layers[3], stride=(2, 2, 2))
88 | self.layer4 = self._make_layer(NoBottleneck, 256, 256, layers[4], stride=(2, 2, 2))
89 |
90 | self.fusionConv = nn.Sequential(
91 | nn.GroupNorm(16, 256),
92 | nn.ReLU(inplace=in_place),
93 | conv3x3x3(256, 256, kernel_size=(1, 1, 1), padding=(0, 0, 0), weight_std=self.weight_std)
94 | )
95 |
96 | self.upsamplex2 = nn.Upsample(scale_factor=2, mode='trilinear')
97 |
98 | self.x8_resb = self._make_layer(NoBottleneck, 256, 128, 1, stride=(1, 1, 1))
99 | self.x4_resb = self._make_layer(NoBottleneck, 128, 64, 1, stride=(1, 1, 1))
100 | self.x2_resb = self._make_layer(NoBottleneck, 64, 32, 1, stride=(1, 1, 1))
101 | self.x1_resb = self._make_layer(NoBottleneck, 32, 32, 1, stride=(1, 1, 1))
102 |
103 | self.precls_conv = nn.Sequential(
104 | nn.GroupNorm(16, 32),
105 | nn.ReLU(inplace=in_place),
106 | nn.Conv3d(32, 8, kernel_size=1)
107 | )
108 |
109 | self.GAP = nn.Sequential(
110 | nn.GroupNorm(16, 256),
111 | nn.ReLU(inplace=in_place),
112 | torch.nn.AdaptiveAvgPool3d((1,1,1))
113 | )
114 | self.controller = nn.Conv3d(256+7, 162, kernel_size=1, stride=1, padding=0)
115 |
116 | def _make_layer(self, block, inplanes, planes, blocks, stride=(1, 1, 1), dilation=1, multi_grid=1):
117 | downsample = None
118 | if stride[0] != 1 or stride[1] != 1 or stride[2] != 1 or inplanes != planes:
119 | downsample = nn.Sequential(
120 | nn.GroupNorm(16, inplanes),
121 | nn.ReLU(inplace=in_place),
122 | conv3x3x3(inplanes, planes, kernel_size=(1, 1, 1), stride=stride, padding=0,
123 | weight_std=self.weight_std),
124 | )
125 |
126 | layers = []
127 | generate_multi_grid = lambda index, grids: grids[index % len(grids)] if isinstance(grids, tuple) else 1
128 | layers.append(block(inplanes, planes, stride, dilation=dilation, downsample=downsample,
129 | multi_grid=generate_multi_grid(0, multi_grid), weight_std=self.weight_std))
130 | # self.inplanes = planes
131 | for i in range(1, blocks):
132 | layers.append(
133 | block(planes, planes, dilation=dilation, multi_grid=generate_multi_grid(i, multi_grid),
134 | weight_std=self.weight_std))
135 |
136 | return nn.Sequential(*layers)
137 |
138 | def encoding_task(self, task_id):
139 | N = task_id.shape[0]
140 | task_encoding = torch.zeros(size=(N, 7))
141 | for i in range(N):
142 | task_encoding[i, task_id[i]]=1
143 | return task_encoding.cuda()
144 |
145 | def parse_dynamic_params(self, params, channels, weight_nums, bias_nums):
146 | assert params.dim() == 2
147 | assert len(weight_nums) == len(bias_nums)
148 | assert params.size(1) == sum(weight_nums) + sum(bias_nums)
149 |
150 | num_insts = params.size(0)
151 | num_layers = len(weight_nums)
152 |
153 | params_splits = list(torch.split_with_sizes(
154 | params, weight_nums + bias_nums, dim=1
155 | ))
156 |
157 | weight_splits = params_splits[:num_layers]
158 | bias_splits = params_splits[num_layers:]
159 |
160 | for l in range(num_layers):
161 | if l < num_layers - 1:
162 | weight_splits[l] = weight_splits[l].reshape(num_insts * channels, -1, 1, 1, 1)
163 | bias_splits[l] = bias_splits[l].reshape(num_insts * channels)
164 | else:
165 | weight_splits[l] = weight_splits[l].reshape(num_insts * 2, -1, 1, 1, 1)
166 | bias_splits[l] = bias_splits[l].reshape(num_insts * 2)
167 |
168 | return weight_splits, bias_splits
169 |
170 | def heads_forward(self, features, weights, biases, num_insts):
171 | assert features.dim() == 5
172 | n_layers = len(weights)
173 | x = features
174 | for i, (w, b) in enumerate(zip(weights, biases)):
175 | x = F.conv3d(
176 | x, w, bias=b,
177 | stride=1, padding=0,
178 | groups=num_insts
179 | )
180 | if i < n_layers - 1:
181 | x = F.relu(x)
182 | return x
183 |
184 | def forward(self, input, task_id):
185 |
186 | x = self.conv1(input)
187 | x = self.layer0(x)
188 | skip0 = x
189 |
190 | x = self.layer1(x)
191 | skip1 = x
192 |
193 | x = self.layer2(x)
194 | skip2 = x
195 |
196 | x = self.layer3(x)
197 | skip3 = x
198 |
199 | x = self.layer4(x)
200 |
201 | x = self.fusionConv(x)
202 |
203 | # generate conv filters for classification layer
204 | task_encoding = self.encoding_task(task_id)
205 | task_encoding.unsqueeze_(2).unsqueeze_(2).unsqueeze_(2)
206 | x_feat = self.GAP(x)
207 | x_cond = torch.cat([x_feat, task_encoding], 1)
208 | params = self.controller(x_cond)
209 | params.squeeze_(-1).squeeze_(-1).squeeze_(-1)
210 |
211 |
212 | # x8
213 | x = self.upsamplex2(x)
214 | x = x + skip3
215 | x = self.x8_resb(x)
216 |
217 | # x4
218 | x = self.upsamplex2(x)
219 | x = x + skip2
220 | x = self.x4_resb(x)
221 |
222 | # x2
223 | x = self.upsamplex2(x)
224 | x = x + skip1
225 | x = self.x2_resb(x)
226 |
227 | # x1
228 | x = self.upsamplex2(x)
229 | x = x + skip0
230 | x = self.x1_resb(x)
231 |
232 | head_inputs = self.precls_conv(x)
233 |
234 | N, _, D, H, W = head_inputs.size()
235 | head_inputs = head_inputs.reshape(1, -1, D, H, W)
236 |
237 | weight_nums, bias_nums = [], []
238 | weight_nums.append(8*8)
239 | weight_nums.append(8*8)
240 | weight_nums.append(8*2)
241 | bias_nums.append(8)
242 | bias_nums.append(8)
243 | bias_nums.append(2)
244 | weights, biases = self.parse_dynamic_params(params, 8, weight_nums, bias_nums)
245 |
246 | logits = self.heads_forward(head_inputs, weights, biases, N)
247 |
248 | logits = logits.reshape(-1, 2, D, H, W)
249 |
250 | return logits
251 |
252 | def UNet3D(num_classes=1, weight_std=False):
253 | print("Using DynConv 8,8,2")
254 | model = unet3D([1, 2, 2, 2, 2], num_classes, weight_std)
255 | return model
--------------------------------------------------------------------------------
/data_list/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jianpengz/DoDNet/f06ce1b02988b6894a85f925623c6bcc6012284f/data_list/.DS_Store
--------------------------------------------------------------------------------
/data_list/MOTS/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jianpengz/DoDNet/f06ce1b02988b6894a85f925623c6bcc6012284f/data_list/MOTS/.DS_Store
--------------------------------------------------------------------------------
/data_list/MOTS/MOTS_test.txt:
--------------------------------------------------------------------------------
1 | liver_0
2 | liver_102
3 | liver_107
4 | liver_111
5 | liver_116
6 | liver_120
7 | liver_125
8 | liver_13
9 | liver_17
10 | liver_21
11 | liver_26
12 | liver_30
13 | liver_35
14 | liver_4
15 | liver_44
16 | liver_49
17 | liver_53
18 | liver_58
19 | liver_62
20 | liver_67
21 | liver_71
22 | liver_76
23 | liver_80
24 | liver_85
25 | liver_9
26 | liver_94
27 | liver_99
28 | kidney_00000
29 | kidney_00005
30 | kidney_00010
31 | kidney_00015
32 | kidney_00020
33 | kidney_00025
34 | kidney_00030
35 | kidney_00035
36 | kidney_00040
37 | kidney_00045
38 | kidney_00050
39 | kidney_00055
40 | kidney_00060
41 | kidney_00065
42 | kidney_00070
43 | kidney_00075
44 | kidney_00080
45 | kidney_00085
46 | kidney_00090
47 | kidney_00095
48 | kidney_00100
49 | kidney_00105
50 | kidney_00110
51 | kidney_00115
52 | kidney_00120
53 | kidney_00125
54 | kidney_00130
55 | kidney_00135
56 | kidney_00140
57 | kidney_00145
58 | kidney_00150
59 | kidney_00155
60 | kidney_00160
61 | kidney_00165
62 | kidney_00170
63 | kidney_00175
64 | kidney_00180
65 | kidney_00185
66 | kidney_00190
67 | kidney_00195
68 | kidney_00200
69 | kidney_00205
70 | hepaticvessel_001
71 | hepaticvessel_008
72 | hepaticvessel_018
73 | hepaticvessel_025
74 | hepaticvessel_030
75 | hepaticvessel_040
76 | hepaticvessel_050
77 | hepaticvessel_058
78 | hepaticvessel_066
79 | hepaticvessel_072
80 | hepaticvessel_080
81 | hepaticvessel_085
82 | hepaticvessel_090
83 | hepaticvessel_096
84 | hepaticvessel_103
85 | hepaticvessel_112
86 | hepaticvessel_119
87 | hepaticvessel_127
88 | hepaticvessel_133
89 | hepaticvessel_140
90 | hepaticvessel_146
91 | hepaticvessel_154
92 | hepaticvessel_161
93 | hepaticvessel_167
94 | hepaticvessel_175
95 | hepaticvessel_183
96 | hepaticvessel_192
97 | hepaticvessel_197
98 | hepaticvessel_203
99 | hepaticvessel_210
100 | hepaticvessel_217
101 | hepaticvessel_223
102 | hepaticvessel_230
103 | hepaticvessel_236
104 | hepaticvessel_244
105 | hepaticvessel_256
106 | hepaticvessel_265
107 | hepaticvessel_271
108 | hepaticvessel_279
109 | hepaticvessel_285
110 | hepaticvessel_291
111 | hepaticvessel_299
112 | hepaticvessel_308
113 | hepaticvessel_320
114 | hepaticvessel_325
115 | hepaticvessel_333
116 | hepaticvessel_341
117 | hepaticvessel_350
118 | hepaticvessel_361
119 | hepaticvessel_369
120 | hepaticvessel_375
121 | hepaticvessel_384
122 | hepaticvessel_391
123 | hepaticvessel_400
124 | hepaticvessel_407
125 | hepaticvessel_416
126 | hepaticvessel_424
127 | hepaticvessel_432
128 | hepaticvessel_440
129 | hepaticvessel_445
130 | hepaticvessel_455
131 | pancreas_001
132 | pancreas_012
133 | pancreas_021
134 | pancreas_032
135 | pancreas_042
136 | pancreas_049
137 | pancreas_056
138 | pancreas_067
139 | pancreas_075
140 | pancreas_083
141 | pancreas_089
142 | pancreas_095
143 | pancreas_101
144 | pancreas_106
145 | pancreas_113
146 | pancreas_122
147 | pancreas_129
148 | pancreas_138
149 | pancreas_149
150 | pancreas_160
151 | pancreas_170
152 | pancreas_179
153 | pancreas_186
154 | pancreas_196
155 | pancreas_201
156 | pancreas_210
157 | pancreas_215
158 | pancreas_224
159 | pancreas_229
160 | pancreas_236
161 | pancreas_244
162 | pancreas_254
163 | pancreas_261
164 | pancreas_267
165 | pancreas_275
166 | pancreas_280
167 | pancreas_287
168 | pancreas_293
169 | pancreas_298
170 | pancreas_303
171 | pancreas_310
172 | pancreas_316
173 | pancreas_325
174 | pancreas_330
175 | pancreas_339
176 | pancreas_346
177 | pancreas_354
178 | pancreas_360
179 | pancreas_366
180 | pancreas_374
181 | pancreas_379
182 | pancreas_387
183 | pancreas_393
184 | pancreas_401
185 | pancreas_409
186 | pancreas_414
187 | pancreas_421
188 | colon_001
189 | colon_009
190 | colon_024
191 | colon_029
192 | colon_036
193 | colon_042
194 | colon_053
195 | colon_065
196 | colon_075
197 | colon_088
198 | colon_096
199 | colon_103
200 | colon_111
201 | colon_118
202 | colon_126
203 | colon_134
204 | colon_140
205 | colon_145
206 | colon_157
207 | colon_164
208 | colon_171
209 | colon_185
210 | colon_196
211 | colon_206
212 | colon_214
213 | colon_219
214 | lung_001
215 | lung_009
216 | lung_018
217 | lung_026
218 | lung_033
219 | lung_041
220 | lung_046
221 | lung_053
222 | lung_059
223 | lung_066
224 | lung_074
225 | lung_081
226 | lung_093
227 | spleen_10
228 | spleen_17
229 | spleen_21
230 | spleen_27
231 | spleen_32
232 | spleen_44
233 | spleen_52
234 | spleen_60
235 | spleen_9
236 |
--------------------------------------------------------------------------------
/dataset/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jianpengz/DoDNet/f06ce1b02988b6894a85f925623c6bcc6012284f/dataset/.DS_Store
--------------------------------------------------------------------------------
/dataset/list/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jianpengz/DoDNet/f06ce1b02988b6894a85f925623c6bcc6012284f/dataset/list/.DS_Store
--------------------------------------------------------------------------------
/dataset/re_spacing.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import os
3 | import nibabel as nib
4 | from skimage.transform import resize
5 | from tqdm import tqdm
6 | import matplotlib.pyplot as plt
7 | import SimpleITK as sitk
8 |
9 | spacing = {
10 | 0: [1.5, 0.8, 0.8],
11 | 1: [1.5, 0.8, 0.8],
12 | 2: [1.5, 0.8, 0.8],
13 | 3: [1.5, 0.8, 0.8],
14 | 4: [1.5, 0.8, 0.8],
15 | 5: [1.5, 0.8, 0.8],
16 | 6: [1.5, 0.8, 0.8],
17 | }
18 |
19 | ori_path = './0123456'
20 | new_path = './0123456_spacing_same'
21 |
22 | count = -1
23 | for root1, dirs1, _ in os.walk(ori_path):
24 | for i_dirs1 in tqdm(sorted(dirs1)): # 0Liver
25 | # if i_dirs1 != '0Liver':
26 | # continue
27 | ###########################################################################
28 | if i_dirs1 == '1Kidney':
29 | for root2, dirs2, files2 in os.walk(os.path.join(root1, i_dirs1)):
30 |
31 | for root3, dirs3, files3 in os.walk(os.path.join(root2, 'origin')):
32 |
33 | for i_dirs3 in sorted(dirs3): # case_00000
34 | # if int(i_dirs3[-2:])!=4:
35 | # continue
36 |
37 | for root4, dirs4, files4 in os.walk(os.path.join(root3, i_dirs3)):
38 | for i_files4 in sorted(files4):
39 | # read img
40 | print("Processing %s" % (i_files4))
41 | img_path = os.path.join(root4, i_files4)
42 | imageITK = sitk.ReadImage(img_path)
43 | image = sitk.GetArrayFromImage(imageITK)
44 | ori_size = np.array(imageITK.GetSize())[[2, 1, 0]]
45 | ori_spacing = np.array(imageITK.GetSpacing())[[2, 1, 0]]
46 | ori_origin = imageITK.GetOrigin()
47 | ori_direction = imageITK.GetDirection()
48 |
49 | task_id = int(i_dirs1[0])
50 | target_spacing = np.array(spacing[task_id])
51 |
52 | if ori_spacing[0] < 0 or ori_spacing[1] < 0 or ori_spacing[2] < 0:
53 | print("error")
54 | spc_ratio = ori_spacing / target_spacing
55 |
56 | data_type = image.dtype
57 | if i_files4 != 'segmentation.nii.gz':
58 | data_type = np.int32
59 |
60 | if i_files4 == 'segmentation.nii.gz':
61 | order = 0
62 | mode_ = 'edge'
63 | else:
64 | order = 3
65 | mode_ = 'constant'
66 |
67 | image = image.astype(np.float)
68 |
69 | image_resize = resize(image, (
70 | int(ori_size[0] * spc_ratio[0]), int(ori_size[1] * spc_ratio[1]),
71 | int(ori_size[2] * spc_ratio[2])), order=order, cval=0, clip=True,
72 | preserve_range=True)
73 |
74 | image_resize = np.round(image_resize).astype(data_type)
75 |
76 | # save
77 | save_path = os.path.join(new_path, i_dirs1, 'origin', i_dirs3)
78 | if not os.path.exists(save_path):
79 | os.makedirs(save_path)
80 | saveITK = sitk.GetImageFromArray(image_resize)
81 | saveITK.SetSpacing(target_spacing[[2, 1, 0]])
82 | saveITK.SetOrigin(ori_origin)
83 | saveITK.SetDirection(ori_direction)
84 | sitk.WriteImage(saveITK, os.path.join(save_path, i_files4))
85 |
86 | #############################################################################
87 | for root2, dirs2, files2 in os.walk(os.path.join(root1, i_dirs1)):
88 | for i_dirs2 in sorted(dirs2): # imagesTr
89 |
90 | for root3, dirs3, files3 in os.walk(os.path.join(root2, i_dirs2)):
91 | for i_files3 in sorted(files3):
92 | if i_files3[0] == '.':
93 | continue
94 | # read img
95 | print("Processing %s" % (i_files3))
96 | img_path = os.path.join(root3, i_files3)
97 | imageITK = sitk.ReadImage(img_path)
98 | image = sitk.GetArrayFromImage(imageITK)
99 | ori_size = np.array(imageITK.GetSize())[[2, 1, 0]]
100 | ori_spacing = np.array(imageITK.GetSpacing())[[2, 1, 0]]
101 | ori_origin = imageITK.GetOrigin()
102 | ori_direction = imageITK.GetDirection()
103 |
104 | task_id = int(i_dirs1[0])
105 | target_spacing = np.array(spacing[task_id])
106 | spc_ratio = ori_spacing / target_spacing
107 |
108 | data_type = image.dtype
109 | if i_dirs2 != 'labelsTr':
110 | data_type = np.int32
111 |
112 | if i_dirs2 == 'labelsTr':
113 | order = 0
114 | mode_ = 'edge'
115 | else:
116 | order = 3
117 | mode_ = 'constant'
118 |
119 | image = image.astype(np.float)
120 |
121 | image_resize = resize(image, (int(ori_size[0] * spc_ratio[0]), int(ori_size[1] * spc_ratio[1]),
122 | int(ori_size[2] * spc_ratio[2])),
123 | order=order, mode=mode_, cval=0, clip=True, preserve_range=True)
124 | image_resize = np.round(image_resize).astype(data_type)
125 |
126 | # save
127 | save_path = os.path.join(new_path, i_dirs1, i_dirs2)
128 | if not os.path.exists(save_path):
129 | os.makedirs(save_path)
130 | saveITK = sitk.GetImageFromArray(image_resize)
131 | saveITK.SetSpacing(target_spacing[[2, 1, 0]])
132 | saveITK.SetOrigin(ori_origin)
133 | saveITK.SetDirection(ori_direction)
134 | sitk.WriteImage(saveITK, os.path.join(save_path, i_files3))
135 |
--------------------------------------------------------------------------------
/engine.py:
--------------------------------------------------------------------------------
1 | import os
2 | import os.path as osp
3 | import time
4 | import argparse
5 |
6 | import torch
7 | import torch.distributed as dist
8 |
9 | from utils.logger import get_logger
10 | from utils.pyt_utils import all_reduce_tensor, extant_file
11 |
12 | try:
13 | from apex.parallel import DistributedDataParallel, SyncBatchNorm
14 | except ImportError:
15 | raise ImportError(
16 | "Please install apex from https://www.github.com/nvidia/apex .")
17 |
18 |
19 | logger = get_logger()
20 |
21 |
22 | class Engine(object):
23 | def __init__(self, custom_parser=None):
24 | logger.info(
25 | "PyTorch Version {}".format(torch.__version__))
26 | self.devices = None
27 | self.distributed = False
28 |
29 | if custom_parser is None:
30 | self.parser = argparse.ArgumentParser()
31 | else:
32 | assert isinstance(custom_parser, argparse.ArgumentParser)
33 | self.parser = custom_parser
34 |
35 | self.inject_default_parser()
36 | self.args = self.parser.parse_args()
37 |
38 | self.continue_state_object = self.args.continue_fpath
39 |
40 | # if not self.args.gpu == 'None':
41 | # os.environ["CUDA_VISIBLE_DEVICES"]=self.args.gpu
42 |
43 | if 'WORLD_SIZE' in os.environ:
44 | self.distributed = int(os.environ['WORLD_SIZE']) > 1
45 | print("WORLD_SIZE is %d" % (int(os.environ['WORLD_SIZE'])))
46 | if self.distributed:
47 | self.local_rank = self.args.local_rank
48 | self.world_size = int(os.environ['WORLD_SIZE'])
49 | torch.cuda.set_device(self.local_rank)
50 | dist.init_process_group(backend="nccl", init_method='env://')
51 | self.devices = [i for i in range(self.world_size)]
52 | else:
53 | gpus = os.environ["CUDA_VISIBLE_DEVICES"]
54 | self.devices = [i for i in range(len(gpus.split(',')))]
55 |
56 | def inject_default_parser(self):
57 | p = self.parser
58 | p.add_argument('-d', '--devices', default='',
59 | help='set data parallel training')
60 | p.add_argument('-c', '--continue', type=extant_file,
61 | metavar="FILE",
62 | dest="continue_fpath",
63 | help='continue from one certain checkpoint')
64 | # p.add_argument('--local_rank', default=0, type=int,
65 | # help='process rank on node')
66 |
67 | def data_parallel(self, model):
68 | if self.distributed:
69 | model = DistributedDataParallel(model)
70 | else:
71 | model = torch.nn.DataParallel(model)
72 | return model
73 |
74 | def get_train_loader(self, train_dataset, collate_fn=None):
75 | train_sampler = None
76 | is_shuffle = True
77 | batch_size = self.args.batch_size
78 |
79 | if self.distributed:
80 | train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset)
81 | batch_size = self.args.batch_size // self.world_size
82 | is_shuffle = False
83 |
84 | train_loader = torch.utils.data.DataLoader(train_dataset,
85 | batch_size=batch_size,
86 | num_workers=self.args.num_workers,
87 | drop_last=False,
88 | shuffle=is_shuffle,
89 | pin_memory=True,
90 | sampler=train_sampler,
91 | collate_fn=collate_fn)
92 |
93 | return train_loader, train_sampler
94 |
95 | def get_test_loader(self, test_dataset):
96 | test_sampler = None
97 | is_shuffle = False
98 | batch_size = self.args.batch_size
99 |
100 | if self.distributed:
101 | test_sampler = torch.utils.data.distributed.DistributedSampler(
102 | test_dataset)
103 | batch_size = self.args.batch_size // self.world_size
104 |
105 | test_loader = torch.utils.data.DataLoader(test_dataset,
106 | batch_size=1,
107 | num_workers=self.args.num_workers,
108 | drop_last=False,
109 | shuffle=is_shuffle,
110 | pin_memory=True,
111 | sampler=test_sampler)
112 |
113 | return test_loader, test_sampler
114 |
115 |
116 | def all_reduce_tensor(self, tensor, norm=True):
117 | if self.distributed:
118 | return all_reduce_tensor(tensor, world_size=self.world_size, norm=norm)
119 | else:
120 | return torch.mean(tensor)
121 |
122 |
123 | def __enter__(self):
124 | return self
125 |
126 | def __exit__(self, type, value, tb):
127 | torch.cuda.empty_cache()
128 | if type is not None:
129 | logger.warning(
130 | "A exception occurred during Engine initialization, "
131 | "give up running process")
132 | return False
133 |
--------------------------------------------------------------------------------
/loss_functions/__init__.py:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/loss_functions/__pycache__/__init__.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jianpengz/DoDNet/f06ce1b02988b6894a85f925623c6bcc6012284f/loss_functions/__pycache__/__init__.cpython-37.pyc
--------------------------------------------------------------------------------
/loss_functions/__pycache__/loss.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jianpengz/DoDNet/f06ce1b02988b6894a85f925623c6bcc6012284f/loss_functions/__pycache__/loss.cpython-37.pyc
--------------------------------------------------------------------------------
/loss_functions/loss.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 | import torch.nn as nn
4 | from torch.autograd import Variable
5 | import numpy as np
6 | import scipy.ndimage as nd
7 | from matplotlib import pyplot as plt
8 | from torch import Tensor, einsum
9 |
10 |
11 | class BinaryDiceLoss(nn.Module):
12 | def __init__(self, smooth=1, p=2, reduction='mean'):
13 | super(BinaryDiceLoss, self).__init__()
14 | self.smooth = smooth
15 | self.p = p
16 | self.reduction = reduction
17 |
18 | def forward(self, predict, target):
19 | assert predict.shape[0] == target.shape[0], "predict & target batch size don't match"
20 | predict = predict.contiguous().view(predict.shape[0], -1)
21 | target = target.contiguous().view(target.shape[0], -1)
22 |
23 | num = torch.sum(torch.mul(predict, target), dim=1)
24 | den = torch.sum(predict, dim=1) + torch.sum(target, dim=1) + self.smooth
25 |
26 | dice_score = 2*num / den
27 | dice_loss = 1 - dice_score
28 |
29 | dice_loss_avg = dice_loss[target[:,0]!=-1].sum() / dice_loss[target[:,0]!=-1].shape[0]
30 |
31 | return dice_loss_avg
32 |
33 | class DiceLoss4MOTS(nn.Module):
34 | def __init__(self, weight=None, ignore_index=None, num_classes=3, **kwargs):
35 | super(DiceLoss4MOTS, self).__init__()
36 | self.kwargs = kwargs
37 | self.weight = weight
38 | self.ignore_index = ignore_index
39 | self.num_classes = num_classes
40 | self.dice = BinaryDiceLoss(**self.kwargs)
41 |
42 | def forward(self, predict, target):
43 |
44 | total_loss = []
45 | predict = F.sigmoid(predict)
46 |
47 | for i in range(self.num_classes):
48 | if i != self.ignore_index:
49 | dice_loss = self.dice(predict[:, i], target[:, i])
50 | if self.weight is not None:
51 | assert self.weight.shape[0] == self.num_classes, \
52 | 'Expect weight shape [{}], get[{}]'.format(self.num_classes, self.weight.shape[0])
53 | dice_loss *= self.weights[i]
54 | total_loss.append(dice_loss)
55 |
56 | total_loss = torch.stack(total_loss)
57 | total_loss = total_loss[total_loss==total_loss]
58 |
59 | return total_loss.sum()/total_loss.shape[0]
60 |
61 | class CELoss4MOTS(nn.Module):
62 | def __init__(self, ignore_index=None,num_classes=3, **kwargs):
63 | super(CELoss4MOTS, self).__init__()
64 | self.kwargs = kwargs
65 | self.num_classes = num_classes
66 | self.ignore_index = ignore_index
67 | self.criterion = nn.BCEWithLogitsLoss(reduction='none')
68 |
69 | def weight_function(self, mask):
70 | weights = torch.ones_like(mask).float()
71 | voxels_sum = mask.shape[0] * mask.shape[1] * mask.shape[2]
72 | for i in range(2):
73 | voxels_i = [mask == i][0].sum().cpu().numpy()
74 | w_i = np.log(voxels_sum / voxels_i).astype(np.float32)
75 | weights = torch.where(mask == i, w_i * torch.ones_like(weights).float(), weights)
76 |
77 | return weights
78 |
79 | def forward(self, predict, target):
80 | assert predict.shape == target.shape, 'predict & target shape do not match'
81 |
82 | total_loss = []
83 | for i in range(self.num_classes):
84 | if i != self.ignore_index:
85 | ce_loss = self.criterion(predict[:, i], target[:, i])
86 | ce_loss = torch.mean(ce_loss, dim=[1,2,3])
87 |
88 | ce_loss_avg = ce_loss[target[:, i, 0, 0, 0] != -1].sum() / ce_loss[target[:, i, 0, 0, 0] != -1].shape[0]
89 |
90 | total_loss.append(ce_loss_avg)
91 |
92 | total_loss = torch.stack(total_loss)
93 | total_loss = total_loss[total_loss == total_loss]
94 |
95 | return total_loss.sum()/total_loss.shape[0]
96 |
--------------------------------------------------------------------------------
/run_script.sh:
--------------------------------------------------------------------------------
1 | cd TransDoD/
2 |
3 | # Training
4 | CUDA_VISIBLE_DEVICES=0,1 python -m torch.distributed.launch --nproc_per_node=2 train.py \
5 | --train_list='MOTS/MOTS_train.txt' \
6 | --snapshot_dir='snapshots/TransDoDNet/' \
7 | --nnUNet_preprocessed='/path/to/nnUNet_preprocessed' \
8 | --input_size='64,192,192' \
9 | --learning_rate=2e-4 \
10 | --batch_size=2 \
11 | --num_gpus=2 \
12 | --num_epochs=1000
13 |
14 | # Testing
15 | CUDA_VISIBLE_DEVICES=0 python test.py \
16 | --val_list='MOTS/MOTS_test.txt' \
17 | --nnUNet_preprocessed='/path/to/nnUNet_preprocessed' \
18 | --reload_path='/path/to/checkpoint.pth' \
19 | --reload_from_checkpoint=True \
20 | --save_path='outputs/TransDoDNet' \
21 | --input_size='64,192,192'
--------------------------------------------------------------------------------
/utils/ParaFlop.py:
--------------------------------------------------------------------------------
1 | # coding:utf8
2 | import torch
3 | import torchvision
4 |
5 | import torch.nn as nn
6 | from torch.autograd import Variable
7 | import torchvision.models as models
8 |
9 | import numpy as np
10 | from collections import OrderedDict
11 | import pandas as pd
12 | import torch.nn.functional as F
13 |
14 | ##usage: add to train.py or test.py: misc.print_model_parm_nums(model)
15 | ## misc.print_model_parm_flops(model,inputs)
16 | def print_model_parm_nums(model):
17 | total = sum([param.nelement() for param in model.parameters()])
18 | print(' + Number of params: %.2f(e6)' % (total / 1e6))
19 |
20 |
21 | def print_model_parm_flops(model):
22 | # prods = {}
23 | # def save_prods(self, input, output):
24 | # print 'flops:{}'.format(self.__class__.__name__)
25 | # print 'input:{}'.format(input)
26 | # print '_dim:{}'.format(input[0].dim())
27 | # print 'input_shape:{}'.format(np.prod(input[0].shape))
28 | # grads.append(np.prod(input[0].shape))
29 |
30 | prods = {}
31 |
32 | def save_hook(name):
33 | def hook_per(self, input, output):
34 | # print 'flops:{}'.format(self.__class__.__name__)
35 | # print 'input:{}'.format(input)
36 | # print '_dim:{}'.format(input[0].dim())
37 | # print 'input_shape:{}'.format(np.prod(input[0].shape))
38 | # prods.append(np.prod(input[0].shape))
39 | prods[name] = np.prod(input[0].shape)
40 | # prods.append(np.prod(input[0].shape))
41 |
42 | return hook_per
43 |
44 | list_1 = []
45 |
46 | def simple_hook(self, input, output):
47 | list_1.append(np.prod(input[0].shape))
48 |
49 | list_2 = {}
50 |
51 | def simple_hook2(self, input, output):
52 | list_2['names'] = np.prod(input[0].shape)
53 |
54 | multiply_adds = False
55 | list_conv = []
56 |
57 | def conv_hook(self, input, output):
58 | batch_size, input_channels, input_time, input_height, input_width = input[0].size()
59 | output_channels, output_time, output_height, output_width = output[0].size()
60 |
61 | kernel_ops = self.kernel_size[0] * self.kernel_size[1] * self.kernel_size[2] * (
62 | self.in_channels / self.groups) * (2 if multiply_adds else 1)
63 | bias_ops = 1 if self.bias is not None else 0
64 |
65 | params = output_channels * (kernel_ops + bias_ops)
66 | flops = batch_size * params * output_time * output_height * output_width
67 | list_conv.append(flops)
68 |
69 | list_linear = []
70 |
71 | def linear_hook(self, input, output):
72 | batch_size = input[0].size(0) if input[0].dim() == 2 else 1
73 |
74 | weight_ops = self.weight.nelement() * (2 if multiply_adds else 1)
75 | bias_ops = self.bias.nelement()
76 |
77 | flops = batch_size * (weight_ops + bias_ops)
78 | list_linear.append(flops)
79 |
80 | list_bn = []
81 |
82 | def bn_hook(self, input, output):
83 | list_bn.append(input[0].nelement())
84 |
85 | list_fc = []
86 |
87 | def fc_hook(self, input, output):
88 | list_bn.append(input[0].nelement())
89 |
90 | list_relu = []
91 |
92 | def relu_hook(self, input, output):
93 | list_relu.append(input[0].nelement())
94 |
95 | list_pooling = []
96 |
97 | # def pooling_hook(self, input, output):
98 | # batch_size, input_channels, input_time,input_height, input_width = input[0].size()
99 | # output_channels, output_time, output_height, output_width = output[0].size()
100 |
101 | # kernel_ops = self.kernel_size * self.kernel_size*self.kernel_size
102 | # bias_ops = 0
103 | # params = output_channels * (kernel_ops + bias_ops)
104 | # flops = batch_size * params * output_height * output_width * output_time
105 |
106 | # list_pooling.append(flops)
107 |
108 | def foo(net):
109 | childrens = list(net.children())
110 | if not childrens:
111 | if isinstance(net, torch.nn.Conv3d):
112 | # net.register_forward_hook(save_hook(net.__class__.__name__))
113 | # net.register_forward_hook(simple_hook)
114 | # net.register_forward_hook(simple_hook2)
115 | net.register_forward_hook(conv_hook)
116 | if isinstance(net, torch.nn.Linear):
117 | net.register_forward_hook(linear_hook)
118 | if isinstance(net, torch.nn.BatchNorm3d):
119 | net.register_forward_hook(bn_hook)
120 | if isinstance(net, torch.nn.ReLU):
121 | net.register_forward_hook(relu_hook)
122 | if isinstance(net, torch.nn.ReLU):
123 | net.register_forward_hook(relu_hook)
124 | # if isinstance(net, torch.nn.MaxPool3d) or isinstance(net, torch.nn.AvgPool2d):
125 | # net.register_forward_hook(pooling_hook)
126 | return
127 | for c in childrens:
128 | foo(c)
129 |
130 | foo(model)
131 | model.cuda()
132 | # input = Variable(torch.rand(1,16,256,256).unsqueeze(0), requires_grad = True)
133 | # output = model(input)
134 | input = torch.rand(1, 64, 192, 192).unsqueeze(0).cuda()
135 | # input_res = torch.rand(1, 80, 160, 160).unsqueeze(0).cuda()
136 | # output = model([input, input_res])
137 | # output = model(input, torch.tensor([0]).type(torch.long))
138 |
139 | N = 1
140 | task_encoding = torch.zeros(size=(N, 7, 7)).cuda()
141 | for b in range(N):
142 | for i in range(7):
143 | for j in range(7):
144 | if i == j:
145 | task_encoding[b, i, j] = 1
146 | task_encoding.unsqueeze_(-1).unsqueeze_(-1).unsqueeze_(-1)
147 | output = model(input, [], task_encoding)
148 |
149 | total_flops = (sum(list_conv) + sum(list_linear) + sum(list_bn)+sum(list_relu))
150 | print(' + Number of FLOPs: %.5f(e9)' % (total_flops / 1e9))
151 |
152 |
153 |
154 | def get_names_dict(model):
155 | """
156 | Recursive walk to get names including path
157 | """
158 | names = {}
159 |
160 | def _get_names(module, parent_name=''):
161 | for key, module in module.named_children():
162 | name = parent_name + '.' + key if parent_name else key
163 | names[name] = module
164 | if isinstance(module, torch.nn.Module):
165 | _get_names(module, parent_name=name)
166 |
167 | _get_names(model)
168 | return names
169 |
170 | def torch_summarize_df(input_size, model, weights=False, input_shape=True, nb_trainable=False):
171 | """
172 | Summarizes torch model by showing trainable parameters and weights.
173 |
174 | author: wassname
175 | url: https://gist.github.com/wassname/0fb8f95e4272e6bdd27bd7df386716b7
176 | license: MIT
177 |
178 | Modified from:
179 | - https://github.com/pytorch/pytorch/issues/2001#issuecomment-313735757
180 | - https://gist.github.com/wassname/0fb8f95e4272e6bdd27bd7df386716b7/
181 |
182 | Usage:
183 | import torchvision.models as models
184 | model = models.alexnet()
185 | df = torch_summarize_df(input_size=(3, 224,224), model=model)
186 | print(df)
187 |
188 | # name class_name input_shape output_shape nb_params
189 | # 1 features=>0 Conv2d (-1, 3, 224, 224) (-1, 64, 55, 55) 23296#(3*11*11+1)*64
190 | # 2 features=>1 ReLU (-1, 64, 55, 55) (-1, 64, 55, 55) 0
191 | # ...
192 | """
193 |
194 | def register_hook(module):
195 | def hook(module, input, output):
196 | name = ''
197 | for key, item in names.items():
198 | if item == module:
199 | name = key
200 | #
201 | class_name = str(module.__class__).split('.')[-1].split("'")[0]
202 | module_idx = len(summary)
203 |
204 | m_key = module_idx + 1
205 |
206 | summary[m_key] = OrderedDict()
207 | summary[m_key]['name'] = name
208 | summary[m_key]['class_name'] = class_name
209 | if input_shape:
210 | summary[m_key][
211 | 'input_shape'] = (-1,) + tuple(input[0].size())[1:]
212 | summary[m_key]['output_shape'] = (-1,) + tuple(output.size())[1:]
213 | if weights:
214 | summary[m_key]['weights'] = list(
215 | [tuple(p.size()) for p in module.parameters()])
216 |
217 | # summary[m_key]['trainable'] = any([p.requires_grad for p in module.parameters()])
218 | if nb_trainable:
219 | params_trainable = sum(
220 | [torch.LongTensor(list(p.size())).prod() for p in module.parameters() if p.requires_grad])
221 | summary[m_key]['nb_trainable'] = params_trainable
222 | params = sum([torch.LongTensor(list(p.size())).prod() for p in module.parameters()])
223 | summary[m_key]['nb_params'] = params
224 |
225 | if not isinstance(module, nn.Sequential) and \
226 | not isinstance(module, nn.ModuleList) and \
227 | not (module == model):
228 | hooks.append(module.register_forward_hook(hook))
229 |
230 | # Names are stored in parent and path+name is unique not the name
231 | names = get_names_dict(model)
232 |
233 | # check if there are multiple inputs to the network
234 | if isinstance(input_size[0], (list, tuple)):
235 | x = [Variable(torch.rand(1, *in_size)) for in_size in input_size]
236 | else:
237 | x = Variable(torch.rand(1, *input_size))
238 |
239 | if next(model.parameters()).is_cuda:
240 | x = x.cuda()
241 |
242 | # create properties
243 | summary = OrderedDict()
244 | hooks = []
245 |
246 | # register hook
247 | model.apply(register_hook)
248 |
249 | # make a forward pass
250 | model(x)
251 |
252 | # remove these hooks
253 | for h in hooks:
254 | h.remove()
255 |
256 | # make dataframe
257 | df_summary = pd.DataFrame.from_dict(summary, orient='index')
258 |
259 | print(df_summary)
260 |
261 | return df_summary
262 |
--------------------------------------------------------------------------------
/utils/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jianpengz/DoDNet/f06ce1b02988b6894a85f925623c6bcc6012284f/utils/__init__.py
--------------------------------------------------------------------------------
/utils/__pycache__/ParaFlop.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jianpengz/DoDNet/f06ce1b02988b6894a85f925623c6bcc6012284f/utils/__pycache__/ParaFlop.cpython-37.pyc
--------------------------------------------------------------------------------
/utils/__pycache__/__init__.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jianpengz/DoDNet/f06ce1b02988b6894a85f925623c6bcc6012284f/utils/__pycache__/__init__.cpython-37.pyc
--------------------------------------------------------------------------------
/utils/__pycache__/logger.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jianpengz/DoDNet/f06ce1b02988b6894a85f925623c6bcc6012284f/utils/__pycache__/logger.cpython-37.pyc
--------------------------------------------------------------------------------
/utils/__pycache__/misc.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jianpengz/DoDNet/f06ce1b02988b6894a85f925623c6bcc6012284f/utils/__pycache__/misc.cpython-37.pyc
--------------------------------------------------------------------------------
/utils/__pycache__/my_utils.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jianpengz/DoDNet/f06ce1b02988b6894a85f925623c6bcc6012284f/utils/__pycache__/my_utils.cpython-37.pyc
--------------------------------------------------------------------------------
/utils/__pycache__/pyt_utils.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jianpengz/DoDNet/f06ce1b02988b6894a85f925623c6bcc6012284f/utils/__pycache__/pyt_utils.cpython-37.pyc
--------------------------------------------------------------------------------
/utils/logger.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | import logging
4 |
5 | _default_level_name = os.getenv('ENGINE_LOGGING_LEVEL', 'INFO')
6 | _default_level = logging.getLevelName(_default_level_name.upper())
7 |
8 |
9 | class LogFormatter(logging.Formatter):
10 | log_fout = None
11 | date_full = '[%(asctime)s %(lineno)d@%(filename)s:%(name)s] '
12 | date = '%(asctime)s '
13 | msg = '%(message)s'
14 |
15 | def format(self, record):
16 | if record.levelno == logging.DEBUG:
17 | mcl, mtxt = self._color_dbg, 'DBG'
18 | elif record.levelno == logging.WARNING:
19 | mcl, mtxt = self._color_warn, 'WRN'
20 | elif record.levelno == logging.ERROR:
21 | mcl, mtxt = self._color_err, 'ERR'
22 | else:
23 | mcl, mtxt = self._color_normal, ''
24 |
25 | if mtxt:
26 | mtxt += ' '
27 |
28 | if self.log_fout:
29 | self.__set_fmt(self.date_full + mtxt + self.msg)
30 | formatted = super(LogFormatter, self).format(record)
31 | return formatted
32 |
33 | self.__set_fmt(self._color_date(self.date) + mcl(mtxt + self.msg))
34 | formatted = super(LogFormatter, self).format(record)
35 |
36 | return formatted
37 |
38 | if sys.version_info.major < 3:
39 | def __set_fmt(self, fmt):
40 | self._fmt = fmt
41 | else:
42 | def __set_fmt(self, fmt):
43 | self._style._fmt = fmt
44 |
45 | @staticmethod
46 | def _color_dbg(msg):
47 | return '\x1b[36m{}\x1b[0m'.format(msg)
48 |
49 | @staticmethod
50 | def _color_warn(msg):
51 | return '\x1b[1;31m{}\x1b[0m'.format(msg)
52 |
53 | @staticmethod
54 | def _color_err(msg):
55 | return '\x1b[1;4;31m{}\x1b[0m'.format(msg)
56 |
57 | @staticmethod
58 | def _color_omitted(msg):
59 | return '\x1b[35m{}\x1b[0m'.format(msg)
60 |
61 | @staticmethod
62 | def _color_normal(msg):
63 | return msg
64 |
65 | @staticmethod
66 | def _color_date(msg):
67 | return '\x1b[32m{}\x1b[0m'.format(msg)
68 |
69 |
70 | def get_logger(log_dir=None, log_file=None, formatter=LogFormatter):
71 | logger = logging.getLogger()
72 | logger.setLevel(_default_level)
73 | del logger.handlers[:]
74 |
75 | if log_dir and log_file:
76 | if not os.path.isdir(log_dir):
77 | os.makedirs(log_dir)
78 | LogFormatter.log_fout = True
79 | file_handler = logging.FileHandler(log_file, mode='a')
80 | file_handler.setLevel(logging.INFO)
81 | file_handler.setFormatter(formatter)
82 | logger.addHandler(file_handler)
83 |
84 | stream_handler = logging.StreamHandler()
85 | stream_handler.setFormatter(formatter(datefmt='%d %H:%M:%S'))
86 | stream_handler.setLevel(0)
87 | logger.addHandler(stream_handler)
88 | return logger
89 |
--------------------------------------------------------------------------------
/utils/my_utils.py:
--------------------------------------------------------------------------------
1 | # Author: Zylo117
2 |
3 | import os
4 |
5 | import cv2
6 | import numpy as np
7 | import torch
8 | from glob import glob
9 | from torch import nn
10 | # from torchvision.ops import nms
11 | from typing import Union
12 | import uuid
13 |
14 | # from utils.sync_batchnorm import SynchronizedBatchNorm2d
15 |
16 |
17 | def invert_affine(metas: Union[float, list, tuple], preds):
18 | for i in range(len(preds)):
19 | if len(preds[i]['rois']) == 0:
20 | continue
21 | else:
22 | if metas is float:
23 | preds[i]['rois'][:, [0, 2]] = preds[i]['rois'][:, [0, 2]] / metas
24 | preds[i]['rois'][:, [1, 3]] = preds[i]['rois'][:, [1, 3]] / metas
25 | else:
26 | new_w, new_h, old_w, old_h, padding_w, padding_h = metas[i]
27 | preds[i]['rois'][:, [0, 2]] = preds[i]['rois'][:, [0, 2]] / (new_w / old_w)
28 | preds[i]['rois'][:, [1, 3]] = preds[i]['rois'][:, [1, 3]] / (new_h / old_h)
29 | return preds
30 |
31 |
32 | def aspectaware_resize_padding(image, width, height, interpolation=None, means=None):
33 | old_h, old_w, c = image.shape
34 | if old_w > old_h:
35 | new_w = width
36 | new_h = int(width / old_w * old_h)
37 | else:
38 | new_w = int(height / old_h * old_w)
39 | new_h = height
40 |
41 | canvas = np.zeros((height, height, c), np.float32)
42 | if means is not None:
43 | canvas[...] = means
44 |
45 | if new_w != old_w or new_h != old_h:
46 | if interpolation is None:
47 | image = cv2.resize(image, (new_w, new_h))
48 | else:
49 | image = cv2.resize(image, (new_w, new_h), interpolation=interpolation)
50 |
51 | padding_h = height - new_h
52 | padding_w = width - new_w
53 |
54 | if c > 1:
55 | canvas[:new_h, :new_w] = image
56 | else:
57 | if len(image.shape) == 2:
58 | canvas[:new_h, :new_w, 0] = image
59 | else:
60 | canvas[:new_h, :new_w] = image
61 |
62 | return canvas, new_w, new_h, old_w, old_h, padding_w, padding_h,
63 |
64 |
65 | def preprocess(*image_path, max_size=512, mean=(0.406, 0.456, 0.485), std=(0.225, 0.224, 0.229)):
66 | ori_imgs = [cv2.imread(img_path) for img_path in image_path]
67 | normalized_imgs = [(img / 255 - mean) / std for img in ori_imgs]
68 | imgs_meta = [aspectaware_resize_padding(img[..., ::-1], max_size, max_size,
69 | means=None) for img in normalized_imgs]
70 | framed_imgs = [img_meta[0] for img_meta in imgs_meta]
71 | framed_metas = [img_meta[1:] for img_meta in imgs_meta]
72 |
73 | return ori_imgs, framed_imgs, framed_metas
74 |
75 | def preprocess_video(*frame_from_video, max_size=512, mean=(0.406, 0.456, 0.485), std=(0.225, 0.224, 0.229)):
76 | ori_imgs = frame_from_video
77 | normalized_imgs = [(img / 255 - mean) / std for img in ori_imgs]
78 | imgs_meta = [aspectaware_resize_padding(img[..., ::-1], max_size, max_size,
79 | means=None) for img in normalized_imgs]
80 | framed_imgs = [img_meta[0] for img_meta in imgs_meta]
81 | framed_metas = [img_meta[1:] for img_meta in imgs_meta]
82 |
83 | return ori_imgs, framed_imgs, framed_metas
84 |
85 | def postprocess(x, anchors, regression, classification, regressBoxes, clipBoxes, threshold, iou_threshold):
86 | transformed_anchors = regressBoxes(anchors, regression)
87 | transformed_anchors = clipBoxes(transformed_anchors, x)
88 | scores = torch.max(classification, dim=2, keepdim=True)[0]
89 | scores_over_thresh = (scores > threshold)[:, :, 0]
90 | out = []
91 | for i in range(x.shape[0]):
92 | if scores_over_thresh.sum() == 0:
93 | out.append({
94 | 'rois': np.array(()),
95 | 'class_ids': np.array(()),
96 | 'scores': np.array(()),
97 | })
98 |
99 | classification_per = classification[i, scores_over_thresh[i, :], ...].permute(1, 0)
100 | transformed_anchors_per = transformed_anchors[i, scores_over_thresh[i, :], ...]
101 | scores_per = scores[i, scores_over_thresh[i, :], ...]
102 | anchors_nms_idx = nms(transformed_anchors_per, scores_per[:, 0], iou_threshold=iou_threshold)
103 |
104 | if anchors_nms_idx.shape[0] != 0:
105 | scores_, classes_ = classification_per[:, anchors_nms_idx].max(dim=0)
106 | boxes_ = transformed_anchors_per[anchors_nms_idx, :]
107 |
108 | out.append({
109 | 'rois': boxes_.cpu().numpy(),
110 | 'class_ids': classes_.cpu().numpy(),
111 | 'scores': scores_.cpu().numpy(),
112 | })
113 | else:
114 | out.append({
115 | 'rois': np.array(()),
116 | 'class_ids': np.array(()),
117 | 'scores': np.array(()),
118 | })
119 |
120 | return out
121 |
122 |
123 | def display(preds, imgs, obj_list, imshow=True, imwrite=False):
124 | for i in range(len(imgs)):
125 | if len(preds[i]['rois']) == 0:
126 | continue
127 |
128 | for j in range(len(preds[i]['rois'])):
129 | (x1, y1, x2, y2) = preds[i]['rois'][j].astype(np.int)
130 | cv2.rectangle(imgs[i], (x1, y1), (x2, y2), (255, 255, 0), 2)
131 | obj = obj_list[preds[i]['class_ids'][j]]
132 | score = float(preds[i]['scores'][j])
133 |
134 | cv2.putText(imgs[i], '{}, {:.3f}'.format(obj, score),
135 | (x1, y1 + 10), cv2.FONT_HERSHEY_SIMPLEX, 0.5,
136 | (255, 255, 0), 1)
137 | if imshow:
138 | cv2.imshow('img', imgs[i])
139 | cv2.waitKey(0)
140 |
141 | if imwrite:
142 | os.makedirs('test/', exist_ok=True)
143 | cv2.imwrite(f'test/{uuid.uuid4().hex}.jpg', imgs[i])
144 |
145 |
146 |
147 | class CustomDataParallel(nn.DataParallel):
148 | """
149 | force splitting data to all gpus instead of sending all data to cuda:0 and then moving around.
150 | """
151 |
152 | def __init__(self, module, num_gpus):
153 | super().__init__(module)
154 | self.num_gpus = num_gpus
155 |
156 | def scatter(self, inputs, kwargs, device_ids):
157 | # More like scatter and data prep at the same time. The point is we prep the data in such a way
158 | # that no scatter is necessary, and there's no need to shuffle stuff around different GPUs.
159 | devices = ['cuda:' + str(x) for x in range(self.num_gpus)]
160 | splits = inputs[0].shape[0] // self.num_gpus
161 |
162 | return [(inputs[0][splits * device_idx: splits * (device_idx + 1)].to(f'cuda:{device_idx}', non_blocking=True),
163 | inputs[1][splits * device_idx: splits * (device_idx + 1)].to(f'cuda:{device_idx}', non_blocking=True))
164 | for device_idx in range(len(devices))], \
165 | [kwargs] * len(devices)
166 |
167 |
168 | def get_last_weights(weights_path):
169 | weights_path = glob(weights_path + f'/*.pth')
170 | weights_path = sorted(weights_path,
171 | key=lambda x: int(x.rsplit('_')[-1].rsplit('.')[0]),
172 | reverse=True)[0]
173 | print(f'using weights {weights_path}')
174 | return weights_path
175 |
176 |
177 | def init_weights(model):
178 | for name, module in model.named_modules():
179 | is_conv_layer = isinstance(module, nn.Conv2d)
180 |
181 | if is_conv_layer:
182 | nn.init.kaiming_uniform_(module.weight.data)
183 |
184 | if module.bias is not None:
185 | module.bias.data.zero_()
186 |
187 |
188 |
189 | def restart_from_checkpoint(ckp_path, run_variables=None, **kwargs):
190 | """
191 | Re-start from checkpoint
192 | """
193 | if not os.path.isfile(ckp_path):
194 | return
195 | print("Found checkpoint at {}".format(ckp_path))
196 |
197 | # open checkpoint file
198 | checkpoint = torch.load(ckp_path, map_location="cpu")
199 |
200 | # key is what to look for in the checkpoint file
201 | # value is the object to load
202 | # example: {'state_dict': model}
203 | for key, value in kwargs.items():
204 | if key in checkpoint and value is not None:
205 | try:
206 | msg = value.load_state_dict(checkpoint[key], strict=False)
207 | print("=> loaded {} from checkpoint '{}' with msg {}".format(key, ckp_path, msg))
208 | except TypeError:
209 | try:
210 | msg = value.load_state_dict(checkpoint[key])
211 | print("=> loaded {} from checkpoint '{}'".format(key, ckp_path))
212 | except ValueError:
213 | print("=> failed to load {} from checkpoint '{}'".format(key, ckp_path))
214 | else:
215 | print("=> failed to load {} from checkpoint '{}'".format(key, ckp_path))
216 |
217 | # re load variable important for the run
218 | if run_variables is not None:
219 | for var_name in run_variables:
220 | if var_name in checkpoint:
221 | run_variables[var_name] = checkpoint[var_name]
--------------------------------------------------------------------------------
/utils/pyt_utils.py:
--------------------------------------------------------------------------------
1 | # encoding: utf-8
2 | import os
3 | import sys
4 | import time
5 | import argparse
6 | from collections import OrderedDict, defaultdict
7 |
8 | import torch
9 | import torch.utils.model_zoo as model_zoo
10 | import torch.distributed as dist
11 |
12 | from .logger import get_logger
13 |
14 | logger = get_logger()
15 |
16 |
17 | def reduce_tensor(tensor, dst=0, op=dist.ReduceOp.SUM, world_size=1):
18 | tensor = tensor.clone()
19 | dist.reduce(tensor, dst, op)
20 | if dist.get_rank() == dst:
21 | tensor.div_(world_size)
22 |
23 | return tensor
24 |
25 |
26 | def all_reduce_tensor(tensor, op=dist.ReduceOp.SUM, world_size=1, norm=True):
27 | tensor = tensor.clone()
28 | dist.all_reduce(tensor, op)
29 | if norm:
30 | tensor.div_(world_size)
31 |
32 | return tensor
33 |
34 |
35 | def extant_file(x):
36 | if not os.path.exists(x):
37 | raise argparse.ArgumentTypeError("{0} does not exist".format(x))
38 | return x
39 |
40 |
--------------------------------------------------------------------------------