├── LICENSE ├── README.md ├── ResUnet3d_pytorch.py ├── Unet2d_pytorch.py ├── Unet3d_pytorch.py ├── compute3DSSIM.py ├── dicom2Nii.py ├── extract23DPatch4MultiModalImg.py ├── extract23DPatch4SingleModalImg.py ├── loss_functions.py ├── nnBuildUnits.py ├── runCTRecon.py ├── runCTRecon3d.py ├── runTesting_Recon.py ├── runTesting_Reconv2.py ├── shuffleDataAmongSubjects_2d.py ├── shuffleDataAmongSubjects_3d.py └── utils.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2018 Dong Nie 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # medSynthesisV1 2 | This is a copy of package for medical image synthesis work with LRes-ResUnet and GAN (wgan-gp) in pytorch framework, which is a simple extension of our paper Medical Image Synthesis with Deep Convolutional Adversarial Networks. You are also welcome to visit our Tensorflow version through this link: 3 | https://github.com/ginobilinie/medSynthesis 4 | 5 | # How to run the pytorch code 6 | The main entrance for the code is runCTRecon.py or runCTRecon3d.py (currently, the 2d/2.5d version is fine to run, and the discriminator for 3d version currently only support BCE loss since I suggest you use W-distance (WGAN-GP) since it is easier to tune the hyper-parameters for this one). 7 | 8 | I suppose you have installed:
9 |      python 2.x (e.g., 2.7.x; for python 3.x, change some codes: .next() to .\__next\__(); xrange()->range()) 10 |
     pytorch (>=0.3.0) 11 |
     simpleITK 12 |
     numpy 13 | 14 | Steps to run the code: 15 | 1. use extract23DPatch4MultiModalImg.py (or extract23DPatch4SingleModalImg.py for single input modality) to extract patches for training and validation images (as limited annotated data can be acquired in medical image fields, we usually use patch as the training unit), and save as hdf5 format. Put all these h5 files into two folders (training, validation), and remeber the path to these h5 files 16 | 2. choose the generator (1. UNet, 2. ResUNet, 3. UNet_LRes and 4. ResUNet_LRes (default, 4)
17 |       Note: for low-dose xx to standard-dose xx (such as low-dose pet to standard pet, low CT to High CT...) or low resolution xx to high resolution xx(e.g., 3T->7T), we suggest use ResUNet_LRes(4) which contains a long-skip connection. 18 |
     If the input modality and the output modality is quite different, we suggest use UNet_LRes(3)) 19 | 3. choose the discriminator if you want to use the GAN-framework (we provide wgan-gp and the basic GAN) 20 | 4. choose the loss function (1. LossL1, 2. lossRTL1, 3. MSE (default)) 21 | 5. set up the hyper-parameters in the runCTRecon.py (or 3d with runCTRecon3d.py)
