├── .gitignore ├── LICENSE ├── README.md ├── SCPMNet.py └── siou_plus_plus_loss.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 xdluo 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # [[MedIA2021](https://www.sciencedirect.com/science/article/abs/pii/S1361841521003327)]SCPM-Net:An Anchor-free 3D Lung Nodule Detection Network using SphereRepresentation and Center Points Matching 2 | * [***News***] Now, the training and testing code of a stronger version CPMNet & SCPMNet can be found [CPMNetV2](https://github.com/zunzhumu/CPMNetv2). 3 | * Code for the early accepted MICCAI2020 paper "CPM-Net: A 3D Center-Points Matching Network for Pulmonary Nodule Detection in CT Scans" ([MICCAI2020](https://link.springer.com/chapter/10.1007/978-3-030-59725-2_53)) and its journal extension pubished on **Medical Image Analysis** "SCPM-Net: An Anchor-free 3D Lung Nodule Detection Network using Sphere Representation and Center Points Matching" ([MedIA2021](https://www.sciencedirect.com/science/article/abs/pii/S1361841521003327)). 4 | 5 | If you find it's useful for your research, please consider to cite the following: 6 | 7 | @inproceedings{song2020cpm, 8 | title={CPM-Net: A 3D Center-Points Matching Network for Pulmonary Nodule Detection in CT Scans}, 9 | author={Song, Tao and Chen, Jieneng and Luo, Xiangde and Huang, Yechong and Liu, Xinglong and Huang, Ning and Chen, Yinan and Ye, Zhaoxiang and Sheng, Huaqiang and Zhang, Shaoting and others}, 10 | booktitle={International Conference on Medical Image Computing and Computer-Assisted Intervention}, 11 | pages={550--559}, 12 | year={2020}, 13 | organization={Springer} 14 | } 15 | 16 | @article{luo2021scpmnet, 17 | title={SCPM-Net: An anchor-free 3D lung nodule detection network using sphere representation and center points matching}, 18 | author={Luo, Xiangde and Song, Tao and Wang, Guotai and Chen, Jieneng and Chen, Yinan and Li, Kang and Metaxas, Dimitris N and Zhang, Shaoting}, 19 | journal={Medical Image Analysis}, 20 | volume={75}, 21 | pages={102287}, 22 | year={2022}, 23 | publisher={Elsevier} 24 | } 25 | 26 | 27 | The training code will be released soon, any questions please contact **[Xiangde](https://luoxd1996.github.io)**. 28 | -------------------------------------------------------------------------------- /SCPMNet.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | import torch.nn.functional as F 4 | import numpy as np 5 | 6 | 7 | def Conv_Block(in_planes, out_planes, stride=1): 8 | """3x3x3 convolution with batchnorm and relu""" 9 | return nn.Sequential(nn.Conv3d(in_planes, out_planes, kernel_size=3, stride=stride, 10 | padding=1, bias=False), 11 | nn.BatchNorm3d(out_planes), 12 | nn.ReLU(inplace=True)) 13 | 14 | 15 | def conv3x3x3(in_planes, out_planes, stride=1): 16 | """3x3x3 convolution with padding""" 17 | return nn.Conv3d(in_planes, out_planes, kernel_size=3, stride=stride, 18 | padding=1, bias=False) 19 | 20 | 21 | class Norm(nn.Module): 22 | def __init__(self, N): 23 | super(Norm, self).__init__() 24 | self.normal = nn.BatchNorm3d(N) 25 | 26 | def forward(self, x): 27 | return self.normal(x) 28 | 29 | 30 | class BasicBlock(nn.Module): 31 | expansion = 1 32 | 33 | def __init__(self, inplanes, planes, stride=1, downsample=None): 34 | super(BasicBlock, self).__init__() 35 | self.conv1 = conv3x3x3(inplanes, planes, stride) 36 | self.bn1 = Norm(planes) 37 | self.relu = nn.ReLU(inplace=True) 38 | self.conv2 = conv3x3x3(planes, planes) 39 | self.bn2 = Norm(planes) 40 | self.downsample = downsample 41 | self.stride = stride 42 | 43 | def forward(self, x): 44 | residual = x 45 | 46 | out = self.conv1(x) 47 | out = self.bn1(out) 48 | out = self.relu(out) 49 | 50 | out = self.conv2(out) 51 | out = self.bn2(out) 52 | 53 | if self.downsample is not None: 54 | residual = self.downsample(x) 55 | 56 | out += residual 57 | out = self.relu(out) 58 | 59 | return out 60 | 61 | 62 | class Bottleneck(nn.Module): 63 | expansion = 4 64 | 65 | def __init__(self, inplanes, planes, stride=1, downsample=None): 66 | super(Bottleneck, self).__init__() 67 | self.conv1 = nn.Conv3d(inplanes, planes, kernel_size=1, bias=False) 68 | self.bn1 = Norm(planes) 69 | self.conv2 = nn.Conv3d(planes, planes, kernel_size=3, stride=stride, 70 | padding=1, bias=False) 71 | self.bn2 = Norm(planes) 72 | self.conv3 = nn.Conv3d(planes, planes * 4, kernel_size=1, bias=False) 73 | self.bn3 = Norm(planes * 4) 74 | self.relu = nn.ReLU(inplace=True) 75 | self.downsample = downsample 76 | self.stride = stride 77 | 78 | def forward(self, x): 79 | residual = x 80 | 81 | out = self.conv1(x) 82 | out = self.bn1(out) 83 | out = self.relu(out) 84 | 85 | out = self.conv2(out) 86 | out = self.bn2(out) 87 | out = self.relu(out) 88 | 89 | out = self.conv3(out) 90 | out = self.bn3(out) 91 | 92 | if self.downsample is not None: 93 | residual = self.downsample(x) 94 | 95 | out += residual 96 | out = self.relu(out) 97 | 98 | return out 99 | 100 | 101 | class SAC(nn.Module): 102 | def __init__(self, input_channel, out_channel): 103 | super(SAC, self).__init__() 104 | 105 | self.conv_1 = nn.Conv3d( 106 | input_channel, out_channel, kernel_size=3, stride=1, padding=1) 107 | self.conv_3 = nn.Conv3d( 108 | input_channel, out_channel, kernel_size=3, stride=1, padding=2, dilation=2) 109 | self.conv_5 = nn.Conv3d( 110 | input_channel, out_channel, kernel_size=3, stride=1, padding=3, dilation=3) 111 | self.weights = nn.Parameter(torch.ones(3)) 112 | self.softmax = nn.Softmax(0) 113 | 114 | def forward(self, inputs): 115 | feat_1 = self.conv_1(inputs) 116 | feat_3 = self.conv_3(inputs) 117 | feat_5 = self.conv_5(inputs) 118 | weights = self.softmax(self.weights) 119 | feat = feat_1 * weights[0] + feat_3 * weights[1] + feat_5 * weights[2] 120 | return feat 121 | 122 | 123 | class Pyramid_3D(nn.Module): 124 | def __init__(self, C2_size, C3_size, C4_size, C5_size, feature_size=256, using_sac=False): 125 | super(Pyramid_3D, self).__init__() 126 | 127 | self.P5_1 = nn.Conv3d(C5_size, feature_size, 128 | kernel_size=1, stride=1, padding=0) 129 | self.P5_upsampled = nn.Upsample(scale_factor=2, mode='nearest') 130 | self.P5_2 = nn.Conv3d(feature_size, feature_size, kernel_size=3, stride=1, 131 | padding=1) if not using_sac else SAC(feature_size, feature_size) 132 | 133 | self.P4_1 = nn.Conv3d(C4_size, feature_size, 134 | kernel_size=1, stride=1, padding=0) 135 | self.P4_upsampled = nn.Upsample(scale_factor=2, mode='nearest') 136 | self.P4_2 = nn.Conv3d(feature_size, feature_size, kernel_size=3, stride=1, 137 | padding=1) if not using_sac else SAC(feature_size, feature_size) 138 | 139 | self.P3_1 = nn.Conv3d(C3_size, feature_size, 140 | kernel_size=1, stride=1, padding=0) 141 | self.P3_upsampled = nn.Upsample(scale_factor=2, mode='nearest') 142 | self.P3_2 = nn.Conv3d(feature_size, feature_size, kernel_size=3, stride=1, 143 | padding=1) if not using_sac else SAC(feature_size, feature_size) 144 | 145 | self.P2_1 = nn.Conv3d(C2_size, feature_size, 146 | kernel_size=1, stride=1, padding=0) 147 | self.P2_2 = nn.Conv3d(feature_size, feature_size, kernel_size=3, stride=1, 148 | padding=1) if not using_sac else SAC(feature_size, feature_size) 149 | 150 | def forward(self, inputs): 151 | C2, C3, C4, C5 = inputs 152 | 153 | P5_x = self.P5_1(C5) 154 | P5_upsampled_x = self.P5_upsampled(P5_x) 155 | P5_x = self.P5_2(P5_x) 156 | 157 | P4_x = self.P4_1(C4) 158 | P4_x = P5_upsampled_x + P4_x 159 | P4_upsampled_x = self.P4_upsampled(P4_x) 160 | P4_x = self.P4_2(P4_x) 161 | 162 | P3_x = self.P3_1(C3) 163 | P3_x = P3_x + P4_upsampled_x 164 | P3_upsampled_x = self.P3_upsampled(P3_x) 165 | P3_x = self.P3_2(P3_x) 166 | 167 | P2_x = self.P2_1(C2) 168 | P2_x = P2_x + P3_upsampled_x 169 | P2_x = self.P2_2(P2_x) 170 | 171 | return [P2_x, P3_x, P4_x, P5_x] 172 | 173 | 174 | class Attention_SE_CA(nn.Module): 175 | def __init__(self, channel): 176 | super(Attention_SE_CA, self).__init__() 177 | self.Global_Pool = nn.AdaptiveAvgPool3d(1) 178 | self.FC1 = nn.Sequential(nn.Linear(channel, channel), 179 | nn.ReLU(), ) 180 | self.FC2 = nn.Sequential(nn.Linear(channel, channel), 181 | nn.Sigmoid(), ) 182 | 183 | def forward(self, x): 184 | G = self.Global_Pool(x) 185 | G = G.view(G.size(0), -1) 186 | fc1 = self.FC1(G) 187 | fc2 = self.FC2(fc1) 188 | fc2 = torch.unsqueeze(fc2, 2) 189 | fc2 = torch.unsqueeze(fc2, 3) 190 | fc2 = torch.unsqueeze(fc2, 4) 191 | return fc2*x 192 | 193 | 194 | class SCPMNet(nn.Module): 195 | 196 | def __init__(self, block, layers): 197 | self.inplanes = 32 198 | super(SCPMNet, self).__init__() 199 | self.conv1 = nn.Conv3d(1, 32, kernel_size=3, 200 | stride=1, padding=1, bias=False) 201 | self.bn1 = Norm(32) 202 | self.conv2 = nn.Conv3d(32, 32, kernel_size=3, 203 | stride=1, padding=1, bias=False) 204 | self.bn2 = Norm(32) 205 | self.relu = nn.ReLU(inplace=True) 206 | self.maxpool = nn.MaxPool3d(kernel_size=3, stride=2, padding=1) 207 | self.layer1 = self._make_layer(block, 32, layers[0]) 208 | self.layer2 = self._make_layer(block, 64, layers[1], stride=2) 209 | self.layer3 = self._make_layer(block, 64, layers[2], stride=2) 210 | self.layer4 = self._make_layer(block, 64, layers[3], stride=2) 211 | self.atttion1 = Attention_SE_CA(32) 212 | self.atttion2 = Attention_SE_CA(32) 213 | self.atttion3 = Attention_SE_CA(64) 214 | self.atttion4 = Attention_SE_CA(64) 215 | self.conv_1 = Conv_Block(64 + 3, 64) 216 | self.conv_2 = Conv_Block(64 + 3, 64) 217 | self.conv_3 = Conv_Block(64 + 3, 64) 218 | self.conv_4 = Conv_Block(64 + 3, 64) 219 | self.conv_8x = Conv_Block(64, 64) 220 | self.conv_4x = Conv_Block(64, 64) 221 | self.conv_2x = Conv_Block(64, 64) 222 | self.convc = nn.Conv3d(64, 1, kernel_size=1, stride=1) 223 | self.convr = nn.Conv3d(64, 1, kernel_size=1, stride=1) 224 | self.convo = nn.Conv3d(64, 3, kernel_size=1, stride=1) 225 | if block == BasicBlock: 226 | fpn_sizes = [self.layer1[layers[0]-1].conv2.out_channels, self.layer2[layers[1]-1].conv2.out_channels, 227 | self.layer3[layers[2]-1].conv2.out_channels, self.layer4[layers[3]-1].conv2.out_channels] 228 | elif block == Bottleneck: 229 | fpn_sizes = [self.layer1[layers[0]-1].conv3.out_channels, self.layer2[layers[1]-1].conv3.out_channels, 230 | self.layer3[layers[2]-1].conv3.out_channels, self.layer4[layers[3]-1].conv3.out_channels] 231 | 232 | self.fpn = Pyramid_3D( 233 | fpn_sizes[0], fpn_sizes[1], fpn_sizes[2], fpn_sizes[3], feature_size=64) # 256 234 | 235 | for m in self.modules(): 236 | if isinstance(m, nn.BatchNorm3d): 237 | m.weight.data.fill_(1) 238 | m.bias.data.zero_() 239 | 240 | def _make_layer(self, block, planes, blocks, stride=1): 241 | downsample = None 242 | if stride != 1 or self.inplanes != planes * block.expansion: 243 | downsample = nn.Sequential( 244 | nn.Conv3d(self.inplanes, planes * block.expansion, 245 | kernel_size=1, stride=stride, bias=False), 246 | Norm(planes * block.expansion), 247 | ) 248 | 249 | layers = [] 250 | layers.append(block(self.inplanes, planes, stride, downsample)) 251 | self.inplanes = planes * block.expansion 252 | for i in range(1, blocks): 253 | layers.append(block(self.inplanes, planes)) 254 | 255 | return nn.Sequential(*layers) 256 | 257 | def forward(self, x, c_2, c_4, c_8, c_16): 258 | x = self.conv1(x) 259 | x = self.bn1(x) 260 | x = self.relu(x) 261 | x = self.conv2(x) 262 | x = self.bn2(x) 263 | x = self.relu(x) 264 | x = self.maxpool(x) 265 | x = self.atttion1(x) 266 | x1 = self.layer1(x) 267 | x1 = self.atttion2(x1) 268 | x2 = self.layer2(x1) 269 | x2 = self.atttion3(x2) 270 | x3 = self.layer3(x2) 271 | x3 = self.atttion4(x3) 272 | x4 = self.layer4(x3) 273 | feats = self.fpn([x1, x2, x3, x4]) 274 | feats[0] = torch.cat([feats[0], c_2], 1) 275 | feats[0] = self.conv_1(feats[0]) 276 | feats[1] = torch.cat([feats[1], c_4], 1) 277 | feats[1] = self.conv_2(feats[1]) 278 | feats[2] = torch.cat([feats[2], c_8], 1) 279 | feats[2] = self.conv_3(feats[2]) 280 | feats[3] = torch.cat([feats[3], c_16], 1) 281 | feats[3] = self.conv_4(feats[3]) 282 | 283 | feat_8x = F.upsample( 284 | feats[3], scale_factor=2, mode='nearest') + feats[2] 285 | feat_8x = self.conv_8x(feat_8x) 286 | feat_4x = F.upsample( 287 | feat_8x, scale_factor=2, mode='nearest') + feats[1] 288 | feat_4x = self.conv_4x(feat_4x) 289 | feat_2x = F.upsample(feat_4x, scale_factor=2, mode='nearest') 290 | feat_2x = self.conv_2x(feat_2x) 291 | Cls1 = self.convc(feats[0]) 292 | Cls2 = self.convc(feat_2x) 293 | Reg1 = self.convr(feats[0]) 294 | Reg2 = self.convr(feat_2x) 295 | Off1 = self.convo(feats[0]) 296 | Off2 = self.convo(feat_2x) 297 | output = {} 298 | output['Cls1'] = Cls1 299 | output['Reg1'] = Reg1 300 | output['Off1'] = Off1 301 | output['Cls2'] = Cls2 302 | output['Reg2'] = Reg2 303 | output['Off2'] = Off2 304 | return output 305 | 306 | 307 | def scpmnet18(**kwargs): 308 | """Using ResNet-18 as backbone for SCPMNet. 309 | """ 310 | model = SCPMNet(BasicBlock, [2, 2, 3, 3], **kwargs) # [2,2,2,2] 311 | return model 312 | 313 | 314 | if __name__ == '__main__': 315 | device = torch.device("cuda") 316 | input = torch.ones(1, 1, 96, 96, 96).to(device) 317 | coord_2 = torch.ones(1, 3, 48, 48, 48).to(device) 318 | coord_4 = torch.ones(1, 3, 24, 24, 24).to(device) 319 | coord_8 = torch.ones(1, 3, 12, 12, 12).to(device) 320 | coord_16 = torch.ones(1, 3, 6, 6, 6).to(device) 321 | label = torch.ones(1, 3, 5).to(device) 322 | net = scpmnet18().to(device) 323 | net.eval() 324 | out = net(input, coord_2, coord_4, coord_8, coord_16) 325 | print(out) 326 | print('finish') 327 | -------------------------------------------------------------------------------- /siou_plus_plus_loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | def SIoU_Plus_Plus_3D(gz, gy, gx, gr, pz, py, px, pr, eps=1e-4): 4 | dist = torch.sqrt((gz-pz)**2 + (gy-py)**2 + (gx-px)**2+eps) 5 | if gr + pr >= dist: 6 | if gr + dist < pr: 7 | pr = torch.clamp(pr, min=1e-8) 8 | siou = (gr / pr) ** 3 9 | elif pr + dist < gr: 10 | gr = torch.clamp(gr, min=1e-8) 11 | siou = (pr / gr) ** 3 12 | else: 13 | cos1 = (gr ** 2 + dist ** 2 - pr ** 2) / (2 * gr * dist + eps) 14 | h1 = gr * (1 - cos1) 15 | v1 = 3.1415926 * gr * h1 ** 2 - 3.1415926 * h1 ** 3 / 3 16 | cos2 = (pr ** 2 + dist ** 2 - gr ** 2) / (2 * pr * dist + eps) 17 | h2 = pr * (1 - cos2) 18 | v2 = 3.1415926 * pr * h2 ** 2 - 3.1415926 * h2 ** 3 / 3 19 | ua = (3.1415926 * 4 * (pr ** 3 + gr ** 3) / 3) - (v1 + v2) 20 | ua = torch.clamp(ua, min=1e-8) 21 | eta = torch.acos((gr ** 2 - dist ** 2 + pr ** 2) / (2 * gr * pr + eps)) / 3.1415926 22 | siou = (v1 + v2) / ua - eta 23 | else: 24 | siou = torch.tensor(0).float() 25 | siou.requires_grad=True 26 | dist_ratio = dist / (dist + pr + gr) 27 | sdiou = siou - dist_ratio 28 | sdiou_loss = 1.0 - sdiou 29 | return sdiou_loss 30 | --------------------------------------------------------------------------------