├── LICENSE
├── README.md
├── ResUnet3d_pytorch.py
├── Unet2d_pytorch.py
├── Unet3d_pytorch.py
├── compute3DSSIM.py
├── dicom2Nii.py
├── extract23DPatch4MultiModalImg.py
├── extract23DPatch4SingleModalImg.py
├── loss_functions.py
├── nnBuildUnits.py
├── runCTRecon.py
├── runCTRecon3d.py
├── runTesting_Recon.py
├── runTesting_Reconv2.py
├── shuffleDataAmongSubjects_2d.py
├── shuffleDataAmongSubjects_3d.py
└── utils.py
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2018 Dong Nie
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # medSynthesisV1
2 | This is a copy of package for medical image synthesis work with LRes-ResUnet and GAN (wgan-gp) in pytorch framework, which is a simple extension of our paper Medical Image Synthesis with Deep Convolutional Adversarial Networks. You are also welcome to visit our Tensorflow version through this link:
3 | https://github.com/ginobilinie/medSynthesis
4 |
5 | # How to run the pytorch code
6 | The main entrance for the code is runCTRecon.py or runCTRecon3d.py (currently, the 2d/2.5d version is fine to run, and the discriminator for 3d version currently only support BCE loss since I suggest you use W-distance (WGAN-GP) since it is easier to tune the hyper-parameters for this one).
7 |
8 | I suppose you have installed:
9 | python 2.x (e.g., 2.7.x; for python 3.x, change some codes: .next() to .\__next\__(); xrange()->range())
10 |
pytorch (>=0.3.0)
11 |
simpleITK
12 |
numpy
13 |
14 | Steps to run the code:
15 | 1. use extract23DPatch4MultiModalImg.py (or extract23DPatch4SingleModalImg.py for single input modality) to extract patches for training and validation images (as limited annotated data can be acquired in medical image fields, we usually use patch as the training unit), and save as hdf5 format. Put all these h5 files into two folders (training, validation), and remeber the path to these h5 files
16 | 2. choose the generator (1. UNet, 2. ResUNet, 3. UNet_LRes and 4. ResUNet_LRes (default, 4)
17 | Note: for low-dose xx to standard-dose xx (such as low-dose pet to standard pet, low CT to High CT...) or low resolution xx to high resolution xx(e.g., 3T->7T), we suggest use ResUNet_LRes(4) which contains a long-skip connection.
18 |
If the input modality and the output modality is quite different, we suggest use UNet_LRes(3))
19 | 3. choose the discriminator if you want to use the GAN-framework (we provide wgan-gp and the basic GAN)
20 | 4. choose the loss function (1. LossL1, 2. lossRTL1, 3. MSE (default))
21 | 5. set up the hyper-parameters in the runCTRecon.py (or 3d with runCTRecon3d.py)
22 | You have to place the paths to the training h5 files (path_patients_h5), the validation h5 files (path_patients_h5_test) and also the path to the testing images (path_test ) in the this python file
23 | Also, you have to setup all other config choices, such as network choice, disciminator choise, loss functions (including some additional loss, i.e., gradient difference loss), initial learing rate, decrease learning rate during training even with adam optimal solver and so on
24 | 6. run the code: python runCTRecon.py (or 3d with runCTRecon3d.py) for training stage
25 | 7. run the code: python runTesting_Reconv2.py for testing stage
26 |
27 | If it is helpful to your work, please cite the papers:
28 | # Cite
29 |
30 | @inproceedings{nie2017medical,
31 | title={Medical image synthesis with context-aware generative adversarial networks},
32 | author={Nie, Dong and Trullo, Roger and Lian, Jun and Petitjean, Caroline and Ruan, Su and Wang, Qian and Shen, Dinggang},
33 | booktitle={International Conference on Medical Image Computing and Computer-Assisted Intervention},
34 | pages={417--425},
35 | year={2017},
36 | organization={Springer}
37 | }
38 | @article{nie2018medical,
39 | title={Medical Image Synthesis with Deep Convolutional Adversarial Networks},
40 | author={Nie, Dong and Trullo, Roger and Lian, Jun and Wang, Li and Petitjean, Caroline and Ruan, Su and Wang, Qian and Shen, Dinggang},
41 | journal={IEEE Transactions on Biomedical Engineering},
42 | year={2018},
43 | publisher={IEEE}
44 | }
45 |
46 |
47 | # Dataset
48 | BTW, you can download a real medical image synthesis dataset for reconstructing standard-dose PET from low-dose PET via this link: https://www.aapm.org/GrandChallenge/LowDoseCT/
49 |
50 | Also, there are some MRI synthesis datasets available:
51 | http://brain-development.org/ixi-dataset/
52 |
53 | Tumor prediction:
54 | https://www.med.upenn.edu/sbia/brats2018/data.html
55 |
56 | fastMRI:
57 | https://fastmri.med.nyu.edu/
58 |
59 | ISLES2015:
60 | http://www.isles-challenge.org/ISLES2015/
61 |
62 | # Upload your brain MRI, Predict corresponding CT
63 |
64 | If you're interested in it, you can send me a copy of your data (for example, brain MRI), and I'll inference the CT and send a copy of predicted CT to you. My email is dongnie.at.cs.unc.edu.
65 |
66 | # A parallel training together with an adversarial confidence learning version will uploaded soon.
67 |
68 | # License
69 | medSynthesis is released under the MIT License (refer to the LICENSE file for details).
70 |
--------------------------------------------------------------------------------
/ResUnet3d_pytorch.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | from torch.autograd import Variable
5 | import torch.nn.init as init
6 | import numpy as np
7 |
8 | '''
9 | Ordinary UNet Conv Block
10 | '''
11 | class UNetConvBlock(nn.Module):
12 | def __init__(self, in_size, out_size, kernel_size=3, activation=F.relu):
13 | super(UNetConvBlock, self).__init__()
14 | self.conv = nn.Conv3d(in_size, out_size, kernel_size, stride=1, padding=1)
15 | self.bn = nn.BatchNorm3d(out_size)
16 | self.conv2 = nn.Conv3d(out_size, out_size, kernel_size, stride=1, padding=1)
17 | self.bn2 = nn.BatchNorm3d(out_size)
18 | self.activation = activation
19 |
20 |
21 | init.xavier_uniform(self.conv.weight, gain = np.sqrt(2.0))
22 | init.constant(self.conv.bias,0)
23 | init.xavier_uniform(self.conv2.weight, gain = np.sqrt(2.0))
24 | init.constant(self.conv2.bias,0)
25 | def forward(self, x):
26 | out = self.activation(self.bn(self.conv(x)))
27 | out = self.activation(self.bn2(self.conv2(out)))
28 |
29 | return out
30 |
31 |
32 | '''
33 | two-layer residual unit: two conv with BN/relu and identity mapping
34 | '''
35 | class residualUnit(nn.Module):
36 | def __init__(self, in_size, out_size, kernel_size=3,stride=1, padding=1, activation=F.relu):
37 | super(residualUnit, self).__init__()
38 | self.conv1 = nn.Conv3d(in_size, out_size, kernel_size, stride=1, padding=1)
39 | init.xavier_uniform(self.conv1.weight, gain = np.sqrt(2.0)) #or gain=1
40 | init.constant(self.conv1.bias, 0)
41 | self.conv2 = nn.Conv3d(out_size, out_size, kernel_size, stride=1, padding=1)
42 | init.xavier_uniform(self.conv2.weight, gain = np.sqrt(2.0)) #or gain=1
43 | init.constant(self.conv2.bias, 0)
44 | self.activation = activation
45 | self.bn1 = nn.BatchNorm3d(out_size)
46 | self.bn2 = nn.BatchNorm3d(out_size)
47 | self.in_size = in_size
48 | self.out_size = out_size
49 | if in_size != out_size:
50 | self.convX = nn.Conv3d(in_size, out_size, kernel_size=1, stride=1, padding=0)
51 | self.bnX = nn.BatchNorm3d(out_size)
52 |
53 | def forward(self, x):
54 | out1 = self.activation(self.bn1(self.conv1(x)))
55 | out2 = self.activation(self.bn1(self.conv2(out1)))
56 | if self.in_size!=self.out_size:
57 | bridge = self.activation(self.bnX(self.convX(x)))
58 | output = torch.add(out2, bridge)
59 |
60 | return output
61 |
62 |
63 | '''
64 | Ordinary UNet-Up Conv Block
65 | '''
66 | class UNetUpBlock(nn.Module):
67 | def __init__(self, in_size, out_size, kernel_size=3, activation=F.relu, space_dropout=False):
68 | super(UNetUpBlock, self).__init__()
69 | self.up = nn.ConvTranspose3d(in_size, out_size, 2, stride=2)
70 | self.bnup = nn.BatchNorm3d(out_size)
71 | self.conv = nn.Conv3d(in_size, out_size, kernel_size, stride=1, padding=1)
72 | self.bn = nn.BatchNorm3d(out_size)
73 | self.conv2 = nn.Conv3d(out_size, out_size, kernel_size, stride=1, padding=1)
74 | self.bn2 = nn.BatchNorm3d(out_size)
75 | self.activation = activation
76 | init.xavier_uniform(self.up.weight, gain = np.sqrt(2.0))
77 | init.constant(self.up.bias,0)
78 | init.xavier_uniform(self.conv.weight, gain = np.sqrt(2.0))
79 | init.constant(self.conv.bias,0)
80 | init.xavier_uniform(self.conv2.weight, gain = np.sqrt(2.0))
81 | init.constant(self.conv2.bias,0)
82 |
83 | def center_crop(self, layer, target_size):
84 | batch_size, n_channels, layer_width, layer_height, layer_depth = layer.size()
85 | xy1 = (layer_width - target_size) // 2
86 | return layer[:, :, xy1:(xy1 + target_size), xy1:(xy1 + target_size)]
87 |
88 | def forward(self, x, bridge):
89 | up = self.up(x)
90 | up = self.activation(self.bnup(up))
91 | crop1 = self.center_crop(bridge, up.size()[2])
92 | out = torch.cat([up, crop1], 1)
93 |
94 | out = self.activation(self.bn(self.conv(out)))
95 | out = self.activation(self.bn2(self.conv2(out)))
96 |
97 | return out
98 |
99 |
100 |
101 | '''
102 | Ordinary Residual UNet-Up Conv Block
103 | '''
104 | class UNetUpResBlock(nn.Module):
105 | def __init__(self, in_size, out_size, kernel_size=3, activation=F.relu, space_dropout=False):
106 | super(UNetUpResBlock, self).__init__()
107 | self.up = nn.ConvTranspose3d(in_size, out_size, 2, stride=2)
108 | self.bnup = nn.BatchNorm3d(out_size)
109 |
110 | init.xavier_uniform(self.up.weight, gain = np.sqrt(2.0))
111 | init.constant(self.up.bias,0)
112 |
113 | self.activation = activation
114 |
115 | self.resUnit = residualUnit(in_size, out_size, kernel_size = kernel_size)
116 |
117 | def center_crop(self, layer, target_size):
118 | batch_size, n_channels, layer_width, layer_height, layer_depth = layer.size()
119 | xy1 = (layer_width - target_size) // 2
120 | return layer[:, :, xy1:(xy1 + target_size), xy1:(xy1 + target_size), xy1:(xy1 + target_size)]
121 |
122 | def forward(self, x, bridge):
123 | #print 'x.shape: ',x.shape
124 | up = self.activation(self.bnup(self.up(x)))
125 | #crop1 = self.center_crop(bridge, up.size()[2])
126 | #print 'up.shape: ',up.shape, ' crop1.shape: ',crop1.shape
127 | crop1 = bridge
128 | out = torch.cat([up, crop1], 1)
129 |
130 | out = self.resUnit(out)
131 | # out = self.activation(self.bn2(self.conv2(out)))
132 |
133 | return out
134 |
135 |
136 | '''
137 | Ordinary UNet
138 | '''
139 | class UNet(nn.Module):
140 | def __init__(self, in_channel = 1, n_classes = 4):
141 | super(UNet, self).__init__()
142 | # self.imsize = imsize
143 |
144 | self.activation = F.relu
145 |
146 | self.pool1 = nn.MaxPool3d(2)
147 | self.pool2 = nn.MaxPool3d(2)
148 | self.pool3 = nn.MaxPool3d(2)
149 | # self.pool4 = nn.MaxPool3d(2)
150 |
151 |
152 | self.conv_block1_64 = UNetConvBlock(in_channel, 32)
153 | self.conv_block64_128 = UNetConvBlock(32, 64)
154 | self.conv_block128_256 = UNetConvBlock(64, 128)
155 | self.conv_block256_512 = UNetConvBlock(128, 256)
156 | # self.conv_block512_1024 = UNetConvBlock(512, 1024)
157 | # this kind of symmetric design is awesome, it automatically solves the number of channels during upsamping
158 | # self.up_block1024_512 = UNetUpBlock(1024, 512)
159 | self.up_block512_256 = UNetUpBlock(256, 128)
160 | self.up_block256_128 = UNetUpBlock(128, 64)
161 | self.up_block128_64 = UNetUpBlock(64, 32)
162 |
163 | self.last = nn.Conv3d(32, n_classes, 1, stride=1)
164 |
165 |
166 | def forward(self, x):
167 | # print 'line 70 ',x.size()
168 | block1 = self.conv_block1_64(x)
169 | pool1 = self.pool1(block1)
170 |
171 | block2 = self.conv_block64_128(pool1)
172 | pool2 = self.pool2(block2)
173 |
174 | block3 = self.conv_block128_256(pool2)
175 | pool3 = self.pool3(block3)
176 |
177 | block4 = self.conv_block256_512(pool3)
178 | # pool4 = self.pool4(block4)
179 | #
180 | # block5 = self.conv_block512_1024(pool4)
181 | #
182 | # up1 = self.up_block1024_512(block5, block4)
183 |
184 | up2 = self.up_block512_256(block4, block3)
185 |
186 | up3 = self.up_block256_128(up2, block2)
187 |
188 | up4 = self.up_block128_64(up3, block1)
189 |
190 | return self.last(up4)
191 |
192 |
193 | '''
194 | Ordinary ResUNet
195 | '''
196 |
197 |
198 | class ResUNet(nn.Module):
199 | def __init__(self, in_channel=1, n_classes=4):
200 | super(ResUNet, self).__init__()
201 | # self.imsize = imsize
202 |
203 | self.activation = F.relu
204 |
205 | self.pool1 = nn.MaxPool3d(2)
206 | self.pool2 = nn.MaxPool3d(2)
207 | self.pool3 = nn.MaxPool3d(2)
208 | # self.pool4 = nn.MaxPool3d(2)
209 |
210 | self.conv_block1_64 = UNetConvBlock(in_channel, 32)
211 | self.conv_block64_128 = residualUnit(32, 64)
212 | self.conv_block128_256 = residualUnit(64, 128)
213 | self.conv_block256_512 = residualUnit(128, 256)
214 | # self.conv_block512_1024 = residualUnit(512, 1024)
215 | # this kind of symmetric design is awesome, it automatically solves the number of channels during upsamping
216 | # self.up_block1024_512 = UNetUpResBlock(1024, 512)
217 | self.up_block512_256 = UNetUpResBlock(256, 128)
218 | self.up_block256_128 = UNetUpResBlock(128, 64)
219 | self.up_block128_64 = UNetUpResBlock(64, 32)
220 |
221 | self.last = nn.Conv3d(32, n_classes, 1, stride=1)
222 |
223 | def forward(self, x):
224 | # print 'line 70 ',x.size()
225 | block1 = self.conv_block1_64(x)
226 | pool1 = self.pool1(block1)
227 |
228 | block2 = self.conv_block64_128(pool1)
229 | pool2 = self.pool2(block2)
230 |
231 | block3 = self.conv_block128_256(pool2)
232 | pool3 = self.pool3(block3)
233 |
234 | block4 = self.conv_block256_512(pool3)
235 | # pool4 = self.pool4(block4)
236 | #
237 | # block5 = self.conv_block512_1024(pool4)
238 | #
239 | # up1 = self.up_block1024_512(block5, block4)
240 |
241 | up2 = self.up_block512_256(block4, block3)
242 |
243 | up3 = self.up_block256_128(up2, block2)
244 |
245 | up4 = self.up_block128_64(up3, block1)
246 |
247 | return self.last(up4)
248 |
249 |
250 | '''
251 | UNet (lateral connection) with long-skip residual connection (from 1st to last layer)
252 | '''
253 | class UNet_LRes(nn.Module):
254 | def __init__(self, in_channel = 1, n_classes = 4):
255 | super(UNet_LRes, self).__init__()
256 | # self.imsize = imsize
257 |
258 | self.activation = F.relu
259 |
260 | self.pool1 = nn.MaxPool3d(2)
261 | self.pool2 = nn.MaxPool3d(2)
262 | self.pool3 = nn.MaxPool3d(2)
263 | # self.pool4 = nn.MaxPool3d(2)
264 |
265 | self.conv_block1_64 = UNetConvBlock(in_channel, 32)
266 | self.conv_block64_128 = UNetConvBlock(32, 64)
267 | self.conv_block128_256 = UNetConvBlock(64, 128)
268 | self.conv_block256_512 = UNetConvBlock(128, 256)
269 | # self.conv_block512_1024 = UNetConvBlock(512, 1024)
270 | # this kind of symmetric design is awesome, it automatically solves the number of channels during upsamping
271 | # self.up_block1024_512 = UNetUpBlock(1024, 512)
272 | self.up_block512_256 = UNetUpBlock(256, 128)
273 | self.up_block256_128 = UNetUpBlock(128, 64)
274 | self.up_block128_64 = UNetUpBlock(64, 32)
275 |
276 | self.last = nn.Conv3d(32, n_classes, 1, stride=1)
277 |
278 |
279 | def forward(self, x, res_x):
280 | # print 'line 70 ',x.size()
281 | block1 = self.conv_block1_64(x)
282 | pool1 = self.pool1(block1)
283 |
284 | block2 = self.conv_block64_128(pool1)
285 | pool2 = self.pool2(block2)
286 |
287 | block3 = self.conv_block128_256(pool2)
288 | pool3 = self.pool3(block3)
289 |
290 | block4 = self.conv_block256_512(pool3)
291 | # pool4 = self.pool4(block4)
292 |
293 | # block5 = self.conv_block512_1024(pool4)
294 | #
295 | # up1 = self.up_block1024_512(block5, block4)
296 |
297 | up2 = self.up_block512_256(block4, block3)
298 |
299 | up3 = self.up_block256_128(up2, block2)
300 |
301 | up4 = self.up_block128_64(up3, block1)
302 |
303 | last = self.last(up4)
304 | #print 'res_x.shape is ',res_x.shape,' and last.shape is ',last.shape
305 | if len(res_x.shape)==3:
306 | res_x = res_x.unsqueeze(1)
307 | out = torch.add(last,res_x)
308 |
309 | #print 'out.shape is ',out.shape
310 | return out
311 |
312 |
313 | '''
314 | ResUNet (lateral connection) with long-skip residual connection (from 1st to last layer)
315 | '''
316 |
317 |
318 | class ResUNet_LRes(nn.Module):
319 | def __init__(self, in_channel=1, n_classes=4, dp_prob=0):
320 | super(ResUNet_LRes, self).__init__()
321 | # self.imsize = imsize
322 |
323 | self.activation = F.relu
324 |
325 | self.pool1 = nn.MaxPool3d(2)
326 | self.pool2 = nn.MaxPool3d(2)
327 | self.pool3 = nn.MaxPool3d(2)
328 | # self.pool4 = nn.MaxPool3d(2)
329 |
330 | self.conv_block1_64 = UNetConvBlock(in_channel, 32)
331 | self.conv_block64_128 = residualUnit(32, 64)
332 | self.conv_block128_256 = residualUnit(64, 128)
333 | self.conv_block256_512 = residualUnit(128, 256)
334 | # self.conv_block512_1024 = residualUnit(512, 1024)
335 | # this kind of symmetric design is awesome, it automatically solves the number of channels during upsamping
336 | # self.up_block1024_512 = UNetUpResBlock(1024, 512)
337 | self.up_block512_256 = UNetUpResBlock(256, 128)
338 | self.up_block256_128 = UNetUpResBlock(128, 64)
339 | self.up_block128_64 = UNetUpResBlock(64, 32)
340 | self.Dropout = nn.Dropout3d(p=dp_prob)
341 | self.last = nn.Conv3d(32, n_classes, 1, stride=1)
342 |
343 | def forward(self, x, res_x):
344 | # print 'line 70 ',x.size()
345 | block1 = self.conv_block1_64(x)
346 | # print 'block1.shape: ', block1.shape
347 | pool1 = self.pool1(block1)
348 | # print 'pool1.shape: ', block1.shape
349 | pool1_dp = self.Dropout(pool1)
350 | # print 'pool1_dp.shape: ', pool1_dp.shape
351 | block2 = self.conv_block64_128(pool1_dp)
352 | pool2 = self.pool2(block2)
353 |
354 | pool2_dp = self.Dropout(pool2)
355 |
356 | block3 = self.conv_block128_256(pool2_dp)
357 | pool3 = self.pool3(block3)
358 |
359 | pool3_dp = self.Dropout(pool3)
360 |
361 | block4 = self.conv_block256_512(pool3_dp)
362 | # pool4 = self.pool4(block4)
363 | #
364 | # pool4_dp = self.Dropout(pool4)
365 | #
366 | # # block5 = self.conv_block512_1024(pool4_dp)
367 | #
368 | # up1 = self.up_block1024_512(block5, block4)
369 |
370 | up2 = self.up_block512_256(block4, block3)
371 |
372 | up3 = self.up_block256_128(up2, block2)
373 |
374 | up4 = self.up_block128_64(up3, block1)
375 |
376 | last = self.last(up4)
377 | # print 'res_x.shape is ',res_x.shape,' and last.shape is ',last.shape
378 | if len(res_x.shape) == 3:
379 | res_x = res_x.unsqueeze(1)
380 | out = torch.add(last, res_x)
381 |
382 | # print 'out.shape is ',out.shape
383 | return out
384 |
385 |
386 |
387 | '''
388 | Discriminator for the reconstruction project
389 | '''
390 | class Discriminator(nn.Module):
391 | def __init__(self):
392 | super(Discriminator,self).__init__()
393 | #you can make abbreviations for conv and fc, this is not necessary
394 | #class torch.nn.Conv2d(in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True)
395 | self.conv1 = nn.Conv3d(1,32,9)
396 | self.bn1 = nn.BatchNorm3d(32)
397 | self.conv2 = nn.Conv3d(32,64,5)
398 | self.bn2 = nn.BatchNorm3d(64)
399 | self.conv3 = nn.Conv3d(64,64,5)
400 | self.bn3 = nn.BatchNorm3d(64)
401 | self.fc1 = nn.Linear(64*4*4,512)
402 | #self.bn3= nn.BatchNorm1d(6)
403 | self.fc2 = nn.Linear(512,64)
404 | self.fc3 = nn.Linear(64,1)
405 |
406 |
407 | def forward(self,x):
408 | # print 'line 114: x shape: ',x.size()
409 | #x = F.max_pool3d(F.relu(self.bn1(self.conv1(x))),(2,2,2))#conv->relu->pool
410 | x = F.max_pool3d(F.relu(self.conv1(x)),(2,2,2))#conv->relu->pool
411 |
412 | x = F.max_pool3d(F.relu(self.conv2(x)),(2,2,2))#conv->relu->pool
413 |
414 | x = F.max_pool3d(F.relu(self.conv3(x)),(2,2,2))#conv->relu->pool
415 |
416 | #reshape them into Vector, review ruturned tensor shares the same data but have different shape, same as reshape in matlab
417 | x = x.view(-1,self.num_of_flat_features(x))
418 | x = F.relu(self.fc1(x))
419 |
420 | x = F.relu(self.fc2(x))
421 |
422 | x = self.fc3(x)
423 |
424 | #x = F.sigmoid(x)
425 | #print 'min,max,mean of x in 0st layer',x.min(),x.max(),x.mean()
426 | return x
427 |
428 | def num_of_flat_features(self,x):
429 | size=x.size()[1:]#we donot consider the batch dimension
430 | num_features=1
431 | for s in size:
432 | num_features*=s
433 | return num_features
434 |
--------------------------------------------------------------------------------
/Unet2d_pytorch.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | from torch.autograd import Variable
5 | import torch.nn.init as init
6 | import numpy as np
7 |
8 | '''
9 | Ordinary UNet Conv Block
10 | '''
11 | class UNetConvBlock(nn.Module):
12 | def __init__(self, in_size, out_size, kernel_size=3, activation=F.relu):
13 | super(UNetConvBlock, self).__init__()
14 | self.conv = nn.Conv2d(in_size, out_size, kernel_size, stride=1, padding=1)
15 | self.bn = nn.BatchNorm2d(out_size)
16 | self.conv2 = nn.Conv2d(out_size, out_size, kernel_size, stride=1, padding=1)
17 | self.bn2 = nn.BatchNorm2d(out_size)
18 | self.activation = activation
19 |
20 |
21 | init.xavier_uniform(self.conv.weight, gain = np.sqrt(2.0))
22 | init.constant(self.conv.bias,0)
23 | init.xavier_uniform(self.conv2.weight, gain = np.sqrt(2.0))
24 | init.constant(self.conv2.bias,0)
25 | def forward(self, x):
26 | out = self.activation(self.bn(self.conv(x)))
27 | out = self.activation(self.bn2(self.conv2(out)))
28 |
29 | return out
30 |
31 |
32 | '''
33 | two-layer residual unit: two conv with BN/relu and identity mapping
34 | '''
35 | class residualUnit(nn.Module):
36 | def __init__(self, in_size, out_size, kernel_size=3,stride=1, padding=1, activation=F.relu):
37 | super(residualUnit, self).__init__()
38 | self.conv1 = nn.Conv2d(in_size, out_size, kernel_size, stride=1, padding=1)
39 | init.xavier_uniform(self.conv1.weight, gain = np.sqrt(2.0)) #or gain=1
40 | init.constant(self.conv1.bias, 0)
41 | self.conv2 = nn.Conv2d(out_size, out_size, kernel_size, stride=1, padding=1)
42 | init.xavier_uniform(self.conv2.weight, gain = np.sqrt(2.0)) #or gain=1
43 | init.constant(self.conv2.bias, 0)
44 | self.activation = activation
45 | self.bn1 = nn.BatchNorm2d(out_size)
46 | self.bn2 = nn.BatchNorm2d(out_size)
47 | self.in_size = in_size
48 | self.out_size = out_size
49 | if in_size != out_size:
50 | self.convX = nn.Conv2d(in_size, out_size, kernel_size=1, stride=1, padding=0)
51 | self.bnX = nn.BatchNorm2d(out_size)
52 |
53 | def forward(self, x):
54 | out1 = self.activation(self.bn1(self.conv1(x)))
55 | out2 = self.activation(self.bn2(self.conv2(out1)))
56 | if self.in_size!=self.out_size:
57 | bridge = self.activation(self.bnX(self.convX(x)))
58 | output = torch.add(out2, bridge)
59 |
60 | return output
61 |
62 |
63 | '''
64 | Ordinary UNet-Up Conv Block
65 | '''
66 | class UNetUpBlock(nn.Module):
67 | def __init__(self, in_size, out_size, kernel_size=3, activation=F.relu, space_dropout=False):
68 | super(UNetUpBlock, self).__init__()
69 | self.up = nn.ConvTranspose2d(in_size, out_size, 2, stride=2)
70 | self.bnup = nn.BatchNorm2d(out_size)
71 | self.conv = nn.Conv2d(in_size, out_size, kernel_size, stride=1, padding=1)
72 | self.bn = nn.BatchNorm2d(out_size)
73 | self.conv2 = nn.Conv2d(out_size, out_size, kernel_size, stride=1, padding=1)
74 | self.bn2 = nn.BatchNorm2d(out_size)
75 | self.activation = activation
76 | init.xavier_uniform(self.up.weight, gain = np.sqrt(2.0))
77 | init.constant(self.up.bias,0)
78 | init.xavier_uniform(self.conv.weight, gain = np.sqrt(2.0))
79 | init.constant(self.conv.bias,0)
80 | init.xavier_uniform(self.conv2.weight, gain = np.sqrt(2.0))
81 | init.constant(self.conv2.bias,0)
82 |
83 | def center_crop(self, layer, target_size):
84 | batch_size, n_channels, layer_width, layer_height = layer.size()
85 | xy1 = (layer_width - target_size) // 2
86 | return layer[:, :, xy1:(xy1 + target_size), xy1:(xy1 + target_size)]
87 |
88 | def forward(self, x, bridge):
89 | up = self.up(x)
90 | up = self.activation(self.bnup(up))
91 | crop1 = self.center_crop(bridge, up.size()[2])
92 | out = torch.cat([up, crop1], 1)
93 |
94 | out = self.activation(self.bn(self.conv(out)))
95 | out = self.activation(self.bn2(self.conv2(out)))
96 |
97 | return out
98 |
99 |
100 |
101 | '''
102 | Ordinary Residual UNet-Up Conv Block
103 | '''
104 | class UNetUpResBlock(nn.Module):
105 | def __init__(self, in_size, out_size, kernel_size=3, activation=F.relu, space_dropout=False):
106 | super(UNetUpResBlock, self).__init__()
107 | self.up = nn.ConvTranspose2d(in_size, out_size, 2, stride=2)
108 | self.bnup = nn.BatchNorm2d(out_size)
109 |
110 | init.xavier_uniform(self.up.weight, gain = np.sqrt(2.0))
111 | init.constant(self.up.bias,0)
112 |
113 | self.activation = activation
114 |
115 | self.resUnit = residualUnit(in_size, out_size, kernel_size = kernel_size)
116 |
117 | def center_crop(self, layer, target_size):
118 | batch_size, n_channels, layer_width, layer_height = layer.size()
119 | xy1 = (layer_width - target_size) // 2
120 | return layer[:, :, xy1:(xy1 + target_size), xy1:(xy1 + target_size)]
121 |
122 | def forward(self, x, bridge):
123 | up = self.activation(self.bnup(self.up(x)))
124 | crop1 = self.center_crop(bridge, up.size()[2])
125 | out = torch.cat([up, crop1], 1)
126 |
127 | out = self.resUnit(out)
128 | # out = self.activation(self.bn2(self.conv2(out)))
129 |
130 | return out
131 |
132 |
133 | '''
134 | Ordinary UNet
135 | '''
136 | class UNet(nn.Module):
137 | def __init__(self, in_channel = 1, n_classes = 4):
138 | super(UNet, self).__init__()
139 | # self.imsize = imsize
140 |
141 | self.activation = F.relu
142 |
143 | self.pool1 = nn.MaxPool2d(2)
144 | self.pool2 = nn.MaxPool2d(2)
145 | self.pool3 = nn.MaxPool2d(2)
146 | self.pool4 = nn.MaxPool2d(2)
147 |
148 | self.conv_block1_64 = UNetConvBlock(in_channel, 64)
149 | self.conv_block64_128 = UNetConvBlock(64, 128)
150 | self.conv_block128_256 = UNetConvBlock(128, 256)
151 | self.conv_block256_512 = UNetConvBlock(256, 512)
152 | self.conv_block512_1024 = UNetConvBlock(512, 1024)
153 | # this kind of symmetric design is awesome, it automatically solves the number of channels during upsamping
154 | self.up_block1024_512 = UNetUpBlock(1024, 512)
155 | self.up_block512_256 = UNetUpBlock(512, 256)
156 | self.up_block256_128 = UNetUpBlock(256, 128)
157 | self.up_block128_64 = UNetUpBlock(128, 64)
158 |
159 | self.last = nn.Conv2d(64, n_classes, 1, stride=1)
160 |
161 |
162 | def forward(self, x):
163 | # print 'line 70 ',x.size()
164 | block1 = self.conv_block1_64(x)
165 | pool1 = self.pool1(block1)
166 |
167 | block2 = self.conv_block64_128(pool1)
168 | pool2 = self.pool2(block2)
169 |
170 | block3 = self.conv_block128_256(pool2)
171 | pool3 = self.pool3(block3)
172 |
173 | block4 = self.conv_block256_512(pool3)
174 | pool4 = self.pool4(block4)
175 |
176 | block5 = self.conv_block512_1024(pool4)
177 |
178 | up1 = self.up_block1024_512(block5, block4)
179 |
180 | up2 = self.up_block512_256(up1, block3)
181 |
182 | up3 = self.up_block256_128(up2, block2)
183 |
184 | up4 = self.up_block128_64(up3, block1)
185 |
186 | return self.last(up4)
187 |
188 |
189 | '''
190 | Ordinary ResUNet
191 | '''
192 |
193 |
194 | class ResUNet(nn.Module):
195 | def __init__(self, in_channel=1, n_classes=4):
196 | super(ResUNet, self).__init__()
197 | # self.imsize = imsize
198 |
199 | self.activation = F.relu
200 |
201 | self.pool1 = nn.MaxPool2d(2)
202 | self.pool2 = nn.MaxPool2d(2)
203 | self.pool3 = nn.MaxPool2d(2)
204 | self.pool4 = nn.MaxPool2d(2)
205 |
206 | self.conv_block1_64 = UNetConvBlock(in_channel, 64)
207 | self.conv_block64_128 = residualUnit(64, 128)
208 | self.conv_block128_256 = residualUnit(128, 256)
209 | self.conv_block256_512 = residualUnit(256, 512)
210 | self.conv_block512_1024 = residualUnit(512, 1024)
211 | # this kind of symmetric design is awesome, it automatically solves the number of channels during upsamping
212 | self.up_block1024_512 = UNetUpResBlock(1024, 512)
213 | self.up_block512_256 = UNetUpResBlock(512, 256)
214 | self.up_block256_128 = UNetUpResBlock(256, 128)
215 | self.up_block128_64 = UNetUpResBlock(128, 64)
216 |
217 | self.last = nn.Conv2d(64, n_classes, 1, stride=1)
218 |
219 | def forward(self, x):
220 | # print 'line 70 ',x.size()
221 | block1 = self.conv_block1_64(x)
222 | pool1 = self.pool1(block1)
223 |
224 | block2 = self.conv_block64_128(pool1)
225 | pool2 = self.pool2(block2)
226 |
227 | block3 = self.conv_block128_256(pool2)
228 | pool3 = self.pool3(block3)
229 |
230 | block4 = self.conv_block256_512(pool3)
231 | pool4 = self.pool4(block4)
232 |
233 | block5 = self.conv_block512_1024(pool4)
234 |
235 | up1 = self.up_block1024_512(block5, block4)
236 |
237 | up2 = self.up_block512_256(up1, block3)
238 |
239 | up3 = self.up_block256_128(up2, block2)
240 |
241 | up4 = self.up_block128_64(up3, block1)
242 |
243 | return self.last(up4)
244 |
245 |
246 | '''
247 | UNet (lateral connection) with long-skip residual connection (from 1st to last layer)
248 | '''
249 | class UNet_LRes(nn.Module):
250 | def __init__(self, in_channel = 1, n_classes = 4):
251 | super(UNet_LRes, self).__init__()
252 | # self.imsize = imsize
253 |
254 | self.activation = F.relu
255 |
256 | self.pool1 = nn.MaxPool2d(2)
257 | self.pool2 = nn.MaxPool2d(2)
258 | self.pool3 = nn.MaxPool2d(2)
259 | self.pool4 = nn.MaxPool2d(2)
260 |
261 | self.conv_block1_64 = UNetConvBlock(in_channel, 64)
262 | self.conv_block64_128 = UNetConvBlock(64, 128)
263 | self.conv_block128_256 = UNetConvBlock(128, 256)
264 | self.conv_block256_512 = UNetConvBlock(256, 512)
265 | self.conv_block512_1024 = UNetConvBlock(512, 1024)
266 | # this kind of symmetric design is awesome, it automatically solves the number of channels during upsamping
267 | self.up_block1024_512 = UNetUpBlock(1024, 512)
268 | self.up_block512_256 = UNetUpBlock(512, 256)
269 | self.up_block256_128 = UNetUpBlock(256, 128)
270 | self.up_block128_64 = UNetUpBlock(128, 64)
271 |
272 | self.last = nn.Conv2d(64, n_classes, 1, stride=1)
273 |
274 |
275 | def forward(self, x, res_x):
276 | # print 'line 70 ',x.size()
277 | block1 = self.conv_block1_64(x)
278 | pool1 = self.pool1(block1)
279 |
280 | block2 = self.conv_block64_128(pool1)
281 | pool2 = self.pool2(block2)
282 |
283 | block3 = self.conv_block128_256(pool2)
284 | pool3 = self.pool3(block3)
285 |
286 | block4 = self.conv_block256_512(pool3)
287 | pool4 = self.pool4(block4)
288 |
289 | block5 = self.conv_block512_1024(pool4)
290 |
291 | up1 = self.up_block1024_512(block5, block4)
292 |
293 | up2 = self.up_block512_256(up1, block3)
294 |
295 | up3 = self.up_block256_128(up2, block2)
296 |
297 | up4 = self.up_block128_64(up3, block1)
298 |
299 | last = self.last(up4)
300 | #print 'res_x.shape is ',res_x.shape,' and last.shape is ',last.shape
301 | if len(res_x.shape)==3:
302 | res_x = res_x.unsqueeze(1)
303 | out = torch.add(last,res_x)
304 |
305 | #print 'out.shape is ',out.shape
306 | return out
307 |
308 |
309 | '''
310 | ResUNet (lateral connection) with long-skip residual connection (from 1st to last layer)
311 | '''
312 |
313 |
314 | class ResUNet_LRes(nn.Module):
315 | def __init__(self, in_channel=1, n_classes=4, dp_prob=0):
316 | super(ResUNet_LRes, self).__init__()
317 | # self.imsize = imsize
318 |
319 | self.activation = F.relu
320 |
321 | self.pool1 = nn.MaxPool2d(2)
322 | self.pool2 = nn.MaxPool2d(2)
323 | self.pool3 = nn.MaxPool2d(2)
324 | self.pool4 = nn.MaxPool2d(2)
325 |
326 | self.conv_block1_64 = UNetConvBlock(in_channel, 64)
327 | self.conv_block64_128 = residualUnit(64, 128)
328 | self.conv_block128_256 = residualUnit(128, 256)
329 | self.conv_block256_512 = residualUnit(256, 512)
330 | self.conv_block512_1024 = residualUnit(512, 1024)
331 | # this kind of symmetric design is awesome, it automatically solves the number of channels during upsamping
332 | self.up_block1024_512 = UNetUpResBlock(1024, 512)
333 | self.up_block512_256 = UNetUpResBlock(512, 256)
334 | self.up_block256_128 = UNetUpResBlock(256, 128)
335 | self.up_block128_64 = UNetUpResBlock(128, 64)
336 | self.Dropout = nn.Dropout2d(p=dp_prob)
337 | self.last = nn.Conv2d(64, n_classes, 1, stride=1)
338 |
339 | def forward(self, x, res_x):
340 | # print 'line 70 ',x.size()
341 | block1 = self.conv_block1_64(x)
342 | pool1 = self.pool1(block1)
343 |
344 | pool1_dp = self.Dropout(pool1)
345 |
346 | block2 = self.conv_block64_128(pool1_dp)
347 | pool2 = self.pool2(block2)
348 |
349 | pool2_dp = self.Dropout(pool2)
350 |
351 | block3 = self.conv_block128_256(pool2_dp)
352 | pool3 = self.pool3(block3)
353 |
354 | pool3_dp = self.Dropout(pool3)
355 |
356 | block4 = self.conv_block256_512(pool3_dp)
357 | pool4 = self.pool4(block4)
358 |
359 | pool4_dp = self.Dropout(pool4)
360 |
361 | block5 = self.conv_block512_1024(pool4_dp)
362 |
363 | up1 = self.up_block1024_512(block5, block4)
364 |
365 | up2 = self.up_block512_256(up1, block3)
366 |
367 | up3 = self.up_block256_128(up2, block2)
368 |
369 | up4 = self.up_block128_64(up3, block1)
370 |
371 | last = self.last(up4)
372 | # print 'res_x.shape is ',res_x.shape,' and last.shape is ',last.shape
373 | if len(res_x.shape) == 3:
374 | res_x = res_x.unsqueeze(1)
375 | out = torch.add(last, res_x)
376 |
377 | # print 'out.shape is ',out.shape
378 | return out
379 |
380 |
381 |
382 | '''
383 | Discriminator for the reconstruction project
384 | '''
385 | class Discriminator(nn.Module):
386 | def __init__(self):
387 | super(Discriminator,self).__init__()
388 | #you can make abbreviations for conv and fc, this is not necessary
389 | #class torch.nn.Conv2d(in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True)
390 | self.conv1 = nn.Conv2d(1,32,(9,9))
391 | self.bn1 = nn.BatchNorm2d(32)
392 | self.conv2 = nn.Conv2d(32,64,(5,5))
393 | self.bn2 = nn.BatchNorm2d(64)
394 | self.conv3 = nn.Conv2d(64,64,(5,5))
395 | self.bn3 = nn.BatchNorm2d(64)
396 | self.fc1 = nn.Linear(64*4*4,512)
397 | #self.bn3= nn.BatchNorm1d(6)
398 | self.fc2 = nn.Linear(512,64)
399 | self.fc3 = nn.Linear(64,1)
400 |
401 |
402 | def forward(self,x):
403 | # print 'line 114: x shape: ',x.size()
404 | #x = F.max_pool2d(F.relu(self.bn1(self.conv1(x))),(2,2))#conv->relu->pool
405 | x = F.max_pool2d(F.relu(self.conv1(x)),(2,2))#conv->relu->pool
406 |
407 | x = F.max_pool2d(F.relu(self.conv2(x)),(2,2))#conv->relu->pool
408 |
409 | x = F.max_pool2d(F.relu(self.conv3(x)),(2,2))#conv->relu->pool
410 |
411 | #reshape them into Vector, review ruturned tensor shares the same data but have different shape, same as reshape in matlab
412 | x = x.view(-1,self.num_of_flat_features(x))
413 | x = F.relu(self.fc1(x))
414 |
415 | x = F.relu(self.fc2(x))
416 |
417 | x = self.fc3(x)
418 |
419 | #x = F.sigmoid(x)
420 | #print 'min,max,mean of x in 0st layer',x.min(),x.max(),x.mean()
421 | return x
422 |
423 | def num_of_flat_features(self,x):
424 | size=x.size()[1:]#we donot consider the batch dimension
425 | num_features=1
426 | for s in size:
427 | num_features*=s
428 | return num_features
429 |
--------------------------------------------------------------------------------
/Unet3d_pytorch.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 | class UNet3D(nn.Module):
5 | def __init__(self, in_channel, n_classes):
6 | self.in_channel = in_channel
7 | self.n_classes = n_classes
8 | super(UNet3D, self).__init__()
9 | self.ec0 = self.encoder(self.in_channel, 32, bias=False, batchnorm=False)
10 | self.ec1 = self.encoder(32, 64, bias=False, batchnorm=False)
11 | self.ec2 = self.encoder(64, 64, bias=False, batchnorm=False)
12 | self.ec3 = self.encoder(64, 128, bias=False, batchnorm=False)
13 | self.ec4 = self.encoder(128, 128, bias=False, batchnorm=False)
14 | self.ec5 = self.encoder(128, 256, bias=False, batchnorm=False)
15 | self.ec6 = self.encoder(256, 256, bias=False, batchnorm=False)
16 | self.ec7 = self.encoder(256, 512, bias=False, batchnorm=False)
17 |
18 | self.pool0 = nn.MaxPool3d(2)
19 | self.pool1 = nn.MaxPool3d(2)
20 | self.pool2 = nn.MaxPool3d(2)
21 |
22 | self.dc9 = self.decoder(512, 512, kernel_size=4, stride=2, padding=1, bias=False)
23 | self.dc8 = self.decoder(256 + 512, 256, kernel_size=3, stride=1, padding=1, bias=False)
24 | self.dc7 = self.decoder(256, 256, kernel_size=3, stride=1, padding=1, bias=False)
25 | self.dc6 = self.decoder(256, 256, kernel_size=4, stride=2, padding=1, bias=False)
26 | self.dc5 = self.decoder(128 + 256, 128, kernel_size=3, stride=1, padding=1, bias=False)
27 | self.dc4 = self.decoder(128, 128, kernel_size=3, stride=1, padding=1, bias=False)
28 | self.dc3 = self.decoder(128, 128, kernel_size=4, stride=2, padding=1, bias=False)
29 | self.dc2 = self.decoder(64 + 128, 64, kernel_size=3, stride=1, padding=1, bias=False)
30 | self.dc1 = self.decoder(64, 64, kernel_size=3, stride=1, padding=1, bias=False)
31 | self.dc0 = self.decoder(64, n_classes, kernel_size=1, stride=1, bias=False)
32 |
33 |
34 | def encoder(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1,
35 | bias=True, batchnorm=False):
36 | if batchnorm:
37 | layer = nn.Sequential(
38 | nn.Conv3d(in_channels, out_channels, kernel_size, stride=stride, padding=padding, bias=bias),
39 | nn.BatchNorm2d(out_channels),
40 | nn.ReLU())
41 | else:
42 | layer = nn.Sequential(
43 | nn.Conv3d(in_channels, out_channels, kernel_size, stride=stride, padding=padding, bias=bias),
44 | nn.ReLU())
45 | return layer
46 |
47 |
48 | def decoder(self, in_channels, out_channels, kernel_size, stride=1, padding=0,
49 | output_padding=0, bias=True):
50 | layer = nn.Sequential(
51 | nn.ConvTranspose3d(in_channels, out_channels, kernel_size, stride=stride,
52 | padding=padding, output_padding=output_padding, bias=bias),
53 | nn.ReLU())
54 | return layer
55 |
56 | def forward(self, x):
57 | e0 = self.ec0(x)
58 | syn0 = self.ec1(e0)
59 | e1 = self.pool0(syn0)
60 | e2 = self.ec2(e1)
61 | syn1 = self.ec3(e2)
62 | # print 'syn size1: ',syn1.size()
63 | del e0, e1, e2
64 |
65 | e3 = self.pool1(syn1)
66 | e4 = self.ec4(e3)
67 | syn2 = self.ec5(e4)
68 | # print 'syn size2: ',syn2.size()
69 | del e3, e4
70 |
71 | e5 = self.pool2(syn2)
72 | e6 = self.ec6(e5)
73 | e7 = self.ec7(e6)
74 | # print 'e7: ',e7.size()
75 | del e5, e6
76 | dc9 = self.dc9(e7)
77 | # print 'dc9: ',dc9.size()
78 | d9 = torch.cat((self.dc9(e7), syn2),dim=1)
79 | del e7, syn2
80 |
81 | d8 = self.dc8(d9)
82 | d7 = self.dc7(d8)
83 | del d9, d8
84 |
85 | d6 = torch.cat((self.dc6(d7), syn1),dim=1)
86 | del d7, syn1
87 |
88 | d5 = self.dc5(d6)
89 | d4 = self.dc4(d5)
90 | del d6, d5
91 |
92 | d3 = torch.cat((self.dc3(d4), syn0),dim=1)
93 | del d4, syn0
94 |
95 | d2 = self.dc2(d3)
96 | d1 = self.dc1(d2)
97 | del d3, d2
98 |
99 | d0 = self.dc0(d1)
100 | return d0
101 |
--------------------------------------------------------------------------------
/compute3DSSIM.py:
--------------------------------------------------------------------------------
1 | '''
2 | Target: Compute structure similarity (SSIM) between two 3D volumes
3 | Created on Jan, 22th 2018
4 | Author: Dong Nie
5 |
6 | reference from: http://simpleitk-prototype.readthedocs.io/en/latest/user_guide/plot_image.html
7 | '''
8 |
9 | import SimpleITK as sitk
10 |
11 | from multiprocessing import Pool
12 | import os
13 | import h5py
14 | import numpy as np
15 | import scipy.io as scio
16 | from morpologicalTransformation import denoiseImg_closing, denoiseImg_isolation
17 | from skimage import measure
18 |
19 |
20 | path = '/shenlab/lab_stor5/dongnie/3T7T/results/'
21 |
22 |
23 | def main():
24 | ids = [14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29]
25 | ids = range(0, 30)
26 | ids = [1]
27 | for id in ids:
28 | # datafn = os.path.join(path,'Case%02d.mhd'%id)
29 | # outdatafn = os.path.join(path,'Case%02d.nii.gz'%id)
30 | #
31 | # dataOrg = sitk.ReadImage(datafn)
32 | # dataMat = sitk.GetArrayFromImage(dataOrg)
33 | # #gtMat=np.transpose(gtMat,(2,1,0))
34 | # dataVol = sitk.GetImageFromArray(dataMat)
35 | # sitk.WriteImage(dataVol,outdatafn)
36 |
37 | # datafn = os.path.join(path, 'img1.mhd')
38 | # dataOrg = sitk.ReadImage(datafn)
39 | # spacing = dataOrg.GetSpacing()
40 | # origin = dataOrg.GetOrigin()
41 | # direction = dataOrg.GetDirection()
42 | # dataMat = sitk.GetArrayFromImage(dataOrg)
43 |
44 | gtfn = os.path.join(path, 'S1to1_7t.nii.gz')
45 | gtOrg = sitk.ReadImage(gtfn)
46 | gtMat = sitk.GetArrayFromImage(gtOrg)
47 | # gtMat=np.transpose(gtMat,(2,1,0))
48 | gtMat = gtMat.astype(np.float32)
49 | gtMat = (gtMat - np.amin(gtMat))/(np.amax(gtMat)-np.amin(gtMat))
50 | print np.amax(gtMat),',', np.amin(gtMat), 'dtype: ',gtMat.dtype, ',',gtMat.shape
51 |
52 | prefn = os.path.join(path,'preSub1_1112_195000.nii.gz')
53 | preOrg = sitk.ReadImage(prefn)
54 | preMat = sitk.GetArrayFromImage(preOrg)
55 | preMat = preMat.astype(np.float32)
56 | preMat = (preMat - np.amin(preMat))/(np.amax(preMat)-np.amin(preMat))
57 | # preMat = np.transpose(preMat,(2,0,1))
58 | print np.amax(preMat),',', np.amin(preMat), 'dtype:',preMat.dtype, ',',preMat.shape
59 |
60 |
61 | ssim_3d_sk = measure.compare_ssim(gtMat, preMat, multichannel=True, gaussian_weights=True, data_range=1.0,
62 | use_sample_covariance=False)
63 | print ssim_3d_sk
64 |
65 | # ssim_3d_sk = measure.structural_similarity(gtMat, preMat, multichannel=True, gaussian_weights=True, data_range=1.0,
66 | # use_sample_covariance=False)
67 | # gtMat1 = denoiseImg_closing(gtMat, kernel=np.ones((20, 20, 20)))
68 | # gtMat2 = gtMat + gtMat1
69 | # gtMat2[np.where(gtMat2 > 1)] = 1
70 | # gtMat = gtMat2
71 | # gtMat = denoiseImg_isolation(gtMat, struct=np.ones((3, 3, 3)))
72 | #
73 | # gtMat = gtMat.astype(np.uint8)
74 |
75 | # ind1 = np.where((gtMat==1)&(preMat==1))
76 | # preMat[ind1] = 0
77 | # ind2 = np.where((gtMat==2)&(preMat==2))
78 | # preMat[ind2] = 0
79 | # ind3 = np.where((gtMat==3)&(preMat==3))
80 | # preMat[ind3] = 0
81 | # errorMat = preMat
82 | #
83 | # outgtfn = os.path.join(path, 'sgm_errormap_sub1.nii.gz')
84 | # errorVol = sitk.GetImageFromArray(errorMat)
85 | # errorVol.SetSpacing(spacing)
86 | # errorVol.SetOrigin(origin)
87 | # errorVol.SetDirection(direction)
88 | # sitk.WriteImage(errorVol, outgtfn)
89 |
90 |
91 | #
92 | # prefn='preSub%d_as32_v12.nii'%id
93 | # preOrg=sitk.ReadImage(prefn)
94 | # preMat=sitk.GetArrayFromImage(preOrg)
95 | # preMat=np.transpose(preMat,(2,1,0))
96 | # preVol=sitk.GetImageFromArra(preMat)
97 | # sitk.WriteImage(preVol,prefn)
98 |
99 |
100 | if __name__ == '__main__':
101 | main()
102 |
--------------------------------------------------------------------------------
/dicom2Nii.py:
--------------------------------------------------------------------------------
1 | '''
2 | 05/02, at Chapel Hill
3 | Dong
4 |
5 | convert dicom series to nifti format
6 | '''
7 | import numpy
8 | import SimpleITK as sitk
9 | import os
10 | from doctest import SKIP
11 |
12 |
13 | class ScanFile(object):
14 | def __init__(self,directory,prefix=None,postfix=None):
15 | self.directory=directory
16 | self.prefix=prefix
17 | self.postfix=postfix
18 |
19 | def scan_files(self):
20 | files_list=[]
21 |
22 | for dirpath,dirnames,filenames in os.walk(self.directory):
23 | '''''
24 | dirpath is a string, the path to the directory.
25 | dirnames is a list of the names of the subdirectories in dirpath (excluding '.' and '..').
26 | filenames is a list of the names of the non-directory files in dirpath.
27 | '''
28 | for special_file in filenames:
29 | if self.postfix:
30 | special_file.endswith(self.postfix)
31 | files_list.append(os.path.join(dirpath,special_file))
32 | elif self.prefix:
33 | special_file.startswith(self.prefix)
34 | files_list.append(os.path.join(dirpath,special_file))
35 | else:
36 | files_list.append(os.path.join(dirpath,special_file))
37 |
38 | return files_list
39 |
40 | def scan_subdir(self):
41 | subdir_list=[]
42 | for dirpath,dirnames,files in os.walk(self.directory):
43 | subdir_list.append(dirpath)
44 | return subdir_list
45 |
46 |
47 | def main():
48 | path='/home/dongnie/warehouse/pelvicSeg/newData/pelvic_0118/'
49 | subpath='atkinson_lafayette'
50 | outfn=subpath+'.nii.gz'
51 | inputdir=path+subpath
52 | scan=ScanFile(path)
53 | subdirs=scan.scan_subdir()
54 | for subdir in subdirs:
55 | if subdir==path or subdir=='..':
56 | continue
57 |
58 | print 'subdir is, ',subdir
59 |
60 | ss=subdir.split('/')
61 | print 'ss is, ',ss, 'and s7 is, ',ss[7]
62 |
63 | outfn=ss[7]+'.nii.gz'
64 |
65 | reader = sitk.ImageSeriesReader()
66 |
67 | dicom_names = reader.GetGDCMSeriesFileNames(subdir)
68 | reader.SetFileNames(dicom_names)
69 |
70 | image = reader.Execute()
71 |
72 | size = image.GetSize()
73 | print( "Image size:", size[0], size[1], size[2] )
74 |
75 | print( "Writing image:", outfn)
76 |
77 | sitk.WriteImage(image,outfn)
78 |
79 |
80 | if __name__ == '__main__':
81 | main()
82 |
--------------------------------------------------------------------------------
/extract23DPatch4MultiModalImg.py:
--------------------------------------------------------------------------------
1 |
2 | '''
3 | Target: Crop patches for kinds of medical images, such as hdr, nii, mha, mhd, raw and so on, and store them as hdf5 files
4 | for single-scale patches
5 | Created in June, 2016
6 | Author: Dong Nie
7 | '''
8 |
9 |
10 |
11 | import SimpleITK as sitk
12 |
13 | from multiprocessing import Pool
14 | import os, argparse
15 | import h5py
16 | import numpy as np
17 |
18 | parser = argparse.ArgumentParser(description="PyTorch InfantSeg")
19 | parser.add_argument("--how2normalize", type=int, default=6, help="how to normalize the data")
20 |
21 | global opt
22 | opt = parser.parse_args()
23 |
24 |
25 | d1=5
26 | d2=64
27 | d3=64
28 | dFA=[d1,d2,d3] # size of patches of input data
29 | dSeg=[1,64,64] # size of pathes of label data
30 | step1=1
31 | step2=32
32 | step3=32
33 | step=[step1,step2,step3]
34 |
35 |
36 | class ScanFile(object):
37 | def __init__(self,directory,prefix=None,postfix=None):
38 | self.directory=directory
39 | self.prefix=prefix
40 | self.postfix=postfix
41 |
42 | def scan_files(self):
43 | files_list=[]
44 |
45 | for dirpath,dirnames,filenames in os.walk(self.directory):
46 | '''''
47 | dirpath is a string, the path to the directory.
48 | dirnames is a list of the names of the subdirectories in dirpath (excluding '.' and '..').
49 | filenames is a list of the names of the non-directory files in dirpath.
50 | '''
51 | for special_file in filenames:
52 | if self.postfix:
53 | if special_file.endswith(self.postfix):
54 | files_list.append(os.path.join(dirpath,special_file))
55 | elif self.prefix:
56 | if special_file.startswith(self.prefix):
57 | files_list.append(os.path.join(dirpath,special_file))
58 | else:
59 | files_list.append(os.path.join(dirpath,special_file))
60 |
61 | return files_list
62 |
63 | def scan_subdir(self):
64 | subdir_list=[]
65 | for dirpath,dirnames,files in os.walk(self.directory):
66 | subdir_list.append(dirpath)
67 | return subdir_list
68 |
69 |
70 |
71 | '''
72 | Actually, we donot need it any more, this is useful to generate hdf5 database
73 | '''
74 | def extractPatch4OneSubject(matFA, matMR, matSeg, matMask, fileID ,d, step, rate):
75 |
76 | eps=5e-2
77 | rate1=1.0/2
78 | rate2=1.0/4
79 | [row,col,leng]=matFA.shape
80 | cubicCnt=0
81 | estNum=40000
82 | trainFA=np.zeros([estNum,1, dFA[0],dFA[1],dFA[2]],dtype=np.float16)
83 | trainSeg=np.zeros([estNum,1,dSeg[0],dSeg[1],dSeg[2]],dtype=np.float16)
84 | trainMR=np.zeros([estNum,1,dFA[0],dFA[1],dFA[2]],dtype=np.float16)
85 |
86 | print 'trainFA shape, ',trainFA.shape
87 | #to padding for input
88 | margin1=(dFA[0]-dSeg[0])/2
89 | margin2=(dFA[1]-dSeg[1])/2
90 | margin3=(dFA[2]-dSeg[2])/2
91 | cubicCnt=0
92 | marginD=[margin1,margin2,margin3]
93 | print 'matFA shape is ',matFA.shape
94 | matFAOut=np.zeros([row+2*marginD[0],col+2*marginD[1],leng+2*marginD[2]],dtype=np.float16)
95 | print 'matFAOut shape is ',matFAOut.shape
96 | matFAOut[marginD[0]:row+marginD[0],marginD[1]:col+marginD[1],marginD[2]:leng+marginD[2]]=matFA
97 |
98 | matMROut=np.zeros([row+2*marginD[0],col+2*marginD[1],leng+2*marginD[2]],dtype=np.float16)
99 | print 'matMROut shape is ',matMROut.shape
100 | matMROut[marginD[0]:row+marginD[0],marginD[1]:col+marginD[1],marginD[2]:leng+marginD[2]]=matMR
101 |
102 | matSegOut=np.zeros([row+2*marginD[0],col+2*marginD[1],leng+2*marginD[2]],dtype=np.float16)
103 | matSegOut[marginD[0]:row+marginD[0],marginD[1]:col+marginD[1],marginD[2]:leng+marginD[2]]=matSeg
104 |
105 |
106 | matMaskOut=np.zeros([row+2*marginD[0],col+2*marginD[1],leng+2*marginD[2]],dtype=np.float16)
107 | matMaskOut[marginD[0]:row+marginD[0],marginD[1]:col+marginD[1],marginD[2]:leng+marginD[2]]=matMask
108 |
109 | #for mageFA, enlarge it by padding
110 | if margin1!=0:
111 | matFAOut[0:marginD[0],marginD[1]:col+marginD[1],marginD[2]:leng+marginD[2]]=matFA[marginD[0]-1::-1,:,:] #reverse 0:marginD[0]
112 | matFAOut[row+marginD[0]:matFAOut.shape[0],marginD[1]:col+marginD[1],marginD[2]:leng+marginD[2]]=matFA[matFA.shape[0]-1:row-marginD[0]-1:-1,:,:] #we'd better flip it along the 1st dimension
113 | if margin2!=0:
114 | matFAOut[marginD[0]:row+marginD[0],0:marginD[1],marginD[2]:leng+marginD[2]]=matFA[:,marginD[1]-1::-1,:] #we'd flip it along the 2nd dimension
115 | matFAOut[marginD[0]:row+marginD[0],col+marginD[1]:matFAOut.shape[1],marginD[2]:leng+marginD[2]]=matFA[:,matFA.shape[1]-1:col-marginD[1]-1:-1,:] #we'd flip it along the 2nd dimension
116 | if margin3!=0:
117 | matFAOut[marginD[0]:row+marginD[0],marginD[1]:col+marginD[1],0:marginD[2]]=matFA[:,:,marginD[2]-1::-1] #we'd better flip it along the 3rd dimension
118 | matFAOut[marginD[0]:row+marginD[0],marginD[1]:col+marginD[1],marginD[2]+leng:matFAOut.shape[2]]=matFA[:,:,matFA.shape[2]-1:leng-marginD[2]-1:-1]
119 |
120 | #for matMR, enlarge it by padding
121 | if margin1!=0:
122 | matMROut[0:marginD[0],marginD[1]:col+marginD[1],marginD[2]:leng+marginD[2]]=matMR[marginD[0]-1::-1,:,:] #reverse 0:marginD[0]
123 | matMROut[row+marginD[0]:matMROut.shape[0],marginD[1]:col+marginD[1],marginD[2]:leng+marginD[2]]=matMR[matMR.shape[0]-1:row-marginD[0]-1:-1,:,:] #we'd better flip it along the 1st dimension
124 | if margin2!=0:
125 | matMROut[marginD[0]:row+marginD[0],0:marginD[1],marginD[2]:leng+marginD[2]]=matMR[:,marginD[1]-1::-1,:] #we'd flip it along the 2nd dimension
126 | matMROut[marginD[0]:row+marginD[0],col+marginD[1]:matMROut.shape[1],marginD[2]:leng+marginD[2]]=matMR[:,matMR.shape[1]-1:col-marginD[1]-1:-1,:] #we'd flip it along the 2nd dimension
127 | if margin3!=0:
128 | matMROut[marginD[0]:row+marginD[0],marginD[1]:col+marginD[1],0:marginD[2]]=matMR[:,:,marginD[2]-1::-1] #we'd better flip it along the 3rd dimension
129 | matMROut[marginD[0]:row+marginD[0],marginD[1]:col+marginD[1],marginD[2]+leng:matMROut.shape[2]]=matMR[:,:,matMR.shape[2]-1:leng-marginD[2]-1:-1]
130 |
131 | #for matseg, enlarge it by padding
132 | if margin1!=0:
133 | matSegOut[0:marginD[0],marginD[1]:col+marginD[1],marginD[2]:leng+marginD[2]]=matSeg[marginD[0]-1::-1,:,:] #reverse 0:marginD[0]
134 | matSegOut[row+marginD[0]:matSegOut.shape[0],marginD[1]:col+marginD[1],marginD[2]:leng+marginD[2]]=matSeg[matSeg.shape[0]-1:row-marginD[0]-1:-1,:,:] #we'd better flip it along the 1st dimension
135 | if margin2!=0:
136 | matSegOut[marginD[0]:row+marginD[0],0:marginD[1],marginD[2]:leng+marginD[2]]=matSeg[:,marginD[1]-1::-1,:] #we'd flip it along the 2nd dimension
137 | matSegOut[marginD[0]:row+marginD[0],col+marginD[1]:matSegOut.shape[1],marginD[2]:leng+marginD[2]]=matSeg[:,matSeg.shape[1]-1:col-marginD[1]-1:-1,:] #we'd flip it along the 2nd dimension
138 | if margin3!=0:
139 | matSegOut[marginD[0]:row+marginD[0],marginD[1]:col+marginD[1],0:marginD[2]]=matSeg[:,:,marginD[2]-1::-1] #we'd better flip it along the 3rd dimension
140 | matSegOut[marginD[0]:row+marginD[0],marginD[1]:col+marginD[1],marginD[2]+leng:matSegOut.shape[2]]=matSeg[:,:,matSeg.shape[2]-1:leng-marginD[2]-1:-1]
141 |
142 | #for matseg, enlarge it by padding
143 | if margin1!=0:
144 | matMaskOut[0:marginD[0],marginD[1]:col+marginD[1],marginD[2]:leng+marginD[2]]=matMask[marginD[0]-1::-1,:,:] #reverse 0:marginD[0]
145 | matMaskOut[row+marginD[0]:matMaskOut.shape[0],marginD[1]:col+marginD[1],marginD[2]:leng+marginD[2]]=matMask[matMask.shape[0]-1:row-marginD[0]-1:-1,:,:] #we'd better flip it along the 1st dimension
146 | if margin2!=0:
147 | matMaskOut[marginD[0]:row+marginD[0],0:marginD[1],marginD[2]:leng+marginD[2]]=matMask[:,marginD[1]-1::-1,:] #we'd flip it along the 2nd dimension
148 | matMaskOut[marginD[0]:row+marginD[0],col+marginD[1]:matMaskOut.shape[1],marginD[2]:leng+marginD[2]]=matMask[:,matMask.shape[1]-1:col-marginD[1]-1:-1,:] #we'd flip it along the 2nd dimension
149 | if margin3!=0:
150 | matMaskOut[marginD[0]:row+marginD[0],marginD[1]:col+marginD[1],0:marginD[2]]=matMask[:,:,marginD[2]-1::-1] #we'd better flip it along the 3rd dimension
151 | matMaskOut[marginD[0]:row+marginD[0],marginD[1]:col+marginD[1],marginD[2]+leng:matMaskOut.shape[2]]=matMask[:,:,matMask.shape[2]-1:leng-marginD[2]-1:-1]
152 |
153 | dsfactor = rate
154 |
155 | for i in range(0,row-dSeg[0],step[0]):
156 | for j in range(0,col-dSeg[1],step[1]):
157 | for k in range(0,leng-dSeg[2],step[2]):
158 | volMask = matMaskOut[i:i+dSeg[0],j:j+dSeg[1],k:k+dSeg[2]]
159 | if np.sum(volMask)maxV)] = maxV
308 | # print 'maxV is: ',np.ndarray.max(mrimg)
309 | # mu=np.mean(mrimg) # we should have a fixed std and mean
310 | # std = np.std(mrimg)
311 | # mrnp = (mrimg - mu)/std
312 | # print 'maxV,',np.ndarray.max(mrnp),' minV, ',np.ndarray.min(mrnp)
313 |
314 | #matLPET = (mrimg - meanLPET)/(stdLPET)
315 | #print 'lpet: maxV,',np.ndarray.max(matLPET),' minV, ',np.ndarray.min(matLPET), ' meanV: ', np.mean(matLPET), ' stdV: ', np.std(matLPET)
316 |
317 | # matLPET = (mrnp - minLPET)/(maxPercentLPET-minLPET)
318 | # print 'lpet: maxV,',np.ndarray.max(matLPET),' minV, ',np.ndarray.min(matLPET), ' meanV: ', np.mean(matLPET), ' stdV: ', np.std(matLPET)
319 |
320 |
321 |
322 |
323 | # maxV1, minV1 = np.percentile(mrimg1, [99.5 ,1])
324 | # print 'maxV1 is: ',np.ndarray.max(mrimg1)
325 | # mrimg1[np.where(mrimg1>maxV1)] = maxV1
326 | # print 'maxV1 is: ',np.ndarray.max(mrimg1)
327 | # mu1 = np.mean(mrimg1) # we should have a fixed std and mean
328 | # std1 = np.std(mrimg1)
329 | # mrnp1 = (mrimg1 - mu1)/std1
330 | # print 'maxV1,',np.ndarray.max(mrnp1),' minV, ',np.ndarray.min(mrnp1)
331 |
332 | # ctnp[np.where(ctnp>maxPercentCT)] = maxPercentCT
333 | # matCT = (ctnp - meanCT)/stdCT
334 | # print 'ct: maxV,',np.ndarray.max(matCT),' minV, ',np.ndarray.min(matCT), 'meanV: ', np.mean(matCT), 'stdV: ', np.std(matCT)
335 |
336 |
337 |
338 |
339 | # maxVal = np.amax(labelimg)
340 | # minVal = np.amin(labelimg)
341 | # print 'maxV is: ', maxVal, ' minVal is: ', minVal
342 | # mu=np.mean(labelimg) # we should have a fixed std and mean
343 | # std = np.std(labelimg)
344 | #
345 | # labelimg = (labelimg - minVal)/(maxVal - minVal)
346 | #
347 | # print 'maxV,',np.ndarray.max(labelimg),' minV, ',np.ndarray.min(labelimg)
348 | #you can do what you want here for for your label img
349 |
350 | # matSPET = (labelimg - minSPET)/(maxPercentSPET-minSPET)
351 | # print 'spet: maxV,',np.ndarray.max(matSPET),' minV, ',np.ndarray.min(matSPET), ' meanV: ',np.mean(matSPET), ' stdV: ', np.std(matSPET)
352 |
353 | sdir = filename.split('/')
354 | print 'sdir is, ',sdir, 'and s5 is, ',sdir[5]
355 | lpet_fn = sdir[5]
356 | words = lpet_fn.split('_')
357 | print 'words are, ',words
358 | ind = int(words[0])
359 |
360 |
361 | fileID = words[0]
362 | rate = 1
363 | cubicCnt = extractPatch4OneSubject(matLPET, matCT, matSPET, maskimg, fileID,dSeg,step,rate)
364 | #cubicCnt = extractPatch4OneSubject(mrnp, matCT, hpetnp, maskimg, fileID,dSeg,step,rate)
365 | print '# of patches is ', cubicCnt
366 |
367 | # reverse along the 1st dimension
368 | rmrimg = matLPET[matLPET.shape[0] - 1::-1, :, :]
369 | rmatCT = matCT[matCT.shape[0] - 1::-1, :, :]
370 | rlabelimg = matSPET[matSPET.shape[0] - 1::-1, :, :]
371 | rmaskimg = maskimg[maskimg.shape[0] - 1::-1, :, :]
372 | fileID = words[0]+'r'
373 | cubicCnt = extractPatch4OneSubject(rmrimg, rmatCT, rlabelimg, rmaskimg, fileID,dSeg,step,rate)
374 | print '# of patches is ', cubicCnt
375 |
376 | if __name__ == '__main__':
377 | main()
378 |
--------------------------------------------------------------------------------
/extract23DPatch4SingleModalImg.py:
--------------------------------------------------------------------------------
1 | '''
2 | Target: Crop patches for kinds of medical images, such as hdr, nii, mha, mhd, raw and so on, and store them as hdf5 files
3 | for single input modality
4 | Created in June, 2016
5 | Author: Dong Nie
6 | '''
7 |
8 | import SimpleITK as sitk
9 |
10 | from multiprocessing import Pool
11 | import os, argparse
12 | import h5py
13 | import numpy as np
14 |
15 | parser = argparse.ArgumentParser(description="PyTorch InfantSeg")
16 | parser.add_argument("--how2normalize", type=int, default=6, help="how to normalize the data")
17 |
18 | global opt
19 | opt = parser.parse_args()
20 |
21 | # input patch size
22 | d1 = 5
23 | d2 = 64
24 | d3 = 64
25 | # output patch size
26 | dFA = [d1, d2, d3] # size of patches of input data
27 | dSeg = [1, 64, 64] # size of pathes of label data
28 | # stride for extracting patches along the volume
29 | step1 = 1
30 | step2 = 16
31 | step3 = 16
32 | step = [step1, step2, step3]
33 |
34 |
35 | class ScanFile(object):
36 | def __init__(self, directory, prefix=None, postfix=None):
37 | self.directory = directory
38 | self.prefix = prefix
39 | self.postfix = postfix
40 |
41 | def scan_files(self):
42 | files_list = []
43 |
44 | for dirpath, dirnames, filenames in os.walk(self.directory):
45 | '''''
46 | dirpath is a string, the path to the directory.
47 | dirnames is a list of the names of the subdirectories in dirpath (excluding '.' and '..').
48 | filenames is a list of the names of the non-directory files in dirpath.
49 | '''
50 | for special_file in filenames:
51 | if self.postfix:
52 | if special_file.endswith(self.postfix):
53 | files_list.append(os.path.join(dirpath, special_file))
54 | elif self.prefix:
55 | if special_file.startswith(self.prefix):
56 | files_list.append(os.path.join(dirpath, special_file))
57 | else:
58 | files_list.append(os.path.join(dirpath, special_file))
59 |
60 | return files_list
61 |
62 | def scan_subdir(self):
63 | subdir_list = []
64 | for dirpath, dirnames, files in os.walk(self.directory):
65 | subdir_list.append(dirpath)
66 | return subdir_list
67 |
68 |
69 | '''
70 | Actually, we donot need it any more, this is useful to generate hdf5 database
71 | '''
72 |
73 |
74 | def extractPatch4OneSubject(matFA, matSeg, matMask, fileID, d, step, rate):
75 | eps = 5e-2
76 | rate1 = 1.0 / 2
77 | rate2 = 1.0 / 4
78 | [row, col, leng] = matFA.shape
79 | cubicCnt = 0
80 | estNum = 40000
81 | trainFA = np.zeros([estNum, 1, dFA[0], dFA[1], dFA[2]], dtype=np.float16)
82 | trainSeg = np.zeros([estNum, 1, dSeg[0], dSeg[1], dSeg[2]], dtype=np.float16)
83 |
84 | print 'trainFA shape, ', trainFA.shape
85 | # to padding for input
86 | margin1 = (dFA[0] - dSeg[0]) / 2
87 | margin2 = (dFA[1] - dSeg[1]) / 2
88 | margin3 = (dFA[2] - dSeg[2]) / 2
89 | cubicCnt = 0
90 | marginD = [margin1, margin2, margin3]
91 | print 'matFA shape is ', matFA.shape
92 | matFAOut = np.zeros([row + 2 * marginD[0], col + 2 * marginD[1], leng + 2 * marginD[2]], dtype=np.float16)
93 | print 'matFAOut shape is ', matFAOut.shape
94 | matFAOut[marginD[0]:row + marginD[0], marginD[1]:col + marginD[1], marginD[2]:leng + marginD[2]] = matFA
95 |
96 | matSegOut = np.zeros([row + 2 * marginD[0], col + 2 * marginD[1], leng + 2 * marginD[2]], dtype=np.float16)
97 | matSegOut[marginD[0]:row + marginD[0], marginD[1]:col + marginD[1], marginD[2]:leng + marginD[2]] = matSeg
98 |
99 | matMaskOut = np.zeros([row + 2 * marginD[0], col + 2 * marginD[1], leng + 2 * marginD[2]], dtype=np.float16)
100 | matMaskOut[marginD[0]:row + marginD[0], marginD[1]:col + marginD[1], marginD[2]:leng + marginD[2]] = matMask
101 |
102 | # for mageFA, enlarge it by padding
103 | if margin1 != 0:
104 | matFAOut[0:marginD[0], marginD[1]:col + marginD[1], marginD[2]:leng + marginD[2]] = matFA[marginD[0] - 1::-1, :,
105 | :] # reverse 0:marginD[0]
106 | matFAOut[row + marginD[0]:matFAOut.shape[0], marginD[1]:col + marginD[1], marginD[2]:leng + marginD[2]] = matFA[
107 | matFA.shape[
108 | 0] - 1:row -
109 | marginD[
110 | 0] - 1:-1,
111 | :,
112 | :] # we'd better flip it along the 1st dimension
113 | if margin2 != 0:
114 | matFAOut[marginD[0]:row + marginD[0], 0:marginD[1], marginD[2]:leng + marginD[2]] = matFA[:, marginD[1] - 1::-1,
115 | :] # we'd flip it along the 2nd dimension
116 | matFAOut[marginD[0]:row + marginD[0], col + marginD[1]:matFAOut.shape[1], marginD[2]:leng + marginD[2]] = matFA[
117 | :,
118 | matFA.shape[
119 | 1] - 1:col -
120 | marginD[
121 | 1] - 1:-1,
122 | :] # we'd flip it along the 2nd dimension
123 | if margin3 != 0:
124 | matFAOut[marginD[0]:row + marginD[0], marginD[1]:col + marginD[1], 0:marginD[2]] = matFA[:, :, marginD[
125 | 2] - 1::-1] # we'd better flip it along the 3rd dimension
126 | matFAOut[marginD[0]:row + marginD[0], marginD[1]:col + marginD[1], marginD[2] + leng:matFAOut.shape[2]] = matFA[
127 | :, :,
128 | matFA.shape[
129 | 2] - 1:leng -
130 | marginD[
131 | 2] - 1:-1]
132 | # for matseg, enlarge it by padding
133 | if margin1 != 0:
134 | matSegOut[0:marginD[0], marginD[1]:col + marginD[1], marginD[2]:leng + marginD[2]] = matSeg[marginD[0] - 1::-1,
135 | :,
136 | :] # reverse 0:marginD[0]
137 | matSegOut[row + marginD[0]:matSegOut.shape[0], marginD[1]:col + marginD[1],
138 | marginD[2]:leng + marginD[2]] = matSeg[matSeg.shape[0] - 1:row - marginD[0] - 1:-1, :,
139 | :] # we'd better flip it along the 1st dimension
140 | if margin2 != 0:
141 | matSegOut[marginD[0]:row + marginD[0], 0:marginD[1], marginD[2]:leng + marginD[2]] = matSeg[:,
142 | marginD[1] - 1::-1,
143 | :] # we'd flip it along the 2nd dimension
144 | matSegOut[marginD[0]:row + marginD[0], col + marginD[1]:matSegOut.shape[1],
145 | marginD[2]:leng + marginD[2]] = matSeg[:, matSeg.shape[1] - 1:col - marginD[1] - 1:-1,
146 | :] # we'd flip it along the 2nd dimension
147 | if margin3 != 0:
148 | matSegOut[marginD[0]:row + marginD[0], marginD[1]:col + marginD[1], 0:marginD[2]] = matSeg[:, :, marginD[
149 | 2] - 1::-1] # we'd better flip it along the 3rd dimension
150 | matSegOut[marginD[0]:row + marginD[0], marginD[1]:col + marginD[1],
151 | marginD[2] + leng:matSegOut.shape[2]] = matSeg[:, :, matSeg.shape[2] - 1:leng - marginD[2] - 1:-1]
152 |
153 | # for matseg, enlarge it by padding
154 | if margin1 != 0:
155 | matMaskOut[0:marginD[0], marginD[1]:col + marginD[1], marginD[2]:leng + marginD[2]] = matMask[
156 | marginD[0] - 1::-1, :,
157 | :] # reverse 0:marginD[0]
158 | matMaskOut[row + marginD[0]:matMaskOut.shape[0], marginD[1]:col + marginD[1],
159 | marginD[2]:leng + marginD[2]] = matMask[matMask.shape[0] - 1:row - marginD[0] - 1:-1, :,
160 | :] # we'd better flip it along the 1st dimension
161 | if margin2 != 0:
162 | matMaskOut[marginD[0]:row + marginD[0], 0:marginD[1], marginD[2]:leng + marginD[2]] = matMask[:,
163 | marginD[1] - 1::-1,
164 | :] # we'd flip it along the 2nd dimension
165 | matMaskOut[marginD[0]:row + marginD[0], col + marginD[1]:matMaskOut.shape[1],
166 | marginD[2]:leng + marginD[2]] = matMask[:, matMask.shape[1] - 1:col - marginD[1] - 1:-1,
167 | :] # we'd flip it along the 2nd dimension
168 | if margin3 != 0:
169 | matMaskOut[marginD[0]:row + marginD[0], marginD[1]:col + marginD[1], 0:marginD[2]] = matMask[:, :, marginD[
170 | 2] - 1::-1] # we'd better flip it along the 3rd dimension
171 | matMaskOut[marginD[0]:row + marginD[0], marginD[1]:col + marginD[1],
172 | marginD[2] + leng:matMaskOut.shape[2]] = matMask[:, :, matMask.shape[2] - 1:leng - marginD[2] - 1:-1]
173 |
174 | dsfactor = rate
175 |
176 | for i in range(0, row - dSeg[0], step[0]):
177 | for j in range(0, col - dSeg[1], step[1]):
178 | for k in range(0, leng - dSeg[2], step[2]):
179 | volMask = matMaskOut[i:i + dSeg[0], j:j + dSeg[1], k:k + dSeg[2]]
180 | if np.sum(volMask) < eps:
181 | continue
182 | cubicCnt = cubicCnt + 1
183 | # index at scale 1
184 | volSeg = matSeg[i:i + dSeg[0], j:j + dSeg[1], k:k + dSeg[2]]
185 | volFA = matFAOut[i:i + dFA[0], j:j + dFA[1], k:k + dFA[2]]
186 |
187 |
188 | trainFA[cubicCnt, 0, :, :, :] = volFA # 32*32*32
189 |
190 | trainSeg[cubicCnt, 0, :, :, :] = volSeg # 24*24*24
191 |
192 | trainFA = trainFA[0:cubicCnt, :, :, :, :]
193 | trainSeg = trainSeg[0:cubicCnt, :, :, :, :]
194 |
195 | with h5py.File('./trainMRCT_snorm_64_%s.h5' % fileID, 'w') as f:
196 | f['dataMR'] = trainFA
197 | f['dataCT'] = trainSeg
198 |
199 | with open('./trainMRCT2D_snorm_64_list.txt', 'a') as f:
200 | f.write('./trainMRCT_snorm_64_%s.h5\n' % fileID)
201 | return cubicCnt
202 |
203 |
204 | def main():
205 | print opt
206 | path = '/home/niedong/Data4LowDosePET/data_niigz_scale/'
207 | path = '/shenlab/lab_stor5/dongnie/brain_mr2ct/original_data/' # path to the data, change to your own path
208 | scan = ScanFile(path, postfix='_mr.hdr') # the specify item for your files, change to your own style
209 | filenames = scan.scan_files()
210 |
211 | # for input
212 | maxSource = 149.366742
213 | maxPercentSource = 7.76
214 | minSource = 0.00055037
215 | meanSource = 0.27593288
216 | stdSource = 0.75747500
217 |
218 | # for output
219 | maxTarget = 27279
220 | maxPercentTarget = 1320
221 | minTarget = -1023
222 | meanTarget = -601.1929
223 | stdTarget = 475.034
224 |
225 | for filename in filenames:
226 |
227 | print 'source filename: ', filename
228 |
229 | source_fn = filename
230 | target_fn = filename.replace('_mr.hdr', '_ct.hdr')
231 |
232 | imgOrg = sitk.ReadImage(source_fn)
233 | sourcenp = sitk.GetArrayFromImage(imgOrg)
234 |
235 | imgOrg1 = sitk.ReadImage(target_fn)
236 | targetnp = sitk.GetArrayFromImage(imgOrg1)
237 |
238 | maskimg = sourcenp
239 |
240 | mu = np.mean(sourcenp)
241 |
242 | if opt.how2normalize == 1:
243 | maxV, minV = np.percentile(sourcenp, [99, 1])
244 | print 'maxV,', maxV, ' minV, ', minV
245 | sourcenp = (sourcenp - mu) / (maxV - minV)
246 | print 'unique value: ', np.unique(targetnp)
247 |
248 | # for training data in pelvicSeg
249 | if opt.how2normalize == 2:
250 | maxV, minV = np.percentile(sourcenp, [99, 1])
251 | print 'maxV,', maxV, ' minV, ', minV
252 | sourcenp = (sourcenp - mu) / (maxV - minV)
253 | print 'unique value: ', np.unique(targetnp)
254 |
255 | # for training data in pelvicSegRegH5
256 | if opt.how2normalize == 3:
257 | std = np.std(sourcenp)
258 | sourcenp = (sourcenp - mu) / std
259 | print 'maxV,', np.ndarray.max(sourcenp), ' minV, ', np.ndarray.min(sourcenp)
260 |
261 | if opt.how2normalize == 4:
262 | maxSource = 149.366742
263 | maxPercentSource = 7.76
264 | minSource = 0.00055037
265 | meanSource = 0.27593288
266 | stdSource = 0.75747500
267 |
268 | # for target
269 | maxTarget = 27279
270 | maxPercentTarget = 1320
271 | minTarget = -1023
272 | meanTarget = -601.1929
273 | stdTarget = 475.034
274 |
275 | matSource = (sourcenp - minSource) / (maxPercentSource - minSource)
276 | matTarget = (targetnp - meanTarget) / stdTarget
277 |
278 | if opt.how2normalize == 5:
279 | # for target
280 | maxTarget = 27279
281 | maxPercentTarget = 1320
282 | minTarget = -1023
283 | meanTarget = -601.1929
284 | stdTarget = 475.034
285 |
286 | print 'target, max: ', np.amax(targetnp), ' target, min: ', np.amin(targetnp)
287 |
288 | # matSource = (sourcenp - meanSource) / (stdSource)
289 | matSource = sourcenp
290 | matTarget = (targetnp - meanTarget) / stdTarget
291 |
292 | if opt.how2normalize == 6:
293 | maxPercentSource, minPercentSource = np.percentile(sourcenp, [99.5, 0])
294 | maxPercentTarget, minPercentTarget = np.percentile(targetnp, [99.5, 0])
295 | print 'maxPercentSource: ', maxPercentSource, ' minPercentSource: ', minPercentSource, ' maxPercentTarget: ', maxPercentTarget, 'minPercentTarget: ', minPercentTarget
296 |
297 | matSource = (sourcenp - minPercentSource) / (maxPercentSource - minPercentSource) #input
298 | #output, use input's statistical (if there is big difference between input and output, you can find a simple relation between input and output and then include this relation to normalize output with input's statistical)
299 | matTarget = (targetnp - minPercentSource) / (maxPercentSource - minPercentSource)
300 |
301 | print 'maxSource: ', np.amax(matSource), ' maxTarget: ', np.amax(matTarget)
302 | print 'minSource: ', np.amin(matSource), ' minTarget: ', np.amin(matTarget)
303 |
304 | # maxV, minV = np.percentile(mrimg, [99.5, 0])
305 | # print 'maxV is: ',np.ndarray.max(mrimg)
306 | # mrimg[np.where(mrimg>maxV)] = maxV
307 | # print 'maxV is: ',np.ndarray.max(mrimg)
308 | # mu=np.mean(mrimg) # we should have a fixed std and mean
309 | # std = np.std(mrimg)
310 | # mrnp = (mrimg - mu)/std
311 | # print 'maxV,',np.ndarray.max(mrnp),' minV, ',np.ndarray.min(mrnp)
312 |
313 | # matLPET = (mrimg - meanLPET)/(stdLPET)
314 | # print 'lpet: maxV,',np.ndarray.max(matLPET),' minV, ',np.ndarray.min(matLPET), ' meanV: ', np.mean(matLPET), ' stdV: ', np.std(matLPET)
315 |
316 | # matLPET = (mrnp - minLPET)/(maxPercentLPET-minLPET)
317 | # print 'lpet: maxV,',np.ndarray.max(matLPET),' minV, ',np.ndarray.min(matLPET), ' meanV: ', np.mean(matLPET), ' stdV: ', np.std(matLPET)
318 |
319 | # maxV1, minV1 = np.percentile(mrimg1, [99.5 ,1])
320 | # print 'maxV1 is: ',np.ndarray.max(mrimg1)
321 | # mrimg1[np.where(mrimg1>maxV1)] = maxV1
322 | # print 'maxV1 is: ',np.ndarray.max(mrimg1)
323 | # mu1 = np.mean(mrimg1) # we should have a fixed std and mean
324 | # std1 = np.std(mrimg1)
325 | # mrnp1 = (mrimg1 - mu1)/std1
326 | # print 'maxV1,',np.ndarray.max(mrnp1),' minV, ',np.ndarray.min(mrnp1)
327 |
328 | # ctnp[np.where(ctnp>maxPercentCT)] = maxPercentCT
329 | # matCT = (ctnp - meanCT)/stdCT
330 | # print 'ct: maxV,',np.ndarray.max(matCT),' minV, ',np.ndarray.min(matCT), 'meanV: ', np.mean(matCT), 'stdV: ', np.std(matCT)
331 |
332 | # maxVal = np.amax(labelimg)
333 | # minVal = np.amin(labelimg)
334 | # print 'maxV is: ', maxVal, ' minVal is: ', minVal
335 | # mu=np.mean(labelimg) # we should have a fixed std and mean
336 | # std = np.std(labelimg)
337 | #
338 | # labelimg = (labelimg - minVal)/(maxVal - minVal)
339 | #
340 | # print 'maxV,',np.ndarray.max(labelimg),' minV, ',np.ndarray.min(labelimg)
341 | # you can do what you want here for for your label img
342 |
343 | # matSPET = (labelimg - minSPET)/(maxPercentSPET-minSPET)
344 | # print 'spet: maxV,',np.ndarray.max(matSPET),' minV, ',np.ndarray.min(matSPET), ' meanV: ',np.mean(matSPET), ' stdV: ', np.std(matSPET)
345 |
346 | sdir = filename.split('/')
347 | print 'sdir is, ', sdir, 'and s6 is, ', sdir[len(sdir)-1]
348 | lpet_fn = sdir[len(sdir)-1]
349 | words = lpet_fn.split('_')
350 | print 'words are, ', words
351 | # ind = int(words[0])
352 |
353 | fileID = words[0]
354 | rate = 1
355 | cubicCnt = extractPatch4OneSubject(matSource, matTarget, maskimg, fileID, dSeg, step, rate)
356 | # cubicCnt = extractPatch4OneSubject(mrnp, matCT, hpetnp, maskimg, fileID,dSeg,step,rate)
357 | print '# of patches is ', cubicCnt
358 |
359 | # reverse along the 1st dimension
360 | rmatSource = matSource[matSource.shape[0] - 1::-1, :, :]
361 | rmatTarget = matTarget[matTarget.shape[0] - 1::-1, :, :]
362 |
363 | rmaskimg = maskimg[maskimg.shape[0] - 1::-1, :, :]
364 | fileID = words[0] + 'r'
365 | cubicCnt = extractPatch4OneSubject(rmatSource, rmatTarget, rmaskimg, fileID, dSeg, step, rate)
366 | print '# of patches is ', cubicCnt
367 |
368 |
369 | if __name__ == '__main__':
370 | main()
371 |
--------------------------------------------------------------------------------
/loss_functions.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import os
3 | import SimpleITK as sitk
4 | import torch.nn as nn
5 | import numpy as np
6 | import torch.optim as optim
7 | import torch
8 | import torch.nn.init
9 | from torch.autograd import Variable
10 |
11 |
12 | def gdl_loss(gen_CT, gt_CT, alpha, batch_size_tf):
13 | """
14 | Calculates the sum of GDL losses between the predicted and ground truth frames.
15 |
16 | @param gen_frames: The predicted frames at each scale.
17 | @param gt_frames: The ground truth frames at each scale
18 | @param alpha: The power to which each gradient term is raised.
19 |
20 | @return: The GDL loss.
21 | """
22 | # calculate the loss for each scale
23 |
24 | # create filters [-1, 1] and [[1],[-1]] for diffing to the left and down respectively.
25 | pos = tf.constant(np.identity(1), dtype=tf.float32)
26 | neg = -1 * pos
27 | filter_x = tf.expand_dims(tf.pack([neg, pos]), 0) # [-1, 1]
28 | filter_y = tf.pack([tf.expand_dims(pos, 0), tf.expand_dims(neg, 0)]) # [[1],[-1]]
29 | strides = [1, 1, 1, 1] # stride of (1, 1)
30 | padding = 'SAME'
31 |
32 | gen_dx = tf.abs(tf.nn.conv2d(gen_CT, filter_x, strides, padding=padding))
33 | gen_dy = tf.abs(tf.nn.conv2d(gen_CT, filter_y, strides, padding=padding))
34 | gt_dx = tf.abs(tf.nn.conv2d(gt_CT, filter_x, strides, padding=padding))
35 | gt_dy = tf.abs(tf.nn.conv2d(gt_CT, filter_y, strides, padding=padding))
36 |
37 | grad_diff_x = tf.abs(gt_dx - gen_dx)
38 | grad_diff_y = tf.abs(gt_dy - gen_dy)
39 |
40 | gdl=tf.reduce_sum((grad_diff_x ** alpha + grad_diff_y ** alpha))/tf.cast(batch_size_tf,tf.float32)
41 |
42 | # condense into one tensor and avg
43 | return gdl
44 |
--------------------------------------------------------------------------------
/runCTRecon.py:
--------------------------------------------------------------------------------
1 | # from __future__ import print_function
2 | import argparse, os
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 | from torch.autograd import Variable
6 | import numpy as np
7 | import torch.optim as optim
8 | import torch
9 | import torch.utils.data as data_utils
10 | from utils import *
11 | from Unet2d_pytorch import UNet, ResUNet, UNet_LRes, ResUNet_LRes, Discriminator
12 | from Unet3d_pytorch import UNet3D
13 | from nnBuildUnits import CrossEntropy3d, topK_RegLoss, RelativeThreshold_RegLoss, gdl_loss, adjust_learning_rate, calc_gradient_penalty
14 | import time
15 | import SimpleITK as sitk
16 |
17 | # Training settings
18 | parser = argparse.ArgumentParser(description="PyTorch InfantSeg")
19 | parser.add_argument("--gpuID", type=int, default=1, help="how to normalize the data")
20 | parser.add_argument("--isAdLoss", action="store_true", help="is adversarial loss used?", default=False)
21 | parser.add_argument("--isWDist", action="store_true", help="is adversarial loss with WGAN-GP distance?", default=False)
22 | parser.add_argument("--lambda_AD", default=0.05, type=float, help="weight for AD loss, Default: 0.05")
23 | parser.add_argument("--lambda_D_WGAN_GP", default=10, type=float, help="weight for gradient penalty of WGAN-GP, Default: 10")
24 | parser.add_argument("--how2normalize", type=int, default=6, help="how to normalize the data")
25 | parser.add_argument("--whichLoss", type=int, default=1, help="which loss to use: 1. LossL1, 2. lossRTL1, 3. MSE (default)")
26 | parser.add_argument("--isGDL", action="store_true", help="do we use GDL loss?", default=True)
27 | parser.add_argument("--gdlNorm", default=2, type=int, help="p-norm for the gdl loss, Default: 2")
28 | parser.add_argument("--lambda_gdl", default=0.05, type=float, help="Weight for gdl loss, Default: 0.05")
29 | parser.add_argument("--whichNet", type=int, default=4, help="which loss to use: 1. UNet, 2. ResUNet, 3. UNet_LRes and 4. ResUNet_LRes (default, 3)")
30 | parser.add_argument("--lossBase", type=int, default=1, help="The base to multiply the lossG_G, Default (1)")
31 | parser.add_argument("--batchSize", type=int, default=32, help="training batch size")
32 | parser.add_argument("--isMultiSource", action="store_true", help="is multiple modality used?", default=False)
33 | parser.add_argument("--numOfChannel_singleSource", type=int, default=5, help="# of channels for a 2D patch for the main modality (Default, 5)")
34 | parser.add_argument("--numOfChannel_allSource", type=int, default=5, help="# of channels for a 2D patch for all the concatenated modalities (Default, 5)")
35 | parser.add_argument("--numofIters", type=int, default=200000, help="number of iterations to train for")
36 | parser.add_argument("--showTrainLossEvery", type=int, default=100, help="number of iterations to show train loss")
37 | parser.add_argument("--saveModelEvery", type=int, default=5000, help="number of iterations to save the model")
38 | parser.add_argument("--showValPerformanceEvery", type=int, default=1000, help="number of iterations to show validation performance")
39 | parser.add_argument("--showTestPerformanceEvery", type=int, default=5000, help="number of iterations to show test performance")
40 | parser.add_argument("--lr", type=float, default=5e-3, help="Learning Rate. Default=1e-4")
41 | parser.add_argument("--lr_netD", type=float, default=5e-3, help="Learning Rate for discriminator. Default=5e-3")
42 | parser.add_argument("--dropout_rate", default=0.2, type=float, help="prob to drop neurons to zero: 0.2")
43 | parser.add_argument("--decLREvery", type=int, default=10000, help="Sets the learning rate to the initial LR decayed by momentum every n iterations, Default: n=40000")
44 | parser.add_argument("--lrDecRate", type=float, default=0.5, help="The weight for decreasing learning rate of netG Default=0.5")
45 | parser.add_argument("--lrDecRate_netD", type=float, default=0.1, help="The weight for decreasing learning rate of netD. Default=0.1")
46 | parser.add_argument("--cuda", action="store_true", help="Use cuda?", default=True)
47 | parser.add_argument("--resume", default="", type=str, help="Path to checkpoint (default: none)")
48 | parser.add_argument("--start_epoch", default=1, type=int, help="Manual epoch number (useful on restarts)")
49 | parser.add_argument("--threads", type=int, default=1, help="Number of threads for data loader to use, Default: 1")
50 | parser.add_argument("--momentum", default=0.9, type=float, help="Momentum, Default: 0.9")
51 | parser.add_argument("--weight-decay", "--wd", default=1e-4, type=float, help="weight decay, Default: 1e-4")
52 | parser.add_argument("--RT_th", default=0.005, type=float, help="Relative thresholding: 0.005")
53 | parser.add_argument("--pretrained", default="", type=str, help="path to pretrained model (default: none)")
54 | parser.add_argument("--prefixModelName", default="/home/niedong/Data4LowDosePET/pytorch_UNet/resunet2d_dp_pet_BatchAug_sNorm_lres_bn_lr5e3_lrdec_base1_lossL1_lossGDL0p05_0705_", type=str, help="prefix of the to-be-saved model name")
55 | parser.add_argument("--prefixPredictedFN", default="preSub1_pet_BatchAug_sNorm_resunet_dp_lres_bn_lr5e3_lrdec_base1_lossL1_lossGDL0p05_0705_", type=str, help="prefix of the to-be-saved predicted filename")
56 | parser.add_argument("--test_input_file_name",default='sub13_mr.hdr',type=str, help="the input file name for testing subject")
57 | parser.add_argument("--test_gt_file_name",default='sub13_ct.hdr',type=str, help="the ground-truth file name for testing subject")
58 |
59 | global opt, model
60 | opt = parser.parse_args()
61 |
62 | def main():
63 | print opt
64 |
65 | # prefixModelName = 'Regressor_1112_'
66 | # prefixPredictedFN = 'preSub1_1112_'
67 | # showTrainLossEvery = 100
68 | # lr = 1e-4
69 | # showTestPerformanceEvery = 2000
70 | # saveModelEvery = 2000
71 | # decLREvery = 40000
72 | # numofIters = 200000
73 | # how2normalize = 0
74 |
75 |
76 | netD = Discriminator()
77 | netD.apply(weights_init)
78 | netD.cuda()
79 |
80 | optimizerD = optim.Adam(netD.parameters(),lr=opt.lr_netD)
81 | criterion_bce=nn.BCELoss()
82 | criterion_bce.cuda()
83 |
84 | #net=UNet()
85 | if opt.whichNet==1:
86 | net = UNet(in_channel=opt.numOfChannel_allSource, n_classes=1)
87 | elif opt.whichNet==2:
88 | net = ResUNet(in_channel=opt.numOfChannel_allSource, n_classes=1)
89 | elif opt.whichNet==3:
90 | net = UNet_LRes(in_channel=opt.numOfChannel_allSource, n_classes=1)
91 | elif opt.whichNet==4:
92 | net = ResUNet_LRes(in_channel=opt.numOfChannel_allSource, n_classes=1, dp_prob = opt.dropout_rate)
93 | #net.apply(weights_init)
94 | net.cuda()
95 | params = list(net.parameters())
96 | print('len of params is ')
97 | print(len(params))
98 | print('size of params is ')
99 | print(params[0].size())
100 |
101 |
102 |
103 | optimizer = optim.Adam(net.parameters(),lr=opt.lr)
104 | criterion_L2 = nn.MSELoss()
105 | criterion_L1 = nn.L1Loss()
106 | criterion_RTL1 = RelativeThreshold_RegLoss(opt.RT_th)
107 | criterion_gdl = gdl_loss(opt.gdlNorm)
108 | #criterion = nn.CrossEntropyLoss()
109 | # criterion = nn.NLLLoss2d()
110 |
111 | given_weight = torch.cuda.FloatTensor([1,4,4,2])
112 |
113 | criterion_3d = CrossEntropy3d(weight=given_weight)
114 |
115 | criterion_3d = criterion_3d.cuda()
116 | criterion_L2 = criterion_L2.cuda()
117 | criterion_L1 = criterion_L1.cuda()
118 | criterion_RTL1 = criterion_RTL1.cuda()
119 | criterion_gdl = criterion_gdl.cuda()
120 |
121 | #inputs=Variable(torch.randn(1000,1,32,32)) #here should be tensor instead of variable
122 | #targets=Variable(torch.randn(1000,10,1,1)) #here should be tensor instead of variable
123 | # trainset=data_utils.TensorDataset(inputs, targets)
124 | # trainloader = data_utils.DataLoader(trainset, batch_size=4, shuffle=True, num_workers=2)
125 | # inputs=torch.randn(1000,1,32,32)
126 | # targets=torch.LongTensor(1000)
127 |
128 | path_test ='/home/niedong/DataCT/data_niigz/'
129 | path_patients_h5 = '/home/niedong/DataCT/h5Data_snorm/trainBatch2D_H5'
130 | path_patients_h5_val ='/home/niedong/DataCT/h5Data_snorm/valBatch2D_H5'
131 | # batch_size=10
132 | #data_generator = Generator_2D_slices(path_patients_h5,opt.batchSize,inputKey='data3T',outputKey='data7T')
133 | #data_generator_test = Generator_2D_slices(path_patients_h5_test,opt.batchSize,inputKey='data3T',outputKey='data7T')
134 | if opt.isMultiSource:
135 | data_generator = Generator_2D_slicesV1(path_patients_h5,opt.batchSize, inputKey='dataLPET', segKey='dataCT', contourKey='dataHPET')
136 | data_generator_test = Generator_2D_slicesV1(path_patients_h5_val, opt.batchSize, inputKey='dataLPET', segKey='dataCT', contourKey='dataHPET')
137 | else:
138 | data_generator = Generator_2D_slices(path_patients_h5,opt.batchSize,inputKey='dataMR',outputKey='dataCT')
139 | data_generator_test = Generator_2D_slices(path_patients_h5_val,opt.batchSize,inputKey='dataMR',outputKey='dataCT')
140 |
141 | #data_generator = Generator_2D_slicesV1(path_patients_h5,opt.batchSize, inputKey='dataLPET', segKey='dataCT', contourKey='dataHPET')
142 | #data_generator_test = Generator_2D_slicesV1(path_patients_h5_val, opt.batchSize, inputKey='dataLPET', segKey='dataCT', contourKey='dataHPET')
143 | if opt.resume:
144 | if os.path.isfile(opt.resume):
145 | print("=> loading checkpoint '{}'".format(opt.resume))
146 | checkpoint = torch.load(opt.resume)
147 | net.load_state_dict(checkpoint['model'])
148 | opt.start_epoch = 100000
149 | opt.start_epoch = checkpoint["epoch"] + 1
150 | # net.load_state_dict(checkpoint["model"].state_dict())
151 | else:
152 | print("=> no checkpoint found at '{}'".format(opt.resume))
153 | ########### We'd better use dataloader to load a lot of data,and we also should train several epoches###############
154 | ########### We'd better use dataloader to load a lot of data,and we also should train several epoches###############
155 | running_loss = 0.0
156 | start = time.time()
157 | for iter in range(opt.start_epoch, opt.numofIters+1):
158 | #print('iter %d'%iter)
159 | #print('iter %d'%iter)
160 | if opt.isMultiSource:
161 | inputs, exinputs, labels = data_generator.next()
162 | else:
163 | inputs, labels = data_generator.next()
164 | exinputs = inputs
165 | # inputs, exinputs, labels = data_generator.next()
166 |
167 | # xx = np.transpose(inputs,(5,64,64))
168 | # inputs = np.transpose(inputs,(0,3,1,2))
169 | inputs = np.squeeze(inputs) #5x64x64
170 | # exinputs = np.transpose(exinputs,(0,3,1,2))
171 | exinputs = np.squeeze(exinputs) #5x64x64
172 | # print 'shape is ....',inputs.shape
173 | labels = np.squeeze(labels) #64x64
174 | # labels = labels.astype(int)
175 |
176 | inputs = inputs.astype(float)
177 | inputs = torch.from_numpy(inputs)
178 | inputs = inputs.float()
179 | exinputs = exinputs.astype(float)
180 | exinputs = torch.from_numpy(exinputs)
181 | exinputs = exinputs.float()
182 | labels = labels.astype(float)
183 | labels = torch.from_numpy(labels)
184 | labels = labels.float()
185 | #print type(inputs), type(exinputs)
186 | if opt.isMultiSource:
187 | source = torch.cat((inputs, exinputs),dim=1)
188 | else:
189 | source = inputs
190 | #source = inputs
191 | mid_slice = opt.numOfChannel_singleSource//2
192 | residual_source = inputs[:, mid_slice, ...]
193 | #inputs = inputs.cuda()
194 | #exinputs = exinputs.cuda()
195 | source = source.cuda()
196 | residual_source = residual_source.cuda()
197 | labels = labels.cuda()
198 | #we should consider different data to train
199 |
200 | #wrap them into Variable
201 | source, residual_source, labels = Variable(source),Variable(residual_source), Variable(labels)
202 | #inputs, exinputs, labels = Variable(inputs),Variable(exinputs), Variable(labels)
203 |
204 | ## (1) update D network: maximize log(D(x)) + log(1 - D(G(z)))
205 | if opt.isAdLoss:
206 | #outputG = net(source,residual_source) #5x64x64->1*64x64
207 | if opt.whichNet == 3 or opt.whichNet == 4:
208 | outputG = net(source, residual_source) # 5x64x64->1*64x64
209 | else:
210 | outputG = net(source) # 5x64x64->1*64x64
211 |
212 | if len(labels.size())==3:
213 | labels = labels.unsqueeze(1)
214 |
215 | outputD_real = netD(labels)
216 | outputD_real = F.sigmoid(outputD_real)
217 |
218 | if len(outputG.size())==3:
219 | outputG = outputG.unsqueeze(1)
220 |
221 | outputD_fake = netD(outputG)
222 | outputD_fake = F.sigmoid(outputD_fake)
223 | netD.zero_grad()
224 | batch_size = inputs.size(0)
225 | real_label = torch.ones(batch_size,1)
226 | real_label = real_label.cuda()
227 | #print(real_label.size())
228 | real_label = Variable(real_label)
229 | #print(outputD_real.size())
230 | loss_real = criterion_bce(outputD_real,real_label)
231 | loss_real.backward()
232 | #train with fake data
233 | fake_label = torch.zeros(batch_size,1)
234 | # fake_label = torch.FloatTensor(batch_size)
235 | # fake_label.data.resize_(batch_size).fill_(0)
236 | fake_label = fake_label.cuda()
237 | fake_label = Variable(fake_label)
238 | loss_fake = criterion_bce(outputD_fake,fake_label)
239 | loss_fake.backward()
240 |
241 | lossD = loss_real + loss_fake
242 | # print 'loss_real is ',loss_real.data[0],'loss_fake is ',loss_fake.data[0],'outputD_real is',outputD_real.data[0]
243 | # print('loss for discriminator is %f'%lossD.data[0])
244 | #update network parameters
245 | optimizerD.step()
246 |
247 | if opt.isWDist:
248 | one = torch.FloatTensor([1])
249 | mone = one * -1
250 | one = one.cuda()
251 | mone = mone.cuda()
252 |
253 | netD.zero_grad()
254 |
255 | #outputG = net(source,residual_source) #5x64x64->1*64x64
256 | if opt.whichNet == 3 or opt.whichNet == 4:
257 | outputG = net(source, residual_source) # 5x64x64->1*64x64
258 | else:
259 | outputG = net(source) # 5x64x64->1*64x64
260 |
261 | if len(labels.size())==3:
262 | labels = labels.unsqueeze(1)
263 |
264 | outputD_real = netD(labels)
265 |
266 | if len(outputG.size())==3:
267 | outputG = outputG.unsqueeze(1)
268 |
269 | outputD_fake = netD(outputG)
270 |
271 |
272 | batch_size = inputs.size(0)
273 |
274 | D_real = outputD_real.mean()
275 | # print D_real
276 | D_real.backward(mone)
277 |
278 |
279 | D_fake = outputD_fake.mean()
280 | D_fake.backward(one)
281 |
282 | gradient_penalty = opt.lambda_D_WGAN_GP*calc_gradient_penalty(netD, labels.data, outputG.data)
283 | gradient_penalty.backward()
284 |
285 | D_cost = D_fake - D_real + gradient_penalty
286 | Wasserstein_D = D_real - D_fake
287 |
288 | optimizerD.step()
289 |
290 |
291 | ## (2) update G network: minimize the L1/L2 loss, maximize the D(G(x))
292 |
293 | # print inputs.data.shape
294 | #outputG = net(source) #here I am not sure whether we should use twice or not
295 | if opt.whichNet == 3 or opt.whichNet == 4:
296 | outputG = net(source, residual_source) # 5x64x64->1*64x64
297 | else:
298 | outputG = net(source) # 5x64x64->1*64x64
299 | #outputG = net(source,residual_source) #5x64x64->1*64x64
300 | net.zero_grad()
301 | if opt.whichLoss==1:
302 | lossG_G = criterion_L1(torch.squeeze(outputG), torch.squeeze(labels))
303 | elif opt.whichLoss==2:
304 | lossG_G = criterion_RTL1(torch.squeeze(outputG), torch.squeeze(labels))
305 | else:
306 | lossG_G = criterion_L2(torch.squeeze(outputG), torch.squeeze(labels))
307 | lossG_G = opt.lossBase * lossG_G
308 | lossG_G.backward(retain_graph=True) #compute gradients
309 |
310 | if opt.isGDL:
311 | lossG_gdl = opt.lambda_gdl * criterion_gdl(outputG,torch.unsqueeze(torch.squeeze(labels,1),1))
312 | lossG_gdl.backward() #compute gradients
313 |
314 | if opt.isAdLoss:
315 | #we want to fool the discriminator, thus we pretend the label here to be real. Actually, we can explain from the
316 | #angel of equation (note the max and min difference for generator and discriminator)
317 | #outputG = net(inputs)
318 | #outputG = net(source,residual_source) #5x64x64->1*64x64
319 | if opt.whichNet == 3 or opt.whichNet == 4:
320 | outputG = net(source, residual_source) # 5x64x64->1*64x64
321 | else:
322 | outputG = net(source) # 5x64x64->1*64x64
323 |
324 | if len(outputG.size())==3:
325 | outputG = outputG.unsqueeze(1)
326 |
327 | outputD = netD(outputG)
328 | outputD = F.sigmoid(outputD)
329 | lossG_D = opt.lambda_AD*criterion_bce(outputD,real_label) #note, for generator, the label for outputG is real, because the G wants to confuse D
330 | lossG_D.backward()
331 |
332 | if opt.isWDist:
333 | #we want to fool the discriminator, thus we pretend the label here to be real. Actually, we can explain from the
334 | #angel of equation (note the max and min difference for generator and discriminator)
335 | #outputG = net(inputs)
336 | #outputG = net(source,residual_source) #5x64x64->1*64x64
337 | if opt.whichNet == 3 or opt.whichNet == 4:
338 | outputG = net(source, residual_source) # 5x64x64->1*64x64
339 | else:
340 | outputG = net(source) # 5x64x64->1*64x64
341 | if len(outputG.size())==3:
342 | outputG = outputG.unsqueeze(1)
343 |
344 | outputD_fake = netD(outputG)
345 |
346 | outputD_fake = outputD_fake.mean()
347 |
348 | lossG_D = opt.lambda_AD*outputD_fake.mean() #note, for generator, the label for outputG is real, because the G wants to confuse D
349 | lossG_D.backward(mone)
350 |
351 | #for other losses, we can define the loss function following the pytorch tutorial
352 |
353 | optimizer.step() #update network parameters
354 |
355 | #print('loss for generator is %f'%lossG.data[0])
356 | #print statistics
357 | running_loss = running_loss + lossG_G.data[0]
358 |
359 |
360 | if iter%opt.showTrainLossEvery==0: #print every 2000 mini-batches
361 | print '************************************************'
362 | print 'time now is: ' + time.asctime(time.localtime(time.time()))
363 | # print 'running loss is ',running_loss
364 | print 'average running loss for generator between iter [%d, %d] is: %.5f'%(iter - 100 + 1,iter,running_loss/100)
365 |
366 | print 'lossG_G is %.5f respectively.'%(lossG_G.data[0])
367 |
368 | if opt.isGDL:
369 | print 'loss for GDL loss is %f'%lossG_gdl.data[0]
370 |
371 | if opt.isAdLoss:
372 | print 'loss_real is ',loss_real.data[0],'loss_fake is ',loss_fake.data[0],'outputD_real is',outputD_real.data[0]
373 | print 'loss for discriminator is %f'%lossD.data[0]
374 | print 'lossG_D for discriminator is %f'%lossG_D.data[0]
375 |
376 | if opt.isWDist:
377 | print 'loss_real is ',torch.mean(D_real).data[0],'loss_fake is ',torch.mean(D_fake).data[0]
378 | print('loss for discriminator is %f'%Wasserstein_D.data[0], ' D cost is %f'%D_cost)
379 | print 'lossG_D for discriminator is %f'%lossG_D.data[0]
380 |
381 | print 'cost time for iter [%d, %d] is %.2f'%(iter - 100 + 1,iter, time.time()-start)
382 | print '************************************************'
383 | running_loss = 0.0
384 | start = time.time()
385 | if iter%opt.saveModelEvery==0: #save the model
386 | state = {
387 | 'epoch': iter+1,
388 | 'model': net.state_dict()
389 | }
390 | torch.save(state, opt.prefixModelName+'%d.pt'%iter)
391 | print 'save model: '+opt.prefixModelName+'%d.pt'%iter
392 |
393 | if opt.isAdLoss or opt.isWDist:
394 | torch.save(netD.state_dict(), opt.prefixModelName+'_net_D%d.pt'%iter)
395 | if iter%opt.decLREvery==0:
396 | opt.lr = opt.lr*opt.lrDecRate
397 | adjust_learning_rate(optimizer, opt.lr)
398 | if opt.isAdLoss or opt.isWDist:
399 | opt.lr_netD = opt.lr_netD*opt.lrDecRate_netD
400 | adjust_learning_rate(optimizerD, opt.lr_netD)
401 |
402 |
403 | if iter%opt.showValPerformanceEvery==0: #test one subject
404 | # to test on the validation dataset in the format of h5
405 | # inputs,exinputs,labels = data_generator_test.next()
406 | if opt.isMultiSource:
407 | inputs, exinputs, labels = data_generator.next()
408 | else:
409 | inputs, labels = data_generator.next()
410 | exinputs = inputs
411 |
412 | # inputs = np.transpose(inputs,(0,3,1,2))
413 | inputs = np.squeeze(inputs)
414 |
415 | # exinputs = np.transpose(exinputs, (0, 3, 1, 2))
416 | exinputs = np.squeeze(exinputs) # 5x64x64
417 |
418 | labels = np.squeeze(labels)
419 |
420 | inputs = torch.from_numpy(inputs)
421 | inputs = inputs.float()
422 | exinputs = torch.from_numpy(exinputs)
423 | exinputs = exinputs.float()
424 | labels = torch.from_numpy(labels)
425 | labels = labels.float()
426 | mid_slice = opt.numOfChannel_singleSource // 2
427 | residual_source = inputs[:, mid_slice, ...]
428 | if opt.isMultiSource:
429 | source = torch.cat((inputs, exinputs), dim=1)
430 | else:
431 | source = inputs
432 | source = source.cuda()
433 | residual_source = residual_source.cuda()
434 | labels = labels.cuda()
435 | source,residual_source,labels = Variable(source),Variable(residual_source), Variable(labels)
436 |
437 | # source = inputs
438 | #outputG = net(inputs)
439 | #outputG = net(source,residual_source) #5x64x64->1*64x64
440 | if opt.whichNet == 3 or opt.whichNet == 4:
441 | outputG = net(source, residual_source) # 5x64x64->1*64x64
442 | else:
443 | outputG = net(source) # 5x64x64->1*64x64
444 | if opt.whichLoss == 1:
445 | lossG_G = criterion_L1(torch.squeeze(outputG), torch.squeeze(labels))
446 | elif opt.whichLoss == 2:
447 | lossG_G = criterion_RTL1(torch.squeeze(outputG), torch.squeeze(labels))
448 | else:
449 | lossG_G = criterion_L2(torch.squeeze(outputG), torch.squeeze(labels))
450 | lossG_G = opt.lossBase * lossG_G
451 | print '.......come to validation stage: iter {}'.format(iter),'........'
452 | print 'lossG_G is %.5f.'%(lossG_G.data[0])
453 |
454 | if opt.isGDL:
455 | lossG_gdl = criterion_gdl(outputG, torch.unsqueeze(torch.squeeze(labels,1),1))
456 | print 'loss for GDL loss is %f'%lossG_gdl.data[0]
457 |
458 | if iter % opt.showTestPerformanceEvery == 0: # test one subject
459 | mr_test_itk=sitk.ReadImage(os.path.join(path_test,opt.test_input_file_name))
460 | ct_test_itk=sitk.ReadImage(os.path.join(path_test,opt.test_input_file_name))
461 | hpet_test_itk = sitk.ReadImage(os.path.join(path_test, opt.test_gt_file_name))
462 | #mr_test_itk=sitk.ReadImage(os.path.join(path_test,'sub1_sourceCT.nii.gz'))
463 | #ct_test_itk=sitk.ReadImage(os.path.join(path_test,'sub1_extraCT.nii.gz'))
464 | #hpet_test_itk = sitk.ReadImage(os.path.join(path_test, 'sub1_targetCT.nii.gz'))
465 |
466 | spacing = hpet_test_itk.GetSpacing()
467 | origin = hpet_test_itk.GetOrigin()
468 | direction = hpet_test_itk.GetDirection()
469 |
470 | mrnp=sitk.GetArrayFromImage(mr_test_itk)
471 | ctnp=sitk.GetArrayFromImage(ct_test_itk)
472 | hpetnp=sitk.GetArrayFromImage(hpet_test_itk)
473 |
474 | ##### specific normalization #####
475 | # mu = np.mean(mrnp)
476 | # maxV, minV = np.percentile(mrnp, [99 ,25])
477 | # #mrimg=mrimg
478 | # mrnp = (mrnp-minV)/(maxV-minV)
479 |
480 |
481 |
482 | #for training data in pelvicSeg
483 | if opt.how2normalize == 1:
484 | maxV, minV = np.percentile(mrnp, [99 ,1])
485 | print 'maxV,',maxV,' minV, ',minV
486 | mrnp = (mrnp-mu)/(maxV-minV)
487 | print 'unique value: ',np.unique(ctnp)
488 |
489 | #for training data in pelvicSeg
490 | if opt.how2normalize == 2:
491 | maxV, minV = np.percentile(mrnp, [99 ,1])
492 | print 'maxV,',maxV,' minV, ',minV
493 | mrnp = (mrnp-mu)/(maxV-minV)
494 | print 'unique value: ',np.unique(ctnp)
495 |
496 | #for training data in pelvicSegRegH5
497 | if opt.how2normalize== 3:
498 | std = np.std(mrnp)
499 | mrnp = (mrnp - mu)/std
500 | print 'maxV,',np.ndarray.max(mrnp),' minV, ',np.ndarray.min(mrnp)
501 |
502 | if opt.how2normalize == 4:
503 | maxLPET = 149.366742
504 | maxPercentLPET = 7.76
505 | minLPET = 0.00055037
506 | meanLPET = 0.27593288
507 | stdLPET = 0.75747500
508 |
509 | # for rsCT
510 | maxCT = 27279
511 | maxPercentCT = 1320
512 | minCT = -1023
513 | meanCT = -601.1929
514 | stdCT = 475.034
515 |
516 | # for s-pet
517 | maxSPET = 156.675962
518 | maxPercentSPET = 7.79
519 | minSPET = 0.00055037
520 | meanSPET = 0.284224789
521 | stdSPET = 0.7642257
522 |
523 | #matLPET = (mrnp - meanLPET) / (stdLPET)
524 | matLPET = (mrnp - minLPET) / (maxPercentLPET - minLPET)
525 | matCT = (ctnp - meanCT) / stdCT
526 | matSPET = (hpetnp - minSPET) / (maxPercentSPET - minSPET)
527 |
528 | if opt.how2normalize == 5:
529 | # for rsCT
530 | maxCT = 27279
531 | maxPercentCT = 1320
532 | minCT = -1023
533 | meanCT = -601.1929
534 | stdCT = 475.034
535 |
536 | print
537 | 'ct, max: ', np.amax(ctnp), ' ct, min: ', np.amin(ctnp)
538 |
539 | # matLPET = (mrnp - meanLPET) / (stdLPET)
540 | matLPET = mrnp
541 | matCT = (ctnp - meanCT) / stdCT
542 | matSPET = hpetnp
543 |
544 | if opt.how2normalize == 6:
545 | maxPercentPET, minPercentPET = np.percentile(mrnp, [99.5, 0])
546 | maxPercentCT, minPercentCT = np.percentile(ctnp, [99.5, 0])
547 | print 'maxPercentPET: ', maxPercentPET, ' minPercentPET: ', minPercentPET, ' maxPercentCT: ', maxPercentCT, 'minPercentCT: ', minPercentCT
548 |
549 | matLPET = (mrnp - minPercentPET) / (maxPercentPET - minPercentPET)
550 | matSPET = (hpetnp - minPercentPET) / (maxPercentPET - minPercentPET)
551 |
552 | matCT = (ctnp - minPercentCT) / (maxPercentCT - minPercentCT)
553 |
554 |
555 | if not opt.isMultiSource:
556 | matFA = matLPET
557 | matGT = hpetnp
558 |
559 | print 'matFA shape: ',matFA.shape, ' matGT shape: ', matGT.shape
560 | matOut = testOneSubject_aver_res(matFA,matGT,[5,64,64],[1,64,64],[1,32,32],net,opt.prefixModelName+'%d.pt'%iter)
561 | print 'matOut shape: ',matOut.shape
562 | if opt.how2normalize==6:
563 | ct_estimated = matOut * (maxPercentPET - minPercentPET) + minPercentPET
564 | else:
565 | ct_estimated = matOut
566 |
567 |
568 | itspsnr = psnr(ct_estimated, matGT)
569 |
570 | print 'pred: ',ct_estimated.dtype, ' shape: ',ct_estimated.shape
571 | print 'gt: ',ctnp.dtype,' shape: ',ct_estimated.shape
572 | print 'psnr = ',itspsnr
573 | volout = sitk.GetImageFromArray(ct_estimated)
574 | volout.SetSpacing(spacing)
575 | volout.SetOrigin(origin)
576 | volout.SetDirection(direction)
577 | sitk.WriteImage(volout,opt.prefixPredictedFN+'{}'.format(iter)+'.nii.gz')
578 | else:
579 | matFA = matLPET
580 | matGT = hpetnp
581 | print 'matFA shape: ', matFA.shape, ' matGT shape: ', matGT.shape
582 | matOut = testOneSubject_aver_res_multiModal(matFA, matCT, matGT, [5, 64, 64], [1, 64, 64], [1, 32, 32], net,
583 | opt.prefixModelName + '%d.pt' % iter)
584 | print 'matOut shape: ', matOut.shape
585 | if opt.how2normalize==6:
586 | ct_estimated = matOut * (maxPercentPET - minPercentPET) + minPercentPET
587 | else:
588 | ct_estimated = matOut
589 |
590 | itspsnr = psnr(ct_estimated, matGT)
591 |
592 | print 'pred: ', ct_estimated.dtype, ' shape: ', ct_estimated.shape
593 | print 'gt: ', ctnp.dtype, ' shape: ', ct_estimated.shape
594 | print 'psnr = ', itspsnr
595 | volout = sitk.GetImageFromArray(ct_estimated)
596 | volout.SetSpacing(spacing)
597 | volout.SetOrigin(origin)
598 | volout.SetDirection(direction)
599 | sitk.WriteImage(volout, opt.prefixPredictedFN + '{}'.format(iter) + '.nii.gz')
600 |
601 | print('Finished Training')
602 |
603 | if __name__ == '__main__':
604 | os.environ['CUDA_VISIBLE_DEVICES'] = str(opt.gpuID)
605 | main()
606 |
607 |
--------------------------------------------------------------------------------
/runCTRecon3d.py:
--------------------------------------------------------------------------------
1 | # from __future__ import print_function
2 | import argparse, os
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 | from torch.autograd import Variable
6 | import numpy as np
7 | import torch.optim as optim
8 | import torch
9 | import torch.utils.data as data_utils
10 | from utils import *
11 | from ResUnet3d_pytorch import UNet, ResUNet, UNet_LRes, ResUNet_LRes, Discriminator
12 | # from Unet3d_pytorch import UNet3D
13 | from nnBuildUnits import CrossEntropy3d, topK_RegLoss, RelativeThreshold_RegLoss, adjust_learning_rate
14 | import time
15 | import SimpleITK as sitk
16 |
17 | # Training settings
18 | parser = argparse.ArgumentParser(description="PyTorch InfantSeg")
19 | parser.add_argument("--gpuID", type=int, default=3, help="how to normalize the data")
20 | parser.add_argument("--isAdLoss", action="store_true", help="is adversarial loss used?", default=False)
21 | parser.add_argument("--lambda_AD", default=0.05, type=float, help="Momentum, Default: 0.05")
22 | parser.add_argument("--how2normalize", type=int, default=5, help="how to normalize the data")
23 | parser.add_argument("--whichLoss", type=int, default=1, help="which loss to use: 1. LossL1, 2. lossRTL1, 3. MSE (default)")
24 | parser.add_argument("--whichNet", type=int, default=4, help="which loss to use: 1. UNet, 2. ResUNet, 3. UNet_LRes and 4. ResUNet_LRes (default, 3)")
25 | parser.add_argument("--lossBase", type=int, default=1, help="The base to multiply the lossG_G, Default (1)")
26 | parser.add_argument("--batchSize", type=int, default=10, help="training batch size")
27 | parser.add_argument("--isMultiSource", action="store_true", help="is multiple modality used?", default=False)
28 | parser.add_argument("--numOfChannel_singleSource", type=int, default=5, help="# of channels for a 2D patch for the main modality (Default, 5)")
29 | parser.add_argument("--numOfChannel_allSource", type=int, default=1, help="# of channels for a 2D patch for all the concatenated modalities (Default, 5)")
30 | parser.add_argument("--numofIters", type=int, default=200000, help="number of iterations to train for")
31 | parser.add_argument("--showTrainLossEvery", type=int, default=100, help="number of iterations to show train loss")
32 | parser.add_argument("--saveModelEvery", type=int, default=5000, help="number of iterations to save the model")
33 | parser.add_argument("--showValPerformanceEvery", type=int, default=1000, help="number of iterations to show validation performance")
34 | parser.add_argument("--showTestPerformanceEvery", type=int, default=5000, help="number of iterations to show test performance")
35 | parser.add_argument("--lr", type=float, default=5e-3, help="Learning Rate. Default=1e-4")
36 | parser.add_argument("--dropout_rate", default=0.2, type=float, help="prob to drop neurons to zero: 0.2")
37 | parser.add_argument("--decLREvery", type=int, default=10000, help="Sets the learning rate to the initial LR decayed by momentum every n iterations, Default: n=40000")
38 | parser.add_argument("--cuda", action="store_true", help="Use cuda?", default=True)
39 | parser.add_argument("--resume", default="/home/niedong/Data4LowDosePET/pytorch_UNet/resunet3d_dp_pet_BatchAug_noNorm_lres_bn_lr5e3_lrdec_base1_lossL1_0p005_0627_5000.pt", type=str, help="Path to checkpoint (default: none)")
40 | parser.add_argument("--start_epoch", default=1, type=int, help="Manual epoch number (useful on restarts)")
41 | parser.add_argument("--threads", type=int, default=1, help="Number of threads for data loader to use, Default: 1")
42 | parser.add_argument("--momentum", default=0.9, type=float, help="Momentum, Default: 0.9")
43 | parser.add_argument("--weight-decay", "--wd", default=1e-4, type=float, help="weight decay, Default: 1e-4")
44 | parser.add_argument("--RT_th", default=0.005, type=float, help="Relative thresholding: 0.005")
45 | parser.add_argument("--pretrained", default="", type=str, help="path to pretrained model (default: none)")
46 | parser.add_argument("--prefixModelName", default="/home/niedong/Data4LowDosePET/pytorch_UNet/resunet3d_dp_pet_BatchAug_noNorm_lres_bn_lr5e3_lrdec_base1_lossL1_0p005_0627_", type=str, help="prefix of the to-be-saved model name")
47 | parser.add_argument("--prefixPredictedFN", default="preSub1_pet_BatchAug_noNorm_resunet3d_dp_lres_bn_lr5e3_lrdec_base1_lossL1_0p005_0627_", type=str, help="prefix of the to-be-saved predicted filename")
48 |
49 | global opt, model
50 | opt = parser.parse_args()
51 |
52 | def main():
53 | print opt
54 |
55 | # prefixModelName = 'Regressor_1112_'
56 | # prefixPredictedFN = 'preSub1_1112_'
57 | # showTrainLossEvery = 100
58 | # lr = 1e-4
59 | # showTestPerformanceEvery = 2000
60 | # saveModelEvery = 2000
61 | # decLREvery = 40000
62 | # numofIters = 200000
63 | # how2normalize = 0
64 |
65 |
66 | netD = Discriminator()
67 | netD.apply(weights_init)
68 | netD.cuda()
69 |
70 | optimizerD = optim.Adam(netD.parameters(),lr=1e-3)
71 | criterion_bce=nn.BCELoss()
72 | criterion_bce.cuda()
73 |
74 | #net=UNet()
75 | if opt.whichNet==1:
76 | net = UNet(in_channel=opt.numOfChannel_allSource, n_classes=1)
77 | elif opt.whichNet==2:
78 | net = ResUNet(in_channel=opt.numOfChannel_allSource, n_classes=1)
79 | elif opt.whichNet==3:
80 | net = UNet_LRes(in_channel=opt.numOfChannel_allSource, n_classes=1)
81 | elif opt.whichNet==4:
82 | net = ResUNet_LRes(in_channel=opt.numOfChannel_allSource, n_classes=1, dp_prob = opt.dropout_rate)
83 | #net.apply(weights_init)
84 | net.cuda()
85 | params = list(net.parameters())
86 | print('len of params is ')
87 | print(len(params))
88 | print('size of params is ')
89 | print(params[0].size())
90 |
91 |
92 |
93 | optimizer = optim.Adam(net.parameters(),lr=opt.lr)
94 | criterion_L2 = nn.MSELoss()
95 | criterion_L1 = nn.L1Loss()
96 | criterion_RTL1 = RelativeThreshold_RegLoss(opt.RT_th)
97 | #criterion = nn.CrossEntropyLoss()
98 | # criterion = nn.NLLLoss2d()
99 |
100 | given_weight = torch.cuda.FloatTensor([1,4,4,2])
101 |
102 | criterion_3d = CrossEntropy3d(weight=given_weight)
103 |
104 | criterion_3d = criterion_3d.cuda()
105 | criterion_L2 = criterion_L2.cuda()
106 | criterion_L1 = criterion_L1.cuda()
107 | criterion_RTL1 = criterion_RTL1.cuda()
108 |
109 | #inputs=Variable(torch.randn(1000,1,32,32)) #here should be tensor instead of variable
110 | #targets=Variable(torch.randn(1000,10,1,1)) #here should be tensor instead of variable
111 | # trainset=data_utils.TensorDataset(inputs, targets)
112 | # trainloader = data_utils.DataLoader(trainset, batch_size=4, shuffle=True, num_workers=2)
113 | # inputs=torch.randn(1000,1,32,32)
114 | # targets=torch.LongTensor(1000)
115 |
116 | path_test ='/home/niedong/DataCT/data_niigz/'
117 | path_patients_h5 = '/home/niedong/DataCT/h5Data3D_noNorm/trainBatch3D_H5'
118 | path_patients_h5_test ='/home/niedong/DataCT/h5Data3D_noNorm/val3D_H5'
119 | # path_patients_h5_test ='/home/niedong/Data4LowDosePET/test2D_H5'
120 | # batch_size=10
121 | #data_generator = Generator_2D_slices(path_patients_h5,opt.batchSize,inputKey='data3T',outputKey='data7T')
122 | #data_generator_test = Generator_2D_slices(path_patients_h5_test,opt.batchSize,inputKey='data3T',outputKey='data7T')
123 |
124 | data_generator = Generator_3D_patches(path_patients_h5,opt.batchSize, inputKey='dataLPET', outputKey='dataHPET')
125 | data_generator_test = Generator_3D_patches(path_patients_h5_test,opt.batchSize, inputKey='dataLPET', outputKey='dataHPET')
126 | if opt.resume:
127 | if os.path.isfile(opt.resume):
128 | print("=> loading checkpoint '{}'".format(opt.resume))
129 | checkpoint = torch.load(opt.resume)
130 | net.load_state_dict(checkpoint['model'])
131 | opt.start_epoch = 100000
132 | opt.start_epoch = checkpoint["epoch"] - 1
133 | # net.load_state_dict(checkpoint["model"].state_dict())
134 | else:
135 | print("=> no checkpoint found at '{}'".format(opt.resume))
136 | ########### We'd better use dataloader to load a lot of data,and we also should train several epoches###############
137 | ########### We'd better use dataloader to load a lot of data,and we also should train several epoches###############
138 | running_loss = 0.0
139 | start = time.time()
140 | for iter in range(opt.start_epoch, opt.numofIters+1):
141 | #print('iter %d'%iter)
142 |
143 | # inputs, exinputs, labels = data_generator.next()
144 | inputs, labels = data_generator.next()
145 |
146 | # xx = np.transpose(inputs,(5,64,64))
147 | # print 'size of inputs: ', inputs.shape
148 | inputs = np.transpose(inputs,(0,4,1,2,3))
149 | # inputs = np.squeeze(inputs) #16x64x64
150 | # exinputs = np.squeeze(exinputs) #5x64x64
151 | # print 'shape is ....',inputs.shape
152 | labels = np.squeeze(labels) #64x64
153 | # labels = labels.astype(int)
154 |
155 | inputs = inputs.astype(float)
156 | inputs = torch.from_numpy(inputs)
157 | inputs = inputs.float()
158 | # exinputs = exinputs.astype(float)
159 | # exinputs = torch.from_numpy(exinputs)
160 | # exinputs = exinputs.float()
161 | labels = labels.astype(float)
162 | labels = torch.from_numpy(labels)
163 | labels = labels.float()
164 | #print type(inputs), type(exinputs)
165 | if opt.isMultiSource:
166 | # source = torch.cat((inputs, exinputs),dim=1)
167 | print 'you have to tune the multi source part'
168 | else:
169 | source = inputs
170 | #source = inputs
171 | # mid_slice = opt.numOfChannel_singleSource//2
172 | residual_source = inputs
173 | #inputs = inputs.cuda()
174 | #exinputs = exinputs.cuda()
175 | source = source.cuda()
176 | residual_source = residual_source.cuda()
177 | labels = labels.cuda()
178 | #we should consider different data to train
179 |
180 | #wrap them into Variable
181 | source, residual_source, labels = Variable(source),Variable(residual_source), Variable(labels)
182 | #inputs, exinputs, labels = Variable(inputs),Variable(exinputs), Variable(labels)
183 |
184 | ## (1) update D network: maximize log(D(x)) + log(1 - D(G(z)))
185 | if opt.isAdLoss:
186 | if opt.whichNet == 3 or opt.whichNet == 4:
187 | outputG = net(source, residual_source) # 5x64x64->1*64x64
188 | else:
189 | outputG = net(source) # 5x64x64->1*64x64
190 | #outputG = net(source,residual_source) #5x64x64->1*64x64
191 |
192 | if len(labels.size())==3:
193 | labels = labels.unsqueeze(1)
194 |
195 | outputD_real = netD(labels)
196 | outputD_real = F.sigmoid(outputD_real)
197 |
198 | if len(outputG.size())==3:
199 | outputG = outputG.unsqueeze(1)
200 |
201 | outputD_fake = netD(outputG)
202 | outputD_fake = F.sigmoid(outputD_fake)
203 | netD.zero_grad()
204 | batch_size = inputs.size(0)
205 | real_label = torch.ones(batch_size,1)
206 | real_label = real_label.cuda()
207 | #print(real_label.size())
208 | real_label = Variable(real_label)
209 | #print(outputD_real.size())
210 | loss_real = criterion_bce(outputD_real,real_label)
211 | loss_real.backward()
212 | #train with fake data
213 | fake_label = torch.zeros(batch_size,1)
214 | # fake_label = torch.FloatTensor(batch_size)
215 | # fake_label.data.resize_(batch_size).fill_(0)
216 | fake_label = fake_label.cuda()
217 | fake_label = Variable(fake_label)
218 | loss_fake = criterion_bce(outputD_fake,fake_label)
219 | loss_fake.backward()
220 |
221 | lossD = loss_real + loss_fake
222 | # print 'loss_real is ',loss_real.data[0],'loss_fake is ',loss_fake.data[0],'outputD_real is',outputD_real.data[0]
223 | # print('loss for discriminator is %f'%lossD.data[0])
224 | #update network parameters
225 | optimizerD.step()
226 |
227 |
228 | ## (2) update G network: minimize the L1/L2 loss, maximize the D(G(x))
229 |
230 | # print inputs.data.shape
231 | #outputG = net(source) #here I am not sure whether we should use twice or not
232 | if opt.whichNet == 3 or opt.whichNet == 4:
233 | outputG = net(source, residual_source) # 5x64x64->1*64x64
234 | else:
235 | outputG = net(source) # 5x64x64->1*64x64
236 | #outputG = net(source,residual_source) #5x64x64->1*64x64
237 | net.zero_grad()
238 | if opt.whichLoss==1:
239 | lossG_G = criterion_L1(torch.squeeze(outputG), torch.squeeze(labels))
240 | elif opt.whichLoss==2:
241 | lossG_G = criterion_RTL1(torch.squeeze(outputG), torch.squeeze(labels))
242 | else:
243 | lossG_G = criterion_L2(torch.squeeze(outputG), torch.squeeze(labels))
244 | lossG_G = opt.lossBase * lossG_G
245 | lossG_G.backward() #compute gradients
246 |
247 | if opt.isAdLoss:
248 | #we want to fool the discriminator, thus we pretend the label here to be real. Actually, we can explain from the
249 | #angel of equation (note the max and min difference for generator and discriminator)
250 | #outputG = net(inputs)
251 | #outputG = net(source,residual_source) #5x64x64->1*64x64
252 | if opt.whichNet == 3 or opt.whichNet == 4:
253 | outputG = net(source, residual_source) # 5x64x64->1*64x64
254 | else:
255 | outputG = net(source) # 5x64x64->1*64x64
256 |
257 | if len(outputG.size())==3:
258 | outputG = outputG.unsqueeze(1)
259 |
260 | outputD = netD(outputG)
261 | outputD = F.sigmoid(outputD)
262 | lossG_D = opt.lambda_AD*criterion_bce(outputD,real_label) #note, for generator, the label for outputG is real, because the G wants to confuse D
263 | lossG_D.backward()
264 |
265 | #for other losses, we can define the loss function following the pytorch tutorial
266 |
267 | optimizer.step() #update network parameters
268 |
269 | #print('loss for generator is %f'%lossG.data[0])
270 | #print statistics
271 | running_loss = running_loss + lossG_G.data[0]
272 |
273 |
274 | if iter%opt.showTrainLossEvery==0: #print every 2000 mini-batches
275 | print '************************************************'
276 | print 'time now is: ' + time.asctime(time.localtime(time.time()))
277 | # print 'running loss is ',running_loss
278 | print 'average running loss for generator between iter [%d, %d] is: %.5f'%(iter - 100 + 1,iter,running_loss/100)
279 |
280 | print 'lossG_G is %.5f respectively.'%(lossG_G.data[0])
281 | if opt.isAdLoss:
282 | print 'loss_real is ',loss_real.data[0],'loss_fake is ',loss_fake.data[0],'outputD_real is',outputD_real.data[0]
283 | print('loss for discriminator is %f'%lossD.data[0])
284 |
285 | print 'cost time for iter [%d, %d] is %.2f'%(iter - 100 + 1,iter, time.time()-start)
286 | print '************************************************'
287 | running_loss = 0.0
288 | start = time.time()
289 | if iter%opt.saveModelEvery==0: #save the model
290 | state = {
291 | 'epoch': iter+1,
292 | 'model': net.state_dict()
293 | }
294 | torch.save(state, opt.prefixModelName+'%d.pt'%iter)
295 | print 'save model: '+opt.prefixModelName+'%d.pt'%iter
296 |
297 | if opt.isAdLoss:
298 | torch.save(netD.state_dict(), opt.prefixModelName+'_net_D%d.pt'%iter)
299 | if iter%opt.decLREvery==0:
300 | opt.lr = opt.lr*0.5
301 | adjust_learning_rate(optimizer, opt.lr)
302 |
303 | if iter%opt.showValPerformanceEvery==0: #test one subject
304 | # to test on the validation dataset in the format of h5
305 | # inputs,exinputs,labels = data_generator_test.next()
306 | inputs, labels = data_generator_test.next()
307 |
308 | inputs = np.transpose(inputs,(0,4,1,2,3))
309 | # inputs = np.squeeze(inputs)
310 |
311 | # exinputs = np.transpose(exinputs, (0, 3, 1, 2))
312 | # exinputs = np.squeeze(exinputs) # 5x64x64
313 |
314 | labels = np.squeeze(labels)
315 |
316 | inputs = torch.from_numpy(inputs)
317 | inputs = inputs.float()
318 | # exinputs = torch.from_numpy(exinputs)
319 | # exinputs = exinputs.float()
320 | labels = torch.from_numpy(labels)
321 | labels = labels.float()
322 | # mid_slice = opt.numOfChannel_singleSource // 2
323 | residual_source = inputs
324 | if opt.isMultiSource:
325 | # source = torch.cat((inputs, exinputs), dim=1)
326 | print 'you have to tune the multi source part'
327 | else:
328 | source = inputs
329 | source = source.cuda()
330 | residual_source = residual_source.cuda()
331 | labels = labels.cuda()
332 | source,residual_source,labels = Variable(source),Variable(residual_source), Variable(labels)
333 |
334 | # source = inputs
335 | #outputG = net(inputs)
336 | if opt.whichNet == 3 or opt.whichNet == 4:
337 | outputG = net(source, residual_source) # 5x64x64->1*64x64
338 | else:
339 | outputG = net(source) # 5x64x64->1*64x64
340 | #outputG = net(source,residual_source) #5x64x64->1*64x64
341 |
342 | if opt.whichLoss == 1:
343 | lossG_G = criterion_L1(torch.squeeze(outputG), torch.squeeze(labels))
344 | elif opt.whichLoss == 2:
345 | lossG_G = criterion_RTL1(torch.squeeze(outputG), torch.squeeze(labels))
346 | else:
347 | lossG_G = criterion_L2(torch.squeeze(outputG), torch.squeeze(labels))
348 | lossG_G = opt.lossBase * lossG_G
349 | print '.......come to validation stage: iter {}'.format(iter),'........'
350 | print 'lossG_G is %.5f.'%(lossG_G.data[0])
351 |
352 | if iter % opt.showTestPerformanceEvery == 0: # test one subject
353 | mr_test_itk=sitk.ReadImage(os.path.join(path_test,'sub1_sourceCT.nii.gz'))
354 | ct_test_itk=sitk.ReadImage(os.path.join(path_test,'sub1_extraCT.nii.gz'))
355 | hpet_test_itk = sitk.ReadImage(os.path.join(path_test, 'sub1_targetCT.nii.gz'))
356 |
357 | spacing = hpet_test_itk.GetSpacing()
358 | origin = hpet_test_itk.GetOrigin()
359 | direction = hpet_test_itk.GetDirection()
360 |
361 | mrnp=sitk.GetArrayFromImage(mr_test_itk)
362 | ctnp=sitk.GetArrayFromImage(ct_test_itk)
363 | hpetnp=sitk.GetArrayFromImage(hpet_test_itk)
364 |
365 | ##### specific normalization #####
366 | # mu = np.mean(mrnp)
367 | # maxV, minV = np.percentile(mrnp, [99 ,25])
368 | # #mrimg=mrimg
369 | # mrnp = (mrnp-minV)/(maxV-minV)
370 |
371 |
372 |
373 | #for training data in pelvicSeg
374 | if opt.how2normalize == 1:
375 | maxV, minV = np.percentile(mrnp, [99 ,1])
376 | print 'maxV,',maxV,' minV, ',minV
377 | mrnp = (mrnp-mu)/(maxV-minV)
378 | print 'unique value: ',np.unique(ctnp)
379 |
380 | #for training data in pelvicSeg
381 | if opt.how2normalize == 2:
382 | maxV, minV = np.percentile(mrnp, [99 ,1])
383 | print 'maxV,',maxV,' minV, ',minV
384 | mrnp = (mrnp-mu)/(maxV-minV)
385 | print 'unique value: ',np.unique(ctnp)
386 |
387 | #for training data in pelvicSegRegH5
388 | if opt.how2normalize== 3:
389 | std = np.std(mrnp)
390 | mrnp = (mrnp - mu)/std
391 | print 'maxV,',np.ndarray.max(mrnp),' minV, ',np.ndarray.min(mrnp)
392 |
393 | if opt.how2normalize == 4:
394 | maxLPET = 149.366742
395 | maxPercentLPET = 7.76
396 | minLPET = 0.00055037
397 | meanLPET = 0.27593288
398 | stdLPET = 0.75747500
399 |
400 | # for rsCT
401 | maxCT = 27279
402 | maxPercentCT = 1320
403 | minCT = -1023
404 | meanCT = -601.1929
405 | stdCT = 475.034
406 |
407 | # for s-pet
408 | maxSPET = 156.675962
409 | maxPercentSPET = 7.79
410 | minSPET = 0.00055037
411 | meanSPET = 0.284224789
412 | stdSPET = 0.7642257
413 |
414 | #matLPET = (mrnp - meanLPET) / (stdLPET)
415 | matLPET = (mrnp - minLPET) / (maxPercentLPET - minLPET)
416 | matCT = (ctnp - meanCT) / stdCT
417 | matSPET = (hpetnp - minSPET) / (maxPercentSPET - minSPET)
418 |
419 | if opt.how2normalize == 5:
420 | # for rsCT
421 | maxCT = 27279
422 | maxPercentCT = 1320
423 | minCT = -1023
424 | meanCT = -601.1929
425 | stdCT = 475.034
426 |
427 | print
428 | 'ct, max: ', np.amax(ctnp), ' ct, min: ', np.amin(ctnp)
429 |
430 | # matLPET = (mrnp - meanLPET) / (stdLPET)
431 | matLPET = mrnp
432 | matCT = (ctnp - meanCT) / stdCT
433 | matSPET = hpetnp
434 |
435 | if not opt.isMultiSource:
436 | # matFA = matLPET
437 | # matGT = matSPET
438 | matFA = mrnp
439 | matGT = hpetnp
440 | print 'matFA shape: ',matFA.shape, ' matGT shape: ', matGT.shape
441 | matOut = testOneSubject_aver_res(matFA,matGT,[16,64,64],[16,64,64],[8,32,32],net,opt.prefixModelName+'%d.pt'%iter, nd=3)
442 | print 'matOut shape: ',matOut.shape
443 | ct_estimated = matOut
444 |
445 | itspsnr = psnr(ct_estimated, matGT)
446 |
447 | print 'pred: ',ct_estimated.dtype, ' shape: ',ct_estimated.shape
448 | print 'gt: ',ctnp.dtype,' shape: ',ct_estimated.shape
449 | print 'psnr = ',itspsnr
450 | volout = sitk.GetImageFromArray(ct_estimated)
451 | volout.SetSpacing(spacing)
452 | volout.SetOrigin(origin)
453 | volout.SetDirection(direction)
454 | sitk.WriteImage(volout,opt.prefixPredictedFN+'{}'.format(iter)+'.nii.gz')
455 | else:
456 | # matFA = matLPET
457 | # matGT = matSPET
458 | matFA = mrnp
459 | matGT = hpetnp
460 | print 'matFA shape: ', matFA.shape, ' matGT shape: ', matGT.shape
461 | matOut = testOneSubject_aver_res_multiModal(matFA, matCT, matGT, [16, 64, 64], [16, 64, 64], [8, 32, 32], net,
462 | opt.prefixModelName + '%d.pt' % iter)
463 | print 'matOut shape: ', matOut.shape
464 | ct_estimated = matOut
465 |
466 | itspsnr = psnr(ct_estimated, matGT)
467 |
468 | print 'pred: ', ct_estimated.dtype, ' shape: ', ct_estimated.shape
469 | print 'gt: ', ctnp.dtype, ' shape: ', ct_estimated.shape
470 | print 'psnr = ', itspsnr
471 | volout = sitk.GetImageFromArray(ct_estimated)
472 | volout.SetSpacing(spacing)
473 | volout.SetOrigin(origin)
474 | volout.SetDirection(direction)
475 | sitk.WriteImage(volout, opt.prefixPredictedFN + '{}'.format(iter) + '.nii.gz')
476 |
477 | print('Finished Training')
478 |
479 | if __name__ == '__main__':
480 | os.environ['CUDA_VISIBLE_DEVICES'] = str(opt.gpuID)
481 | main()
482 |
483 |
--------------------------------------------------------------------------------
/runTesting_Recon.py:
--------------------------------------------------------------------------------
1 | # from __future__ import print_function
2 | import argparse, os
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 | from torch.autograd import Variable
6 | import numpy as np
7 | import torch.optim as optim
8 | import torch
9 | import torch.utils.data as data_utils
10 | from Unet2d_pytorch import UNet, ResUNet, UNet_LRes, ResUNet_LRes, Discriminator
11 | from utils import *
12 | # from ganComponents import *
13 | # from nnBuildUnits import CrossEntropy2d
14 | # from nnBuildUnits import computeSampleAttentionWeight
15 | # from nnBuildUnits import adjust_learning_rate
16 | import time
17 | # from dataClean import denoiseImg,denoiseImg_isolation,denoiseImg_closing
18 | import SimpleITK as sitk
19 |
20 | parser = argparse.ArgumentParser(description="PyTorch InfantSeg")
21 |
22 | parser.add_argument("--isSegReg", action="store_true", help="is Seg and Reg?", default=False)
23 | parser.add_argument("--isMultiSource", action="store_true", help="is multiple input modality used?", default=False)
24 | parser.add_argument("--whichLoss", type=int, default=1, help="which loss to use: 1. LossL1, 2. lossRTL1, 3. MSE (default)")
25 | parser.add_argument("--whichNet", type=int, default=4, help="which loss to use: 1. UNet, 2. ResUNet, 3. UNet_LRes and 4. ResUNet_LRes (default, 3)")
26 | parser.add_argument("--lossBase", type=int, default=1, help="The base to multiply the lossG_G, Default (1)")
27 | parser.add_argument("--batchSize", type=int, default=32, help="training batch size")
28 | parser.add_argument("--numOfChannel_singleSource", type=int, default=5, help="# of channels for a 2D patch for the main modality (Default, 5)")
29 | parser.add_argument("--numOfChannel_allSource", type=int, default=5, help="# of channels for a 2D patch for all the concatenated modalities (Default, 5)")
30 | parser.add_argument("--isResidualEnhancement", action="store_true", help="is residual learning operation enhanced?", default=False)
31 | parser.add_argument("--isViewExpansion", action="store_true", help="is view expanded?", default=True)
32 | parser.add_argument("--isAdLoss", action="store_true", help="is adversarial loss used?", default=True)
33 | parser.add_argument("--isSpatialDropOut", action="store_true", help="is spatial dropout used?", default=False)
34 | parser.add_argument("--isFocalLoss", action="store_true", help="is focal loss used?", default=False)
35 | parser.add_argument("--isSampleImportanceFromAd", action="store_true", help="is sample importance from adversarial network used?", default=False)
36 | parser.add_argument("--dropoutRate", type=float, default=0.25, help="Spatial Dropout Rate. Default=0.25")
37 | parser.add_argument("--lambdaAD", type=float, default=0, help="loss coefficient for AD loss. Default=0")
38 | parser.add_argument("--adImportance", type=float, default=0, help="Sample importance from AD network. Default=0")
39 | parser.add_argument("--isFixedRegions", action="store_true", help="Is the organ regions roughly known?", default=False)
40 | #parser.add_argument("--modelPath", default="/home/niedong/Data4LowDosePET/pytorch_UNet/model/resunet2d_pet_Aug_noNorm_lres_bn_lr5e3_base1_lossL1_0p01_0624_200000.pt", type=str, help="prefix of the to-be-saved model name")
41 | parser.add_argument("--modelPath", default="/home/niedong/Data4LowDosePET/pytorch_UNet/model/resunet2d_dp_pet_BatchAug_sNorm_lres_bn_lr5e3_lrdec_base1_lossL1_0p005_0628_200000.pt", type=str, help="prefix of the to-be-saved model name")
42 | parser.add_argument("--prefixPredictedFN", default="pred_resunet2d_dp_pet_Aug_sNorm_lres_lrdce_bn_lr5e3_base1_lossL1_0628_20w_", type=str, help="prefix of the to-be-saved predicted filename")
43 | parser.add_argument("--how2normalize", type=int, default=6, help="how to normalize the data")
44 | parser.add_argument("--resType", type=int, default=2, help="resType: 0: segmentation map (integer); 1: regression map (continuous); 2: segmentation map + probability map")
45 |
46 | def main():
47 | opt = parser.parse_args()
48 | print opt
49 |
50 | path_test = '/home/niedong/Data4LowDosePET/data_niigz_scale/'
51 |
52 | if opt.whichNet==1:
53 | netG = UNet(in_channel=opt.numOfChannel_allSource, n_classes=1)
54 | elif opt.whichNet==2:
55 | netG = ResUNet(in_channel=opt.numOfChannel_allSource, n_classes=1)
56 | elif opt.whichNet==3:
57 | netG = UNet_LRes(in_channel=opt.numOfChannel_allSource, n_classes=1)
58 | elif opt.whichNet==4:
59 | netG = ResUNet_LRes(in_channel=opt.numOfChannel_allSource, n_classes=1)
60 |
61 | #netG.apply(weights_init)
62 | netG.cuda()
63 |
64 | checkpoint = torch.load(opt.modelPath)
65 | netG.load_state_dict(checkpoint['model'])
66 |
67 |
68 | ids = [1,6,11,16,21,26,31,36,41,46] #in on folder, we test 10 which is the testing set
69 | ids = [1] #in on folder, we test 10 which is the testing set
70 |
71 | ids = ['1_QFZ','2_LLQ','3_LMB','4_ZSL','5_CJB','11_TCL','15_WYL','21_PY','25_LYL','31_CZX','35_WLL','41_WQC','45_YXM']
72 | for ind in ids:
73 | start = time.time()
74 |
75 | mr_test_itk = sitk.ReadImage(os.path.join(path_test,'%s_60s_suv.nii.gz'%ind))#input modality
76 | ct_test_itk = sitk.ReadImage(os.path.join(path_test,'%s_rsCT.nii.gz'%ind))#auxialliary modality
77 | hpet_test_itk = sitk.ReadImage(os.path.join(path_test, '%s_120s_suv.nii.gz'%ind))#output modality
78 |
79 |
80 | spacing = hpet_test_itk.GetSpacing()
81 | origin = hpet_test_itk.GetOrigin()
82 | direction = hpet_test_itk.GetDirection()
83 |
84 | mrnp = sitk.GetArrayFromImage(mr_test_itk)
85 | ctnp = sitk.GetArrayFromImage(ct_test_itk)
86 | hpetnp = sitk.GetArrayFromImage(hpet_test_itk)
87 |
88 | ##### specific normalization #####
89 | # mu = np.mean(mrnp)
90 | # maxV, minV = np.percentile(mrnp, [99 ,25])
91 | # #mrimg=mrimg
92 | # mrnp = (mrnp-minV)/(maxV-minV)
93 |
94 | # for training data in pelvicSeg
95 | if opt.how2normalize == 1:
96 | maxV, minV = np.percentile(mrnp, [99, 1])
97 | print 'maxV,', maxV, ' minV, ', minV
98 | mrnp = (mrnp - mu) / (maxV - minV)
99 | print 'unique value: ', np.unique(ctnp)
100 |
101 | # for training data in pelvicSeg
102 | if opt.how2normalize == 2:
103 | maxV, minV = np.percentile(mrnp, [99, 1])
104 | print 'maxV,', maxV, ' minV, ', minV
105 | mrnp = (mrnp - mu) / (maxV - minV)
106 | print 'unique value: ', np.unique(ctnp)
107 |
108 | # for training data in pelvicSegRegH5
109 | if opt.how2normalize == 3:
110 | std = np.std(mrnp)
111 | mrnp = (mrnp - mu) / std
112 | print 'maxV,', np.ndarray.max(mrnp), ' minV, ', np.ndarray.min(mrnp)
113 |
114 | if opt.how2normalize == 4:
115 | maxLPET = 149.366742
116 | maxPercentLPET = 7.76
117 | minLPET = 0.00055037
118 | meanLPET = 0.27593288
119 | stdLPET = 0.75747500
120 |
121 | # for rsCT
122 | maxCT = 27279
123 | maxPercentCT = 1320
124 | minCT = -1023
125 | meanCT = -601.1929
126 | stdCT = 475.034
127 |
128 | # for s-pet
129 | maxSPET = 156.675962
130 | maxPercentSPET = 7.79
131 | minSPET = 0.00055037
132 | meanSPET = 0.284224789
133 | stdSPET = 0.7642257
134 |
135 | # matLPET = (mrnp - meanLPET) / (stdLPET)
136 | matLPET = (mrnp - minLPET) / (maxPercentLPET - minLPET)
137 | matCT = (ctnp - meanCT) / stdCT
138 | matSPET = (hpetnp - minSPET) / (maxPercentSPET - minSPET)
139 |
140 | if opt.how2normalize == 5:
141 | # for rsCT
142 | maxCT = 27279
143 | maxPercentCT = 1320
144 | minCT = -1023
145 | meanCT = -601.1929
146 | stdCT = 475.034
147 |
148 | print 'ct, max: ', np.amax(ctnp), ' ct, min: ', np.amin(ctnp)
149 |
150 | # matLPET = (mrnp - meanLPET) / (stdLPET)
151 | matLPET = mrnp
152 | matCT = (ctnp - meanCT) / stdCT
153 | matSPET = hpetnp
154 |
155 | if opt.how2normalize == 6:
156 | maxPercentPET, minPercentPET = np.percentile(mrnp, [99.5, 0])
157 | maxPercentCT, minPercentCT = np.percentile(ctnp, [99.5, 0])
158 | print 'maxPercentPET: ', maxPercentPET, ' minPercentPET: ', minPercentPET, ' maxPercentCT: ', maxPercentCT, 'minPercentCT: ', minPercentCT
159 |
160 | matLPET = (mrnp - minPercentPET) / (maxPercentPET - minPercentPET)
161 | matSPET = (hpetnp - minPercentPET) / (maxPercentPET - minPercentPET)
162 |
163 | matCT = (ctnp - minPercentCT) / (maxPercentCT - minPercentCT)
164 |
165 | if not opt.isMultiSource:
166 | matFA = matLPET
167 | matGT = hpetnp
168 |
169 | print 'matFA shape: ', matFA.shape, ' matGT shape: ', matGT.shape
170 | matOut = testOneSubject_aver_res(matFA, matGT, [5, 64, 64], [1, 64, 64], [1, 16, 16], netG, opt.modelPath)
171 | print 'matOut shape: ', matOut.shape
172 | if opt.how2normalize == 6:
173 | ct_estimated = matOut * (maxPercentPET - minPercentPET) + minPercentPET
174 | else:
175 | ct_estimated = matOut
176 | ct_estimated[np.where(mrnp==0)] = 0
177 | itspsnr = psnr(ct_estimated, matGT)
178 |
179 | print 'pred: ', ct_estimated.dtype, ' shape: ', ct_estimated.shape
180 | print 'gt: ', ctnp.dtype, ' shape: ', ct_estimated.shape
181 | print 'psnr = ', itspsnr
182 | volout = sitk.GetImageFromArray(ct_estimated)
183 | volout.SetSpacing(spacing)
184 | volout.SetOrigin(origin)
185 | volout.SetDirection(direction)
186 | sitk.WriteImage(volout, opt.prefixPredictedFN + '{}'.format(ind) + '.nii.gz')
187 | else:
188 | matFA = matLPET
189 | matGT = hpetnp
190 | print 'matFA shape: ', matFA.shape, ' matGT shape: ', matGT.shape
191 | matOut = testOneSubject_aver_res_multiModal(matFA, matCT, matGT, [5, 64, 64], [1, 64, 64], [1, 16, 16], netG, opt.modelPath)
192 | print 'matOut shape: ', matOut.shape
193 | if opt.how2normalize == 6:
194 | ct_estimated = matOut * (maxPercentPET - minPercentPET) + minPercentPET
195 | else:
196 | ct_estimated = matOut
197 |
198 | ct_estimated[np.where(mrnp==0)] = 0
199 | itspsnr = psnr(ct_estimated, matGT)
200 |
201 | print 'pred: ', ct_estimated.dtype, ' shape: ', ct_estimated.shape
202 | print 'gt: ', ctnp.dtype, ' shape: ', ct_estimated.shape
203 | print 'psnr = ', itspsnr
204 | volout = sitk.GetImageFromArray(ct_estimated)
205 | volout.SetSpacing(spacing)
206 | volout.SetOrigin(origin)
207 | volout.SetDirection(direction)
208 | sitk.WriteImage(volout, opt.prefixPredictedFN + '{}'.format(ind) + '.nii.gz')
209 |
210 | if __name__ == '__main__':
211 | # testGradients()
212 | os.environ['CUDA_VISIBLE_DEVICES'] = '0'
213 | main()
214 |
--------------------------------------------------------------------------------
/runTesting_Reconv2.py:
--------------------------------------------------------------------------------
1 | # from __future__ import print_function
2 | import argparse, os
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 | from torch.autograd import Variable
6 | import numpy as np
7 | import torch.optim as optim
8 | import torch
9 | import torch.utils.data as data_utils
10 | from Unet2d_pytorch import UNet, ResUNet, UNet_LRes, ResUNet_LRes, Discriminator
11 | from utils import *
12 | # from ganComponents import *
13 | # from nnBuildUnits import CrossEntropy2d
14 | # from nnBuildUnits import computeSampleAttentionWeight
15 | # from nnBuildUnits import adjust_learning_rate
16 | import time
17 | # from dataClean import denoiseImg,denoiseImg_isolation,denoiseImg_closing
18 | import SimpleITK as sitk
19 |
20 | parser = argparse.ArgumentParser(description="PyTorch InfantSeg")
21 |
22 | parser.add_argument("--gpuID", type=int, default=1, help="how to normalize the data")
23 | parser.add_argument("--isSegReg", action="store_true", help="is Seg and Reg?", default=False)
24 | parser.add_argument("--isMultiSource", action="store_true", help="is multiple input modality used?", default=False)
25 | parser.add_argument("--whichLoss", type=int, default=1, help="which loss to use: 1. LossL1, 2. lossRTL1, 3. MSE (default)")
26 | parser.add_argument("--whichNet", type=int, default=2, help="which loss to use: 1. UNet, 2. ResUNet, 3. UNet_LRes and 4. ResUNet_LRes (default, 3)")
27 | parser.add_argument("--lossBase", type=int, default=1, help="The base to multiply the lossG_G, Default (1)")
28 | parser.add_argument("--batchSize", type=int, default=32, help="training batch size")
29 | parser.add_argument("--numOfChannel_singleSource", type=int, default=5, help="# of channels for a 2D patch for the main modality (Default, 5)")
30 | parser.add_argument("--numOfChannel_allSource", type=int, default=5, help="# of channels for a 2D patch for all the concatenated modalities (Default, 5)")
31 | parser.add_argument("--isResidualEnhancement", action="store_true", help="is residual learning operation enhanced?", default=False)
32 | parser.add_argument("--isViewExpansion", action="store_true", help="is view expanded?", default=True)
33 | parser.add_argument("--isAdLoss", action="store_true", help="is adversarial loss used?", default=True)
34 | parser.add_argument("--isSpatialDropOut", action="store_true", help="is spatial dropout used?", default=False)
35 | parser.add_argument("--isFocalLoss", action="store_true", help="is focal loss used?", default=False)
36 | parser.add_argument("--isSampleImportanceFromAd", action="store_true", help="is sample importance from adversarial network used?", default=False)
37 | parser.add_argument("--dropoutRate", type=float, default=0.25, help="Spatial Dropout Rate. Default=0.25")
38 | parser.add_argument("--lambdaAD", type=float, default=0, help="loss coefficient for AD loss. Default=0")
39 | parser.add_argument("--adImportance", type=float, default=0, help="Sample importance from AD network. Default=0")
40 | parser.add_argument("--isFixedRegions", action="store_true", help="Is the organ regions roughly known?", default=False)
41 | #parser.add_argument("--modelPath", default="/home/niedong/Data4LowDosePET/pytorch_UNet/model/resunet2d_pet_Aug_noNorm_lres_bn_lr5e3_base1_lossL1_0p01_0624_200000.pt", type=str, help="prefix of the to-be-saved model name")
42 | # parser.add_argument("--modelPath", default="/shenlab/lab_stor5/dongnie/brain_mr2ct/modelFiles/resunet2d_dp_brain_BatchAug_sNorm_lres_bn_lr5e3_lrnetD5e3_lrdec_base1_wgan_gp_1107_140000.pt", type=str, help="prefix of the to-be-saved model name")
43 | parser.add_argument("--modelPath", default="/shenlab/lab_stor/dongnie/brats2018/modelFiles/resunet2d_dp_brats_BatchAug_sNorm_bn_lr5e3_lrnetD5e3_lrdec0p5_lrDdec0p05_wgan_gp_1112_200000.pt", type=str, help="prefix of the to-be-saved model name")
44 | # parser.add_argument("--prefixPredictedFN", default="/shenlab/lab_stor5/dongnie/brain_mr2ct/res/testResult/predCT_brain_resunet2d_dp_Aug_sNorm_lres_lrdce_bn_lr5e3_lossL1_1107_14w_", type=str, help="prefix of the to-be-saved predicted filename")
45 | parser.add_argument("--prefixPredictedFN", default="/shenlab/lab_stor/dongnie/brats2018/res/testResult/predBrats_v1_resunet2d_dp_Aug_sNorm_lres_lrdce_bn_lr5e3_lossL1_1112_20w_", type=str, help="prefix of the to-be-saved predicted filename")
46 | parser.add_argument("--how2normalize", type=int, default=6, help="how to normalize the data")
47 | parser.add_argument("--resType", type=int, default=1, help="resType: 0: segmentation map (integer); 1: regression map (continuous); 2: segmentation map + probability map")
48 |
49 | global opt
50 | opt = parser.parse_args()
51 |
52 |
53 | def main():
54 | print opt
55 |
56 | path_test = '/home/niedong/Data4LowDosePET/data_niigz_scale/'
57 | path_test = '/shenlab/lab_stor5/dongnie/brain_mr2ct/original_data/'
58 | path_test = '/shenlab/lab_stor/dongnie/brats2018/TrainData/HGG/Brats18_2013_11_1'
59 |
60 | if opt.whichNet==1:
61 | netG = UNet(in_channel=opt.numOfChannel_allSource, n_classes=1)
62 | elif opt.whichNet==2:
63 | netG = ResUNet(in_channel=opt.numOfChannel_allSource, n_classes=1)
64 | elif opt.whichNet==3:
65 | netG = UNet_LRes(in_channel=opt.numOfChannel_allSource, n_classes=1)
66 | elif opt.whichNet==4:
67 | netG = ResUNet_LRes(in_channel=opt.numOfChannel_allSource, n_classes=1)
68 |
69 | #netG.apply(weights_init)
70 | netG.cuda()
71 |
72 | checkpoint = torch.load(opt.modelPath)
73 | netG.load_state_dict(checkpoint['model'])
74 |
75 |
76 | ids = [1,6,11,16,21,26,31,36,41,46] #in on folder, we test 10 which is the testing set
77 | ids = [1] #in on folder, we test 10 which is the testing set
78 |
79 | ids = ['1_QFZ','2_LLQ','3_LMB','4_ZSL','5_CJB','11_TCL','15_WYL','21_PY','25_LYL','31_CZX','35_WLL','41_WQC','45_YXM']
80 | ids = [2,3,4,5,8,9,10,13]
81 | ids = ['Brats18_2013_11_1']
82 | for ind in ids:
83 | start = time.time()
84 |
85 | # mr_test_itk = sitk.ReadImage(os.path.join(path_test,'sub%d_mr.hdr'%ind))#input modality
86 | # ct_test_itk = sitk.ReadImage(os.path.join(path_test,'sub%d_ct.hdr'%ind))#auxialliary modality
87 | # hpet_test_itk = sitk.ReadImage(os.path.join(path_test, 'sub%d_ct.hdr'%ind))#output modality
88 |
89 | mr_test_itk = sitk.ReadImage(os.path.join(path_test, 'Brats18_2013_11_1_t1ce.nii.gz'))
90 | ct_test_itk = sitk.ReadImage(os.path.join(path_test, 'Brats18_2013_11_1_t2.nii.gz'))
91 |
92 | spacing = mr_test_itk.GetSpacing()
93 | origin = mr_test_itk.GetOrigin()
94 | direction = mr_test_itk.GetDirection()
95 |
96 | mrnp = sitk.GetArrayFromImage(mr_test_itk)
97 | ctnp = sitk.GetArrayFromImage(ct_test_itk)
98 | # hpetnp = sitk.GetArrayFromImage(hpet_test_itk)
99 |
100 | if opt.isMultiSource:
101 | hpet_test_itk = sitk.ReadImage(os.path.join(path_test, '%s_120s_suv.nii.gz' % ind))
102 | hpetnp = sitk.GetArrayFromImage(hpet_test_itk)
103 | else:
104 | hpetnp = ctnp
105 |
106 | ##### specific normalization #####
107 | mu = np.mean(mrnp)
108 | # maxV, minV = np.percentile(mrnp, [99 ,25])
109 | # #mrimg=mrimg
110 | # mrnp = (mrnp-minV)/(maxV-minV)
111 |
112 | # for training data in pelvicSeg
113 | if opt.how2normalize == 1:
114 | maxV, minV = np.percentile(mrnp, [99, 1])
115 | print 'maxV,', maxV, ' minV, ', minV
116 | mrnp = (mrnp - mu) / (maxV - minV)
117 | print 'unique value: ', np.unique(ctnp)
118 |
119 | # for training data in pelvicSeg
120 | if opt.how2normalize == 2:
121 | maxV, minV = np.percentile(mrnp, [99, 1])
122 | print 'maxV,', maxV, ' minV, ', minV
123 | mrnp = (mrnp - mu) / (maxV - minV)
124 | print 'unique value: ', np.unique(ctnp)
125 |
126 | # for training data in pelvicSegRegH5
127 | if opt.how2normalize == 3:
128 | std = np.std(mrnp)
129 | mrnp = (mrnp - mu) / std
130 | print 'maxV,', np.ndarray.max(mrnp), ' minV, ', np.ndarray.min(mrnp)
131 |
132 | if opt.how2normalize == 4:
133 | maxLPET = 149.366742
134 | maxPercentLPET = 7.76
135 | minLPET = 0.00055037
136 | meanLPET = 0.27593288
137 | stdLPET = 0.75747500
138 |
139 | # for rsCT
140 | maxCT = 27279
141 | maxPercentCT = 1320
142 | minCT = -1023
143 | meanCT = -601.1929
144 | stdCT = 475.034
145 |
146 | # for s-pet
147 | maxSPET = 156.675962
148 | maxPercentSPET = 7.79
149 | minSPET = 0.00055037
150 | meanSPET = 0.284224789
151 | stdSPET = 0.7642257
152 |
153 | # matLPET = (mrnp - meanLPET) / (stdLPET)
154 | matLPET = (mrnp - minLPET) / (maxPercentLPET - minLPET)
155 | matCT = (ctnp - meanCT) / stdCT
156 | matSPET = (hpetnp - minSPET) / (maxPercentSPET - minSPET)
157 |
158 | if opt.how2normalize == 5:
159 | # for rsCT
160 | maxCT = 27279
161 | maxPercentCT = 1320
162 | minCT = -1023
163 | meanCT = -601.1929
164 | stdCT = 475.034
165 |
166 | print 'ct, max: ', np.amax(ctnp), ' ct, min: ', np.amin(ctnp)
167 |
168 | # matLPET = (mrnp - meanLPET) / (stdLPET)
169 | matLPET = mrnp
170 | matCT = (ctnp - meanCT) / stdCT
171 | matSPET = hpetnp
172 |
173 | if opt.how2normalize == 6:
174 | maxPercentPET, minPercentPET = np.percentile(mrnp, [99.5, 0])
175 | maxPercentCT, minPercentCT = np.percentile(ctnp, [99.5, 0])
176 | print 'maxPercentPET: ', maxPercentPET, ' minPercentPET: ', minPercentPET, ' maxPercentCT: ', maxPercentCT, 'minPercentCT: ', minPercentCT
177 |
178 | matLPET = (mrnp - minPercentPET) / (maxPercentPET - minPercentPET)
179 | matCT = (ctnp - minPercentCT) / (maxPercentCT - minPercentCT)
180 | if opt.isMultiSource:
181 | matSPET = (hpetnp - minPercentPET) / (maxPercentPET - minPercentPET)
182 |
183 |
184 |
185 | if not opt.isMultiSource:
186 | matFA = matLPET
187 | matGT = hpetnp
188 |
189 | print 'matFA shape: ', matFA.shape, ' matGT shape: ', matGT.shape,' max(matFA): ',np.amax(matFA),' min(matFA): ',np.amin(matFA)
190 | # matOut = testOneSubject_aver_res(matFA, matGT, [5, 64, 64], [1, 64, 64], [1, 16, 16], netG, opt.modelPath)
191 | if opt.whichNet == 3 or opt.whichNet == 4:
192 | matOut = testOneSubject_aver_res(matFA, matGT, [5, 64, 64], [1, 64, 64], [1, 32, 32], netG,
193 | opt.modelPath)
194 | else:
195 | matOut = testOneSubject_aver(matFA, matGT, [5, 64, 64], [1, 64, 64], [1, 32, 32], netG,
196 | opt.modelPath)
197 | print 'matOut shape: ', matOut.shape, ' max(matOut): ',np.amax(matOut),' min(matOut): ',np.amin(matOut)
198 | if opt.how2normalize == 6:
199 | # ct_estimated = matOut * (maxPercentPET - minPercentPET) + minPercentPET
200 | ct_estimated = matOut * (maxPercentCT - minPercentCT) + minPercentCT
201 | else:
202 | ct_estimated = matOut
203 | #ct_estimated[np.where(mrnp==0)] = 0
204 | itspsnr = psnr(ct_estimated, matGT)
205 |
206 | print 'pred: ', ct_estimated.dtype, ' shape: ', ct_estimated.shape
207 | print 'gt: ', ctnp.dtype, ' shape: ', matGT.shape
208 | print 'psnr = ', itspsnr
209 | volout = sitk.GetImageFromArray(ct_estimated)
210 | volout.SetSpacing(spacing)
211 | volout.SetOrigin(origin)
212 | volout.SetDirection(direction)
213 | sitk.WriteImage(volout, opt.prefixPredictedFN + '{}'.format(ind) + '.nii.gz')
214 | else:
215 | matFA = matLPET
216 | matGT = hpetnp
217 | print 'matFA shape: ', matFA.shape, ' matGT shape: ', matGT.shape
218 | # matOut = testOneSubject_aver_res_multiModal(matFA, matCT, matGT, [5, 64, 64], [1, 64, 64], [1, 16, 16], netG, opt.modelPath)
219 | if opt.whichNet == 3 or opt.whichNet == 4:
220 | matOut = testOneSubject_aver_res_multiModal(matFA, matCT, matGT, [5, 64, 64], [1, 64, 64], [1, 32, 32], netG,
221 | opt.modelPath)
222 | else:
223 | matOut = testOneSubject_aver_MultiModal(matFA, matCT, matGT, [5, 64, 64], [1, 64, 64], [1, 32, 32], netG,
224 | opt.modelPath)
225 | print 'matOut shape: ', matOut.shape
226 | if opt.how2normalize == 6:
227 | ct_estimated = matOut * (maxPercentPET - minPercentPET) + minPercentPET
228 | else:
229 | ct_estimated = matOut
230 |
231 | #ct_estimated[np.where(mrnp==0)] = 0
232 | itspsnr = psnr(ct_estimated, matGT)
233 |
234 | print 'pred: ', ct_estimated.dtype, ' shape: ', ct_estimated.shape
235 | print 'gt: ', ctnp.dtype, ' shape: ', ct_estimated.shape
236 | print 'psnr = ', itspsnr
237 | volout = sitk.GetImageFromArray(ct_estimated)
238 | volout.SetSpacing(spacing)
239 | volout.SetOrigin(origin)
240 | volout.SetDirection(direction)
241 | sitk.WriteImage(volout, opt.prefixPredictedFN + '{}'.format(ind) + '.nii.gz')
242 |
243 | if __name__ == '__main__':
244 | # testGradients()
245 | os.environ['CUDA_VISIBLE_DEVICES'] = str(opt.gpuID)
246 | main()
247 |
--------------------------------------------------------------------------------
/shuffleDataAmongSubjects_2d.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import h5py
3 | import os
4 |
5 | '''
6 | Shuffle data (patches) among the subjects
7 | input:
8 | save_dir: the h5 files you save
9 | num: the num of components in the h5 file
10 | output:
11 | save as the same name h5 files
12 | '''
13 |
14 |
15 | def shuffleDataAmongSubjects(save_dir, savepath):
16 | # allfilenames = os.listdir(save_dir)
17 | # allfilenames = filter(lambda x: '.h5' in x and 'train' in x, allfilenames)
18 |
19 | nn = 200000
20 | # dataMR = np.zeros([nn, 1, 5, 64, 64], dtype=np.float16)
21 | dataLPET = np.zeros([nn,1, 5, 64, 64], dtype=np.float16)
22 | dataCT = np.zeros([nn,1, 5, 64, 64], dtype=np.float16)
23 | dataHPET = np.zeros([nn, 1, 1, 64, 64], dtype=np.float16)
24 |
25 | allfilenames = os.listdir(save_dir)
26 | # print allfilenames
27 | allfilenames = filter(lambda x: '.h5' in x and 'train' in x, allfilenames)
28 | # print allfilenames
29 | cnt = 0
30 | numInOneSub = 5
31 | batchID = 0
32 | startInd = 0
33 | savefilename = 'train5x64x64_'
34 | for i_file, filename in enumerate(allfilenames):
35 |
36 | with h5py.File(os.path.join(save_dir, filename), 'r+') as h5f:
37 | print '*******path is ', os.path.join(save_dir, filename)
38 | dLPET = h5f['dataLPET'][:]
39 | dCT = h5f['dataCT'][:]
40 | dHPET = h5f['dataHPET'][:]
41 |
42 | unitNum = dLPET.shape[0]
43 | print 'unitNum: ', unitNum, 'dLPET shape: ', dLPET.shape
44 |
45 | dataLPET[startInd: (startInd + unitNum), ...] = dLPET
46 | dataCT[startInd: startInd + unitNum, ...] = dCT
47 | dataHPET[startInd: startInd + unitNum, ...] = dHPET
48 |
49 | startInd = startInd + unitNum
50 |
51 | cnt = cnt + 1
52 |
53 | if cnt == numInOneSub:
54 | batchID = batchID + 1
55 | dataLPET = dataLPET[0:startInd, ...]
56 | dataCT = dataCT[0:startInd, ...]
57 | dataHPET = dataHPET[0:startInd, ...]
58 |
59 | with h5py.File(os.path.join(savepath, savefilename + '{}.h5'.format(batchID)), 'w') as hf:
60 | hf.create_dataset('dataLPET', data=dataLPET)
61 | hf.create_dataset('dataCT', data=dataCT)
62 | hf.create_dataset('dataHPET', data=dataHPET)
63 |
64 | ############ initialization ###############
65 | cnt = 0
66 | startInd = 0
67 | print
68 | 'nn:', nn
69 | dataLPET = np.zeros([nn, 1, 5, 64, 64], dtype=np.float16)
70 | dataCT = np.zeros([nn, 1, 5, 64, 64], dtype=np.float16)
71 | dataHPET = np.zeros([nn, 1, 1, 64, 64], dtype=np.float16)
72 |
73 | # mean_train, std_train = 0., 0
74 | batchID = batchID + 1
75 | if startInd != 0:
76 | dataLPET = dataLPET[0:startInd, ...]
77 | dataCT = dataCT[0:startInd, ...]
78 | dataHPET = dataHPET[0:startInd, ...]
79 |
80 | with h5py.File(os.path.join(savepath, savefilename + '{}.h5'.format(batchID)), 'w') as hf:
81 | hf.create_dataset('dataLPET', data=dataLPET)
82 | hf.create_dataset('dataCT', data=dataCT)
83 | hf.create_dataset('dataHPET', data=dataHPET)
84 |
85 | return
86 |
87 |
88 | def main():
89 | path = '/home/niedong/Data4LowDosePET/h5DataAug_noNorm/train2D_H5/'
90 | savepath = '/home/niedong/Data4LowDosePET/h5DataAug_noNorm/trainBatch2D_H5/'
91 | basePath = '/home/niedong/Data4LowDosePET/h5DataAug_noNorm/'
92 | path = basePath + 'train2D_H5/'
93 | savepath = basePath + 'trainBatch2D_H5/'
94 | shuffleDataAmongSubjects(path, savepath)
95 |
96 |
97 | if __name__ == "__main__":
98 | os.environ['CUDA_VISIBLE_DEVICES'] = '0'
99 | main()
100 |
--------------------------------------------------------------------------------
/shuffleDataAmongSubjects_3d.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import h5py
3 | import os
4 |
5 | '''
6 | Shuffle data (patches) among the subjects
7 | input:
8 | save_dir: the h5 files you save
9 | num: the num of components in the h5 file
10 | output:
11 | save as the same name h5 files
12 | '''
13 |
14 | # dFA = [3,168,112] # size of patches of input data
15 | # dSeg = [1,168,112] # size of pathes of label data
16 |
17 | dFA = [16, 64, 64] # size of patches of input data
18 | dSeg = [16, 64, 64] # size of pathes of label data
19 |
20 |
21 | def shuffleDataAmongSubjects(save_dir, savepath):
22 | # allfilenames = os.listdir(save_dir)
23 | # allfilenames = filter(lambda x: '.h5' in x and 'train' in x, allfilenames)
24 |
25 | nn = 10000
26 | dataLPET = np.zeros([nn,1, 16, 64, 64], dtype=np.float16)
27 | dataCT = np.zeros([nn,1, 16, 64, 64], dtype=np.float16)
28 | dataHPET = np.zeros([nn, 1, 16, 64, 64], dtype=np.float16)
29 | #
30 |
31 | allfilenames = os.listdir(save_dir)
32 | # print allfilenames
33 | allfilenames = filter(lambda x: '.h5' in x and 'train' in x, allfilenames)
34 | # print allfilenames
35 | cnt = 0
36 | numInOneSub = 5
37 | batchID = 0
38 | startInd = 0
39 |
40 | savefilename = 'train5x64x64_'
41 | for i_file, filename in enumerate(allfilenames):
42 |
43 | with h5py.File(os.path.join(save_dir, filename), 'r+') as h5f:
44 | print '*******path is ', os.path.join(save_dir, filename)
45 | dLPET = h5f['dataLPET'][:]
46 | dCT = h5f['dataCT'][:]
47 | dHPET = h5f['dataHPET'][:]
48 |
49 | unitNum = dLPET.shape[0]
50 | print 'unitNum: ', unitNum, 'dLPET shape: ', dLPET.shape
51 |
52 | dataLPET[startInd: (startInd + unitNum), ...] = dLPET
53 | dataCT[startInd: startInd + unitNum, ...] = dCT
54 | dataHPET[startInd: startInd + unitNum, ...] = dHPET
55 |
56 | startInd = startInd + unitNum
57 |
58 | cnt = cnt + 1
59 |
60 | if cnt == numInOneSub:
61 | batchID = batchID + 1
62 | dataLPET = dataLPET[0:startInd, ...]
63 | dataCT = dataCT[0:startInd, ...]
64 | dataHPET = dataHPET[0:startInd, ...]
65 |
66 | with h5py.File(os.path.join(savepath, savefilename + '{}.h5'.format(batchID)), 'w') as hf:
67 | hf.create_dataset('dataLPET', data=dataLPET)
68 | hf.create_dataset('dataCT', data=dataCT)
69 | hf.create_dataset('dataHPET', data=dataHPET)
70 |
71 | ############ initialization ###############
72 | cnt = 0
73 | startInd = 0
74 | print
75 | 'nn:', nn
76 | dataLPET = np.zeros([nn, 1, 16, 64, 64], dtype=np.float16)
77 | dataCT = np.zeros([nn, 1, 16, 64, 64], dtype=np.float16)
78 | dataHPET = np.zeros([nn, 16, 1, 64, 64], dtype=np.float16)
79 |
80 | # mean_train, std_train = 0., 0
81 | batchID = batchID + 1
82 | if startInd != 0:
83 | dataLPET = dataLPET[0:startInd, ...]
84 | dataCT = dataCT[0:startInd, ...]
85 | dataHPET = dataHPET[0:startInd, ...]
86 |
87 | with h5py.File(os.path.join(savepath, savefilename + '{}.h5'.format(batchID)), 'w') as hf:
88 | hf.create_dataset('dataLPET', data=dataLPET)
89 | hf.create_dataset('dataCT', data=dataCT)
90 | hf.create_dataset('dataHPET', data=dataHPET)
91 |
92 | return
93 |
94 |
95 | def main():
96 | path = '/home/niedong/Data4LowDosePET/h5Data3D_noNorm/train3D_H5/'
97 | savepath = '/home/niedong/Data4LowDosePET/h5DataAug_noNorm/trainBatch3D_H5/'
98 | basePath = '/home/niedong/Data4LowDosePET/h5DataAug_noNorm/'
99 | # path = basePath + 'train2D_H5/'
100 | # savepath = basePath + 'trainBatch2D_H5/'
101 | shuffleDataAmongSubjects(path, savepath)
102 |
103 |
104 | if __name__ == "__main__":
105 | os.environ['CUDA_VISIBLE_DEVICES'] = '0'
106 | main()
107 |
--------------------------------------------------------------------------------