22 |       You have to place the paths to the training h5 files (path_patients_h5), the validation h5 files (path_patients_h5_test) and also the path to the testing images (path_test ) in the this python file
23 |       Also, you have to setup all other config choices, such as network choice, disciminator choise, loss functions (including some additional loss, i.e., gradient difference loss), initial learing rate, decrease learning rate during training even with adam optimal solver and so on 24 | 6. run the code: python runCTRecon.py (or 3d with runCTRecon3d.py) for training stage 25 | 7. run the code: python runTesting_Reconv2.py for testing stage 26 | 27 | If it is helpful to your work, please cite the papers: 28 | # Cite 29 |
30 | @inproceedings{nie2017medical,
31 |   title={Medical image synthesis with context-aware generative adversarial networks},
32 |   author={Nie, Dong and Trullo, Roger and Lian, Jun and Petitjean, Caroline and Ruan, Su and Wang, Qian and Shen, Dinggang},
33 |   booktitle={International Conference on Medical Image Computing and Computer-Assisted Intervention},
34 |   pages={417--425},
35 |   year={2017},
36 |   organization={Springer}
37 | }
38 | @article{nie2018medical,
39 |   title={Medical Image Synthesis with Deep Convolutional Adversarial Networks},
40 |   author={Nie, Dong and Trullo, Roger and Lian, Jun and Wang, Li and Petitjean, Caroline and Ruan, Su and Wang, Qian and Shen, Dinggang},
41 |   journal={IEEE Transactions on Biomedical Engineering},
42 |   year={2018},
43 |   publisher={IEEE}
44 | }
45 | 
46 | 47 | # Dataset 48 | BTW, you can download a real medical image synthesis dataset for reconstructing standard-dose PET from low-dose PET via this link: https://www.aapm.org/GrandChallenge/LowDoseCT/ 49 | 50 | Also, there are some MRI synthesis datasets available: 51 | http://brain-development.org/ixi-dataset/ 52 | 53 | Tumor prediction: 54 | https://www.med.upenn.edu/sbia/brats2018/data.html 55 | 56 | fastMRI: 57 | https://fastmri.med.nyu.edu/ 58 | 59 | ISLES2015: 60 | http://www.isles-challenge.org/ISLES2015/ 61 | 62 | # Upload your brain MRI, Predict corresponding CT 63 | 64 | If you're interested in it, you can send me a copy of your data (for example, brain MRI), and I'll inference the CT and send a copy of predicted CT to you. My email is dongnie.at.cs.unc.edu. 65 | 66 | # A parallel training together with an adversarial confidence learning version will uploaded soon. 67 | 68 | # License 69 | medSynthesis is released under the MIT License (refer to the LICENSE file for details). 70 | -------------------------------------------------------------------------------- /ResUnet3d_pytorch.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.autograd import Variable 5 | import torch.nn.init as init 6 | import numpy as np 7 | 8 | ''' 9 | Ordinary UNet Conv Block 10 | ''' 11 | class UNetConvBlock(nn.Module): 12 | def __init__(self, in_size, out_size, kernel_size=3, activation=F.relu): 13 | super(UNetConvBlock, self).__init__() 14 | self.conv = nn.Conv3d(in_size, out_size, kernel_size, stride=1, padding=1) 15 | self.bn = nn.BatchNorm3d(out_size) 16 | self.conv2 = nn.Conv3d(out_size, out_size, kernel_size, stride=1, padding=1) 17 | self.bn2 = nn.BatchNorm3d(out_size) 18 | self.activation = activation 19 | 20 | 21 | init.xavier_uniform(self.conv.weight, gain = np.sqrt(2.0)) 22 | init.constant(self.conv.bias,0) 23 | init.xavier_uniform(self.conv2.weight, gain = np.sqrt(2.0)) 24 | init.constant(self.conv2.bias,0) 25 | def forward(self, x): 26 | out = self.activation(self.bn(self.conv(x))) 27 | out = self.activation(self.bn2(self.conv2(out))) 28 | 29 | return out 30 | 31 | 32 | ''' 33 | two-layer residual unit: two conv with BN/relu and identity mapping 34 | ''' 35 | class residualUnit(nn.Module): 36 | def __init__(self, in_size, out_size, kernel_size=3,stride=1, padding=1, activation=F.relu): 37 | super(residualUnit, self).__init__() 38 | self.conv1 = nn.Conv3d(in_size, out_size, kernel_size, stride=1, padding=1) 39 | init.xavier_uniform(self.conv1.weight, gain = np.sqrt(2.0)) #or gain=1 40 | init.constant(self.conv1.bias, 0) 41 | self.conv2 = nn.Conv3d(out_size, out_size, kernel_size, stride=1, padding=1) 42 | init.xavier_uniform(self.conv2.weight, gain = np.sqrt(2.0)) #or gain=1 43 | init.constant(self.conv2.bias, 0) 44 | self.activation = activation 45 | self.bn1 = nn.BatchNorm3d(out_size) 46 | self.bn2 = nn.BatchNorm3d(out_size) 47 | self.in_size = in_size 48 | self.out_size = out_size 49 | if in_size != out_size: 50 | self.convX = nn.Conv3d(in_size, out_size, kernel_size=1, stride=1, padding=0) 51 | self.bnX = nn.BatchNorm3d(out_size) 52 | 53 | def forward(self, x): 54 | out1 = self.activation(self.bn1(self.conv1(x))) 55 | out2 = self.activation(self.bn1(self.conv2(out1))) 56 | if self.in_size!=self.out_size: 57 | bridge = self.activation(self.bnX(self.convX(x))) 58 | output = torch.add(out2, bridge) 59 | 60 | return output 61 | 62 | 63 | ''' 64 | Ordinary UNet-Up Conv Block 65 | ''' 66 | class UNetUpBlock(nn.Module): 67 | def __init__(self, in_size, out_size, kernel_size=3, activation=F.relu, space_dropout=False): 68 | super(UNetUpBlock, self).__init__() 69 | self.up = nn.ConvTranspose3d(in_size, out_size, 2, stride=2) 70 | self.bnup = nn.BatchNorm3d(out_size) 71 | self.conv = nn.Conv3d(in_size, out_size, kernel_size, stride=1, padding=1) 72 | self.bn = nn.BatchNorm3d(out_size) 73 | self.conv2 = nn.Conv3d(out_size, out_size, kernel_size, stride=1, padding=1) 74 | self.bn2 = nn.BatchNorm3d(out_size) 75 | self.activation = activation 76 | init.xavier_uniform(self.up.weight, gain = np.sqrt(2.0)) 77 | init.constant(self.up.bias,0) 78 | init.xavier_uniform(self.conv.weight, gain = np.sqrt(2.0)) 79 | init.constant(self.conv.bias,0) 80 | init.xavier_uniform(self.conv2.weight, gain = np.sqrt(2.0)) 81 | init.constant(self.conv2.bias,0) 82 | 83 | def center_crop(self, layer, target_size): 84 | batch_size, n_channels, layer_width, layer_height, layer_depth = layer.size() 85 | xy1 = (layer_width - target_size) // 2 86 | return layer[:, :, xy1:(xy1 + target_size), xy1:(xy1 + target_size)] 87 | 88 | def forward(self, x, bridge): 89 | up = self.up(x) 90 | up = self.activation(self.bnup(up)) 91 | crop1 = self.center_crop(bridge, up.size()[2]) 92 | out = torch.cat([up, crop1], 1) 93 | 94 | out = self.activation(self.bn(self.conv(out))) 95 | out = self.activation(self.bn2(self.conv2(out))) 96 | 97 | return out 98 | 99 | 100 | 101 | ''' 102 | Ordinary Residual UNet-Up Conv Block 103 | ''' 104 | class UNetUpResBlock(nn.Module): 105 | def __init__(self, in_size, out_size, kernel_size=3, activation=F.relu, space_dropout=False): 106 | super(UNetUpResBlock, self).__init__() 107 | self.up = nn.ConvTranspose3d(in_size, out_size, 2, stride=2) 108 | self.bnup = nn.BatchNorm3d(out_size) 109 | 110 | init.xavier_uniform(self.up.weight, gain = np.sqrt(2.0)) 111 | init.constant(self.up.bias,0) 112 | 113 | self.activation = activation 114 | 115 | self.resUnit = residualUnit(in_size, out_size, kernel_size = kernel_size) 116 | 117 | def center_crop(self, layer, target_size): 118 | batch_size, n_channels, layer_width, layer_height, layer_depth = layer.size() 119 | xy1 = (layer_width - target_size) // 2 120 | return layer[:, :, xy1:(xy1 + target_size), xy1:(xy1 + target_size), xy1:(xy1 + target_size)] 121 | 122 | def forward(self, x, bridge): 123 | #print 'x.shape: ',x.shape 124 | up = self.activation(self.bnup(self.up(x))) 125 | #crop1 = self.center_crop(bridge, up.size()[2]) 126 | #print 'up.shape: ',up.shape, ' crop1.shape: ',crop1.shape 127 | crop1 = bridge 128 | out = torch.cat([up, crop1], 1) 129 | 130 | out = self.resUnit(out) 131 | # out = self.activation(self.bn2(self.conv2(out))) 132 | 133 | return out 134 | 135 | 136 | ''' 137 | Ordinary UNet 138 | ''' 139 | class UNet(nn.Module): 140 | def __init__(self, in_channel = 1, n_classes = 4): 141 | super(UNet, self).__init__() 142 | # self.imsize = imsize 143 | 144 | self.activation = F.relu 145 | 146 | self.pool1 = nn.MaxPool3d(2) 147 | self.pool2 = nn.MaxPool3d(2) 148 | self.pool3 = nn.MaxPool3d(2) 149 | # self.pool4 = nn.MaxPool3d(2) 150 | 151 | 152 | self.conv_block1_64 = UNetConvBlock(in_channel, 32) 153 | self.conv_block64_128 = UNetConvBlock(32, 64) 154 | self.conv_block128_256 = UNetConvBlock(64, 128) 155 | self.conv_block256_512 = UNetConvBlock(128, 256) 156 | # self.conv_block512_1024 = UNetConvBlock(512, 1024) 157 | # this kind of symmetric design is awesome, it automatically solves the number of channels during upsamping 158 | # self.up_block1024_512 = UNetUpBlock(1024, 512) 159 | self.up_block512_256 = UNetUpBlock(256, 128) 160 | self.up_block256_128 = UNetUpBlock(128, 64) 161 | self.up_block128_64 = UNetUpBlock(64, 32) 162 | 163 | self.last = nn.Conv3d(32, n_classes, 1, stride=1) 164 | 165 | 166 | def forward(self, x): 167 | # print 'line 70 ',x.size() 168 | block1 = self.conv_block1_64(x) 169 | pool1 = self.pool1(block1) 170 | 171 | block2 = self.conv_block64_128(pool1) 172 | pool2 = self.pool2(block2) 173 | 174 | block3 = self.conv_block128_256(pool2) 175 | pool3 = self.pool3(block3) 176 | 177 | block4 = self.conv_block256_512(pool3) 178 | # pool4 = self.pool4(block4) 179 | # 180 | # block5 = self.conv_block512_1024(pool4) 181 | # 182 | # up1 = self.up_block1024_512(block5, block4) 183 | 184 | up2 = self.up_block512_256(block4, block3) 185 | 186 | up3 = self.up_block256_128(up2, block2) 187 | 188 | up4 = self.up_block128_64(up3, block1) 189 | 190 | return self.last(up4) 191 | 192 | 193 | ''' 194 | Ordinary ResUNet 195 | ''' 196 | 197 | 198 | class ResUNet(nn.Module): 199 | def __init__(self, in_channel=1, n_classes=4): 200 | super(ResUNet, self).__init__() 201 | # self.imsize = imsize 202 | 203 | self.activation = F.relu 204 | 205 | self.pool1 = nn.MaxPool3d(2) 206 | self.pool2 = nn.MaxPool3d(2) 207 | self.pool3 = nn.MaxPool3d(2) 208 | # self.pool4 = nn.MaxPool3d(2) 209 | 210 | self.conv_block1_64 = UNetConvBlock(in_channel, 32) 211 | self.conv_block64_128 = residualUnit(32, 64) 212 | self.conv_block128_256 = residualUnit(64, 128) 213 | self.conv_block256_512 = residualUnit(128, 256) 214 | # self.conv_block512_1024 = residualUnit(512, 1024) 215 | # this kind of symmetric design is awesome, it automatically solves the number of channels during upsamping 216 | # self.up_block1024_512 = UNetUpResBlock(1024, 512) 217 | self.up_block512_256 = UNetUpResBlock(256, 128) 218 | self.up_block256_128 = UNetUpResBlock(128, 64) 219 | self.up_block128_64 = UNetUpResBlock(64, 32) 220 | 221 | self.last = nn.Conv3d(32, n_classes, 1, stride=1) 222 | 223 | def forward(self, x): 224 | # print 'line 70 ',x.size() 225 | block1 = self.conv_block1_64(x) 226 | pool1 = self.pool1(block1) 227 | 228 | block2 = self.conv_block64_128(pool1) 229 | pool2 = self.pool2(block2) 230 | 231 | block3 = self.conv_block128_256(pool2) 232 | pool3 = self.pool3(block3) 233 | 234 | block4 = self.conv_block256_512(pool3) 235 | # pool4 = self.pool4(block4) 236 | # 237 | # block5 = self.conv_block512_1024(pool4) 238 | # 239 | # up1 = self.up_block1024_512(block5, block4) 240 | 241 | up2 = self.up_block512_256(block4, block3) 242 | 243 | up3 = self.up_block256_128(up2, block2) 244 | 245 | up4 = self.up_block128_64(up3, block1) 246 | 247 | return self.last(up4) 248 | 249 | 250 | ''' 251 | UNet (lateral connection) with long-skip residual connection (from 1st to last layer) 252 | ''' 253 | class UNet_LRes(nn.Module): 254 | def __init__(self, in_channel = 1, n_classes = 4): 255 | super(UNet_LRes, self).__init__() 256 | # self.imsize = imsize 257 | 258 | self.activation = F.relu 259 | 260 | self.pool1 = nn.MaxPool3d(2) 261 | self.pool2 = nn.MaxPool3d(2) 262 | self.pool3 = nn.MaxPool3d(2) 263 | # self.pool4 = nn.MaxPool3d(2) 264 | 265 | self.conv_block1_64 = UNetConvBlock(in_channel, 32) 266 | self.conv_block64_128 = UNetConvBlock(32, 64) 267 | self.conv_block128_256 = UNetConvBlock(64, 128) 268 | self.conv_block256_512 = UNetConvBlock(128, 256) 269 | # self.conv_block512_1024 = UNetConvBlock(512, 1024) 270 | # this kind of symmetric design is awesome, it automatically solves the number of channels during upsamping 271 | # self.up_block1024_512 = UNetUpBlock(1024, 512) 272 | self.up_block512_256 = UNetUpBlock(256, 128) 273 | self.up_block256_128 = UNetUpBlock(128, 64) 274 | self.up_block128_64 = UNetUpBlock(64, 32) 275 | 276 | self.last = nn.Conv3d(32, n_classes, 1, stride=1) 277 | 278 | 279 | def forward(self, x, res_x): 280 | # print 'line 70 ',x.size() 281 | block1 = self.conv_block1_64(x) 282 | pool1 = self.pool1(block1) 283 | 284 | block2 = self.conv_block64_128(pool1) 285 | pool2 = self.pool2(block2) 286 | 287 | block3 = self.conv_block128_256(pool2) 288 | pool3 = self.pool3(block3) 289 | 290 | block4 = self.conv_block256_512(pool3) 291 | # pool4 = self.pool4(block4) 292 | 293 | # block5 = self.conv_block512_1024(pool4) 294 | # 295 | # up1 = self.up_block1024_512(block5, block4) 296 | 297 | up2 = self.up_block512_256(block4, block3) 298 | 299 | up3 = self.up_block256_128(up2, block2) 300 | 301 | up4 = self.up_block128_64(up3, block1) 302 | 303 | last = self.last(up4) 304 | #print 'res_x.shape is ',res_x.shape,' and last.shape is ',last.shape 305 | if len(res_x.shape)==3: 306 | res_x = res_x.unsqueeze(1) 307 | out = torch.add(last,res_x) 308 | 309 | #print 'out.shape is ',out.shape 310 | return out 311 | 312 | 313 | ''' 314 | ResUNet (lateral connection) with long-skip residual connection (from 1st to last layer) 315 | ''' 316 | 317 | 318 | class ResUNet_LRes(nn.Module): 319 | def __init__(self, in_channel=1, n_classes=4, dp_prob=0): 320 | super(ResUNet_LRes, self).__init__() 321 | # self.imsize = imsize 322 | 323 | self.activation = F.relu 324 | 325 | self.pool1 = nn.MaxPool3d(2) 326 | self.pool2 = nn.MaxPool3d(2) 327 | self.pool3 = nn.MaxPool3d(2) 328 | # self.pool4 = nn.MaxPool3d(2) 329 | 330 | self.conv_block1_64 = UNetConvBlock(in_channel, 32) 331 | self.conv_block64_128 = residualUnit(32, 64) 332 | self.conv_block128_256 = residualUnit(64, 128) 333 | self.conv_block256_512 = residualUnit(128, 256) 334 | # self.conv_block512_1024 = residualUnit(512, 1024) 335 | # this kind of symmetric design is awesome, it automatically solves the number of channels during upsamping 336 | # self.up_block1024_512 = UNetUpResBlock(1024, 512) 337 | self.up_block512_256 = UNetUpResBlock(256, 128) 338 | self.up_block256_128 = UNetUpResBlock(128, 64) 339 | self.up_block128_64 = UNetUpResBlock(64, 32) 340 | self.Dropout = nn.Dropout3d(p=dp_prob) 341 | self.last = nn.Conv3d(32, n_classes, 1, stride=1) 342 | 343 | def forward(self, x, res_x): 344 | # print 'line 70 ',x.size() 345 | block1 = self.conv_block1_64(x) 346 | # print 'block1.shape: ', block1.shape 347 | pool1 = self.pool1(block1) 348 | # print 'pool1.shape: ', block1.shape 349 | pool1_dp = self.Dropout(pool1) 350 | # print 'pool1_dp.shape: ', pool1_dp.shape 351 | block2 = self.conv_block64_128(pool1_dp) 352 | pool2 = self.pool2(block2) 353 | 354 | pool2_dp = self.Dropout(pool2) 355 | 356 | block3 = self.conv_block128_256(pool2_dp) 357 | pool3 = self.pool3(block3) 358 | 359 | pool3_dp = self.Dropout(pool3) 360 | 361 | block4 = self.conv_block256_512(pool3_dp) 362 | # pool4 = self.pool4(block4) 363 | # 364 | # pool4_dp = self.Dropout(pool4) 365 | # 366 | # # block5 = self.conv_block512_1024(pool4_dp) 367 | # 368 | # up1 = self.up_block1024_512(block5, block4) 369 | 370 | up2 = self.up_block512_256(block4, block3) 371 | 372 | up3 = self.up_block256_128(up2, block2) 373 | 374 | up4 = self.up_block128_64(up3, block1) 375 | 376 | last = self.last(up4) 377 | # print 'res_x.shape is ',res_x.shape,' and last.shape is ',last.shape 378 | if len(res_x.shape) == 3: 379 | res_x = res_x.unsqueeze(1) 380 | out = torch.add(last, res_x) 381 | 382 | # print 'out.shape is ',out.shape 383 | return out 384 | 385 | 386 | 387 | ''' 388 | Discriminator for the reconstruction project 389 | ''' 390 | class Discriminator(nn.Module): 391 | def __init__(self): 392 | super(Discriminator,self).__init__() 393 | #you can make abbreviations for conv and fc, this is not necessary 394 | #class torch.nn.Conv2d(in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True) 395 | self.conv1 = nn.Conv3d(1,32,9) 396 | self.bn1 = nn.BatchNorm3d(32) 397 | self.conv2 = nn.Conv3d(32,64,5) 398 | self.bn2 = nn.BatchNorm3d(64) 399 | self.conv3 = nn.Conv3d(64,64,5) 400 | self.bn3 = nn.BatchNorm3d(64) 401 | self.fc1 = nn.Linear(64*4*4,512) 402 | #self.bn3= nn.BatchNorm1d(6) 403 | self.fc2 = nn.Linear(512,64) 404 | self.fc3 = nn.Linear(64,1) 405 | 406 | 407 | def forward(self,x): 408 | # print 'line 114: x shape: ',x.size() 409 | #x = F.max_pool3d(F.relu(self.bn1(self.conv1(x))),(2,2,2))#conv->relu->pool 410 | x = F.max_pool3d(F.relu(self.conv1(x)),(2,2,2))#conv->relu->pool 411 | 412 | x = F.max_pool3d(F.relu(self.conv2(x)),(2,2,2))#conv->relu->pool 413 | 414 | x = F.max_pool3d(F.relu(self.conv3(x)),(2,2,2))#conv->relu->pool 415 | 416 | #reshape them into Vector, review ruturned tensor shares the same data but have different shape, same as reshape in matlab 417 | x = x.view(-1,self.num_of_flat_features(x)) 418 | x = F.relu(self.fc1(x)) 419 | 420 | x = F.relu(self.fc2(x)) 421 | 422 | x = self.fc3(x) 423 | 424 | #x = F.sigmoid(x) 425 | #print 'min,max,mean of x in 0st layer',x.min(),x.max(),x.mean() 426 | return x 427 | 428 | def num_of_flat_features(self,x): 429 | size=x.size()[1:]#we donot consider the batch dimension 430 | num_features=1 431 | for s in size: 432 | num_features*=s 433 | return num_features 434 | -------------------------------------------------------------------------------- /Unet2d_pytorch.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.autograd import Variable 5 | import torch.nn.init as init 6 | import numpy as np 7 | 8 | ''' 9 | Ordinary UNet Conv Block 10 | ''' 11 | class UNetConvBlock(nn.Module): 12 | def __init__(self, in_size, out_size, kernel_size=3, activation=F.relu): 13 | super(UNetConvBlock, self).__init__() 14 | self.conv = nn.Conv2d(in_size, out_size, kernel_size, stride=1, padding=1) 15 | self.bn = nn.BatchNorm2d(out_size) 16 | self.conv2 = nn.Conv2d(out_size, out_size, kernel_size, stride=1, padding=1) 17 | self.bn2 = nn.BatchNorm2d(out_size) 18 | self.activation = activation 19 | 20 | 21 | init.xavier_uniform(self.conv.weight, gain = np.sqrt(2.0)) 22 | init.constant(self.conv.bias,0) 23 | init.xavier_uniform(self.conv2.weight, gain = np.sqrt(2.0)) 24 | init.constant(self.conv2.bias,0) 25 | def forward(self, x): 26 | out = self.activation(self.bn(self.conv(x))) 27 | out = self.activation(self.bn2(self.conv2(out))) 28 | 29 | return out 30 | 31 | 32 | ''' 33 | two-layer residual unit: two conv with BN/relu and identity mapping 34 | ''' 35 | class residualUnit(nn.Module): 36 | def __init__(self, in_size, out_size, kernel_size=3,stride=1, padding=1, activation=F.relu): 37 | super(residualUnit, self).__init__() 38 | self.conv1 = nn.Conv2d(in_size, out_size, kernel_size, stride=1, padding=1) 39 | init.xavier_uniform(self.conv1.weight, gain = np.sqrt(2.0)) #or gain=1 40 | init.constant(self.conv1.bias, 0) 41 | self.conv2 = nn.Conv2d(out_size, out_size, kernel_size, stride=1, padding=1) 42 | init.xavier_uniform(self.conv2.weight, gain = np.sqrt(2.0)) #or gain=1 43 | init.constant(self.conv2.bias, 0) 44 | self.activation = activation 45 | self.bn1 = nn.BatchNorm2d(out_size) 46 | self.bn2 = nn.BatchNorm2d(out_size) 47 | self.in_size = in_size 48 | self.out_size = out_size 49 | if in_size != out_size: 50 | self.convX = nn.Conv2d(in_size, out_size, kernel_size=1, stride=1, padding=0) 51 | self.bnX = nn.BatchNorm2d(out_size) 52 | 53 | def forward(self, x): 54 | out1 = self.activation(self.bn1(self.conv1(x))) 55 | out2 = self.activation(self.bn2(self.conv2(out1))) 56 | if self.in_size!=self.out_size: 57 | bridge = self.activation(self.bnX(self.convX(x))) 58 | output = torch.add(out2, bridge) 59 | 60 | return output 61 | 62 | 63 | ''' 64 | Ordinary UNet-Up Conv Block 65 | ''' 66 | class UNetUpBlock(nn.Module): 67 | def __init__(self, in_size, out_size, kernel_size=3, activation=F.relu, space_dropout=False): 68 | super(UNetUpBlock, self).__init__() 69 | self.up = nn.ConvTranspose2d(in_size, out_size, 2, stride=2) 70 | self.bnup = nn.BatchNorm2d(out_size) 71 | self.conv = nn.Conv2d(in_size, out_size, kernel_size, stride=1, padding=1) 72 | self.bn = nn.BatchNorm2d(out_size) 73 | self.conv2 = nn.Conv2d(out_size, out_size, kernel_size, stride=1, padding=1) 74 | self.bn2 = nn.BatchNorm2d(out_size) 75 | self.activation = activation 76 | init.xavier_uniform(self.up.weight, gain = np.sqrt(2.0)) 77 | init.constant(self.up.bias,0) 78 | init.xavier_uniform(self.conv.weight, gain = np.sqrt(2.0)) 79 | init.constant(self.conv.bias,0) 80 | init.xavier_uniform(self.conv2.weight, gain = np.sqrt(2.0)) 81 | init.constant(self.conv2.bias,0) 82 | 83 | def center_crop(self, layer, target_size): 84 | batch_size, n_channels, layer_width, layer_height = layer.size() 85 | xy1 = (layer_width - target_size) // 2 86 | return layer[:, :, xy1:(xy1 + target_size), xy1:(xy1 + target_size)] 87 | 88 | def forward(self, x, bridge): 89 | up = self.up(x) 90 | up = self.activation(self.bnup(up)) 91 | crop1 = self.center_crop(bridge, up.size()[2]) 92 | out = torch.cat([up, crop1], 1) 93 | 94 | out = self.activation(self.bn(self.conv(out))) 95 | out = self.activation(self.bn2(self.conv2(out))) 96 | 97 | return out 98 | 99 | 100 | 101 | ''' 102 | Ordinary Residual UNet-Up Conv Block 103 | ''' 104 | class UNetUpResBlock(nn.Module): 105 | def __init__(self, in_size, out_size, kernel_size=3, activation=F.relu, space_dropout=False): 106 | super(UNetUpResBlock, self).__init__() 107 | self.up = nn.ConvTranspose2d(in_size, out_size, 2, stride=2) 108 | self.bnup = nn.BatchNorm2d(out_size) 109 | 110 | init.xavier_uniform(self.up.weight, gain = np.sqrt(2.0)) 111 | init.constant(self.up.bias,0) 112 | 113 | self.activation = activation 114 | 115 | self.resUnit = residualUnit(in_size, out_size, kernel_size = kernel_size) 116 | 117 | def center_crop(self, layer, target_size): 118 | batch_size, n_channels, layer_width, layer_height = layer.size() 119 | xy1 = (layer_width - target_size) // 2 120 | return layer[:, :, xy1:(xy1 + target_size), xy1:(xy1 + target_size)] 121 | 122 | def forward(self, x, bridge): 123 | up = self.activation(self.bnup(self.up(x))) 124 | crop1 = self.center_crop(bridge, up.size()[2]) 125 | out = torch.cat([up, crop1], 1) 126 | 127 | out = self.resUnit(out) 128 | # out = self.activation(self.bn2(self.conv2(out))) 129 | 130 | return out 131 | 132 | 133 | ''' 134 | Ordinary UNet 135 | ''' 136 | class UNet(nn.Module): 137 | def __init__(self, in_channel = 1, n_classes = 4): 138 | super(UNet, self).__init__() 139 | # self.imsize = imsize 140 | 141 | self.activation = F.relu 142 | 143 | self.pool1 = nn.MaxPool2d(2) 144 | self.pool2 = nn.MaxPool2d(2) 145 | self.pool3 = nn.MaxPool2d(2) 146 | self.pool4 = nn.MaxPool2d(2) 147 | 148 | self.conv_block1_64 = UNetConvBlock(in_channel, 64) 149 | self.conv_block64_128 = UNetConvBlock(64, 128) 150 | self.conv_block128_256 = UNetConvBlock(128, 256) 151 | self.conv_block256_512 = UNetConvBlock(256, 512) 152 | self.conv_block512_1024 = UNetConvBlock(512, 1024) 153 | # this kind of symmetric design is awesome, it automatically solves the number of channels during upsamping 154 | self.up_block1024_512 = UNetUpBlock(1024, 512) 155 | self.up_block512_256 = UNetUpBlock(512, 256) 156 | self.up_block256_128 = UNetUpBlock(256, 128) 157 | self.up_block128_64 = UNetUpBlock(128, 64) 158 | 159 | self.last = nn.Conv2d(64, n_classes, 1, stride=1) 160 | 161 | 162 | def forward(self, x): 163 | # print 'line 70 ',x.size() 164 | block1 = self.conv_block1_64(x) 165 | pool1 = self.pool1(block1) 166 | 167 | block2 = self.conv_block64_128(pool1) 168 | pool2 = self.pool2(block2) 169 | 170 | block3 = self.conv_block128_256(pool2) 171 | pool3 = self.pool3(block3) 172 | 173 | block4 = self.conv_block256_512(pool3) 174 | pool4 = self.pool4(block4) 175 | 176 | block5 = self.conv_block512_1024(pool4) 177 | 178 | up1 = self.up_block1024_512(block5, block4) 179 | 180 | up2 = self.up_block512_256(up1, block3) 181 | 182 | up3 = self.up_block256_128(up2, block2) 183 | 184 | up4 = self.up_block128_64(up3, block1) 185 | 186 | return self.last(up4) 187 | 188 | 189 | ''' 190 | Ordinary ResUNet 191 | ''' 192 | 193 | 194 | class ResUNet(nn.Module): 195 | def __init__(self, in_channel=1, n_classes=4): 196 | super(ResUNet, self).__init__() 197 | # self.imsize = imsize 198 | 199 | self.activation = F.relu 200 | 201 | self.pool1 = nn.MaxPool2d(2) 202 | self.pool2 = nn.MaxPool2d(2) 203 | self.pool3 = nn.MaxPool2d(2) 204 | self.pool4 = nn.MaxPool2d(2) 205 | 206 | self.conv_block1_64 = UNetConvBlock(in_channel, 64) 207 | self.conv_block64_128 = residualUnit(64, 128) 208 | self.conv_block128_256 = residualUnit(128, 256) 209 | self.conv_block256_512 = residualUnit(256, 512) 210 | self.conv_block512_1024 = residualUnit(512, 1024) 211 | # this kind of symmetric design is awesome, it automatically solves the number of channels during upsamping 212 | self.up_block1024_512 = UNetUpResBlock(1024, 512) 213 | self.up_block512_256 = UNetUpResBlock(512, 256) 214 | self.up_block256_128 = UNetUpResBlock(256, 128) 215 | self.up_block128_64 = UNetUpResBlock(128, 64) 216 | 217 | self.last = nn.Conv2d(64, n_classes, 1, stride=1) 218 | 219 | def forward(self, x): 220 | # print 'line 70 ',x.size() 221 | block1 = self.conv_block1_64(x) 222 | pool1 = self.pool1(block1) 223 | 224 | block2 = self.conv_block64_128(pool1) 225 | pool2 = self.pool2(block2) 226 | 227 | block3 = self.conv_block128_256(pool2) 228 | pool3 = self.pool3(block3) 229 | 230 | block4 = self.conv_block256_512(pool3) 231 | pool4 = self.pool4(block4) 232 | 233 | block5 = self.conv_block512_1024(pool4) 234 | 235 | up1 = self.up_block1024_512(block5, block4) 236 | 237 | up2 = self.up_block512_256(up1, block3) 238 | 239 | up3 = self.up_block256_128(up2, block2) 240 | 241 | up4 = self.up_block128_64(up3, block1) 242 | 243 | return self.last(up4) 244 | 245 | 246 | ''' 247 | UNet (lateral connection) with long-skip residual connection (from 1st to last layer) 248 | ''' 249 | class UNet_LRes(nn.Module): 250 | def __init__(self, in_channel = 1, n_classes = 4): 251 | super(UNet_LRes, self).__init__() 252 | # self.imsize = imsize 253 | 254 | self.activation = F.relu 255 | 256 | self.pool1 = nn.MaxPool2d(2) 257 | self.pool2 = nn.MaxPool2d(2) 258 | self.pool3 = nn.MaxPool2d(2) 259 | self.pool4 = nn.MaxPool2d(2) 260 | 261 | self.conv_block1_64 = UNetConvBlock(in_channel, 64) 262 | self.conv_block64_128 = UNetConvBlock(64, 128) 263 | self.conv_block128_256 = UNetConvBlock(128, 256) 264 | self.conv_block256_512 = UNetConvBlock(256, 512) 265 | self.conv_block512_1024 = UNetConvBlock(512, 1024) 266 | # this kind of symmetric design is awesome, it automatically solves the number of channels during upsamping 267 | self.up_block1024_512 = UNetUpBlock(1024, 512) 268 | self.up_block512_256 = UNetUpBlock(512, 256) 269 | self.up_block256_128 = UNetUpBlock(256, 128) 270 | self.up_block128_64 = UNetUpBlock(128, 64) 271 | 272 | self.last = nn.Conv2d(64, n_classes, 1, stride=1) 273 | 274 | 275 | def forward(self, x, res_x): 276 | # print 'line 70 ',x.size() 277 | block1 = self.conv_block1_64(x) 278 | pool1 = self.pool1(block1) 279 | 280 | block2 = self.conv_block64_128(pool1) 281 | pool2 = self.pool2(block2) 282 | 283 | block3 = self.conv_block128_256(pool2) 284 | pool3 = self.pool3(block3) 285 | 286 | block4 = self.conv_block256_512(pool3) 287 | pool4 = self.pool4(block4) 288 | 289 | block5 = self.conv_block512_1024(pool4) 290 | 291 | up1 = self.up_block1024_512(block5, block4) 292 | 293 | up2 = self.up_block512_256(up1, block3) 294 | 295 | up3 = self.up_block256_128(up2, block2) 296 | 297 | up4 = self.up_block128_64(up3, block1) 298 | 299 | last = self.last(up4) 300 | #print 'res_x.shape is ',res_x.shape,' and last.shape is ',last.shape 301 | if len(res_x.shape)==3: 302 | res_x = res_x.unsqueeze(1) 303 | out = torch.add(last,res_x) 304 | 305 | #print 'out.shape is ',out.shape 306 | return out 307 | 308 | 309 | ''' 310 | ResUNet (lateral connection) with long-skip residual connection (from 1st to last layer) 311 | ''' 312 | 313 | 314 | class ResUNet_LRes(nn.Module): 315 | def __init__(self, in_channel=1, n_classes=4, dp_prob=0): 316 | super(ResUNet_LRes, self).__init__() 317 | # self.imsize = imsize 318 | 319 | self.activation = F.relu 320 | 321 | self.pool1 = nn.MaxPool2d(2) 322 | self.pool2 = nn.MaxPool2d(2) 323 | self.pool3 = nn.MaxPool2d(2) 324 | self.pool4 = nn.MaxPool2d(2) 325 | 326 | self.conv_block1_64 = UNetConvBlock(in_channel, 64) 327 | self.conv_block64_128 = residualUnit(64, 128) 328 | self.conv_block128_256 = residualUnit(128, 256) 329 | self.conv_block256_512 = residualUnit(256, 512) 330 | self.conv_block512_1024 = residualUnit(512, 1024) 331 | # this kind of symmetric design is awesome, it automatically solves the number of channels during upsamping 332 | self.up_block1024_512 = UNetUpResBlock(1024, 512) 333 | self.up_block512_256 = UNetUpResBlock(512, 256) 334 | self.up_block256_128 = UNetUpResBlock(256, 128) 335 | self.up_block128_64 = UNetUpResBlock(128, 64) 336 | self.Dropout = nn.Dropout2d(p=dp_prob) 337 | self.last = nn.Conv2d(64, n_classes, 1, stride=1) 338 | 339 | def forward(self, x, res_x): 340 | # print 'line 70 ',x.size() 341 | block1 = self.conv_block1_64(x) 342 | pool1 = self.pool1(block1) 343 | 344 | pool1_dp = self.Dropout(pool1) 345 | 346 | block2 = self.conv_block64_128(pool1_dp) 347 | pool2 = self.pool2(block2) 348 | 349 | pool2_dp = self.Dropout(pool2) 350 | 351 | block3 = self.conv_block128_256(pool2_dp) 352 | pool3 = self.pool3(block3) 353 | 354 | pool3_dp = self.Dropout(pool3) 355 | 356 | block4 = self.conv_block256_512(pool3_dp) 357 | pool4 = self.pool4(block4) 358 | 359 | pool4_dp = self.Dropout(pool4) 360 | 361 | block5 = self.conv_block512_1024(pool4_dp) 362 | 363 | up1 = self.up_block1024_512(block5, block4) 364 | 365 | up2 = self.up_block512_256(up1, block3) 366 | 367 | up3 = self.up_block256_128(up2, block2) 368 | 369 | up4 = self.up_block128_64(up3, block1) 370 | 371 | last = self.last(up4) 372 | # print 'res_x.shape is ',res_x.shape,' and last.shape is ',last.shape 373 | if len(res_x.shape) == 3: 374 | res_x = res_x.unsqueeze(1) 375 | out = torch.add(last, res_x) 376 | 377 | # print 'out.shape is ',out.shape 378 | return out 379 | 380 | 381 | 382 | ''' 383 | Discriminator for the reconstruction project 384 | ''' 385 | class Discriminator(nn.Module): 386 | def __init__(self): 387 | super(Discriminator,self).__init__() 388 | #you can make abbreviations for conv and fc, this is not necessary 389 | #class torch.nn.Conv2d(in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True) 390 | self.conv1 = nn.Conv2d(1,32,(9,9)) 391 | self.bn1 = nn.BatchNorm2d(32) 392 | self.conv2 = nn.Conv2d(32,64,(5,5)) 393 | self.bn2 = nn.BatchNorm2d(64) 394 | self.conv3 = nn.Conv2d(64,64,(5,5)) 395 | self.bn3 = nn.BatchNorm2d(64) 396 | self.fc1 = nn.Linear(64*4*4,512) 397 | #self.bn3= nn.BatchNorm1d(6) 398 | self.fc2 = nn.Linear(512,64) 399 | self.fc3 = nn.Linear(64,1) 400 | 401 | 402 | def forward(self,x): 403 | # print 'line 114: x shape: ',x.size() 404 | #x = F.max_pool2d(F.relu(self.bn1(self.conv1(x))),(2,2))#conv->relu->pool 405 | x = F.max_pool2d(F.relu(self.conv1(x)),(2,2))#conv->relu->pool 406 | 407 | x = F.max_pool2d(F.relu(self.conv2(x)),(2,2))#conv->relu->pool 408 | 409 | x = F.max_pool2d(F.relu(self.conv3(x)),(2,2))#conv->relu->pool 410 | 411 | #reshape them into Vector, review ruturned tensor shares the same data but have different shape, same as reshape in matlab 412 | x = x.view(-1,self.num_of_flat_features(x)) 413 | x = F.relu(self.fc1(x)) 414 | 415 | x = F.relu(self.fc2(x)) 416 | 417 | x = self.fc3(x) 418 | 419 | #x = F.sigmoid(x) 420 | #print 'min,max,mean of x in 0st layer',x.min(),x.max(),x.mean() 421 | return x 422 | 423 | def num_of_flat_features(self,x): 424 | size=x.size()[1:]#we donot consider the batch dimension 425 | num_features=1 426 | for s in size: 427 | num_features*=s 428 | return num_features 429 | -------------------------------------------------------------------------------- /Unet3d_pytorch.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | class UNet3D(nn.Module): 5 | def __init__(self, in_channel, n_classes): 6 | self.in_channel = in_channel 7 | self.n_classes = n_classes 8 | super(UNet3D, self).__init__() 9 | self.ec0 = self.encoder(self.in_channel, 32, bias=False, batchnorm=False) 10 | self.ec1 = self.encoder(32, 64, bias=False, batchnorm=False) 11 | self.ec2 = self.encoder(64, 64, bias=False, batchnorm=False) 12 | self.ec3 = self.encoder(64, 128, bias=False, batchnorm=False) 13 | self.ec4 = self.encoder(128, 128, bias=False, batchnorm=False) 14 | self.ec5 = self.encoder(128, 256, bias=False, batchnorm=False) 15 | self.ec6 = self.encoder(256, 256, bias=False, batchnorm=False) 16 | self.ec7 = self.encoder(256, 512, bias=False, batchnorm=False) 17 | 18 | self.pool0 = nn.MaxPool3d(2) 19 | self.pool1 = nn.MaxPool3d(2) 20 | self.pool2 = nn.MaxPool3d(2) 21 | 22 | self.dc9 = self.decoder(512, 512, kernel_size=4, stride=2, padding=1, bias=False) 23 | self.dc8 = self.decoder(256 + 512, 256, kernel_size=3, stride=1, padding=1, bias=False) 24 | self.dc7 = self.decoder(256, 256, kernel_size=3, stride=1, padding=1, bias=False) 25 | self.dc6 = self.decoder(256, 256, kernel_size=4, stride=2, padding=1, bias=False) 26 | self.dc5 = self.decoder(128 + 256, 128, kernel_size=3, stride=1, padding=1, bias=False) 27 | self.dc4 = self.decoder(128, 128, kernel_size=3, stride=1, padding=1, bias=False) 28 | self.dc3 = self.decoder(128, 128, kernel_size=4, stride=2, padding=1, bias=False) 29 | self.dc2 = self.decoder(64 + 128, 64, kernel_size=3, stride=1, padding=1, bias=False) 30 | self.dc1 = self.decoder(64, 64, kernel_size=3, stride=1, padding=1, bias=False) 31 | self.dc0 = self.decoder(64, n_classes, kernel_size=1, stride=1, bias=False) 32 | 33 | 34 | def encoder(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1, 35 | bias=True, batchnorm=False): 36 | if batchnorm: 37 | layer = nn.Sequential( 38 | nn.Conv3d(in_channels, out_channels, kernel_size, stride=stride, padding=padding, bias=bias), 39 | nn.BatchNorm2d(out_channels), 40 | nn.ReLU()) 41 | else: 42 | layer = nn.Sequential( 43 | nn.Conv3d(in_channels, out_channels, kernel_size, stride=stride, padding=padding, bias=bias), 44 | nn.ReLU()) 45 | return layer 46 | 47 | 48 | def decoder(self, in_channels, out_channels, kernel_size, stride=1, padding=0, 49 | output_padding=0, bias=True): 50 | layer = nn.Sequential( 51 | nn.ConvTranspose3d(in_channels, out_channels, kernel_size, stride=stride, 52 | padding=padding, output_padding=output_padding, bias=bias), 53 | nn.ReLU()) 54 | return layer 55 | 56 | def forward(self, x): 57 | e0 = self.ec0(x) 58 | syn0 = self.ec1(e0) 59 | e1 = self.pool0(syn0) 60 | e2 = self.ec2(e1) 61 | syn1 = self.ec3(e2) 62 | # print 'syn size1: ',syn1.size() 63 | del e0, e1, e2 64 | 65 | e3 = self.pool1(syn1) 66 | e4 = self.ec4(e3) 67 | syn2 = self.ec5(e4) 68 | # print 'syn size2: ',syn2.size() 69 | del e3, e4 70 | 71 | e5 = self.pool2(syn2) 72 | e6 = self.ec6(e5) 73 | e7 = self.ec7(e6) 74 | # print 'e7: ',e7.size() 75 | del e5, e6 76 | dc9 = self.dc9(e7) 77 | # print 'dc9: ',dc9.size() 78 | d9 = torch.cat((self.dc9(e7), syn2),dim=1) 79 | del e7, syn2 80 | 81 | d8 = self.dc8(d9) 82 | d7 = self.dc7(d8) 83 | del d9, d8 84 | 85 | d6 = torch.cat((self.dc6(d7), syn1),dim=1) 86 | del d7, syn1 87 | 88 | d5 = self.dc5(d6) 89 | d4 = self.dc4(d5) 90 | del d6, d5 91 | 92 | d3 = torch.cat((self.dc3(d4), syn0),dim=1) 93 | del d4, syn0 94 | 95 | d2 = self.dc2(d3) 96 | d1 = self.dc1(d2) 97 | del d3, d2 98 | 99 | d0 = self.dc0(d1) 100 | return d0 101 | -------------------------------------------------------------------------------- /compute3DSSIM.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Target: Compute structure similarity (SSIM) between two 3D volumes 3 | Created on Jan, 22th 2018 4 | Author: Dong Nie 5 | 6 | reference from: http://simpleitk-prototype.readthedocs.io/en/latest/user_guide/plot_image.html 7 | ''' 8 | 9 | import SimpleITK as sitk 10 | 11 | from multiprocessing import Pool 12 | import os 13 | import h5py 14 | import numpy as np 15 | import scipy.io as scio 16 | from morpologicalTransformation import denoiseImg_closing, denoiseImg_isolation 17 | from skimage import measure 18 | 19 | 20 | path = '/shenlab/lab_stor5/dongnie/3T7T/results/' 21 | 22 | 23 | def main(): 24 | ids = [14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29] 25 | ids = range(0, 30) 26 | ids = [1] 27 | for id in ids: 28 | # datafn = os.path.join(path,'Case%02d.mhd'%id) 29 | # outdatafn = os.path.join(path,'Case%02d.nii.gz'%id) 30 | # 31 | # dataOrg = sitk.ReadImage(datafn) 32 | # dataMat = sitk.GetArrayFromImage(dataOrg) 33 | # #gtMat=np.transpose(gtMat,(2,1,0)) 34 | # dataVol = sitk.GetImageFromArray(dataMat) 35 | # sitk.WriteImage(dataVol,outdatafn) 36 | 37 | # datafn = os.path.join(path, 'img1.mhd') 38 | # dataOrg = sitk.ReadImage(datafn) 39 | # spacing = dataOrg.GetSpacing() 40 | # origin = dataOrg.GetOrigin() 41 | # direction = dataOrg.GetDirection() 42 | # dataMat = sitk.GetArrayFromImage(dataOrg) 43 | 44 | gtfn = os.path.join(path, 'S1to1_7t.nii.gz') 45 | gtOrg = sitk.ReadImage(gtfn) 46 | gtMat = sitk.GetArrayFromImage(gtOrg) 47 | # gtMat=np.transpose(gtMat,(2,1,0)) 48 | gtMat = gtMat.astype(np.float32) 49 | gtMat = (gtMat - np.amin(gtMat))/(np.amax(gtMat)-np.amin(gtMat)) 50 | print np.amax(gtMat),',', np.amin(gtMat), 'dtype: ',gtMat.dtype, ',',gtMat.shape 51 | 52 | prefn = os.path.join(path,'preSub1_1112_195000.nii.gz') 53 | preOrg = sitk.ReadImage(prefn) 54 | preMat = sitk.GetArrayFromImage(preOrg) 55 | preMat = preMat.astype(np.float32) 56 | preMat = (preMat - np.amin(preMat))/(np.amax(preMat)-np.amin(preMat)) 57 | # preMat = np.transpose(preMat,(2,0,1)) 58 | print np.amax(preMat),',', np.amin(preMat), 'dtype:',preMat.dtype, ',',preMat.shape 59 | 60 | 61 | ssim_3d_sk = measure.compare_ssim(gtMat, preMat, multichannel=True, gaussian_weights=True, data_range=1.0, 62 | use_sample_covariance=False) 63 | print ssim_3d_sk 64 | 65 | # ssim_3d_sk = measure.structural_similarity(gtMat, preMat, multichannel=True, gaussian_weights=True, data_range=1.0, 66 | # use_sample_covariance=False) 67 | # gtMat1 = denoiseImg_closing(gtMat, kernel=np.ones((20, 20, 20))) 68 | # gtMat2 = gtMat + gtMat1 69 | # gtMat2[np.where(gtMat2 > 1)] = 1 70 | # gtMat = gtMat2 71 | # gtMat = denoiseImg_isolation(gtMat, struct=np.ones((3, 3, 3))) 72 | # 73 | # gtMat = gtMat.astype(np.uint8) 74 | 75 | # ind1 = np.where((gtMat==1)&(preMat==1)) 76 | # preMat[ind1] = 0 77 | # ind2 = np.where((gtMat==2)&(preMat==2)) 78 | # preMat[ind2] = 0 79 | # ind3 = np.where((gtMat==3)&(preMat==3)) 80 | # preMat[ind3] = 0 81 | # errorMat = preMat 82 | # 83 | # outgtfn = os.path.join(path, 'sgm_errormap_sub1.nii.gz') 84 | # errorVol = sitk.GetImageFromArray(errorMat) 85 | # errorVol.SetSpacing(spacing) 86 | # errorVol.SetOrigin(origin) 87 | # errorVol.SetDirection(direction) 88 | # sitk.WriteImage(errorVol, outgtfn) 89 | 90 | 91 | # 92 | # prefn='preSub%d_as32_v12.nii'%id 93 | # preOrg=sitk.ReadImage(prefn) 94 | # preMat=sitk.GetArrayFromImage(preOrg) 95 | # preMat=np.transpose(preMat,(2,1,0)) 96 | # preVol=sitk.GetImageFromArra(preMat) 97 | # sitk.WriteImage(preVol,prefn) 98 | 99 | 100 | if __name__ == '__main__': 101 | main() 102 | -------------------------------------------------------------------------------- /dicom2Nii.py: -------------------------------------------------------------------------------- 1 | ''' 2 | 05/02, at Chapel Hill 3 | Dong 4 | 5 | convert dicom series to nifti format 6 | ''' 7 | import numpy 8 | import SimpleITK as sitk 9 | import os 10 | from doctest import SKIP 11 | 12 | 13 | class ScanFile(object): 14 | def __init__(self,directory,prefix=None,postfix=None): 15 | self.directory=directory 16 | self.prefix=prefix 17 | self.postfix=postfix 18 | 19 | def scan_files(self): 20 | files_list=[] 21 | 22 | for dirpath,dirnames,filenames in os.walk(self.directory): 23 | ''''' 24 | dirpath is a string, the path to the directory. 25 | dirnames is a list of the names of the subdirectories in dirpath (excluding '.' and '..'). 26 | filenames is a list of the names of the non-directory files in dirpath. 27 | ''' 28 | for special_file in filenames: 29 | if self.postfix: 30 | special_file.endswith(self.postfix) 31 | files_list.append(os.path.join(dirpath,special_file)) 32 | elif self.prefix: 33 | special_file.startswith(self.prefix) 34 | files_list.append(os.path.join(dirpath,special_file)) 35 | else: 36 | files_list.append(os.path.join(dirpath,special_file)) 37 | 38 | return files_list 39 | 40 | def scan_subdir(self): 41 | subdir_list=[] 42 | for dirpath,dirnames,files in os.walk(self.directory): 43 | subdir_list.append(dirpath) 44 | return subdir_list 45 | 46 | 47 | def main(): 48 | path='/home/dongnie/warehouse/pelvicSeg/newData/pelvic_0118/' 49 | subpath='atkinson_lafayette' 50 | outfn=subpath+'.nii.gz' 51 | inputdir=path+subpath 52 | scan=ScanFile(path) 53 | subdirs=scan.scan_subdir() 54 | for subdir in subdirs: 55 | if subdir==path or subdir=='..': 56 | continue 57 | 58 | print 'subdir is, ',subdir 59 | 60 | ss=subdir.split('/') 61 | print 'ss is, ',ss, 'and s7 is, ',ss[7] 62 | 63 | outfn=ss[7]+'.nii.gz' 64 | 65 | reader = sitk.ImageSeriesReader() 66 | 67 | dicom_names = reader.GetGDCMSeriesFileNames(subdir) 68 | reader.SetFileNames(dicom_names) 69 | 70 | image = reader.Execute() 71 | 72 | size = image.GetSize() 73 | print( "Image size:", size[0], size[1], size[2] ) 74 | 75 | print( "Writing image:", outfn) 76 | 77 | sitk.WriteImage(image,outfn) 78 | 79 | 80 | if __name__ == '__main__': 81 | main() 82 | -------------------------------------------------------------------------------- /extract23DPatch4MultiModalImg.py: -------------------------------------------------------------------------------- 1 | 2 | ''' 3 | Target: Crop patches for kinds of medical images, such as hdr, nii, mha, mhd, raw and so on, and store them as hdf5 files 4 | for single-scale patches 5 | Created in June, 2016 6 | Author: Dong Nie 7 | ''' 8 | 9 | 10 | 11 | import SimpleITK as sitk 12 | 13 | from multiprocessing import Pool 14 | import os, argparse 15 | import h5py 16 | import numpy as np 17 | 18 | parser = argparse.ArgumentParser(description="PyTorch InfantSeg") 19 | parser.add_argument("--how2normalize", type=int, default=6, help="how to normalize the data") 20 | 21 | global opt 22 | opt = parser.parse_args() 23 | 24 | 25 | d1=5 26 | d2=64 27 | d3=64 28 | dFA=[d1,d2,d3] # size of patches of input data 29 | dSeg=[1,64,64] # size of pathes of label data 30 | step1=1 31 | step2=32 32 | step3=32 33 | step=[step1,step2,step3] 34 | 35 | 36 | class ScanFile(object): 37 | def __init__(self,directory,prefix=None,postfix=None): 38 | self.directory=directory 39 | self.prefix=prefix 40 | self.postfix=postfix 41 | 42 | def scan_files(self): 43 | files_list=[] 44 | 45 | for dirpath,dirnames,filenames in os.walk(self.directory): 46 | ''''' 47 | dirpath is a string, the path to the directory. 48 | dirnames is a list of the names of the subdirectories in dirpath (excluding '.' and '..'). 49 | filenames is a list of the names of the non-directory files in dirpath. 50 | ''' 51 | for special_file in filenames: 52 | if self.postfix: 53 | if special_file.endswith(self.postfix): 54 | files_list.append(os.path.join(dirpath,special_file)) 55 | elif self.prefix: 56 | if special_file.startswith(self.prefix): 57 | files_list.append(os.path.join(dirpath,special_file)) 58 | else: 59 | files_list.append(os.path.join(dirpath,special_file)) 60 | 61 | return files_list 62 | 63 | def scan_subdir(self): 64 | subdir_list=[] 65 | for dirpath,dirnames,files in os.walk(self.directory): 66 | subdir_list.append(dirpath) 67 | return subdir_list 68 | 69 | 70 | 71 | ''' 72 | Actually, we donot need it any more, this is useful to generate hdf5 database 73 | ''' 74 | def extractPatch4OneSubject(matFA, matMR, matSeg, matMask, fileID ,d, step, rate): 75 | 76 | eps=5e-2 77 | rate1=1.0/2 78 | rate2=1.0/4 79 | [row,col,leng]=matFA.shape 80 | cubicCnt=0 81 | estNum=40000 82 | trainFA=np.zeros([estNum,1, dFA[0],dFA[1],dFA[2]],dtype=np.float16) 83 | trainSeg=np.zeros([estNum,1,dSeg[0],dSeg[1],dSeg[2]],dtype=np.float16) 84 | trainMR=np.zeros([estNum,1,dFA[0],dFA[1],dFA[2]],dtype=np.float16) 85 | 86 | print 'trainFA shape, ',trainFA.shape 87 | #to padding for input 88 | margin1=(dFA[0]-dSeg[0])/2 89 | margin2=(dFA[1]-dSeg[1])/2 90 | margin3=(dFA[2]-dSeg[2])/2 91 | cubicCnt=0 92 | marginD=[margin1,margin2,margin3] 93 | print 'matFA shape is ',matFA.shape 94 | matFAOut=np.zeros([row+2*marginD[0],col+2*marginD[1],leng+2*marginD[2]],dtype=np.float16) 95 | print 'matFAOut shape is ',matFAOut.shape 96 | matFAOut[marginD[0]:row+marginD[0],marginD[1]:col+marginD[1],marginD[2]:leng+marginD[2]]=matFA 97 | 98 | matMROut=np.zeros([row+2*marginD[0],col+2*marginD[1],leng+2*marginD[2]],dtype=np.float16) 99 | print 'matMROut shape is ',matMROut.shape 100 | matMROut[marginD[0]:row+marginD[0],marginD[1]:col+marginD[1],marginD[2]:leng+marginD[2]]=matMR 101 | 102 | matSegOut=np.zeros([row+2*marginD[0],col+2*marginD[1],leng+2*marginD[2]],dtype=np.float16) 103 | matSegOut[marginD[0]:row+marginD[0],marginD[1]:col+marginD[1],marginD[2]:leng+marginD[2]]=matSeg 104 | 105 | 106 | matMaskOut=np.zeros([row+2*marginD[0],col+2*marginD[1],leng+2*marginD[2]],dtype=np.float16) 107 | matMaskOut[marginD[0]:row+marginD[0],marginD[1]:col+marginD[1],marginD[2]:leng+marginD[2]]=matMask 108 | 109 | #for mageFA, enlarge it by padding 110 | if margin1!=0: 111 | matFAOut[0:marginD[0],marginD[1]:col+marginD[1],marginD[2]:leng+marginD[2]]=matFA[marginD[0]-1::-1,:,:] #reverse 0:marginD[0] 112 | matFAOut[row+marginD[0]:matFAOut.shape[0],marginD[1]:col+marginD[1],marginD[2]:leng+marginD[2]]=matFA[matFA.shape[0]-1:row-marginD[0]-1:-1,:,:] #we'd better flip it along the 1st dimension 113 | if margin2!=0: 114 | matFAOut[marginD[0]:row+marginD[0],0:marginD[1],marginD[2]:leng+marginD[2]]=matFA[:,marginD[1]-1::-1,:] #we'd flip it along the 2nd dimension 115 | matFAOut[marginD[0]:row+marginD[0],col+marginD[1]:matFAOut.shape[1],marginD[2]:leng+marginD[2]]=matFA[:,matFA.shape[1]-1:col-marginD[1]-1:-1,:] #we'd flip it along the 2nd dimension 116 | if margin3!=0: 117 | matFAOut[marginD[0]:row+marginD[0],marginD[1]:col+marginD[1],0:marginD[2]]=matFA[:,:,marginD[2]-1::-1] #we'd better flip it along the 3rd dimension 118 | matFAOut[marginD[0]:row+marginD[0],marginD[1]:col+marginD[1],marginD[2]+leng:matFAOut.shape[2]]=matFA[:,:,matFA.shape[2]-1:leng-marginD[2]-1:-1] 119 | 120 | #for matMR, enlarge it by padding 121 | if margin1!=0: 122 | matMROut[0:marginD[0],marginD[1]:col+marginD[1],marginD[2]:leng+marginD[2]]=matMR[marginD[0]-1::-1,:,:] #reverse 0:marginD[0] 123 | matMROut[row+marginD[0]:matMROut.shape[0],marginD[1]:col+marginD[1],marginD[2]:leng+marginD[2]]=matMR[matMR.shape[0]-1:row-marginD[0]-1:-1,:,:] #we'd better flip it along the 1st dimension 124 | if margin2!=0: 125 | matMROut[marginD[0]:row+marginD[0],0:marginD[1],marginD[2]:leng+marginD[2]]=matMR[:,marginD[1]-1::-1,:] #we'd flip it along the 2nd dimension 126 | matMROut[marginD[0]:row+marginD[0],col+marginD[1]:matMROut.shape[1],marginD[2]:leng+marginD[2]]=matMR[:,matMR.shape[1]-1:col-marginD[1]-1:-1,:] #we'd flip it along the 2nd dimension 127 | if margin3!=0: 128 | matMROut[marginD[0]:row+marginD[0],marginD[1]:col+marginD[1],0:marginD[2]]=matMR[:,:,marginD[2]-1::-1] #we'd better flip it along the 3rd dimension 129 | matMROut[marginD[0]:row+marginD[0],marginD[1]:col+marginD[1],marginD[2]+leng:matMROut.shape[2]]=matMR[:,:,matMR.shape[2]-1:leng-marginD[2]-1:-1] 130 | 131 | #for matseg, enlarge it by padding 132 | if margin1!=0: 133 | matSegOut[0:marginD[0],marginD[1]:col+marginD[1],marginD[2]:leng+marginD[2]]=matSeg[marginD[0]-1::-1,:,:] #reverse 0:marginD[0] 134 | matSegOut[row+marginD[0]:matSegOut.shape[0],marginD[1]:col+marginD[1],marginD[2]:leng+marginD[2]]=matSeg[matSeg.shape[0]-1:row-marginD[0]-1:-1,:,:] #we'd better flip it along the 1st dimension 135 | if margin2!=0: 136 | matSegOut[marginD[0]:row+marginD[0],0:marginD[1],marginD[2]:leng+marginD[2]]=matSeg[:,marginD[1]-1::-1,:] #we'd flip it along the 2nd dimension 137 | matSegOut[marginD[0]:row+marginD[0],col+marginD[1]:matSegOut.shape[1],marginD[2]:leng+marginD[2]]=matSeg[:,matSeg.shape[1]-1:col-marginD[1]-1:-1,:] #we'd flip it along the 2nd dimension 138 | if margin3!=0: 139 | matSegOut[marginD[0]:row+marginD[0],marginD[1]:col+marginD[1],0:marginD[2]]=matSeg[:,:,marginD[2]-1::-1] #we'd better flip it along the 3rd dimension 140 | matSegOut[marginD[0]:row+marginD[0],marginD[1]:col+marginD[1],marginD[2]+leng:matSegOut.shape[2]]=matSeg[:,:,matSeg.shape[2]-1:leng-marginD[2]-1:-1] 141 | 142 | #for matseg, enlarge it by padding 143 | if margin1!=0: 144 | matMaskOut[0:marginD[0],marginD[1]:col+marginD[1],marginD[2]:leng+marginD[2]]=matMask[marginD[0]-1::-1,:,:] #reverse 0:marginD[0] 145 | matMaskOut[row+marginD[0]:matMaskOut.shape[0],marginD[1]:col+marginD[1],marginD[2]:leng+marginD[2]]=matMask[matMask.shape[0]-1:row-marginD[0]-1:-1,:,:] #we'd better flip it along the 1st dimension 146 | if margin2!=0: 147 | matMaskOut[marginD[0]:row+marginD[0],0:marginD[1],marginD[2]:leng+marginD[2]]=matMask[:,marginD[1]-1::-1,:] #we'd flip it along the 2nd dimension 148 | matMaskOut[marginD[0]:row+marginD[0],col+marginD[1]:matMaskOut.shape[1],marginD[2]:leng+marginD[2]]=matMask[:,matMask.shape[1]-1:col-marginD[1]-1:-1,:] #we'd flip it along the 2nd dimension 149 | if margin3!=0: 150 | matMaskOut[marginD[0]:row+marginD[0],marginD[1]:col+marginD[1],0:marginD[2]]=matMask[:,:,marginD[2]-1::-1] #we'd better flip it along the 3rd dimension 151 | matMaskOut[marginD[0]:row+marginD[0],marginD[1]:col+marginD[1],marginD[2]+leng:matMaskOut.shape[2]]=matMask[:,:,matMask.shape[2]-1:leng-marginD[2]-1:-1] 152 | 153 | dsfactor = rate 154 | 155 | for i in range(0,row-dSeg[0],step[0]): 156 | for j in range(0,col-dSeg[1],step[1]): 157 | for k in range(0,leng-dSeg[2],step[2]): 158 | volMask = matMaskOut[i:i+dSeg[0],j:j+dSeg[1],k:k+dSeg[2]] 159 | if np.sum(volMask)maxV)] = maxV 308 | # print 'maxV is: ',np.ndarray.max(mrimg) 309 | # mu=np.mean(mrimg) # we should have a fixed std and mean 310 | # std = np.std(mrimg) 311 | # mrnp = (mrimg - mu)/std 312 | # print 'maxV,',np.ndarray.max(mrnp),' minV, ',np.ndarray.min(mrnp) 313 | 314 | #matLPET = (mrimg - meanLPET)/(stdLPET) 315 | #print 'lpet: maxV,',np.ndarray.max(matLPET),' minV, ',np.ndarray.min(matLPET), ' meanV: ', np.mean(matLPET), ' stdV: ', np.std(matLPET) 316 | 317 | # matLPET = (mrnp - minLPET)/(maxPercentLPET-minLPET) 318 | # print 'lpet: maxV,',np.ndarray.max(matLPET),' minV, ',np.ndarray.min(matLPET), ' meanV: ', np.mean(matLPET), ' stdV: ', np.std(matLPET) 319 | 320 | 321 | 322 | 323 | # maxV1, minV1 = np.percentile(mrimg1, [99.5 ,1]) 324 | # print 'maxV1 is: ',np.ndarray.max(mrimg1) 325 | # mrimg1[np.where(mrimg1>maxV1)] = maxV1 326 | # print 'maxV1 is: ',np.ndarray.max(mrimg1) 327 | # mu1 = np.mean(mrimg1) # we should have a fixed std and mean 328 | # std1 = np.std(mrimg1) 329 | # mrnp1 = (mrimg1 - mu1)/std1 330 | # print 'maxV1,',np.ndarray.max(mrnp1),' minV, ',np.ndarray.min(mrnp1) 331 | 332 | # ctnp[np.where(ctnp>maxPercentCT)] = maxPercentCT 333 | # matCT = (ctnp - meanCT)/stdCT 334 | # print 'ct: maxV,',np.ndarray.max(matCT),' minV, ',np.ndarray.min(matCT), 'meanV: ', np.mean(matCT), 'stdV: ', np.std(matCT) 335 | 336 | 337 | 338 | 339 | # maxVal = np.amax(labelimg) 340 | # minVal = np.amin(labelimg) 341 | # print 'maxV is: ', maxVal, ' minVal is: ', minVal 342 | # mu=np.mean(labelimg) # we should have a fixed std and mean 343 | # std = np.std(labelimg) 344 | # 345 | # labelimg = (labelimg - minVal)/(maxVal - minVal) 346 | # 347 | # print 'maxV,',np.ndarray.max(labelimg),' minV, ',np.ndarray.min(labelimg) 348 | #you can do what you want here for for your label img 349 | 350 | # matSPET = (labelimg - minSPET)/(maxPercentSPET-minSPET) 351 | # print 'spet: maxV,',np.ndarray.max(matSPET),' minV, ',np.ndarray.min(matSPET), ' meanV: ',np.mean(matSPET), ' stdV: ', np.std(matSPET) 352 | 353 | sdir = filename.split('/') 354 | print 'sdir is, ',sdir, 'and s5 is, ',sdir[5] 355 | lpet_fn = sdir[5] 356 | words = lpet_fn.split('_') 357 | print 'words are, ',words 358 | ind = int(words[0]) 359 | 360 | 361 | fileID = words[0] 362 | rate = 1 363 | cubicCnt = extractPatch4OneSubject(matLPET, matCT, matSPET, maskimg, fileID,dSeg,step,rate) 364 | #cubicCnt = extractPatch4OneSubject(mrnp, matCT, hpetnp, maskimg, fileID,dSeg,step,rate) 365 | print '# of patches is ', cubicCnt 366 | 367 | # reverse along the 1st dimension 368 | rmrimg = matLPET[matLPET.shape[0] - 1::-1, :, :] 369 | rmatCT = matCT[matCT.shape[0] - 1::-1, :, :] 370 | rlabelimg = matSPET[matSPET.shape[0] - 1::-1, :, :] 371 | rmaskimg = maskimg[maskimg.shape[0] - 1::-1, :, :] 372 | fileID = words[0]+'r' 373 | cubicCnt = extractPatch4OneSubject(rmrimg, rmatCT, rlabelimg, rmaskimg, fileID,dSeg,step,rate) 374 | print '# of patches is ', cubicCnt 375 | 376 | if __name__ == '__main__': 377 | main() 378 | -------------------------------------------------------------------------------- /extract23DPatch4SingleModalImg.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Target: Crop patches for kinds of medical images, such as hdr, nii, mha, mhd, raw and so on, and store them as hdf5 files 3 | for single input modality 4 | Created in June, 2016 5 | Author: Dong Nie 6 | ''' 7 | 8 | import SimpleITK as sitk 9 | 10 | from multiprocessing import Pool 11 | import os, argparse 12 | import h5py 13 | import numpy as np 14 | 15 | parser = argparse.ArgumentParser(description="PyTorch InfantSeg") 16 | parser.add_argument("--how2normalize", type=int, default=6, help="how to normalize the data") 17 | 18 | global opt 19 | opt = parser.parse_args() 20 | 21 | # input patch size 22 | d1 = 5 23 | d2 = 64 24 | d3 = 64 25 | # output patch size 26 | dFA = [d1, d2, d3] # size of patches of input data 27 | dSeg = [1, 64, 64] # size of pathes of label data 28 | # stride for extracting patches along the volume 29 | step1 = 1 30 | step2 = 16 31 | step3 = 16 32 | step = [step1, step2, step3] 33 | 34 | 35 | class ScanFile(object): 36 | def __init__(self, directory, prefix=None, postfix=None): 37 | self.directory = directory 38 | self.prefix = prefix 39 | self.postfix = postfix 40 | 41 | def scan_files(self): 42 | files_list = [] 43 | 44 | for dirpath, dirnames, filenames in os.walk(self.directory): 45 | ''''' 46 | dirpath is a string, the path to the directory. 47 | dirnames is a list of the names of the subdirectories in dirpath (excluding '.' and '..'). 48 | filenames is a list of the names of the non-directory files in dirpath. 49 | ''' 50 | for special_file in filenames: 51 | if self.postfix: 52 | if special_file.endswith(self.postfix): 53 | files_list.append(os.path.join(dirpath, special_file)) 54 | elif self.prefix: 55 | if special_file.startswith(self.prefix): 56 | files_list.append(os.path.join(dirpath, special_file)) 57 | else: 58 | files_list.append(os.path.join(dirpath, special_file)) 59 | 60 | return files_list 61 | 62 | def scan_subdir(self): 63 | subdir_list = [] 64 | for dirpath, dirnames, files in os.walk(self.directory): 65 | subdir_list.append(dirpath) 66 | return subdir_list 67 | 68 | 69 | ''' 70 | Actually, we donot need it any more, this is useful to generate hdf5 database 71 | ''' 72 | 73 | 74 | def extractPatch4OneSubject(matFA, matSeg, matMask, fileID, d, step, rate): 75 | eps = 5e-2 76 | rate1 = 1.0 / 2 77 | rate2 = 1.0 / 4 78 | [row, col, leng] = matFA.shape 79 | cubicCnt = 0 80 | estNum = 40000 81 | trainFA = np.zeros([estNum, 1, dFA[0], dFA[1], dFA[2]], dtype=np.float16) 82 | trainSeg = np.zeros([estNum, 1, dSeg[0], dSeg[1], dSeg[2]], dtype=np.float16) 83 | 84 | print 'trainFA shape, ', trainFA.shape 85 | # to padding for input 86 | margin1 = (dFA[0] - dSeg[0]) / 2 87 | margin2 = (dFA[1] - dSeg[1]) / 2 88 | margin3 = (dFA[2] - dSeg[2]) / 2 89 | cubicCnt = 0 90 | marginD = [margin1, margin2, margin3] 91 | print 'matFA shape is ', matFA.shape 92 | matFAOut = np.zeros([row + 2 * marginD[0], col + 2 * marginD[1], leng + 2 * marginD[2]], dtype=np.float16) 93 | print 'matFAOut shape is ', matFAOut.shape 94 | matFAOut[marginD[0]:row + marginD[0], marginD[1]:col + marginD[1], marginD[2]:leng + marginD[2]] = matFA 95 | 96 | matSegOut = np.zeros([row + 2 * marginD[0], col + 2 * marginD[1], leng + 2 * marginD[2]], dtype=np.float16) 97 | matSegOut[marginD[0]:row + marginD[0], marginD[1]:col + marginD[1], marginD[2]:leng + marginD[2]] = matSeg 98 | 99 | matMaskOut = np.zeros([row + 2 * marginD[0], col + 2 * marginD[1], leng + 2 * marginD[2]], dtype=np.float16) 100 | matMaskOut[marginD[0]:row + marginD[0], marginD[1]:col + marginD[1], marginD[2]:leng + marginD[2]] = matMask 101 | 102 | # for mageFA, enlarge it by padding 103 | if margin1 != 0: 104 | matFAOut[0:marginD[0], marginD[1]:col + marginD[1], marginD[2]:leng + marginD[2]] = matFA[marginD[0] - 1::-1, :, 105 | :] # reverse 0:marginD[0] 106 | matFAOut[row + marginD[0]:matFAOut.shape[0], marginD[1]:col + marginD[1], marginD[2]:leng + marginD[2]] = matFA[ 107 | matFA.shape[ 108 | 0] - 1:row - 109 | marginD[ 110 | 0] - 1:-1, 111 | :, 112 | :] # we'd better flip it along the 1st dimension 113 | if margin2 != 0: 114 | matFAOut[marginD[0]:row + marginD[0], 0:marginD[1], marginD[2]:leng + marginD[2]] = matFA[:, marginD[1] - 1::-1, 115 | :] # we'd flip it along the 2nd dimension 116 | matFAOut[marginD[0]:row + marginD[0], col + marginD[1]:matFAOut.shape[1], marginD[2]:leng + marginD[2]] = matFA[ 117 | :, 118 | matFA.shape[ 119 | 1] - 1:col - 120 | marginD[ 121 | 1] - 1:-1, 122 | :] # we'd flip it along the 2nd dimension 123 | if margin3 != 0: 124 | matFAOut[marginD[0]:row + marginD[0], marginD[1]:col + marginD[1], 0:marginD[2]] = matFA[:, :, marginD[ 125 | 2] - 1::-1] # we'd better flip it along the 3rd dimension 126 | matFAOut[marginD[0]:row + marginD[0], marginD[1]:col + marginD[1], marginD[2] + leng:matFAOut.shape[2]] = matFA[ 127 | :, :, 128 | matFA.shape[ 129 | 2] - 1:leng - 130 | marginD[ 131 | 2] - 1:-1] 132 | # for matseg, enlarge it by padding 133 | if margin1 != 0: 134 | matSegOut[0:marginD[0], marginD[1]:col + marginD[1], marginD[2]:leng + marginD[2]] = matSeg[marginD[0] - 1::-1, 135 | :, 136 | :] # reverse 0:marginD[0] 137 | matSegOut[row + marginD[0]:matSegOut.shape[0], marginD[1]:col + marginD[1], 138 | marginD[2]:leng + marginD[2]] = matSeg[matSeg.shape[0] - 1:row - marginD[0] - 1:-1, :, 139 | :] # we'd better flip it along the 1st dimension 140 | if margin2 != 0: 141 | matSegOut[marginD[0]:row + marginD[0], 0:marginD[1], marginD[2]:leng + marginD[2]] = matSeg[:, 142 | marginD[1] - 1::-1, 143 | :] # we'd flip it along the 2nd dimension 144 | matSegOut[marginD[0]:row + marginD[0], col + marginD[1]:matSegOut.shape[1], 145 | marginD[2]:leng + marginD[2]] = matSeg[:, matSeg.shape[1] - 1:col - marginD[1] - 1:-1, 146 | :] # we'd flip it along the 2nd dimension 147 | if margin3 != 0: 148 | matSegOut[marginD[0]:row + marginD[0], marginD[1]:col + marginD[1], 0:marginD[2]] = matSeg[:, :, marginD[ 149 | 2] - 1::-1] # we'd better flip it along the 3rd dimension 150 | matSegOut[marginD[0]:row + marginD[0], marginD[1]:col + marginD[1], 151 | marginD[2] + leng:matSegOut.shape[2]] = matSeg[:, :, matSeg.shape[2] - 1:leng - marginD[2] - 1:-1] 152 | 153 | # for matseg, enlarge it by padding 154 | if margin1 != 0: 155 | matMaskOut[0:marginD[0], marginD[1]:col + marginD[1], marginD[2]:leng + marginD[2]] = matMask[ 156 | marginD[0] - 1::-1, :, 157 | :] # reverse 0:marginD[0] 158 | matMaskOut[row + marginD[0]:matMaskOut.shape[0], marginD[1]:col + marginD[1], 159 | marginD[2]:leng + marginD[2]] = matMask[matMask.shape[0] - 1:row - marginD[0] - 1:-1, :, 160 | :] # we'd better flip it along the 1st dimension 161 | if margin2 != 0: 162 | matMaskOut[marginD[0]:row + marginD[0], 0:marginD[1], marginD[2]:leng + marginD[2]] = matMask[:, 163 | marginD[1] - 1::-1, 164 | :] # we'd flip it along the 2nd dimension 165 | matMaskOut[marginD[0]:row + marginD[0], col + marginD[1]:matMaskOut.shape[1], 166 | marginD[2]:leng + marginD[2]] = matMask[:, matMask.shape[1] - 1:col - marginD[1] - 1:-1, 167 | :] # we'd flip it along the 2nd dimension 168 | if margin3 != 0: 169 | matMaskOut[marginD[0]:row + marginD[0], marginD[1]:col + marginD[1], 0:marginD[2]] = matMask[:, :, marginD[ 170 | 2] - 1::-1] # we'd better flip it along the 3rd dimension 171 | matMaskOut[marginD[0]:row + marginD[0], marginD[1]:col + marginD[1], 172 | marginD[2] + leng:matMaskOut.shape[2]] = matMask[:, :, matMask.shape[2] - 1:leng - marginD[2] - 1:-1] 173 | 174 | dsfactor = rate 175 | 176 | for i in range(0, row - dSeg[0], step[0]): 177 | for j in range(0, col - dSeg[1], step[1]): 178 | for k in range(0, leng - dSeg[2], step[2]): 179 | volMask = matMaskOut[i:i + dSeg[0], j:j + dSeg[1], k:k + dSeg[2]] 180 | if np.sum(volMask) < eps: 181 | continue 182 | cubicCnt = cubicCnt + 1 183 | # index at scale 1 184 | volSeg = matSeg[i:i + dSeg[0], j:j + dSeg[1], k:k + dSeg[2]] 185 | volFA = matFAOut[i:i + dFA[0], j:j + dFA[1], k:k + dFA[2]] 186 | 187 | 188 | trainFA[cubicCnt, 0, :, :, :] = volFA # 32*32*32 189 | 190 | trainSeg[cubicCnt, 0, :, :, :] = volSeg # 24*24*24 191 | 192 | trainFA = trainFA[0:cubicCnt, :, :, :, :] 193 | trainSeg = trainSeg[0:cubicCnt, :, :, :, :] 194 | 195 | with h5py.File('./trainMRCT_snorm_64_%s.h5' % fileID, 'w') as f: 196 | f['dataMR'] = trainFA 197 | f['dataCT'] = trainSeg 198 | 199 | with open('./trainMRCT2D_snorm_64_list.txt', 'a') as f: 200 | f.write('./trainMRCT_snorm_64_%s.h5\n' % fileID) 201 | return cubicCnt 202 | 203 | 204 | def main(): 205 | print opt 206 | path = '/home/niedong/Data4LowDosePET/data_niigz_scale/' 207 | path = '/shenlab/lab_stor5/dongnie/brain_mr2ct/original_data/' # path to the data, change to your own path 208 | scan = ScanFile(path, postfix='_mr.hdr') # the specify item for your files, change to your own style 209 | filenames = scan.scan_files() 210 | 211 | # for input 212 | maxSource = 149.366742 213 | maxPercentSource = 7.76 214 | minSource = 0.00055037 215 | meanSource = 0.27593288 216 | stdSource = 0.75747500 217 | 218 | # for output 219 | maxTarget = 27279 220 | maxPercentTarget = 1320 221 | minTarget = -1023 222 | meanTarget = -601.1929 223 | stdTarget = 475.034 224 | 225 | for filename in filenames: 226 | 227 | print 'source filename: ', filename 228 | 229 | source_fn = filename 230 | target_fn = filename.replace('_mr.hdr', '_ct.hdr') 231 | 232 | imgOrg = sitk.ReadImage(source_fn) 233 | sourcenp = sitk.GetArrayFromImage(imgOrg) 234 | 235 | imgOrg1 = sitk.ReadImage(target_fn) 236 | targetnp = sitk.GetArrayFromImage(imgOrg1) 237 | 238 | maskimg = sourcenp 239 | 240 | mu = np.mean(sourcenp) 241 | 242 | if opt.how2normalize == 1: 243 | maxV, minV = np.percentile(sourcenp, [99, 1]) 244 | print 'maxV,', maxV, ' minV, ', minV 245 | sourcenp = (sourcenp - mu) / (maxV - minV) 246 | print 'unique value: ', np.unique(targetnp) 247 | 248 | # for training data in pelvicSeg 249 | if opt.how2normalize == 2: 250 | maxV, minV = np.percentile(sourcenp, [99, 1]) 251 | print 'maxV,', maxV, ' minV, ', minV 252 | sourcenp = (sourcenp - mu) / (maxV - minV) 253 | print 'unique value: ', np.unique(targetnp) 254 | 255 | # for training data in pelvicSegRegH5 256 | if opt.how2normalize == 3: 257 | std = np.std(sourcenp) 258 | sourcenp = (sourcenp - mu) / std 259 | print 'maxV,', np.ndarray.max(sourcenp), ' minV, ', np.ndarray.min(sourcenp) 260 | 261 | if opt.how2normalize == 4: 262 | maxSource = 149.366742 263 | maxPercentSource = 7.76 264 | minSource = 0.00055037 265 | meanSource = 0.27593288 266 | stdSource = 0.75747500 267 | 268 | # for target 269 | maxTarget = 27279 270 | maxPercentTarget = 1320 271 | minTarget = -1023 272 | meanTarget = -601.1929 273 | stdTarget = 475.034 274 | 275 | matSource = (sourcenp - minSource) / (maxPercentSource - minSource) 276 | matTarget = (targetnp - meanTarget) / stdTarget 277 | 278 | if opt.how2normalize == 5: 279 | # for target 280 | maxTarget = 27279 281 | maxPercentTarget = 1320 282 | minTarget = -1023 283 | meanTarget = -601.1929 284 | stdTarget = 475.034 285 | 286 | print 'target, max: ', np.amax(targetnp), ' target, min: ', np.amin(targetnp) 287 | 288 | # matSource = (sourcenp - meanSource) / (stdSource) 289 | matSource = sourcenp 290 | matTarget = (targetnp - meanTarget) / stdTarget 291 | 292 | if opt.how2normalize == 6: 293 | maxPercentSource, minPercentSource = np.percentile(sourcenp, [99.5, 0]) 294 | maxPercentTarget, minPercentTarget = np.percentile(targetnp, [99.5, 0]) 295 | print 'maxPercentSource: ', maxPercentSource, ' minPercentSource: ', minPercentSource, ' maxPercentTarget: ', maxPercentTarget, 'minPercentTarget: ', minPercentTarget 296 | 297 | matSource = (sourcenp - minPercentSource) / (maxPercentSource - minPercentSource) #input 298 | #output, use input's statistical (if there is big difference between input and output, you can find a simple relation between input and output and then include this relation to normalize output with input's statistical) 299 | matTarget = (targetnp - minPercentSource) / (maxPercentSource - minPercentSource) 300 | 301 | print 'maxSource: ', np.amax(matSource), ' maxTarget: ', np.amax(matTarget) 302 | print 'minSource: ', np.amin(matSource), ' minTarget: ', np.amin(matTarget) 303 | 304 | # maxV, minV = np.percentile(mrimg, [99.5, 0]) 305 | # print 'maxV is: ',np.ndarray.max(mrimg) 306 | # mrimg[np.where(mrimg>maxV)] = maxV 307 | # print 'maxV is: ',np.ndarray.max(mrimg) 308 | # mu=np.mean(mrimg) # we should have a fixed std and mean 309 | # std = np.std(mrimg) 310 | # mrnp = (mrimg - mu)/std 311 | # print 'maxV,',np.ndarray.max(mrnp),' minV, ',np.ndarray.min(mrnp) 312 | 313 | # matLPET = (mrimg - meanLPET)/(stdLPET) 314 | # print 'lpet: maxV,',np.ndarray.max(matLPET),' minV, ',np.ndarray.min(matLPET), ' meanV: ', np.mean(matLPET), ' stdV: ', np.std(matLPET) 315 | 316 | # matLPET = (mrnp - minLPET)/(maxPercentLPET-minLPET) 317 | # print 'lpet: maxV,',np.ndarray.max(matLPET),' minV, ',np.ndarray.min(matLPET), ' meanV: ', np.mean(matLPET), ' stdV: ', np.std(matLPET) 318 | 319 | # maxV1, minV1 = np.percentile(mrimg1, [99.5 ,1]) 320 | # print 'maxV1 is: ',np.ndarray.max(mrimg1) 321 | # mrimg1[np.where(mrimg1>maxV1)] = maxV1 322 | # print 'maxV1 is: ',np.ndarray.max(mrimg1) 323 | # mu1 = np.mean(mrimg1) # we should have a fixed std and mean 324 | # std1 = np.std(mrimg1) 325 | # mrnp1 = (mrimg1 - mu1)/std1 326 | # print 'maxV1,',np.ndarray.max(mrnp1),' minV, ',np.ndarray.min(mrnp1) 327 | 328 | # ctnp[np.where(ctnp>maxPercentCT)] = maxPercentCT 329 | # matCT = (ctnp - meanCT)/stdCT 330 | # print 'ct: maxV,',np.ndarray.max(matCT),' minV, ',np.ndarray.min(matCT), 'meanV: ', np.mean(matCT), 'stdV: ', np.std(matCT) 331 | 332 | # maxVal = np.amax(labelimg) 333 | # minVal = np.amin(labelimg) 334 | # print 'maxV is: ', maxVal, ' minVal is: ', minVal 335 | # mu=np.mean(labelimg) # we should have a fixed std and mean 336 | # std = np.std(labelimg) 337 | # 338 | # labelimg = (labelimg - minVal)/(maxVal - minVal) 339 | # 340 | # print 'maxV,',np.ndarray.max(labelimg),' minV, ',np.ndarray.min(labelimg) 341 | # you can do what you want here for for your label img 342 | 343 | # matSPET = (labelimg - minSPET)/(maxPercentSPET-minSPET) 344 | # print 'spet: maxV,',np.ndarray.max(matSPET),' minV, ',np.ndarray.min(matSPET), ' meanV: ',np.mean(matSPET), ' stdV: ', np.std(matSPET) 345 | 346 | sdir = filename.split('/') 347 | print 'sdir is, ', sdir, 'and s6 is, ', sdir[len(sdir)-1] 348 | lpet_fn = sdir[len(sdir)-1] 349 | words = lpet_fn.split('_') 350 | print 'words are, ', words 351 | # ind = int(words[0]) 352 | 353 | fileID = words[0] 354 | rate = 1 355 | cubicCnt = extractPatch4OneSubject(matSource, matTarget, maskimg, fileID, dSeg, step, rate) 356 | # cubicCnt = extractPatch4OneSubject(mrnp, matCT, hpetnp, maskimg, fileID,dSeg,step,rate) 357 | print '# of patches is ', cubicCnt 358 | 359 | # reverse along the 1st dimension 360 | rmatSource = matSource[matSource.shape[0] - 1::-1, :, :] 361 | rmatTarget = matTarget[matTarget.shape[0] - 1::-1, :, :] 362 | 363 | rmaskimg = maskimg[maskimg.shape[0] - 1::-1, :, :] 364 | fileID = words[0] + 'r' 365 | cubicCnt = extractPatch4OneSubject(rmatSource, rmatTarget, rmaskimg, fileID, dSeg, step, rate) 366 | print '# of patches is ', cubicCnt 367 | 368 | 369 | if __name__ == '__main__': 370 | main() 371 | -------------------------------------------------------------------------------- /loss_functions.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | import SimpleITK as sitk 4 | import torch.nn as nn 5 | import numpy as np 6 | import torch.optim as optim 7 | import torch 8 | import torch.nn.init 9 | from torch.autograd import Variable 10 | 11 | 12 | def gdl_loss(gen_CT, gt_CT, alpha, batch_size_tf): 13 | """ 14 | Calculates the sum of GDL losses between the predicted and ground truth frames. 15 | 16 | @param gen_frames: The predicted frames at each scale. 17 | @param gt_frames: The ground truth frames at each scale 18 | @param alpha: The power to which each gradient term is raised. 19 | 20 | @return: The GDL loss. 21 | """ 22 | # calculate the loss for each scale 23 | 24 | # create filters [-1, 1] and [[1],[-1]] for diffing to the left and down respectively. 25 | pos = tf.constant(np.identity(1), dtype=tf.float32) 26 | neg = -1 * pos 27 | filter_x = tf.expand_dims(tf.pack([neg, pos]), 0) # [-1, 1] 28 | filter_y = tf.pack([tf.expand_dims(pos, 0), tf.expand_dims(neg, 0)]) # [[1],[-1]] 29 | strides = [1, 1, 1, 1] # stride of (1, 1) 30 | padding = 'SAME' 31 | 32 | gen_dx = tf.abs(tf.nn.conv2d(gen_CT, filter_x, strides, padding=padding)) 33 | gen_dy = tf.abs(tf.nn.conv2d(gen_CT, filter_y, strides, padding=padding)) 34 | gt_dx = tf.abs(tf.nn.conv2d(gt_CT, filter_x, strides, padding=padding)) 35 | gt_dy = tf.abs(tf.nn.conv2d(gt_CT, filter_y, strides, padding=padding)) 36 | 37 | grad_diff_x = tf.abs(gt_dx - gen_dx) 38 | grad_diff_y = tf.abs(gt_dy - gen_dy) 39 | 40 | gdl=tf.reduce_sum((grad_diff_x ** alpha + grad_diff_y ** alpha))/tf.cast(batch_size_tf,tf.float32) 41 | 42 | # condense into one tensor and avg 43 | return gdl 44 | -------------------------------------------------------------------------------- /runCTRecon.py: -------------------------------------------------------------------------------- 1 | # from __future__ import print_function 2 | import argparse, os 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from torch.autograd import Variable 6 | import numpy as np 7 | import torch.optim as optim 8 | import torch 9 | import torch.utils.data as data_utils 10 | from utils import * 11 | from Unet2d_pytorch import UNet, ResUNet, UNet_LRes, ResUNet_LRes, Discriminator 12 | from Unet3d_pytorch import UNet3D 13 | from nnBuildUnits import CrossEntropy3d, topK_RegLoss, RelativeThreshold_RegLoss, gdl_loss, adjust_learning_rate, calc_gradient_penalty 14 | import time 15 | import SimpleITK as sitk 16 | 17 | # Training settings 18 | parser = argparse.ArgumentParser(description="PyTorch InfantSeg") 19 | parser.add_argument("--gpuID", type=int, default=1, help="how to normalize the data") 20 | parser.add_argument("--isAdLoss", action="store_true", help="is adversarial loss used?", default=False) 21 | parser.add_argument("--isWDist", action="store_true", help="is adversarial loss with WGAN-GP distance?", default=False) 22 | parser.add_argument("--lambda_AD", default=0.05, type=float, help="weight for AD loss, Default: 0.05") 23 | parser.add_argument("--lambda_D_WGAN_GP", default=10, type=float, help="weight for gradient penalty of WGAN-GP, Default: 10") 24 | parser.add_argument("--how2normalize", type=int, default=6, help="how to normalize the data") 25 | parser.add_argument("--whichLoss", type=int, default=1, help="which loss to use: 1. LossL1, 2. lossRTL1, 3. MSE (default)") 26 | parser.add_argument("--isGDL", action="store_true", help="do we use GDL loss?", default=True) 27 | parser.add_argument("--gdlNorm", default=2, type=int, help="p-norm for the gdl loss, Default: 2") 28 | parser.add_argument("--lambda_gdl", default=0.05, type=float, help="Weight for gdl loss, Default: 0.05") 29 | parser.add_argument("--whichNet", type=int, default=4, help="which loss to use: 1. UNet, 2. ResUNet, 3. UNet_LRes and 4. ResUNet_LRes (default, 3)") 30 | parser.add_argument("--lossBase", type=int, default=1, help="The base to multiply the lossG_G, Default (1)") 31 | parser.add_argument("--batchSize", type=int, default=32, help="training batch size") 32 | parser.add_argument("--isMultiSource", action="store_true", help="is multiple modality used?", default=False) 33 | parser.add_argument("--numOfChannel_singleSource", type=int, default=5, help="# of channels for a 2D patch for the main modality (Default, 5)") 34 | parser.add_argument("--numOfChannel_allSource", type=int, default=5, help="# of channels for a 2D patch for all the concatenated modalities (Default, 5)") 35 | parser.add_argument("--numofIters", type=int, default=200000, help="number of iterations to train for") 36 | parser.add_argument("--showTrainLossEvery", type=int, default=100, help="number of iterations to show train loss") 37 | parser.add_argument("--saveModelEvery", type=int, default=5000, help="number of iterations to save the model") 38 | parser.add_argument("--showValPerformanceEvery", type=int, default=1000, help="number of iterations to show validation performance") 39 | parser.add_argument("--showTestPerformanceEvery", type=int, default=5000, help="number of iterations to show test performance") 40 | parser.add_argument("--lr", type=float, default=5e-3, help="Learning Rate. Default=1e-4") 41 | parser.add_argument("--lr_netD", type=float, default=5e-3, help="Learning Rate for discriminator. Default=5e-3") 42 | parser.add_argument("--dropout_rate", default=0.2, type=float, help="prob to drop neurons to zero: 0.2") 43 | parser.add_argument("--decLREvery", type=int, default=10000, help="Sets the learning rate to the initial LR decayed by momentum every n iterations, Default: n=40000") 44 | parser.add_argument("--lrDecRate", type=float, default=0.5, help="The weight for decreasing learning rate of netG Default=0.5") 45 | parser.add_argument("--lrDecRate_netD", type=float, default=0.1, help="The weight for decreasing learning rate of netD. Default=0.1") 46 | parser.add_argument("--cuda", action="store_true", help="Use cuda?", default=True) 47 | parser.add_argument("--resume", default="", type=str, help="Path to checkpoint (default: none)") 48 | parser.add_argument("--start_epoch", default=1, type=int, help="Manual epoch number (useful on restarts)") 49 | parser.add_argument("--threads", type=int, default=1, help="Number of threads for data loader to use, Default: 1") 50 | parser.add_argument("--momentum", default=0.9, type=float, help="Momentum, Default: 0.9") 51 | parser.add_argument("--weight-decay", "--wd", default=1e-4, type=float, help="weight decay, Default: 1e-4") 52 | parser.add_argument("--RT_th", default=0.005, type=float, help="Relative thresholding: 0.005") 53 | parser.add_argument("--pretrained", default="", type=str, help="path to pretrained model (default: none)") 54 | parser.add_argument("--prefixModelName", default="/home/niedong/Data4LowDosePET/pytorch_UNet/resunet2d_dp_pet_BatchAug_sNorm_lres_bn_lr5e3_lrdec_base1_lossL1_lossGDL0p05_0705_", type=str, help="prefix of the to-be-saved model name") 55 | parser.add_argument("--prefixPredictedFN", default="preSub1_pet_BatchAug_sNorm_resunet_dp_lres_bn_lr5e3_lrdec_base1_lossL1_lossGDL0p05_0705_", type=str, help="prefix of the to-be-saved predicted filename") 56 | parser.add_argument("--test_input_file_name",default='sub13_mr.hdr',type=str, help="the input file name for testing subject") 57 | parser.add_argument("--test_gt_file_name",default='sub13_ct.hdr',type=str, help="the ground-truth file name for testing subject") 58 | 59 | global opt, model 60 | opt = parser.parse_args() 61 | 62 | def main(): 63 | print opt 64 | 65 | # prefixModelName = 'Regressor_1112_' 66 | # prefixPredictedFN = 'preSub1_1112_' 67 | # showTrainLossEvery = 100 68 | # lr = 1e-4 69 | # showTestPerformanceEvery = 2000 70 | # saveModelEvery = 2000 71 | # decLREvery = 40000 72 | # numofIters = 200000 73 | # how2normalize = 0 74 | 75 | 76 | netD = Discriminator() 77 | netD.apply(weights_init) 78 | netD.cuda() 79 | 80 | optimizerD = optim.Adam(netD.parameters(),lr=opt.lr_netD) 81 | criterion_bce=nn.BCELoss() 82 | criterion_bce.cuda() 83 | 84 | #net=UNet() 85 | if opt.whichNet==1: 86 | net = UNet(in_channel=opt.numOfChannel_allSource, n_classes=1) 87 | elif opt.whichNet==2: 88 | net = ResUNet(in_channel=opt.numOfChannel_allSource, n_classes=1) 89 | elif opt.whichNet==3: 90 | net = UNet_LRes(in_channel=opt.numOfChannel_allSource, n_classes=1) 91 | elif opt.whichNet==4: 92 | net = ResUNet_LRes(in_channel=opt.numOfChannel_allSource, n_classes=1, dp_prob = opt.dropout_rate) 93 | #net.apply(weights_init) 94 | net.cuda() 95 | params = list(net.parameters()) 96 | print('len of params is ') 97 | print(len(params)) 98 | print('size of params is ') 99 | print(params[0].size()) 100 | 101 | 102 | 103 | optimizer = optim.Adam(net.parameters(),lr=opt.lr) 104 | criterion_L2 = nn.MSELoss() 105 | criterion_L1 = nn.L1Loss() 106 | criterion_RTL1 = RelativeThreshold_RegLoss(opt.RT_th) 107 | criterion_gdl = gdl_loss(opt.gdlNorm) 108 | #criterion = nn.CrossEntropyLoss() 109 | # criterion = nn.NLLLoss2d() 110 | 111 | given_weight = torch.cuda.FloatTensor([1,4,4,2]) 112 | 113 | criterion_3d = CrossEntropy3d(weight=given_weight) 114 | 115 | criterion_3d = criterion_3d.cuda() 116 | criterion_L2 = criterion_L2.cuda() 117 | criterion_L1 = criterion_L1.cuda() 118 | criterion_RTL1 = criterion_RTL1.cuda() 119 | criterion_gdl = criterion_gdl.cuda() 120 | 121 | #inputs=Variable(torch.randn(1000,1,32,32)) #here should be tensor instead of variable 122 | #targets=Variable(torch.randn(1000,10,1,1)) #here should be tensor instead of variable 123 | # trainset=data_utils.TensorDataset(inputs, targets) 124 | # trainloader = data_utils.DataLoader(trainset, batch_size=4, shuffle=True, num_workers=2) 125 | # inputs=torch.randn(1000,1,32,32) 126 | # targets=torch.LongTensor(1000) 127 | 128 | path_test ='/home/niedong/DataCT/data_niigz/' 129 | path_patients_h5 = '/home/niedong/DataCT/h5Data_snorm/trainBatch2D_H5' 130 | path_patients_h5_val ='/home/niedong/DataCT/h5Data_snorm/valBatch2D_H5' 131 | # batch_size=10 132 | #data_generator = Generator_2D_slices(path_patients_h5,opt.batchSize,inputKey='data3T',outputKey='data7T') 133 | #data_generator_test = Generator_2D_slices(path_patients_h5_test,opt.batchSize,inputKey='data3T',outputKey='data7T') 134 | if opt.isMultiSource: 135 | data_generator = Generator_2D_slicesV1(path_patients_h5,opt.batchSize, inputKey='dataLPET', segKey='dataCT', contourKey='dataHPET') 136 | data_generator_test = Generator_2D_slicesV1(path_patients_h5_val, opt.batchSize, inputKey='dataLPET', segKey='dataCT', contourKey='dataHPET') 137 | else: 138 | data_generator = Generator_2D_slices(path_patients_h5,opt.batchSize,inputKey='dataMR',outputKey='dataCT') 139 | data_generator_test = Generator_2D_slices(path_patients_h5_val,opt.batchSize,inputKey='dataMR',outputKey='dataCT') 140 | 141 | #data_generator = Generator_2D_slicesV1(path_patients_h5,opt.batchSize, inputKey='dataLPET', segKey='dataCT', contourKey='dataHPET') 142 | #data_generator_test = Generator_2D_slicesV1(path_patients_h5_val, opt.batchSize, inputKey='dataLPET', segKey='dataCT', contourKey='dataHPET') 143 | if opt.resume: 144 | if os.path.isfile(opt.resume): 145 | print("=> loading checkpoint '{}'".format(opt.resume)) 146 | checkpoint = torch.load(opt.resume) 147 | net.load_state_dict(checkpoint['model']) 148 | opt.start_epoch = 100000 149 | opt.start_epoch = checkpoint["epoch"] + 1 150 | # net.load_state_dict(checkpoint["model"].state_dict()) 151 | else: 152 | print("=> no checkpoint found at '{}'".format(opt.resume)) 153 | ########### We'd better use dataloader to load a lot of data,and we also should train several epoches############### 154 | ########### We'd better use dataloader to load a lot of data,and we also should train several epoches############### 155 | running_loss = 0.0 156 | start = time.time() 157 | for iter in range(opt.start_epoch, opt.numofIters+1): 158 | #print('iter %d'%iter) 159 | #print('iter %d'%iter) 160 | if opt.isMultiSource: 161 | inputs, exinputs, labels = data_generator.next() 162 | else: 163 | inputs, labels = data_generator.next() 164 | exinputs = inputs 165 | # inputs, exinputs, labels = data_generator.next() 166 | 167 | # xx = np.transpose(inputs,(5,64,64)) 168 | # inputs = np.transpose(inputs,(0,3,1,2)) 169 | inputs = np.squeeze(inputs) #5x64x64 170 | # exinputs = np.transpose(exinputs,(0,3,1,2)) 171 | exinputs = np.squeeze(exinputs) #5x64x64 172 | # print 'shape is ....',inputs.shape 173 | labels = np.squeeze(labels) #64x64 174 | # labels = labels.astype(int) 175 | 176 | inputs = inputs.astype(float) 177 | inputs = torch.from_numpy(inputs) 178 | inputs = inputs.float() 179 | exinputs = exinputs.astype(float) 180 | exinputs = torch.from_numpy(exinputs) 181 | exinputs = exinputs.float() 182 | labels = labels.astype(float) 183 | labels = torch.from_numpy(labels) 184 | labels = labels.float() 185 | #print type(inputs), type(exinputs) 186 | if opt.isMultiSource: 187 | source = torch.cat((inputs, exinputs),dim=1) 188 | else: 189 | source = inputs 190 | #source = inputs 191 | mid_slice = opt.numOfChannel_singleSource//2 192 | residual_source = inputs[:, mid_slice, ...] 193 | #inputs = inputs.cuda() 194 | #exinputs = exinputs.cuda() 195 | source = source.cuda() 196 | residual_source = residual_source.cuda() 197 | labels = labels.cuda() 198 | #we should consider different data to train 199 | 200 | #wrap them into Variable 201 | source, residual_source, labels = Variable(source),Variable(residual_source), Variable(labels) 202 | #inputs, exinputs, labels = Variable(inputs),Variable(exinputs), Variable(labels) 203 | 204 | ## (1) update D network: maximize log(D(x)) + log(1 - D(G(z))) 205 | if opt.isAdLoss: 206 | #outputG = net(source,residual_source) #5x64x64->1*64x64 207 | if opt.whichNet == 3 or opt.whichNet == 4: 208 | outputG = net(source, residual_source) # 5x64x64->1*64x64 209 | else: 210 | outputG = net(source) # 5x64x64->1*64x64 211 | 212 | if len(labels.size())==3: 213 | labels = labels.unsqueeze(1) 214 | 215 | outputD_real = netD(labels) 216 | outputD_real = F.sigmoid(outputD_real) 217 | 218 | if len(outputG.size())==3: 219 | outputG = outputG.unsqueeze(1) 220 | 221 | outputD_fake = netD(outputG) 222 | outputD_fake = F.sigmoid(outputD_fake) 223 | netD.zero_grad() 224 | batch_size = inputs.size(0) 225 | real_label = torch.ones(batch_size,1) 226 | real_label = real_label.cuda() 227 | #print(real_label.size()) 228 | real_label = Variable(real_label) 229 | #print(outputD_real.size()) 230 | loss_real = criterion_bce(outputD_real,real_label) 231 | loss_real.backward() 232 | #train with fake data 233 | fake_label = torch.zeros(batch_size,1) 234 | # fake_label = torch.FloatTensor(batch_size) 235 | # fake_label.data.resize_(batch_size).fill_(0) 236 | fake_label = fake_label.cuda() 237 | fake_label = Variable(fake_label) 238 | loss_fake = criterion_bce(outputD_fake,fake_label) 239 | loss_fake.backward() 240 | 241 | lossD = loss_real + loss_fake 242 | # print 'loss_real is ',loss_real.data[0],'loss_fake is ',loss_fake.data[0],'outputD_real is',outputD_real.data[0] 243 | # print('loss for discriminator is %f'%lossD.data[0]) 244 | #update network parameters 245 | optimizerD.step() 246 | 247 | if opt.isWDist: 248 | one = torch.FloatTensor([1]) 249 | mone = one * -1 250 | one = one.cuda() 251 | mone = mone.cuda() 252 | 253 | netD.zero_grad() 254 | 255 | #outputG = net(source,residual_source) #5x64x64->1*64x64 256 | if opt.whichNet == 3 or opt.whichNet == 4: 257 | outputG = net(source, residual_source) # 5x64x64->1*64x64 258 | else: 259 | outputG = net(source) # 5x64x64->1*64x64 260 | 261 | if len(labels.size())==3: 262 | labels = labels.unsqueeze(1) 263 | 264 | outputD_real = netD(labels) 265 | 266 | if len(outputG.size())==3: 267 | outputG = outputG.unsqueeze(1) 268 | 269 | outputD_fake = netD(outputG) 270 | 271 | 272 | batch_size = inputs.size(0) 273 | 274 | D_real = outputD_real.mean() 275 | # print D_real 276 | D_real.backward(mone) 277 | 278 | 279 | D_fake = outputD_fake.mean() 280 | D_fake.backward(one) 281 | 282 | gradient_penalty = opt.lambda_D_WGAN_GP*calc_gradient_penalty(netD, labels.data, outputG.data) 283 | gradient_penalty.backward() 284 | 285 | D_cost = D_fake - D_real + gradient_penalty 286 | Wasserstein_D = D_real - D_fake 287 | 288 | optimizerD.step() 289 | 290 | 291 | ## (2) update G network: minimize the L1/L2 loss, maximize the D(G(x)) 292 | 293 | # print inputs.data.shape 294 | #outputG = net(source) #here I am not sure whether we should use twice or not 295 | if opt.whichNet == 3 or opt.whichNet == 4: 296 | outputG = net(source, residual_source) # 5x64x64->1*64x64 297 | else: 298 | outputG = net(source) # 5x64x64->1*64x64 299 | #outputG = net(source,residual_source) #5x64x64->1*64x64 300 | net.zero_grad() 301 | if opt.whichLoss==1: 302 | lossG_G = criterion_L1(torch.squeeze(outputG), torch.squeeze(labels)) 303 | elif opt.whichLoss==2: 304 | lossG_G = criterion_RTL1(torch.squeeze(outputG), torch.squeeze(labels)) 305 | else: 306 | lossG_G = criterion_L2(torch.squeeze(outputG), torch.squeeze(labels)) 307 | lossG_G = opt.lossBase * lossG_G 308 | lossG_G.backward(retain_graph=True) #compute gradients 309 | 310 | if opt.isGDL: 311 | lossG_gdl = opt.lambda_gdl * criterion_gdl(outputG,torch.unsqueeze(torch.squeeze(labels,1),1)) 312 | lossG_gdl.backward() #compute gradients 313 | 314 | if opt.isAdLoss: 315 | #we want to fool the discriminator, thus we pretend the label here to be real. Actually, we can explain from the 316 | #angel of equation (note the max and min difference for generator and discriminator) 317 | #outputG = net(inputs) 318 | #outputG = net(source,residual_source) #5x64x64->1*64x64 319 | if opt.whichNet == 3 or opt.whichNet == 4: 320 | outputG = net(source, residual_source) # 5x64x64->1*64x64 321 | else: 322 | outputG = net(source) # 5x64x64->1*64x64 323 | 324 | if len(outputG.size())==3: 325 | outputG = outputG.unsqueeze(1) 326 | 327 | outputD = netD(outputG) 328 | outputD = F.sigmoid(outputD) 329 | lossG_D = opt.lambda_AD*criterion_bce(outputD,real_label) #note, for generator, the label for outputG is real, because the G wants to confuse D 330 | lossG_D.backward() 331 | 332 | if opt.isWDist: 333 | #we want to fool the discriminator, thus we pretend the label here to be real. Actually, we can explain from the 334 | #angel of equation (note the max and min difference for generator and discriminator) 335 | #outputG = net(inputs) 336 | #outputG = net(source,residual_source) #5x64x64->1*64x64 337 | if opt.whichNet == 3 or opt.whichNet == 4: 338 | outputG = net(source, residual_source) # 5x64x64->1*64x64 339 | else: 340 | outputG = net(source) # 5x64x64->1*64x64 341 | if len(outputG.size())==3: 342 | outputG = outputG.unsqueeze(1) 343 | 344 | outputD_fake = netD(outputG) 345 | 346 | outputD_fake = outputD_fake.mean() 347 | 348 | lossG_D = opt.lambda_AD*outputD_fake.mean() #note, for generator, the label for outputG is real, because the G wants to confuse D 349 | lossG_D.backward(mone) 350 | 351 | #for other losses, we can define the loss function following the pytorch tutorial 352 | 353 | optimizer.step() #update network parameters 354 | 355 | #print('loss for generator is %f'%lossG.data[0]) 356 | #print statistics 357 | running_loss = running_loss + lossG_G.data[0] 358 | 359 | 360 | if iter%opt.showTrainLossEvery==0: #print every 2000 mini-batches 361 | print '************************************************' 362 | print 'time now is: ' + time.asctime(time.localtime(time.time())) 363 | # print 'running loss is ',running_loss 364 | print 'average running loss for generator between iter [%d, %d] is: %.5f'%(iter - 100 + 1,iter,running_loss/100) 365 | 366 | print 'lossG_G is %.5f respectively.'%(lossG_G.data[0]) 367 | 368 | if opt.isGDL: 369 | print 'loss for GDL loss is %f'%lossG_gdl.data[0] 370 | 371 | if opt.isAdLoss: 372 | print 'loss_real is ',loss_real.data[0],'loss_fake is ',loss_fake.data[0],'outputD_real is',outputD_real.data[0] 373 | print 'loss for discriminator is %f'%lossD.data[0] 374 | print 'lossG_D for discriminator is %f'%lossG_D.data[0] 375 | 376 | if opt.isWDist: 377 | print 'loss_real is ',torch.mean(D_real).data[0],'loss_fake is ',torch.mean(D_fake).data[0] 378 | print('loss for discriminator is %f'%Wasserstein_D.data[0], ' D cost is %f'%D_cost) 379 | print 'lossG_D for discriminator is %f'%lossG_D.data[0] 380 | 381 | print 'cost time for iter [%d, %d] is %.2f'%(iter - 100 + 1,iter, time.time()-start) 382 | print '************************************************' 383 | running_loss = 0.0 384 | start = time.time() 385 | if iter%opt.saveModelEvery==0: #save the model 386 | state = { 387 | 'epoch': iter+1, 388 | 'model': net.state_dict() 389 | } 390 | torch.save(state, opt.prefixModelName+'%d.pt'%iter) 391 | print 'save model: '+opt.prefixModelName+'%d.pt'%iter 392 | 393 | if opt.isAdLoss or opt.isWDist: 394 | torch.save(netD.state_dict(), opt.prefixModelName+'_net_D%d.pt'%iter) 395 | if iter%opt.decLREvery==0: 396 | opt.lr = opt.lr*opt.lrDecRate 397 | adjust_learning_rate(optimizer, opt.lr) 398 | if opt.isAdLoss or opt.isWDist: 399 | opt.lr_netD = opt.lr_netD*opt.lrDecRate_netD 400 | adjust_learning_rate(optimizerD, opt.lr_netD) 401 | 402 | 403 | if iter%opt.showValPerformanceEvery==0: #test one subject 404 | # to test on the validation dataset in the format of h5 405 | # inputs,exinputs,labels = data_generator_test.next() 406 | if opt.isMultiSource: 407 | inputs, exinputs, labels = data_generator.next() 408 | else: 409 | inputs, labels = data_generator.next() 410 | exinputs = inputs 411 | 412 | # inputs = np.transpose(inputs,(0,3,1,2)) 413 | inputs = np.squeeze(inputs) 414 | 415 | # exinputs = np.transpose(exinputs, (0, 3, 1, 2)) 416 | exinputs = np.squeeze(exinputs) # 5x64x64 417 | 418 | labels = np.squeeze(labels) 419 | 420 | inputs = torch.from_numpy(inputs) 421 | inputs = inputs.float() 422 | exinputs = torch.from_numpy(exinputs) 423 | exinputs = exinputs.float() 424 | labels = torch.from_numpy(labels) 425 | labels = labels.float() 426 | mid_slice = opt.numOfChannel_singleSource // 2 427 | residual_source = inputs[:, mid_slice, ...] 428 | if opt.isMultiSource: 429 | source = torch.cat((inputs, exinputs), dim=1) 430 | else: 431 | source = inputs 432 | source = source.cuda() 433 | residual_source = residual_source.cuda() 434 | labels = labels.cuda() 435 | source,residual_source,labels = Variable(source),Variable(residual_source), Variable(labels) 436 | 437 | # source = inputs 438 | #outputG = net(inputs) 439 | #outputG = net(source,residual_source) #5x64x64->1*64x64 440 | if opt.whichNet == 3 or opt.whichNet == 4: 441 | outputG = net(source, residual_source) # 5x64x64->1*64x64 442 | else: 443 | outputG = net(source) # 5x64x64->1*64x64 444 | if opt.whichLoss == 1: 445 | lossG_G = criterion_L1(torch.squeeze(outputG), torch.squeeze(labels)) 446 | elif opt.whichLoss == 2: 447 | lossG_G = criterion_RTL1(torch.squeeze(outputG), torch.squeeze(labels)) 448 | else: 449 | lossG_G = criterion_L2(torch.squeeze(outputG), torch.squeeze(labels)) 450 | lossG_G = opt.lossBase * lossG_G 451 | print '.......come to validation stage: iter {}'.format(iter),'........' 452 | print 'lossG_G is %.5f.'%(lossG_G.data[0]) 453 | 454 | if opt.isGDL: 455 | lossG_gdl = criterion_gdl(outputG, torch.unsqueeze(torch.squeeze(labels,1),1)) 456 | print 'loss for GDL loss is %f'%lossG_gdl.data[0] 457 | 458 | if iter % opt.showTestPerformanceEvery == 0: # test one subject 459 | mr_test_itk=sitk.ReadImage(os.path.join(path_test,opt.test_input_file_name)) 460 | ct_test_itk=sitk.ReadImage(os.path.join(path_test,opt.test_input_file_name)) 461 | hpet_test_itk = sitk.ReadImage(os.path.join(path_test, opt.test_gt_file_name)) 462 | #mr_test_itk=sitk.ReadImage(os.path.join(path_test,'sub1_sourceCT.nii.gz')) 463 | #ct_test_itk=sitk.ReadImage(os.path.join(path_test,'sub1_extraCT.nii.gz')) 464 | #hpet_test_itk = sitk.ReadImage(os.path.join(path_test, 'sub1_targetCT.nii.gz')) 465 | 466 | spacing = hpet_test_itk.GetSpacing() 467 | origin = hpet_test_itk.GetOrigin() 468 | direction = hpet_test_itk.GetDirection() 469 | 470 | mrnp=sitk.GetArrayFromImage(mr_test_itk) 471 | ctnp=sitk.GetArrayFromImage(ct_test_itk) 472 | hpetnp=sitk.GetArrayFromImage(hpet_test_itk) 473 | 474 | ##### specific normalization ##### 475 | # mu = np.mean(mrnp) 476 | # maxV, minV = np.percentile(mrnp, [99 ,25]) 477 | # #mrimg=mrimg 478 | # mrnp = (mrnp-minV)/(maxV-minV) 479 | 480 | 481 | 482 | #for training data in pelvicSeg 483 | if opt.how2normalize == 1: 484 | maxV, minV = np.percentile(mrnp, [99 ,1]) 485 | print 'maxV,',maxV,' minV, ',minV 486 | mrnp = (mrnp-mu)/(maxV-minV) 487 | print 'unique value: ',np.unique(ctnp) 488 | 489 | #for training data in pelvicSeg 490 | if opt.how2normalize == 2: 491 | maxV, minV = np.percentile(mrnp, [99 ,1]) 492 | print 'maxV,',maxV,' minV, ',minV 493 | mrnp = (mrnp-mu)/(maxV-minV) 494 | print 'unique value: ',np.unique(ctnp) 495 | 496 | #for training data in pelvicSegRegH5 497 | if opt.how2normalize== 3: 498 | std = np.std(mrnp) 499 | mrnp = (mrnp - mu)/std 500 | print 'maxV,',np.ndarray.max(mrnp),' minV, ',np.ndarray.min(mrnp) 501 | 502 | if opt.how2normalize == 4: 503 | maxLPET = 149.366742 504 | maxPercentLPET = 7.76 505 | minLPET = 0.00055037 506 | meanLPET = 0.27593288 507 | stdLPET = 0.75747500 508 | 509 | # for rsCT 510 | maxCT = 27279 511 | maxPercentCT = 1320 512 | minCT = -1023 513 | meanCT = -601.1929 514 | stdCT = 475.034 515 | 516 | # for s-pet 517 | maxSPET = 156.675962 518 | maxPercentSPET = 7.79 519 | minSPET = 0.00055037 520 | meanSPET = 0.284224789 521 | stdSPET = 0.7642257 522 | 523 | #matLPET = (mrnp - meanLPET) / (stdLPET) 524 | matLPET = (mrnp - minLPET) / (maxPercentLPET - minLPET) 525 | matCT = (ctnp - meanCT) / stdCT 526 | matSPET = (hpetnp - minSPET) / (maxPercentSPET - minSPET) 527 | 528 | if opt.how2normalize == 5: 529 | # for rsCT 530 | maxCT = 27279 531 | maxPercentCT = 1320 532 | minCT = -1023 533 | meanCT = -601.1929 534 | stdCT = 475.034 535 | 536 | print 537 | 'ct, max: ', np.amax(ctnp), ' ct, min: ', np.amin(ctnp) 538 | 539 | # matLPET = (mrnp - meanLPET) / (stdLPET) 540 | matLPET = mrnp 541 | matCT = (ctnp - meanCT) / stdCT 542 | matSPET = hpetnp 543 | 544 | if opt.how2normalize == 6: 545 | maxPercentPET, minPercentPET = np.percentile(mrnp, [99.5, 0]) 546 | maxPercentCT, minPercentCT = np.percentile(ctnp, [99.5, 0]) 547 | print 'maxPercentPET: ', maxPercentPET, ' minPercentPET: ', minPercentPET, ' maxPercentCT: ', maxPercentCT, 'minPercentCT: ', minPercentCT 548 | 549 | matLPET = (mrnp - minPercentPET) / (maxPercentPET - minPercentPET) 550 | matSPET = (hpetnp - minPercentPET) / (maxPercentPET - minPercentPET) 551 | 552 | matCT = (ctnp - minPercentCT) / (maxPercentCT - minPercentCT) 553 | 554 | 555 | if not opt.isMultiSource: 556 | matFA = matLPET 557 | matGT = hpetnp 558 | 559 | print 'matFA shape: ',matFA.shape, ' matGT shape: ', matGT.shape 560 | matOut = testOneSubject_aver_res(matFA,matGT,[5,64,64],[1,64,64],[1,32,32],net,opt.prefixModelName+'%d.pt'%iter) 561 | print 'matOut shape: ',matOut.shape 562 | if opt.how2normalize==6: 563 | ct_estimated = matOut * (maxPercentPET - minPercentPET) + minPercentPET 564 | else: 565 | ct_estimated = matOut 566 | 567 | 568 | itspsnr = psnr(ct_estimated, matGT) 569 | 570 | print 'pred: ',ct_estimated.dtype, ' shape: ',ct_estimated.shape 571 | print 'gt: ',ctnp.dtype,' shape: ',ct_estimated.shape 572 | print 'psnr = ',itspsnr 573 | volout = sitk.GetImageFromArray(ct_estimated) 574 | volout.SetSpacing(spacing) 575 | volout.SetOrigin(origin) 576 | volout.SetDirection(direction) 577 | sitk.WriteImage(volout,opt.prefixPredictedFN+'{}'.format(iter)+'.nii.gz') 578 | else: 579 | matFA = matLPET 580 | matGT = hpetnp 581 | print 'matFA shape: ', matFA.shape, ' matGT shape: ', matGT.shape 582 | matOut = testOneSubject_aver_res_multiModal(matFA, matCT, matGT, [5, 64, 64], [1, 64, 64], [1, 32, 32], net, 583 | opt.prefixModelName + '%d.pt' % iter) 584 | print 'matOut shape: ', matOut.shape 585 | if opt.how2normalize==6: 586 | ct_estimated = matOut * (maxPercentPET - minPercentPET) + minPercentPET 587 | else: 588 | ct_estimated = matOut 589 | 590 | itspsnr = psnr(ct_estimated, matGT) 591 | 592 | print 'pred: ', ct_estimated.dtype, ' shape: ', ct_estimated.shape 593 | print 'gt: ', ctnp.dtype, ' shape: ', ct_estimated.shape 594 | print 'psnr = ', itspsnr 595 | volout = sitk.GetImageFromArray(ct_estimated) 596 | volout.SetSpacing(spacing) 597 | volout.SetOrigin(origin) 598 | volout.SetDirection(direction) 599 | sitk.WriteImage(volout, opt.prefixPredictedFN + '{}'.format(iter) + '.nii.gz') 600 | 601 | print('Finished Training') 602 | 603 | if __name__ == '__main__': 604 | os.environ['CUDA_VISIBLE_DEVICES'] = str(opt.gpuID) 605 | main() 606 | 607 | -------------------------------------------------------------------------------- /runCTRecon3d.py: -------------------------------------------------------------------------------- 1 | # from __future__ import print_function 2 | import argparse, os 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from torch.autograd import Variable 6 | import numpy as np 7 | import torch.optim as optim 8 | import torch 9 | import torch.utils.data as data_utils 10 | from utils import * 11 | from ResUnet3d_pytorch import UNet, ResUNet, UNet_LRes, ResUNet_LRes, Discriminator 12 | # from Unet3d_pytorch import UNet3D 13 | from nnBuildUnits import CrossEntropy3d, topK_RegLoss, RelativeThreshold_RegLoss, adjust_learning_rate 14 | import time 15 | import SimpleITK as sitk 16 | 17 | # Training settings 18 | parser = argparse.ArgumentParser(description="PyTorch InfantSeg") 19 | parser.add_argument("--gpuID", type=int, default=3, help="how to normalize the data") 20 | parser.add_argument("--isAdLoss", action="store_true", help="is adversarial loss used?", default=False) 21 | parser.add_argument("--lambda_AD", default=0.05, type=float, help="Momentum, Default: 0.05") 22 | parser.add_argument("--how2normalize", type=int, default=5, help="how to normalize the data") 23 | parser.add_argument("--whichLoss", type=int, default=1, help="which loss to use: 1. LossL1, 2. lossRTL1, 3. MSE (default)") 24 | parser.add_argument("--whichNet", type=int, default=4, help="which loss to use: 1. UNet, 2. ResUNet, 3. UNet_LRes and 4. ResUNet_LRes (default, 3)") 25 | parser.add_argument("--lossBase", type=int, default=1, help="The base to multiply the lossG_G, Default (1)") 26 | parser.add_argument("--batchSize", type=int, default=10, help="training batch size") 27 | parser.add_argument("--isMultiSource", action="store_true", help="is multiple modality used?", default=False) 28 | parser.add_argument("--numOfChannel_singleSource", type=int, default=5, help="# of channels for a 2D patch for the main modality (Default, 5)") 29 | parser.add_argument("--numOfChannel_allSource", type=int, default=1, help="# of channels for a 2D patch for all the concatenated modalities (Default, 5)") 30 | parser.add_argument("--numofIters", type=int, default=200000, help="number of iterations to train for") 31 | parser.add_argument("--showTrainLossEvery", type=int, default=100, help="number of iterations to show train loss") 32 | parser.add_argument("--saveModelEvery", type=int, default=5000, help="number of iterations to save the model") 33 | parser.add_argument("--showValPerformanceEvery", type=int, default=1000, help="number of iterations to show validation performance") 34 | parser.add_argument("--showTestPerformanceEvery", type=int, default=5000, help="number of iterations to show test performance") 35 | parser.add_argument("--lr", type=float, default=5e-3, help="Learning Rate. Default=1e-4") 36 | parser.add_argument("--dropout_rate", default=0.2, type=float, help="prob to drop neurons to zero: 0.2") 37 | parser.add_argument("--decLREvery", type=int, default=10000, help="Sets the learning rate to the initial LR decayed by momentum every n iterations, Default: n=40000") 38 | parser.add_argument("--cuda", action="store_true", help="Use cuda?", default=True) 39 | parser.add_argument("--resume", default="/home/niedong/Data4LowDosePET/pytorch_UNet/resunet3d_dp_pet_BatchAug_noNorm_lres_bn_lr5e3_lrdec_base1_lossL1_0p005_0627_5000.pt", type=str, help="Path to checkpoint (default: none)") 40 | parser.add_argument("--start_epoch", default=1, type=int, help="Manual epoch number (useful on restarts)") 41 | parser.add_argument("--threads", type=int, default=1, help="Number of threads for data loader to use, Default: 1") 42 | parser.add_argument("--momentum", default=0.9, type=float, help="Momentum, Default: 0.9") 43 | parser.add_argument("--weight-decay", "--wd", default=1e-4, type=float, help="weight decay, Default: 1e-4") 44 | parser.add_argument("--RT_th", default=0.005, type=float, help="Relative thresholding: 0.005") 45 | parser.add_argument("--pretrained", default="", type=str, help="path to pretrained model (default: none)") 46 | parser.add_argument("--prefixModelName", default="/home/niedong/Data4LowDosePET/pytorch_UNet/resunet3d_dp_pet_BatchAug_noNorm_lres_bn_lr5e3_lrdec_base1_lossL1_0p005_0627_", type=str, help="prefix of the to-be-saved model name") 47 | parser.add_argument("--prefixPredictedFN", default="preSub1_pet_BatchAug_noNorm_resunet3d_dp_lres_bn_lr5e3_lrdec_base1_lossL1_0p005_0627_", type=str, help="prefix of the to-be-saved predicted filename") 48 | 49 | global opt, model 50 | opt = parser.parse_args() 51 | 52 | def main(): 53 | print opt 54 | 55 | # prefixModelName = 'Regressor_1112_' 56 | # prefixPredictedFN = 'preSub1_1112_' 57 | # showTrainLossEvery = 100 58 | # lr = 1e-4 59 | # showTestPerformanceEvery = 2000 60 | # saveModelEvery = 2000 61 | # decLREvery = 40000 62 | # numofIters = 200000 63 | # how2normalize = 0 64 | 65 | 66 | netD = Discriminator() 67 | netD.apply(weights_init) 68 | netD.cuda() 69 | 70 | optimizerD = optim.Adam(netD.parameters(),lr=1e-3) 71 | criterion_bce=nn.BCELoss() 72 | criterion_bce.cuda() 73 | 74 | #net=UNet() 75 | if opt.whichNet==1: 76 | net = UNet(in_channel=opt.numOfChannel_allSource, n_classes=1) 77 | elif opt.whichNet==2: 78 | net = ResUNet(in_channel=opt.numOfChannel_allSource, n_classes=1) 79 | elif opt.whichNet==3: 80 | net = UNet_LRes(in_channel=opt.numOfChannel_allSource, n_classes=1) 81 | elif opt.whichNet==4: 82 | net = ResUNet_LRes(in_channel=opt.numOfChannel_allSource, n_classes=1, dp_prob = opt.dropout_rate) 83 | #net.apply(weights_init) 84 | net.cuda() 85 | params = list(net.parameters()) 86 | print('len of params is ') 87 | print(len(params)) 88 | print('size of params is ') 89 | print(params[0].size()) 90 | 91 | 92 | 93 | optimizer = optim.Adam(net.parameters(),lr=opt.lr) 94 | criterion_L2 = nn.MSELoss() 95 | criterion_L1 = nn.L1Loss() 96 | criterion_RTL1 = RelativeThreshold_RegLoss(opt.RT_th) 97 | #criterion = nn.CrossEntropyLoss() 98 | # criterion = nn.NLLLoss2d() 99 | 100 | given_weight = torch.cuda.FloatTensor([1,4,4,2]) 101 | 102 | criterion_3d = CrossEntropy3d(weight=given_weight) 103 | 104 | criterion_3d = criterion_3d.cuda() 105 | criterion_L2 = criterion_L2.cuda() 106 | criterion_L1 = criterion_L1.cuda() 107 | criterion_RTL1 = criterion_RTL1.cuda() 108 | 109 | #inputs=Variable(torch.randn(1000,1,32,32)) #here should be tensor instead of variable 110 | #targets=Variable(torch.randn(1000,10,1,1)) #here should be tensor instead of variable 111 | # trainset=data_utils.TensorDataset(inputs, targets) 112 | # trainloader = data_utils.DataLoader(trainset, batch_size=4, shuffle=True, num_workers=2) 113 | # inputs=torch.randn(1000,1,32,32) 114 | # targets=torch.LongTensor(1000) 115 | 116 | path_test ='/home/niedong/DataCT/data_niigz/' 117 | path_patients_h5 = '/home/niedong/DataCT/h5Data3D_noNorm/trainBatch3D_H5' 118 | path_patients_h5_test ='/home/niedong/DataCT/h5Data3D_noNorm/val3D_H5' 119 | # path_patients_h5_test ='/home/niedong/Data4LowDosePET/test2D_H5' 120 | # batch_size=10 121 | #data_generator = Generator_2D_slices(path_patients_h5,opt.batchSize,inputKey='data3T',outputKey='data7T') 122 | #data_generator_test = Generator_2D_slices(path_patients_h5_test,opt.batchSize,inputKey='data3T',outputKey='data7T') 123 | 124 | data_generator = Generator_3D_patches(path_patients_h5,opt.batchSize, inputKey='dataLPET', outputKey='dataHPET') 125 | data_generator_test = Generator_3D_patches(path_patients_h5_test,opt.batchSize, inputKey='dataLPET', outputKey='dataHPET') 126 | if opt.resume: 127 | if os.path.isfile(opt.resume): 128 | print("=> loading checkpoint '{}'".format(opt.resume)) 129 | checkpoint = torch.load(opt.resume) 130 | net.load_state_dict(checkpoint['model']) 131 | opt.start_epoch = 100000 132 | opt.start_epoch = checkpoint["epoch"] - 1 133 | # net.load_state_dict(checkpoint["model"].state_dict()) 134 | else: 135 | print("=> no checkpoint found at '{}'".format(opt.resume)) 136 | ########### We'd better use dataloader to load a lot of data,and we also should train several epoches############### 137 | ########### We'd better use dataloader to load a lot of data,and we also should train several epoches############### 138 | running_loss = 0.0 139 | start = time.time() 140 | for iter in range(opt.start_epoch, opt.numofIters+1): 141 | #print('iter %d'%iter) 142 | 143 | # inputs, exinputs, labels = data_generator.next() 144 | inputs, labels = data_generator.next() 145 | 146 | # xx = np.transpose(inputs,(5,64,64)) 147 | # print 'size of inputs: ', inputs.shape 148 | inputs = np.transpose(inputs,(0,4,1,2,3)) 149 | # inputs = np.squeeze(inputs) #16x64x64 150 | # exinputs = np.squeeze(exinputs) #5x64x64 151 | # print 'shape is ....',inputs.shape 152 | labels = np.squeeze(labels) #64x64 153 | # labels = labels.astype(int) 154 | 155 | inputs = inputs.astype(float) 156 | inputs = torch.from_numpy(inputs) 157 | inputs = inputs.float() 158 | # exinputs = exinputs.astype(float) 159 | # exinputs = torch.from_numpy(exinputs) 160 | # exinputs = exinputs.float() 161 | labels = labels.astype(float) 162 | labels = torch.from_numpy(labels) 163 | labels = labels.float() 164 | #print type(inputs), type(exinputs) 165 | if opt.isMultiSource: 166 | # source = torch.cat((inputs, exinputs),dim=1) 167 | print 'you have to tune the multi source part' 168 | else: 169 | source = inputs 170 | #source = inputs 171 | # mid_slice = opt.numOfChannel_singleSource//2 172 | residual_source = inputs 173 | #inputs = inputs.cuda() 174 | #exinputs = exinputs.cuda() 175 | source = source.cuda() 176 | residual_source = residual_source.cuda() 177 | labels = labels.cuda() 178 | #we should consider different data to train 179 | 180 | #wrap them into Variable 181 | source, residual_source, labels = Variable(source),Variable(residual_source), Variable(labels) 182 | #inputs, exinputs, labels = Variable(inputs),Variable(exinputs), Variable(labels) 183 | 184 | ## (1) update D network: maximize log(D(x)) + log(1 - D(G(z))) 185 | if opt.isAdLoss: 186 | if opt.whichNet == 3 or opt.whichNet == 4: 187 | outputG = net(source, residual_source) # 5x64x64->1*64x64 188 | else: 189 | outputG = net(source) # 5x64x64->1*64x64 190 | #outputG = net(source,residual_source) #5x64x64->1*64x64 191 | 192 | if len(labels.size())==3: 193 | labels = labels.unsqueeze(1) 194 | 195 | outputD_real = netD(labels) 196 | outputD_real = F.sigmoid(outputD_real) 197 | 198 | if len(outputG.size())==3: 199 | outputG = outputG.unsqueeze(1) 200 | 201 | outputD_fake = netD(outputG) 202 | outputD_fake = F.sigmoid(outputD_fake) 203 | netD.zero_grad() 204 | batch_size = inputs.size(0) 205 | real_label = torch.ones(batch_size,1) 206 | real_label = real_label.cuda() 207 | #print(real_label.size()) 208 | real_label = Variable(real_label) 209 | #print(outputD_real.size()) 210 | loss_real = criterion_bce(outputD_real,real_label) 211 | loss_real.backward() 212 | #train with fake data 213 | fake_label = torch.zeros(batch_size,1) 214 | # fake_label = torch.FloatTensor(batch_size) 215 | # fake_label.data.resize_(batch_size).fill_(0) 216 | fake_label = fake_label.cuda() 217 | fake_label = Variable(fake_label) 218 | loss_fake = criterion_bce(outputD_fake,fake_label) 219 | loss_fake.backward() 220 | 221 | lossD = loss_real + loss_fake 222 | # print 'loss_real is ',loss_real.data[0],'loss_fake is ',loss_fake.data[0],'outputD_real is',outputD_real.data[0] 223 | # print('loss for discriminator is %f'%lossD.data[0]) 224 | #update network parameters 225 | optimizerD.step() 226 | 227 | 228 | ## (2) update G network: minimize the L1/L2 loss, maximize the D(G(x)) 229 | 230 | # print inputs.data.shape 231 | #outputG = net(source) #here I am not sure whether we should use twice or not 232 | if opt.whichNet == 3 or opt.whichNet == 4: 233 | outputG = net(source, residual_source) # 5x64x64->1*64x64 234 | else: 235 | outputG = net(source) # 5x64x64->1*64x64 236 | #outputG = net(source,residual_source) #5x64x64->1*64x64 237 | net.zero_grad() 238 | if opt.whichLoss==1: 239 | lossG_G = criterion_L1(torch.squeeze(outputG), torch.squeeze(labels)) 240 | elif opt.whichLoss==2: 241 | lossG_G = criterion_RTL1(torch.squeeze(outputG), torch.squeeze(labels)) 242 | else: 243 | lossG_G = criterion_L2(torch.squeeze(outputG), torch.squeeze(labels)) 244 | lossG_G = opt.lossBase * lossG_G 245 | lossG_G.backward() #compute gradients 246 | 247 | if opt.isAdLoss: 248 | #we want to fool the discriminator, thus we pretend the label here to be real. Actually, we can explain from the 249 | #angel of equation (note the max and min difference for generator and discriminator) 250 | #outputG = net(inputs) 251 | #outputG = net(source,residual_source) #5x64x64->1*64x64 252 | if opt.whichNet == 3 or opt.whichNet == 4: 253 | outputG = net(source, residual_source) # 5x64x64->1*64x64 254 | else: 255 | outputG = net(source) # 5x64x64->1*64x64 256 | 257 | if len(outputG.size())==3: 258 | outputG = outputG.unsqueeze(1) 259 | 260 | outputD = netD(outputG) 261 | outputD = F.sigmoid(outputD) 262 | lossG_D = opt.lambda_AD*criterion_bce(outputD,real_label) #note, for generator, the label for outputG is real, because the G wants to confuse D 263 | lossG_D.backward() 264 | 265 | #for other losses, we can define the loss function following the pytorch tutorial 266 | 267 | optimizer.step() #update network parameters 268 | 269 | #print('loss for generator is %f'%lossG.data[0]) 270 | #print statistics 271 | running_loss = running_loss + lossG_G.data[0] 272 | 273 | 274 | if iter%opt.showTrainLossEvery==0: #print every 2000 mini-batches 275 | print '************************************************' 276 | print 'time now is: ' + time.asctime(time.localtime(time.time())) 277 | # print 'running loss is ',running_loss 278 | print 'average running loss for generator between iter [%d, %d] is: %.5f'%(iter - 100 + 1,iter,running_loss/100) 279 | 280 | print 'lossG_G is %.5f respectively.'%(lossG_G.data[0]) 281 | if opt.isAdLoss: 282 | print 'loss_real is ',loss_real.data[0],'loss_fake is ',loss_fake.data[0],'outputD_real is',outputD_real.data[0] 283 | print('loss for discriminator is %f'%lossD.data[0]) 284 | 285 | print 'cost time for iter [%d, %d] is %.2f'%(iter - 100 + 1,iter, time.time()-start) 286 | print '************************************************' 287 | running_loss = 0.0 288 | start = time.time() 289 | if iter%opt.saveModelEvery==0: #save the model 290 | state = { 291 | 'epoch': iter+1, 292 | 'model': net.state_dict() 293 | } 294 | torch.save(state, opt.prefixModelName+'%d.pt'%iter) 295 | print 'save model: '+opt.prefixModelName+'%d.pt'%iter 296 | 297 | if opt.isAdLoss: 298 | torch.save(netD.state_dict(), opt.prefixModelName+'_net_D%d.pt'%iter) 299 | if iter%opt.decLREvery==0: 300 | opt.lr = opt.lr*0.5 301 | adjust_learning_rate(optimizer, opt.lr) 302 | 303 | if iter%opt.showValPerformanceEvery==0: #test one subject 304 | # to test on the validation dataset in the format of h5 305 | # inputs,exinputs,labels = data_generator_test.next() 306 | inputs, labels = data_generator_test.next() 307 | 308 | inputs = np.transpose(inputs,(0,4,1,2,3)) 309 | # inputs = np.squeeze(inputs) 310 | 311 | # exinputs = np.transpose(exinputs, (0, 3, 1, 2)) 312 | # exinputs = np.squeeze(exinputs) # 5x64x64 313 | 314 | labels = np.squeeze(labels) 315 | 316 | inputs = torch.from_numpy(inputs) 317 | inputs = inputs.float() 318 | # exinputs = torch.from_numpy(exinputs) 319 | # exinputs = exinputs.float() 320 | labels = torch.from_numpy(labels) 321 | labels = labels.float() 322 | # mid_slice = opt.numOfChannel_singleSource // 2 323 | residual_source = inputs 324 | if opt.isMultiSource: 325 | # source = torch.cat((inputs, exinputs), dim=1) 326 | print 'you have to tune the multi source part' 327 | else: 328 | source = inputs 329 | source = source.cuda() 330 | residual_source = residual_source.cuda() 331 | labels = labels.cuda() 332 | source,residual_source,labels = Variable(source),Variable(residual_source), Variable(labels) 333 | 334 | # source = inputs 335 | #outputG = net(inputs) 336 | if opt.whichNet == 3 or opt.whichNet == 4: 337 | outputG = net(source, residual_source) # 5x64x64->1*64x64 338 | else: 339 | outputG = net(source) # 5x64x64->1*64x64 340 | #outputG = net(source,residual_source) #5x64x64->1*64x64 341 | 342 | if opt.whichLoss == 1: 343 | lossG_G = criterion_L1(torch.squeeze(outputG), torch.squeeze(labels)) 344 | elif opt.whichLoss == 2: 345 | lossG_G = criterion_RTL1(torch.squeeze(outputG), torch.squeeze(labels)) 346 | else: 347 | lossG_G = criterion_L2(torch.squeeze(outputG), torch.squeeze(labels)) 348 | lossG_G = opt.lossBase * lossG_G 349 | print '.......come to validation stage: iter {}'.format(iter),'........' 350 | print 'lossG_G is %.5f.'%(lossG_G.data[0]) 351 | 352 | if iter % opt.showTestPerformanceEvery == 0: # test one subject 353 | mr_test_itk=sitk.ReadImage(os.path.join(path_test,'sub1_sourceCT.nii.gz')) 354 | ct_test_itk=sitk.ReadImage(os.path.join(path_test,'sub1_extraCT.nii.gz')) 355 | hpet_test_itk = sitk.ReadImage(os.path.join(path_test, 'sub1_targetCT.nii.gz')) 356 | 357 | spacing = hpet_test_itk.GetSpacing() 358 | origin = hpet_test_itk.GetOrigin() 359 | direction = hpet_test_itk.GetDirection() 360 | 361 | mrnp=sitk.GetArrayFromImage(mr_test_itk) 362 | ctnp=sitk.GetArrayFromImage(ct_test_itk) 363 | hpetnp=sitk.GetArrayFromImage(hpet_test_itk) 364 | 365 | ##### specific normalization ##### 366 | # mu = np.mean(mrnp) 367 | # maxV, minV = np.percentile(mrnp, [99 ,25]) 368 | # #mrimg=mrimg 369 | # mrnp = (mrnp-minV)/(maxV-minV) 370 | 371 | 372 | 373 | #for training data in pelvicSeg 374 | if opt.how2normalize == 1: 375 | maxV, minV = np.percentile(mrnp, [99 ,1]) 376 | print 'maxV,',maxV,' minV, ',minV 377 | mrnp = (mrnp-mu)/(maxV-minV) 378 | print 'unique value: ',np.unique(ctnp) 379 | 380 | #for training data in pelvicSeg 381 | if opt.how2normalize == 2: 382 | maxV, minV = np.percentile(mrnp, [99 ,1]) 383 | print 'maxV,',maxV,' minV, ',minV 384 | mrnp = (mrnp-mu)/(maxV-minV) 385 | print 'unique value: ',np.unique(ctnp) 386 | 387 | #for training data in pelvicSegRegH5 388 | if opt.how2normalize== 3: 389 | std = np.std(mrnp) 390 | mrnp = (mrnp - mu)/std 391 | print 'maxV,',np.ndarray.max(mrnp),' minV, ',np.ndarray.min(mrnp) 392 | 393 | if opt.how2normalize == 4: 394 | maxLPET = 149.366742 395 | maxPercentLPET = 7.76 396 | minLPET = 0.00055037 397 | meanLPET = 0.27593288 398 | stdLPET = 0.75747500 399 | 400 | # for rsCT 401 | maxCT = 27279 402 | maxPercentCT = 1320 403 | minCT = -1023 404 | meanCT = -601.1929 405 | stdCT = 475.034 406 | 407 | # for s-pet 408 | maxSPET = 156.675962 409 | maxPercentSPET = 7.79 410 | minSPET = 0.00055037 411 | meanSPET = 0.284224789 412 | stdSPET = 0.7642257 413 | 414 | #matLPET = (mrnp - meanLPET) / (stdLPET) 415 | matLPET = (mrnp - minLPET) / (maxPercentLPET - minLPET) 416 | matCT = (ctnp - meanCT) / stdCT 417 | matSPET = (hpetnp - minSPET) / (maxPercentSPET - minSPET) 418 | 419 | if opt.how2normalize == 5: 420 | # for rsCT 421 | maxCT = 27279 422 | maxPercentCT = 1320 423 | minCT = -1023 424 | meanCT = -601.1929 425 | stdCT = 475.034 426 | 427 | print 428 | 'ct, max: ', np.amax(ctnp), ' ct, min: ', np.amin(ctnp) 429 | 430 | # matLPET = (mrnp - meanLPET) / (stdLPET) 431 | matLPET = mrnp 432 | matCT = (ctnp - meanCT) / stdCT 433 | matSPET = hpetnp 434 | 435 | if not opt.isMultiSource: 436 | # matFA = matLPET 437 | # matGT = matSPET 438 | matFA = mrnp 439 | matGT = hpetnp 440 | print 'matFA shape: ',matFA.shape, ' matGT shape: ', matGT.shape 441 | matOut = testOneSubject_aver_res(matFA,matGT,[16,64,64],[16,64,64],[8,32,32],net,opt.prefixModelName+'%d.pt'%iter, nd=3) 442 | print 'matOut shape: ',matOut.shape 443 | ct_estimated = matOut 444 | 445 | itspsnr = psnr(ct_estimated, matGT) 446 | 447 | print 'pred: ',ct_estimated.dtype, ' shape: ',ct_estimated.shape 448 | print 'gt: ',ctnp.dtype,' shape: ',ct_estimated.shape 449 | print 'psnr = ',itspsnr 450 | volout = sitk.GetImageFromArray(ct_estimated) 451 | volout.SetSpacing(spacing) 452 | volout.SetOrigin(origin) 453 | volout.SetDirection(direction) 454 | sitk.WriteImage(volout,opt.prefixPredictedFN+'{}'.format(iter)+'.nii.gz') 455 | else: 456 | # matFA = matLPET 457 | # matGT = matSPET 458 | matFA = mrnp 459 | matGT = hpetnp 460 | print 'matFA shape: ', matFA.shape, ' matGT shape: ', matGT.shape 461 | matOut = testOneSubject_aver_res_multiModal(matFA, matCT, matGT, [16, 64, 64], [16, 64, 64], [8, 32, 32], net, 462 | opt.prefixModelName + '%d.pt' % iter) 463 | print 'matOut shape: ', matOut.shape 464 | ct_estimated = matOut 465 | 466 | itspsnr = psnr(ct_estimated, matGT) 467 | 468 | print 'pred: ', ct_estimated.dtype, ' shape: ', ct_estimated.shape 469 | print 'gt: ', ctnp.dtype, ' shape: ', ct_estimated.shape 470 | print 'psnr = ', itspsnr 471 | volout = sitk.GetImageFromArray(ct_estimated) 472 | volout.SetSpacing(spacing) 473 | volout.SetOrigin(origin) 474 | volout.SetDirection(direction) 475 | sitk.WriteImage(volout, opt.prefixPredictedFN + '{}'.format(iter) + '.nii.gz') 476 | 477 | print('Finished Training') 478 | 479 | if __name__ == '__main__': 480 | os.environ['CUDA_VISIBLE_DEVICES'] = str(opt.gpuID) 481 | main() 482 | 483 | -------------------------------------------------------------------------------- /runTesting_Recon.py: -------------------------------------------------------------------------------- 1 | # from __future__ import print_function 2 | import argparse, os 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from torch.autograd import Variable 6 | import numpy as np 7 | import torch.optim as optim 8 | import torch 9 | import torch.utils.data as data_utils 10 | from Unet2d_pytorch import UNet, ResUNet, UNet_LRes, ResUNet_LRes, Discriminator 11 | from utils import * 12 | # from ganComponents import * 13 | # from nnBuildUnits import CrossEntropy2d 14 | # from nnBuildUnits import computeSampleAttentionWeight 15 | # from nnBuildUnits import adjust_learning_rate 16 | import time 17 | # from dataClean import denoiseImg,denoiseImg_isolation,denoiseImg_closing 18 | import SimpleITK as sitk 19 | 20 | parser = argparse.ArgumentParser(description="PyTorch InfantSeg") 21 | 22 | parser.add_argument("--isSegReg", action="store_true", help="is Seg and Reg?", default=False) 23 | parser.add_argument("--isMultiSource", action="store_true", help="is multiple input modality used?", default=False) 24 | parser.add_argument("--whichLoss", type=int, default=1, help="which loss to use: 1. LossL1, 2. lossRTL1, 3. MSE (default)") 25 | parser.add_argument("--whichNet", type=int, default=4, help="which loss to use: 1. UNet, 2. ResUNet, 3. UNet_LRes and 4. ResUNet_LRes (default, 3)") 26 | parser.add_argument("--lossBase", type=int, default=1, help="The base to multiply the lossG_G, Default (1)") 27 | parser.add_argument("--batchSize", type=int, default=32, help="training batch size") 28 | parser.add_argument("--numOfChannel_singleSource", type=int, default=5, help="# of channels for a 2D patch for the main modality (Default, 5)") 29 | parser.add_argument("--numOfChannel_allSource", type=int, default=5, help="# of channels for a 2D patch for all the concatenated modalities (Default, 5)") 30 | parser.add_argument("--isResidualEnhancement", action="store_true", help="is residual learning operation enhanced?", default=False) 31 | parser.add_argument("--isViewExpansion", action="store_true", help="is view expanded?", default=True) 32 | parser.add_argument("--isAdLoss", action="store_true", help="is adversarial loss used?", default=True) 33 | parser.add_argument("--isSpatialDropOut", action="store_true", help="is spatial dropout used?", default=False) 34 | parser.add_argument("--isFocalLoss", action="store_true", help="is focal loss used?", default=False) 35 | parser.add_argument("--isSampleImportanceFromAd", action="store_true", help="is sample importance from adversarial network used?", default=False) 36 | parser.add_argument("--dropoutRate", type=float, default=0.25, help="Spatial Dropout Rate. Default=0.25") 37 | parser.add_argument("--lambdaAD", type=float, default=0, help="loss coefficient for AD loss. Default=0") 38 | parser.add_argument("--adImportance", type=float, default=0, help="Sample importance from AD network. Default=0") 39 | parser.add_argument("--isFixedRegions", action="store_true", help="Is the organ regions roughly known?", default=False) 40 | #parser.add_argument("--modelPath", default="/home/niedong/Data4LowDosePET/pytorch_UNet/model/resunet2d_pet_Aug_noNorm_lres_bn_lr5e3_base1_lossL1_0p01_0624_200000.pt", type=str, help="prefix of the to-be-saved model name") 41 | parser.add_argument("--modelPath", default="/home/niedong/Data4LowDosePET/pytorch_UNet/model/resunet2d_dp_pet_BatchAug_sNorm_lres_bn_lr5e3_lrdec_base1_lossL1_0p005_0628_200000.pt", type=str, help="prefix of the to-be-saved model name") 42 | parser.add_argument("--prefixPredictedFN", default="pred_resunet2d_dp_pet_Aug_sNorm_lres_lrdce_bn_lr5e3_base1_lossL1_0628_20w_", type=str, help="prefix of the to-be-saved predicted filename") 43 | parser.add_argument("--how2normalize", type=int, default=6, help="how to normalize the data") 44 | parser.add_argument("--resType", type=int, default=2, help="resType: 0: segmentation map (integer); 1: regression map (continuous); 2: segmentation map + probability map") 45 | 46 | def main(): 47 | opt = parser.parse_args() 48 | print opt 49 | 50 | path_test = '/home/niedong/Data4LowDosePET/data_niigz_scale/' 51 | 52 | if opt.whichNet==1: 53 | netG = UNet(in_channel=opt.numOfChannel_allSource, n_classes=1) 54 | elif opt.whichNet==2: 55 | netG = ResUNet(in_channel=opt.numOfChannel_allSource, n_classes=1) 56 | elif opt.whichNet==3: 57 | netG = UNet_LRes(in_channel=opt.numOfChannel_allSource, n_classes=1) 58 | elif opt.whichNet==4: 59 | netG = ResUNet_LRes(in_channel=opt.numOfChannel_allSource, n_classes=1) 60 | 61 | #netG.apply(weights_init) 62 | netG.cuda() 63 | 64 | checkpoint = torch.load(opt.modelPath) 65 | netG.load_state_dict(checkpoint['model']) 66 | 67 | 68 | ids = [1,6,11,16,21,26,31,36,41,46] #in on folder, we test 10 which is the testing set 69 | ids = [1] #in on folder, we test 10 which is the testing set 70 | 71 | ids = ['1_QFZ','2_LLQ','3_LMB','4_ZSL','5_CJB','11_TCL','15_WYL','21_PY','25_LYL','31_CZX','35_WLL','41_WQC','45_YXM'] 72 | for ind in ids: 73 | start = time.time() 74 | 75 | mr_test_itk = sitk.ReadImage(os.path.join(path_test,'%s_60s_suv.nii.gz'%ind))#input modality 76 | ct_test_itk = sitk.ReadImage(os.path.join(path_test,'%s_rsCT.nii.gz'%ind))#auxialliary modality 77 | hpet_test_itk = sitk.ReadImage(os.path.join(path_test, '%s_120s_suv.nii.gz'%ind))#output modality 78 | 79 | 80 | spacing = hpet_test_itk.GetSpacing() 81 | origin = hpet_test_itk.GetOrigin() 82 | direction = hpet_test_itk.GetDirection() 83 | 84 | mrnp = sitk.GetArrayFromImage(mr_test_itk) 85 | ctnp = sitk.GetArrayFromImage(ct_test_itk) 86 | hpetnp = sitk.GetArrayFromImage(hpet_test_itk) 87 | 88 | ##### specific normalization ##### 89 | # mu = np.mean(mrnp) 90 | # maxV, minV = np.percentile(mrnp, [99 ,25]) 91 | # #mrimg=mrimg 92 | # mrnp = (mrnp-minV)/(maxV-minV) 93 | 94 | # for training data in pelvicSeg 95 | if opt.how2normalize == 1: 96 | maxV, minV = np.percentile(mrnp, [99, 1]) 97 | print 'maxV,', maxV, ' minV, ', minV 98 | mrnp = (mrnp - mu) / (maxV - minV) 99 | print 'unique value: ', np.unique(ctnp) 100 | 101 | # for training data in pelvicSeg 102 | if opt.how2normalize == 2: 103 | maxV, minV = np.percentile(mrnp, [99, 1]) 104 | print 'maxV,', maxV, ' minV, ', minV 105 | mrnp = (mrnp - mu) / (maxV - minV) 106 | print 'unique value: ', np.unique(ctnp) 107 | 108 | # for training data in pelvicSegRegH5 109 | if opt.how2normalize == 3: 110 | std = np.std(mrnp) 111 | mrnp = (mrnp - mu) / std 112 | print 'maxV,', np.ndarray.max(mrnp), ' minV, ', np.ndarray.min(mrnp) 113 | 114 | if opt.how2normalize == 4: 115 | maxLPET = 149.366742 116 | maxPercentLPET = 7.76 117 | minLPET = 0.00055037 118 | meanLPET = 0.27593288 119 | stdLPET = 0.75747500 120 | 121 | # for rsCT 122 | maxCT = 27279 123 | maxPercentCT = 1320 124 | minCT = -1023 125 | meanCT = -601.1929 126 | stdCT = 475.034 127 | 128 | # for s-pet 129 | maxSPET = 156.675962 130 | maxPercentSPET = 7.79 131 | minSPET = 0.00055037 132 | meanSPET = 0.284224789 133 | stdSPET = 0.7642257 134 | 135 | # matLPET = (mrnp - meanLPET) / (stdLPET) 136 | matLPET = (mrnp - minLPET) / (maxPercentLPET - minLPET) 137 | matCT = (ctnp - meanCT) / stdCT 138 | matSPET = (hpetnp - minSPET) / (maxPercentSPET - minSPET) 139 | 140 | if opt.how2normalize == 5: 141 | # for rsCT 142 | maxCT = 27279 143 | maxPercentCT = 1320 144 | minCT = -1023 145 | meanCT = -601.1929 146 | stdCT = 475.034 147 | 148 | print 'ct, max: ', np.amax(ctnp), ' ct, min: ', np.amin(ctnp) 149 | 150 | # matLPET = (mrnp - meanLPET) / (stdLPET) 151 | matLPET = mrnp 152 | matCT = (ctnp - meanCT) / stdCT 153 | matSPET = hpetnp 154 | 155 | if opt.how2normalize == 6: 156 | maxPercentPET, minPercentPET = np.percentile(mrnp, [99.5, 0]) 157 | maxPercentCT, minPercentCT = np.percentile(ctnp, [99.5, 0]) 158 | print 'maxPercentPET: ', maxPercentPET, ' minPercentPET: ', minPercentPET, ' maxPercentCT: ', maxPercentCT, 'minPercentCT: ', minPercentCT 159 | 160 | matLPET = (mrnp - minPercentPET) / (maxPercentPET - minPercentPET) 161 | matSPET = (hpetnp - minPercentPET) / (maxPercentPET - minPercentPET) 162 | 163 | matCT = (ctnp - minPercentCT) / (maxPercentCT - minPercentCT) 164 | 165 | if not opt.isMultiSource: 166 | matFA = matLPET 167 | matGT = hpetnp 168 | 169 | print 'matFA shape: ', matFA.shape, ' matGT shape: ', matGT.shape 170 | matOut = testOneSubject_aver_res(matFA, matGT, [5, 64, 64], [1, 64, 64], [1, 16, 16], netG, opt.modelPath) 171 | print 'matOut shape: ', matOut.shape 172 | if opt.how2normalize == 6: 173 | ct_estimated = matOut * (maxPercentPET - minPercentPET) + minPercentPET 174 | else: 175 | ct_estimated = matOut 176 | ct_estimated[np.where(mrnp==0)] = 0 177 | itspsnr = psnr(ct_estimated, matGT) 178 | 179 | print 'pred: ', ct_estimated.dtype, ' shape: ', ct_estimated.shape 180 | print 'gt: ', ctnp.dtype, ' shape: ', ct_estimated.shape 181 | print 'psnr = ', itspsnr 182 | volout = sitk.GetImageFromArray(ct_estimated) 183 | volout.SetSpacing(spacing) 184 | volout.SetOrigin(origin) 185 | volout.SetDirection(direction) 186 | sitk.WriteImage(volout, opt.prefixPredictedFN + '{}'.format(ind) + '.nii.gz') 187 | else: 188 | matFA = matLPET 189 | matGT = hpetnp 190 | print 'matFA shape: ', matFA.shape, ' matGT shape: ', matGT.shape 191 | matOut = testOneSubject_aver_res_multiModal(matFA, matCT, matGT, [5, 64, 64], [1, 64, 64], [1, 16, 16], netG, opt.modelPath) 192 | print 'matOut shape: ', matOut.shape 193 | if opt.how2normalize == 6: 194 | ct_estimated = matOut * (maxPercentPET - minPercentPET) + minPercentPET 195 | else: 196 | ct_estimated = matOut 197 | 198 | ct_estimated[np.where(mrnp==0)] = 0 199 | itspsnr = psnr(ct_estimated, matGT) 200 | 201 | print 'pred: ', ct_estimated.dtype, ' shape: ', ct_estimated.shape 202 | print 'gt: ', ctnp.dtype, ' shape: ', ct_estimated.shape 203 | print 'psnr = ', itspsnr 204 | volout = sitk.GetImageFromArray(ct_estimated) 205 | volout.SetSpacing(spacing) 206 | volout.SetOrigin(origin) 207 | volout.SetDirection(direction) 208 | sitk.WriteImage(volout, opt.prefixPredictedFN + '{}'.format(ind) + '.nii.gz') 209 | 210 | if __name__ == '__main__': 211 | # testGradients() 212 | os.environ['CUDA_VISIBLE_DEVICES'] = '0' 213 | main() 214 | -------------------------------------------------------------------------------- /runTesting_Reconv2.py: -------------------------------------------------------------------------------- 1 | # from __future__ import print_function 2 | import argparse, os 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from torch.autograd import Variable 6 | import numpy as np 7 | import torch.optim as optim 8 | import torch 9 | import torch.utils.data as data_utils 10 | from Unet2d_pytorch import UNet, ResUNet, UNet_LRes, ResUNet_LRes, Discriminator 11 | from utils import * 12 | # from ganComponents import * 13 | # from nnBuildUnits import CrossEntropy2d 14 | # from nnBuildUnits import computeSampleAttentionWeight 15 | # from nnBuildUnits import adjust_learning_rate 16 | import time 17 | # from dataClean import denoiseImg,denoiseImg_isolation,denoiseImg_closing 18 | import SimpleITK as sitk 19 | 20 | parser = argparse.ArgumentParser(description="PyTorch InfantSeg") 21 | 22 | parser.add_argument("--gpuID", type=int, default=1, help="how to normalize the data") 23 | parser.add_argument("--isSegReg", action="store_true", help="is Seg and Reg?", default=False) 24 | parser.add_argument("--isMultiSource", action="store_true", help="is multiple input modality used?", default=False) 25 | parser.add_argument("--whichLoss", type=int, default=1, help="which loss to use: 1. LossL1, 2. lossRTL1, 3. MSE (default)") 26 | parser.add_argument("--whichNet", type=int, default=2, help="which loss to use: 1. UNet, 2. ResUNet, 3. UNet_LRes and 4. ResUNet_LRes (default, 3)") 27 | parser.add_argument("--lossBase", type=int, default=1, help="The base to multiply the lossG_G, Default (1)") 28 | parser.add_argument("--batchSize", type=int, default=32, help="training batch size") 29 | parser.add_argument("--numOfChannel_singleSource", type=int, default=5, help="# of channels for a 2D patch for the main modality (Default, 5)") 30 | parser.add_argument("--numOfChannel_allSource", type=int, default=5, help="# of channels for a 2D patch for all the concatenated modalities (Default, 5)") 31 | parser.add_argument("--isResidualEnhancement", action="store_true", help="is residual learning operation enhanced?", default=False) 32 | parser.add_argument("--isViewExpansion", action="store_true", help="is view expanded?", default=True) 33 | parser.add_argument("--isAdLoss", action="store_true", help="is adversarial loss used?", default=True) 34 | parser.add_argument("--isSpatialDropOut", action="store_true", help="is spatial dropout used?", default=False) 35 | parser.add_argument("--isFocalLoss", action="store_true", help="is focal loss used?", default=False) 36 | parser.add_argument("--isSampleImportanceFromAd", action="store_true", help="is sample importance from adversarial network used?", default=False) 37 | parser.add_argument("--dropoutRate", type=float, default=0.25, help="Spatial Dropout Rate. Default=0.25") 38 | parser.add_argument("--lambdaAD", type=float, default=0, help="loss coefficient for AD loss. Default=0") 39 | parser.add_argument("--adImportance", type=float, default=0, help="Sample importance from AD network. Default=0") 40 | parser.add_argument("--isFixedRegions", action="store_true", help="Is the organ regions roughly known?", default=False) 41 | #parser.add_argument("--modelPath", default="/home/niedong/Data4LowDosePET/pytorch_UNet/model/resunet2d_pet_Aug_noNorm_lres_bn_lr5e3_base1_lossL1_0p01_0624_200000.pt", type=str, help="prefix of the to-be-saved model name") 42 | # parser.add_argument("--modelPath", default="/shenlab/lab_stor5/dongnie/brain_mr2ct/modelFiles/resunet2d_dp_brain_BatchAug_sNorm_lres_bn_lr5e3_lrnetD5e3_lrdec_base1_wgan_gp_1107_140000.pt", type=str, help="prefix of the to-be-saved model name") 43 | parser.add_argument("--modelPath", default="/shenlab/lab_stor/dongnie/brats2018/modelFiles/resunet2d_dp_brats_BatchAug_sNorm_bn_lr5e3_lrnetD5e3_lrdec0p5_lrDdec0p05_wgan_gp_1112_200000.pt", type=str, help="prefix of the to-be-saved model name") 44 | # parser.add_argument("--prefixPredictedFN", default="/shenlab/lab_stor5/dongnie/brain_mr2ct/res/testResult/predCT_brain_resunet2d_dp_Aug_sNorm_lres_lrdce_bn_lr5e3_lossL1_1107_14w_", type=str, help="prefix of the to-be-saved predicted filename") 45 | parser.add_argument("--prefixPredictedFN", default="/shenlab/lab_stor/dongnie/brats2018/res/testResult/predBrats_v1_resunet2d_dp_Aug_sNorm_lres_lrdce_bn_lr5e3_lossL1_1112_20w_", type=str, help="prefix of the to-be-saved predicted filename") 46 | parser.add_argument("--how2normalize", type=int, default=6, help="how to normalize the data") 47 | parser.add_argument("--resType", type=int, default=1, help="resType: 0: segmentation map (integer); 1: regression map (continuous); 2: segmentation map + probability map") 48 | 49 | global opt 50 | opt = parser.parse_args() 51 | 52 | 53 | def main(): 54 | print opt 55 | 56 | path_test = '/home/niedong/Data4LowDosePET/data_niigz_scale/' 57 | path_test = '/shenlab/lab_stor5/dongnie/brain_mr2ct/original_data/' 58 | path_test = '/shenlab/lab_stor/dongnie/brats2018/TrainData/HGG/Brats18_2013_11_1' 59 | 60 | if opt.whichNet==1: 61 | netG = UNet(in_channel=opt.numOfChannel_allSource, n_classes=1) 62 | elif opt.whichNet==2: 63 | netG = ResUNet(in_channel=opt.numOfChannel_allSource, n_classes=1) 64 | elif opt.whichNet==3: 65 | netG = UNet_LRes(in_channel=opt.numOfChannel_allSource, n_classes=1) 66 | elif opt.whichNet==4: 67 | netG = ResUNet_LRes(in_channel=opt.numOfChannel_allSource, n_classes=1) 68 | 69 | #netG.apply(weights_init) 70 | netG.cuda() 71 | 72 | checkpoint = torch.load(opt.modelPath) 73 | netG.load_state_dict(checkpoint['model']) 74 | 75 | 76 | ids = [1,6,11,16,21,26,31,36,41,46] #in on folder, we test 10 which is the testing set 77 | ids = [1] #in on folder, we test 10 which is the testing set 78 | 79 | ids = ['1_QFZ','2_LLQ','3_LMB','4_ZSL','5_CJB','11_TCL','15_WYL','21_PY','25_LYL','31_CZX','35_WLL','41_WQC','45_YXM'] 80 | ids = [2,3,4,5,8,9,10,13] 81 | ids = ['Brats18_2013_11_1'] 82 | for ind in ids: 83 | start = time.time() 84 | 85 | # mr_test_itk = sitk.ReadImage(os.path.join(path_test,'sub%d_mr.hdr'%ind))#input modality 86 | # ct_test_itk = sitk.ReadImage(os.path.join(path_test,'sub%d_ct.hdr'%ind))#auxialliary modality 87 | # hpet_test_itk = sitk.ReadImage(os.path.join(path_test, 'sub%d_ct.hdr'%ind))#output modality 88 | 89 | mr_test_itk = sitk.ReadImage(os.path.join(path_test, 'Brats18_2013_11_1_t1ce.nii.gz')) 90 | ct_test_itk = sitk.ReadImage(os.path.join(path_test, 'Brats18_2013_11_1_t2.nii.gz')) 91 | 92 | spacing = mr_test_itk.GetSpacing() 93 | origin = mr_test_itk.GetOrigin() 94 | direction = mr_test_itk.GetDirection() 95 | 96 | mrnp = sitk.GetArrayFromImage(mr_test_itk) 97 | ctnp = sitk.GetArrayFromImage(ct_test_itk) 98 | # hpetnp = sitk.GetArrayFromImage(hpet_test_itk) 99 | 100 | if opt.isMultiSource: 101 | hpet_test_itk = sitk.ReadImage(os.path.join(path_test, '%s_120s_suv.nii.gz' % ind)) 102 | hpetnp = sitk.GetArrayFromImage(hpet_test_itk) 103 | else: 104 | hpetnp = ctnp 105 | 106 | ##### specific normalization ##### 107 | mu = np.mean(mrnp) 108 | # maxV, minV = np.percentile(mrnp, [99 ,25]) 109 | # #mrimg=mrimg 110 | # mrnp = (mrnp-minV)/(maxV-minV) 111 | 112 | # for training data in pelvicSeg 113 | if opt.how2normalize == 1: 114 | maxV, minV = np.percentile(mrnp, [99, 1]) 115 | print 'maxV,', maxV, ' minV, ', minV 116 | mrnp = (mrnp - mu) / (maxV - minV) 117 | print 'unique value: ', np.unique(ctnp) 118 | 119 | # for training data in pelvicSeg 120 | if opt.how2normalize == 2: 121 | maxV, minV = np.percentile(mrnp, [99, 1]) 122 | print 'maxV,', maxV, ' minV, ', minV 123 | mrnp = (mrnp - mu) / (maxV - minV) 124 | print 'unique value: ', np.unique(ctnp) 125 | 126 | # for training data in pelvicSegRegH5 127 | if opt.how2normalize == 3: 128 | std = np.std(mrnp) 129 | mrnp = (mrnp - mu) / std 130 | print 'maxV,', np.ndarray.max(mrnp), ' minV, ', np.ndarray.min(mrnp) 131 | 132 | if opt.how2normalize == 4: 133 | maxLPET = 149.366742 134 | maxPercentLPET = 7.76 135 | minLPET = 0.00055037 136 | meanLPET = 0.27593288 137 | stdLPET = 0.75747500 138 | 139 | # for rsCT 140 | maxCT = 27279 141 | maxPercentCT = 1320 142 | minCT = -1023 143 | meanCT = -601.1929 144 | stdCT = 475.034 145 | 146 | # for s-pet 147 | maxSPET = 156.675962 148 | maxPercentSPET = 7.79 149 | minSPET = 0.00055037 150 | meanSPET = 0.284224789 151 | stdSPET = 0.7642257 152 | 153 | # matLPET = (mrnp - meanLPET) / (stdLPET) 154 | matLPET = (mrnp - minLPET) / (maxPercentLPET - minLPET) 155 | matCT = (ctnp - meanCT) / stdCT 156 | matSPET = (hpetnp - minSPET) / (maxPercentSPET - minSPET) 157 | 158 | if opt.how2normalize == 5: 159 | # for rsCT 160 | maxCT = 27279 161 | maxPercentCT = 1320 162 | minCT = -1023 163 | meanCT = -601.1929 164 | stdCT = 475.034 165 | 166 | print 'ct, max: ', np.amax(ctnp), ' ct, min: ', np.amin(ctnp) 167 | 168 | # matLPET = (mrnp - meanLPET) / (stdLPET) 169 | matLPET = mrnp 170 | matCT = (ctnp - meanCT) / stdCT 171 | matSPET = hpetnp 172 | 173 | if opt.how2normalize == 6: 174 | maxPercentPET, minPercentPET = np.percentile(mrnp, [99.5, 0]) 175 | maxPercentCT, minPercentCT = np.percentile(ctnp, [99.5, 0]) 176 | print 'maxPercentPET: ', maxPercentPET, ' minPercentPET: ', minPercentPET, ' maxPercentCT: ', maxPercentCT, 'minPercentCT: ', minPercentCT 177 | 178 | matLPET = (mrnp - minPercentPET) / (maxPercentPET - minPercentPET) 179 | matCT = (ctnp - minPercentCT) / (maxPercentCT - minPercentCT) 180 | if opt.isMultiSource: 181 | matSPET = (hpetnp - minPercentPET) / (maxPercentPET - minPercentPET) 182 | 183 | 184 | 185 | if not opt.isMultiSource: 186 | matFA = matLPET 187 | matGT = hpetnp 188 | 189 | print 'matFA shape: ', matFA.shape, ' matGT shape: ', matGT.shape,' max(matFA): ',np.amax(matFA),' min(matFA): ',np.amin(matFA) 190 | # matOut = testOneSubject_aver_res(matFA, matGT, [5, 64, 64], [1, 64, 64], [1, 16, 16], netG, opt.modelPath) 191 | if opt.whichNet == 3 or opt.whichNet == 4: 192 | matOut = testOneSubject_aver_res(matFA, matGT, [5, 64, 64], [1, 64, 64], [1, 32, 32], netG, 193 | opt.modelPath) 194 | else: 195 | matOut = testOneSubject_aver(matFA, matGT, [5, 64, 64], [1, 64, 64], [1, 32, 32], netG, 196 | opt.modelPath) 197 | print 'matOut shape: ', matOut.shape, ' max(matOut): ',np.amax(matOut),' min(matOut): ',np.amin(matOut) 198 | if opt.how2normalize == 6: 199 | # ct_estimated = matOut * (maxPercentPET - minPercentPET) + minPercentPET 200 | ct_estimated = matOut * (maxPercentCT - minPercentCT) + minPercentCT 201 | else: 202 | ct_estimated = matOut 203 | #ct_estimated[np.where(mrnp==0)] = 0 204 | itspsnr = psnr(ct_estimated, matGT) 205 | 206 | print 'pred: ', ct_estimated.dtype, ' shape: ', ct_estimated.shape 207 | print 'gt: ', ctnp.dtype, ' shape: ', matGT.shape 208 | print 'psnr = ', itspsnr 209 | volout = sitk.GetImageFromArray(ct_estimated) 210 | volout.SetSpacing(spacing) 211 | volout.SetOrigin(origin) 212 | volout.SetDirection(direction) 213 | sitk.WriteImage(volout, opt.prefixPredictedFN + '{}'.format(ind) + '.nii.gz') 214 | else: 215 | matFA = matLPET 216 | matGT = hpetnp 217 | print 'matFA shape: ', matFA.shape, ' matGT shape: ', matGT.shape 218 | # matOut = testOneSubject_aver_res_multiModal(matFA, matCT, matGT, [5, 64, 64], [1, 64, 64], [1, 16, 16], netG, opt.modelPath) 219 | if opt.whichNet == 3 or opt.whichNet == 4: 220 | matOut = testOneSubject_aver_res_multiModal(matFA, matCT, matGT, [5, 64, 64], [1, 64, 64], [1, 32, 32], netG, 221 | opt.modelPath) 222 | else: 223 | matOut = testOneSubject_aver_MultiModal(matFA, matCT, matGT, [5, 64, 64], [1, 64, 64], [1, 32, 32], netG, 224 | opt.modelPath) 225 | print 'matOut shape: ', matOut.shape 226 | if opt.how2normalize == 6: 227 | ct_estimated = matOut * (maxPercentPET - minPercentPET) + minPercentPET 228 | else: 229 | ct_estimated = matOut 230 | 231 | #ct_estimated[np.where(mrnp==0)] = 0 232 | itspsnr = psnr(ct_estimated, matGT) 233 | 234 | print 'pred: ', ct_estimated.dtype, ' shape: ', ct_estimated.shape 235 | print 'gt: ', ctnp.dtype, ' shape: ', ct_estimated.shape 236 | print 'psnr = ', itspsnr 237 | volout = sitk.GetImageFromArray(ct_estimated) 238 | volout.SetSpacing(spacing) 239 | volout.SetOrigin(origin) 240 | volout.SetDirection(direction) 241 | sitk.WriteImage(volout, opt.prefixPredictedFN + '{}'.format(ind) + '.nii.gz') 242 | 243 | if __name__ == '__main__': 244 | # testGradients() 245 | os.environ['CUDA_VISIBLE_DEVICES'] = str(opt.gpuID) 246 | main() 247 | -------------------------------------------------------------------------------- /shuffleDataAmongSubjects_2d.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import h5py 3 | import os 4 | 5 | ''' 6 | Shuffle data (patches) among the subjects 7 | input: 8 | save_dir: the h5 files you save 9 | num: the num of components in the h5 file 10 | output: 11 | save as the same name h5 files 12 | ''' 13 | 14 | 15 | def shuffleDataAmongSubjects(save_dir, savepath): 16 | # allfilenames = os.listdir(save_dir) 17 | # allfilenames = filter(lambda x: '.h5' in x and 'train' in x, allfilenames) 18 | 19 | nn = 200000 20 | # dataMR = np.zeros([nn, 1, 5, 64, 64], dtype=np.float16) 21 | dataLPET = np.zeros([nn,1, 5, 64, 64], dtype=np.float16) 22 | dataCT = np.zeros([nn,1, 5, 64, 64], dtype=np.float16) 23 | dataHPET = np.zeros([nn, 1, 1, 64, 64], dtype=np.float16) 24 | 25 | allfilenames = os.listdir(save_dir) 26 | # print allfilenames 27 | allfilenames = filter(lambda x: '.h5' in x and 'train' in x, allfilenames) 28 | # print allfilenames 29 | cnt = 0 30 | numInOneSub = 5 31 | batchID = 0 32 | startInd = 0 33 | savefilename = 'train5x64x64_' 34 | for i_file, filename in enumerate(allfilenames): 35 | 36 | with h5py.File(os.path.join(save_dir, filename), 'r+') as h5f: 37 | print '*******path is ', os.path.join(save_dir, filename) 38 | dLPET = h5f['dataLPET'][:] 39 | dCT = h5f['dataCT'][:] 40 | dHPET = h5f['dataHPET'][:] 41 | 42 | unitNum = dLPET.shape[0] 43 | print 'unitNum: ', unitNum, 'dLPET shape: ', dLPET.shape 44 | 45 | dataLPET[startInd: (startInd + unitNum), ...] = dLPET 46 | dataCT[startInd: startInd + unitNum, ...] = dCT 47 | dataHPET[startInd: startInd + unitNum, ...] = dHPET 48 | 49 | startInd = startInd + unitNum 50 | 51 | cnt = cnt + 1 52 | 53 | if cnt == numInOneSub: 54 | batchID = batchID + 1 55 | dataLPET = dataLPET[0:startInd, ...] 56 | dataCT = dataCT[0:startInd, ...] 57 | dataHPET = dataHPET[0:startInd, ...] 58 | 59 | with h5py.File(os.path.join(savepath, savefilename + '{}.h5'.format(batchID)), 'w') as hf: 60 | hf.create_dataset('dataLPET', data=dataLPET) 61 | hf.create_dataset('dataCT', data=dataCT) 62 | hf.create_dataset('dataHPET', data=dataHPET) 63 | 64 | ############ initialization ############### 65 | cnt = 0 66 | startInd = 0 67 | print 68 | 'nn:', nn 69 | dataLPET = np.zeros([nn, 1, 5, 64, 64], dtype=np.float16) 70 | dataCT = np.zeros([nn, 1, 5, 64, 64], dtype=np.float16) 71 | dataHPET = np.zeros([nn, 1, 1, 64, 64], dtype=np.float16) 72 | 73 | # mean_train, std_train = 0., 0 74 | batchID = batchID + 1 75 | if startInd != 0: 76 | dataLPET = dataLPET[0:startInd, ...] 77 | dataCT = dataCT[0:startInd, ...] 78 | dataHPET = dataHPET[0:startInd, ...] 79 | 80 | with h5py.File(os.path.join(savepath, savefilename + '{}.h5'.format(batchID)), 'w') as hf: 81 | hf.create_dataset('dataLPET', data=dataLPET) 82 | hf.create_dataset('dataCT', data=dataCT) 83 | hf.create_dataset('dataHPET', data=dataHPET) 84 | 85 | return 86 | 87 | 88 | def main(): 89 | path = '/home/niedong/Data4LowDosePET/h5DataAug_noNorm/train2D_H5/' 90 | savepath = '/home/niedong/Data4LowDosePET/h5DataAug_noNorm/trainBatch2D_H5/' 91 | basePath = '/home/niedong/Data4LowDosePET/h5DataAug_noNorm/' 92 | path = basePath + 'train2D_H5/' 93 | savepath = basePath + 'trainBatch2D_H5/' 94 | shuffleDataAmongSubjects(path, savepath) 95 | 96 | 97 | if __name__ == "__main__": 98 | os.environ['CUDA_VISIBLE_DEVICES'] = '0' 99 | main() 100 | -------------------------------------------------------------------------------- /shuffleDataAmongSubjects_3d.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import h5py 3 | import os 4 | 5 | ''' 6 | Shuffle data (patches) among the subjects 7 | input: 8 | save_dir: the h5 files you save 9 | num: the num of components in the h5 file 10 | output: 11 | save as the same name h5 files 12 | ''' 13 | 14 | # dFA = [3,168,112] # size of patches of input data 15 | # dSeg = [1,168,112] # size of pathes of label data 16 | 17 | dFA = [16, 64, 64] # size of patches of input data 18 | dSeg = [16, 64, 64] # size of pathes of label data 19 | 20 | 21 | def shuffleDataAmongSubjects(save_dir, savepath): 22 | # allfilenames = os.listdir(save_dir) 23 | # allfilenames = filter(lambda x: '.h5' in x and 'train' in x, allfilenames) 24 | 25 | nn = 10000 26 | dataLPET = np.zeros([nn,1, 16, 64, 64], dtype=np.float16) 27 | dataCT = np.zeros([nn,1, 16, 64, 64], dtype=np.float16) 28 | dataHPET = np.zeros([nn, 1, 16, 64, 64], dtype=np.float16) 29 | # 30 | 31 | allfilenames = os.listdir(save_dir) 32 | # print allfilenames 33 | allfilenames = filter(lambda x: '.h5' in x and 'train' in x, allfilenames) 34 | # print allfilenames 35 | cnt = 0 36 | numInOneSub = 5 37 | batchID = 0 38 | startInd = 0 39 | 40 | savefilename = 'train5x64x64_' 41 | for i_file, filename in enumerate(allfilenames): 42 | 43 | with h5py.File(os.path.join(save_dir, filename), 'r+') as h5f: 44 | print '*******path is ', os.path.join(save_dir, filename) 45 | dLPET = h5f['dataLPET'][:] 46 | dCT = h5f['dataCT'][:] 47 | dHPET = h5f['dataHPET'][:] 48 | 49 | unitNum = dLPET.shape[0] 50 | print 'unitNum: ', unitNum, 'dLPET shape: ', dLPET.shape 51 | 52 | dataLPET[startInd: (startInd + unitNum), ...] = dLPET 53 | dataCT[startInd: startInd + unitNum, ...] = dCT 54 | dataHPET[startInd: startInd + unitNum, ...] = dHPET 55 | 56 | startInd = startInd + unitNum 57 | 58 | cnt = cnt + 1 59 | 60 | if cnt == numInOneSub: 61 | batchID = batchID + 1 62 | dataLPET = dataLPET[0:startInd, ...] 63 | dataCT = dataCT[0:startInd, ...] 64 | dataHPET = dataHPET[0:startInd, ...] 65 | 66 | with h5py.File(os.path.join(savepath, savefilename + '{}.h5'.format(batchID)), 'w') as hf: 67 | hf.create_dataset('dataLPET', data=dataLPET) 68 | hf.create_dataset('dataCT', data=dataCT) 69 | hf.create_dataset('dataHPET', data=dataHPET) 70 | 71 | ############ initialization ############### 72 | cnt = 0 73 | startInd = 0 74 | print 75 | 'nn:', nn 76 | dataLPET = np.zeros([nn, 1, 16, 64, 64], dtype=np.float16) 77 | dataCT = np.zeros([nn, 1, 16, 64, 64], dtype=np.float16) 78 | dataHPET = np.zeros([nn, 16, 1, 64, 64], dtype=np.float16) 79 | 80 | # mean_train, std_train = 0., 0 81 | batchID = batchID + 1 82 | if startInd != 0: 83 | dataLPET = dataLPET[0:startInd, ...] 84 | dataCT = dataCT[0:startInd, ...] 85 | dataHPET = dataHPET[0:startInd, ...] 86 | 87 | with h5py.File(os.path.join(savepath, savefilename + '{}.h5'.format(batchID)), 'w') as hf: 88 | hf.create_dataset('dataLPET', data=dataLPET) 89 | hf.create_dataset('dataCT', data=dataCT) 90 | hf.create_dataset('dataHPET', data=dataHPET) 91 | 92 | return 93 | 94 | 95 | def main(): 96 | path = '/home/niedong/Data4LowDosePET/h5Data3D_noNorm/train3D_H5/' 97 | savepath = '/home/niedong/Data4LowDosePET/h5DataAug_noNorm/trainBatch3D_H5/' 98 | basePath = '/home/niedong/Data4LowDosePET/h5DataAug_noNorm/' 99 | # path = basePath + 'train2D_H5/' 100 | # savepath = basePath + 'trainBatch2D_H5/' 101 | shuffleDataAmongSubjects(path, savepath) 102 | 103 | 104 | if __name__ == "__main__": 105 | os.environ['CUDA_VISIBLE_DEVICES'] = '0' 106 | main() 107 | --------------------------------------------------------------------------------