├── .DS_Store
├── .idea
├── misc.xml
├── modules.xml
├── promise12.iml
└── workspace.xml
├── DataManager.py
├── README.md
├── customDataset.py
├── dataset
├── .DS_Store
├── imagesTr
│ └── .DS_Store
├── imagesTs
│ └── .DS_Store
└── labelsTr
│ └── .DS_Store
├── lossFuncs.py
├── main.py
├── make_graph.py
├── miscs_promise12.py
├── plot.py
├── plot_contours.py
├── train.py
├── utils.py
└── vnet.py
/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/huangmozhilv/promise12_vnet_pytorch/d5e5e09bb266d5acf004eefdc54373476641fbdc/.DS_Store
--------------------------------------------------------------------------------
/.idea/misc.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
--------------------------------------------------------------------------------
/.idea/modules.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
--------------------------------------------------------------------------------
/.idea/promise12.iml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
--------------------------------------------------------------------------------
/.idea/workspace.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
34 |
35 |
36 |
37 |
38 |
39 |
40 |
41 |
42 |
43 |
44 |
45 |
46 |
47 |
48 |
49 |
50 |
51 |
52 |
53 |
54 |
55 |
56 |
57 |
58 |
59 |
60 |
61 |
62 |
63 |
64 |
65 |
66 |
67 |
68 |
69 |
70 |
71 |
72 |
73 |
74 |
75 |
76 |
77 |
82 |
83 |
84 |
85 | make_grap
86 | optim
87 | optimizer
88 | adjust
89 | batchsize
90 | numIterations
91 | baseLR
92 | weight_decay
93 | make_gr
94 | .train(
95 | mpl
96 | dstRes
97 | params['ModelParams']['numcontrolpoints']
98 | params['DataManagerParams']
99 | trainF
100 | adjust_opt
101 | sgd
102 | gamma
103 | momentum
104 | sigma
105 | testprop
106 | original?
107 | original
108 | average
109 | epoch
110 | testF
111 | args.expDir
112 | rolling
113 | plot.py
114 | save_checkpoint
115 |
116 |
117 | numcontrolpoints
118 | args.testProp
119 | DataManagerParams
120 | test
121 | Iter
122 | Iter,
123 | trainI
124 | expDir
125 | xLabel
126 |
127 |
128 |
129 |
130 |
131 |
132 |
133 |
134 |
135 |
136 |
137 |
138 |
139 |
140 |
141 |
142 |
143 |
144 |
145 |
146 |
147 |
148 |
149 | true
150 | DEFINITION_ORDER
151 |
152 |
153 |
154 |
155 |
156 |
157 |
158 |
159 |
160 |
161 |
162 |
163 |
164 |
165 |
166 |
167 |
168 |
169 |
170 |
171 |
172 |
173 |
174 |
175 |
176 |
177 |
178 |
179 |
180 |
181 |
182 |
183 |
184 |
185 |
186 |
187 |
188 |
189 |
190 |
191 |
192 |
193 |
194 |
195 |
196 |
197 |
198 |
199 |
200 |
201 |
202 |
203 |
204 |
205 |
206 | 1535538921569
207 |
208 |
209 | 1535538921569
210 |
211 |
212 |
213 |
214 |
215 |
216 |
217 |
218 |
219 |
220 |
221 |
222 |
223 |
224 |
225 |
226 |
227 |
228 |
229 |
230 |
231 |
232 |
233 |
234 |
235 |
236 |
237 |
238 |
239 |
240 |
241 |
242 |
243 |
244 |
245 |
246 |
247 |
248 |
249 |
250 |
251 |
252 |
253 |
254 |
255 |
256 |
257 |
258 |
259 |
260 |
261 |
262 |
263 |
264 |
265 |
266 |
267 |
268 |
269 |
270 |
271 |
272 |
273 |
274 |
275 |
276 |
277 |
278 |
279 |
280 |
281 |
282 |
283 |
284 |
285 |
286 |
287 |
288 |
289 |
290 |
291 |
292 |
293 |
294 |
295 |
296 |
297 |
298 |
299 |
300 |
301 |
302 |
303 |
304 |
305 |
306 |
307 |
308 |
309 |
310 |
311 |
312 |
313 |
314 |
315 |
316 |
317 |
318 |
319 |
320 |
321 |
322 |
323 |
324 |
325 |
326 |
327 |
328 |
329 |
330 |
331 |
332 |
333 |
334 |
335 |
336 |
337 |
338 |
339 |
340 |
341 |
342 |
343 |
344 |
345 |
346 |
347 |
348 |
349 |
350 |
351 |
352 |
353 |
354 |
355 |
356 |
357 |
358 |
359 |
360 |
361 |
362 |
363 |
364 |
365 |
366 |
367 |
368 |
369 |
370 |
371 |
372 |
373 |
374 |
375 |
376 |
377 |
378 |
379 |
380 |
381 |
382 |
383 |
384 |
385 |
386 |
387 |
388 |
389 |
390 |
391 |
392 |
393 |
394 |
395 |
396 |
397 |
398 |
399 |
400 |
401 |
402 |
403 |
404 |
405 |
406 |
407 |
408 |
409 |
410 |
411 |
412 |
413 |
414 |
415 |
416 |
417 |
418 |
419 |
420 |
421 |
422 |
423 |
424 |
425 |
426 |
427 |
428 |
429 |
430 |
431 |
432 |
433 |
434 |
435 |
436 |
437 |
438 |
439 |
440 |
441 |
442 |
443 |
444 |
445 |
446 |
447 |
448 |
449 |
450 |
451 |
452 |
453 |
454 |
455 |
456 |
457 |
458 |
459 |
460 |
461 |
462 |
463 |
464 |
465 |
466 |
467 |
468 |
469 |
470 |
471 |
472 |
473 |
474 |
475 |
476 |
477 |
478 |
479 |
480 |
481 |
482 |
483 |
484 |
485 |
486 |
487 |
488 |
489 |
490 |
491 |
492 |
493 |
494 |
495 |
496 |
497 |
498 |
499 |
500 |
501 |
502 |
503 |
504 |
505 |
506 |
507 |
508 |
509 |
510 |
511 |
512 |
513 |
514 |
515 |
516 |
517 |
518 |
519 |
520 |
521 |
522 |
523 |
524 |
525 |
526 |
527 |
528 |
529 |
530 |
531 |
532 |
533 |
534 |
535 |
536 |
537 |
538 |
539 |
540 |
541 |
542 |
543 |
544 |
545 |
546 |
547 |
548 |
549 |
550 |
551 |
552 |
553 |
554 |
555 |
556 |
557 |
558 |
559 |
560 |
561 |
562 |
563 |
564 |
565 |
566 |
567 |
568 |
569 |
570 |
571 |
572 |
573 |
--------------------------------------------------------------------------------
/DataManager.py:
--------------------------------------------------------------------------------
1 | import os
2 | from os import listdir
3 | from os.path import isfile, join, splitext
4 |
5 | import numpy as np
6 | import SimpleITK as sitk
7 |
8 | # dataset isotropically scaled to 1x1x1.5mm, volume resized to 128x128x64
9 | # The datasets were first normalised using the N4 bias filed correction function of the ANTs framework
10 |
11 |
12 | class DataManager(object):
13 | # params=None
14 | # srcFolder=None
15 | # resultsDir=None
16 | #
17 | # fileList=None
18 | # gtList=None
19 | #
20 | # sitkImages=None
21 | # sitkGT=None
22 | # meanIntensityTrain = None
23 |
24 | def __init__(self,imageFolder, GTFolder, resultsDir, parameters):
25 | self.params = parameters
26 | self.imageFolder = imageFolder
27 | self.GTFolder = GTFolder
28 | self.resultsDir = resultsDir
29 |
30 | def createImageFileList(self):
31 | self.imageFileList = [f for f in listdir(self.imageFolder) if isfile(join(self.imageFolder, f)) and '.DS_Store' not in f and '._' not in f and '.raw' not in f]
32 | print('imageFileList: ' + str(self.imageFileList))
33 |
34 | def createGTFileList(self):
35 | self.GTFileList = [f for f in listdir(self.GTFolder) if isfile(
36 | join(self.GTFolder, f)) and '.DS_Store' not in f and '._' not in f and '.raw' not in f]
37 | print('GTFileList: ' + str(self.GTFileList))
38 |
39 | def loadImages(self):
40 | self.sitkImages=dict()
41 | rescalFilt=sitk.RescaleIntensityImageFilter()
42 | rescalFilt.SetOutputMaximum(1)
43 | rescalFilt.SetOutputMinimum(0)
44 |
45 | stats = sitk.StatisticsImageFilter()
46 | m = 0.
47 |
48 | for f in self.imageFileList:
49 | id = f.split('.')[0]
50 | self.sitkImages[id]=rescalFilt.Execute(sitk.Cast(sitk.ReadImage(join(self.imageFolder, f)),sitk.sitkFloat32))
51 | stats.Execute(self.sitkImages[id])
52 | m += stats.GetMean()
53 |
54 | self.meanIntensityTrain=m/len(self.sitkImages)
55 |
56 |
57 | def loadGT(self):
58 | self.sitkGT=dict()
59 |
60 | for f in self.GTFileList:
61 | id = f.split('.')[0]
62 | self.sitkGT[id]=sitk.Cast(sitk.ReadImage(join(self.GTFolder, f))>0.5,sitk.sitkFloat32)
63 |
64 | def loadTrainingData(self):
65 | self.createImageFileList()
66 | self.createGTFileList()
67 | self.loadImages()
68 | self.loadGT()
69 |
70 | def loadTestingData(self):
71 | self.createImageFileList()
72 | self.createGTFileList()
73 | self.loadImages()
74 | self.loadGT()
75 |
76 | def loadInferData(self):
77 | self.createImageFileList()
78 | self.loadImages()
79 |
80 | def getNumpyImages(self):
81 | dat = self.getNumpyData(self.sitkImages,sitk.sitkLinear)
82 |
83 | for key in dat.keys(): # https://github.com/faustomilletari/VNet/blob/master/VNet.py, line 147. For standardization?
84 | mean = np.mean(dat[key][dat[key]>0]) # why restrict to >0? By Chao.
85 | std = np.std(dat[key][dat[key]>0])
86 |
87 | dat[key] -= mean
88 | dat[key] /=std
89 |
90 | return dat
91 |
92 |
93 | def getNumpyGT(self):
94 | dat = self.getNumpyData(self.sitkGT,sitk.sitkLinear)
95 |
96 | for key in dat:
97 | dat[key] = (dat[key]>0.5).astype(dtype=np.float32)
98 |
99 | return dat
100 |
101 |
102 | def getNumpyData(self,dat,method):
103 | ret=dict()
104 | for key in dat:
105 | ret[key] = np.zeros([self.params['VolSize'][0], self.params['VolSize'][1], self.params['VolSize'][2]], dtype=np.float32)
106 |
107 | img=dat[key]
108 |
109 | #we rotate the image according to its transformation using the direction and according to the final spacing we want
110 | factor = np.asarray(img.GetSpacing()) / [self.params['dstRes'][0], self.params['dstRes'][1],
111 | self.params['dstRes'][2]]
112 |
113 | factorSize = np.asarray(img.GetSize() * factor, dtype=float)
114 |
115 | newSize = np.max([factorSize, self.params['VolSize']], axis=0)
116 |
117 | newSize = newSize.astype(dtype='int')
118 |
119 | T=sitk.AffineTransform(3)
120 | T.SetMatrix(img.GetDirection())
121 |
122 | resampler = sitk.ResampleImageFilter()
123 | resampler.SetReferenceImage(img)
124 | resampler.SetOutputSpacing([self.params['dstRes'][0], self.params['dstRes'][1], self.params['dstRes'][2]])
125 | resampler.SetSize(newSize.tolist())
126 | resampler.SetInterpolator(method)
127 | if self.params['normDir']:
128 | resampler.SetTransform(T.GetInverse())
129 |
130 | imgResampled = resampler.Execute(img)
131 |
132 |
133 | imgCentroid = np.asarray(newSize, dtype=float) / 2.0
134 |
135 | imgStartPx = (imgCentroid - self.params['VolSize'] / 2.0).astype(dtype='int')
136 |
137 | regionExtractor = sitk.RegionOfInterestImageFilter()
138 | size_2_set = self.params['VolSize'].astype(dtype='int')
139 | regionExtractor.SetSize(size_2_set.tolist())
140 | regionExtractor.SetIndex(imgStartPx.tolist())
141 |
142 | imgResampledCropped = regionExtractor.Execute(imgResampled)
143 |
144 | ret[key] = np.transpose(sitk.GetArrayFromImage(imgResampledCropped).astype(dtype=float), [2, 1, 0])
145 |
146 | return ret
147 |
148 |
149 | def writeResultsFromNumpyLabel(self,result,key, resultTag, ext, resultDir):
150 | '''
151 | :param result: predicted mask
152 | :param key: sample id
153 | :return: register predicted mask (e.g. binary mask of size 96x96x48) to original image (e.g. CT volume of size 320x320x20), output the final mask of the same size as original image.
154 | '''
155 | img = self.sitkImages[key] # original image
156 | print("original img shape{}".format(img.GetSize()))
157 |
158 | toWrite = sitk.Image(img.GetSize()[0],img.GetSize()[1],img.GetSize()[2],sitk.sitkFloat32)
159 |
160 | factor = np.asarray(img.GetSpacing()) / [self.params['dstRes'][0], self.params['dstRes'][1],
161 | self.params['dstRes'][2]]
162 |
163 | factorSize = np.asarray(img.GetSize() * factor, dtype=float)
164 |
165 | newSize = np.max([factorSize, self.params['VolSize']], axis=0)
166 |
167 | newSize = newSize.astype(dtype=int)
168 |
169 | T = sitk.AffineTransform(3)
170 | T.SetMatrix(img.GetDirection())
171 |
172 | resampler = sitk.ResampleImageFilter()
173 | resampler.SetReferenceImage(img)
174 | resampler.SetOutputSpacing([self.params['dstRes'][0], self.params['dstRes'][1], self.params['dstRes'][2]])
175 | resampler.SetSize(newSize.tolist())
176 | resampler.SetInterpolator(sitk.sitkNearestNeighbor)
177 |
178 | if self.params['normDir']:
179 | resampler.SetTransform(T.GetInverse())
180 |
181 | toWrite = resampler.Execute(toWrite)
182 |
183 | imgCentroid = np.asarray(newSize, dtype=float) / 2.0
184 |
185 | imgStartPx = (imgCentroid - self.params['VolSize'] / 2.0).astype(dtype=int)
186 |
187 | for dstX, srcX in zip(range(0, result.shape[0]), range(imgStartPx[0],int(imgStartPx[0]+self.params['VolSize'][0]))):
188 | for dstY, srcY in zip(range(0, result.shape[1]), range(imgStartPx[1], int(imgStartPx[1]+self.params['VolSize'][1]))):
189 | for dstZ, srcZ in zip(range(0, result.shape[2]), range(imgStartPx[2], int(imgStartPx[2]+self.params['VolSize'][2]))):
190 | try:
191 | toWrite.SetPixel(int(srcX),int(srcY),int(srcZ),float(result[dstX,dstY,dstZ]))
192 | except:
193 | pass
194 |
195 |
196 | resampler.SetOutputSpacing([img.GetSpacing()[0], img.GetSpacing()[1], img.GetSpacing()[2]])
197 | resampler.SetSize(img.GetSize())
198 |
199 | if self.params['normDir']:
200 | resampler.SetTransform(T)
201 |
202 | toWrite = resampler.Execute(toWrite)
203 |
204 | thfilter=sitk.BinaryThresholdImageFilter()
205 | thfilter.SetInsideValue(1)
206 | thfilter.SetOutsideValue(0)
207 | thfilter.SetLowerThreshold(0.5)
208 | toWrite = thfilter.Execute(toWrite)
209 |
210 | #connected component analysis (better safe than sorry)
211 |
212 | cc = sitk.ConnectedComponentImageFilter()
213 | toWritecc = cc.Execute(sitk.Cast(toWrite,sitk.sitkUInt8))
214 |
215 | arrCC=np.transpose(sitk.GetArrayFromImage(toWritecc).astype(dtype=float), [2, 1, 0])
216 |
217 | lab=np.zeros(int(np.max(arrCC)+1),dtype=float)
218 |
219 | for i in range(1,int(np.max(arrCC)+1)):
220 | lab[i]=np.sum(arrCC==i)
221 |
222 | activeLab=np.argmax(lab)
223 |
224 | toWrite = (toWritecc==activeLab)
225 |
226 | toWrite = sitk.Cast(toWrite,sitk.sitkUInt8)
227 |
228 | writer = sitk.ImageFileWriter()
229 |
230 | #print join(self.resultsDir, filename + '_result' + ext)
231 | writer.SetFileName(join(resultDir, key + resultTag + ext))
232 | writer.Execute(toWrite)
233 |
234 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # A PyTorch implementation of V-Net
2 |
3 | This V-Net code is a [PyTorch](http://pytorch.org/) implementation of the paper
4 | [V-Net: Fully Convolutional Neural Networks for Volumetric Medical Image Segmentation](https://arxiv.org/abs/1606.04797)
5 | by Fausto Milletari, Nassir Navab, and Seyed-Ahmad Ahmadi. The data preprocessing code is based on the official version [faustomilletari/VNet](https://github.com/faustomilletari/VNet), while the V-Net model was built on [mattmacy/vnet.pytorch](https://github.com/mattmacy/vnet.pytorch).
6 |
7 | To apply the code, just modify the lines from "main.py", "DataManager.py" and "train.py" where marked as "require customization"
8 |
9 | Average dice coefficient for the 30 test cases: 0.887.
10 | PROMISE12 challenge score: 85.67
11 | Rank: #23 on Sep 11, 2018.
12 |
--------------------------------------------------------------------------------
/customDataset.py:
--------------------------------------------------------------------------------
1 | import pdb
2 | import torch
3 | import torch.utils.data as data
4 |
5 | import numpy as np
6 |
7 | class customDataset(data.Dataset):
8 | '''
9 | For medical segmentation decathlon.
10 | '''
11 |
12 | def __init__(self, mode, images, GT, transform=None, GT_transform=None):
13 | if images is None:
14 | raise(RuntimeError("images must be set"))
15 | self.mode = mode
16 | self.images = images
17 | self.GT = GT
18 | self.transform = transform
19 | self.GT_transform = GT_transform
20 |
21 | def __getitem__(self, index):
22 | """
23 | Args:
24 | index(int): Index
25 | Returns:
26 | tuple: (image, GT) where GT is index of the
27 | """
28 | if self.mode == "train":
29 | # keys = list(self.images.keys())
30 | # id = keys[index]
31 | # because of data augmentation, train images are stored in a 4-d array, with first d as sample index.
32 | image = self.images[index]
33 | # print("image shape from DataManager shown in PROMISE12:" + str(image.shape)) # e.g. 96,96,48.
34 | image = np.transpose(image,[2,1,0]) # added by Chao
35 | image = np.expand_dims(image, axis=0)
36 | # print("expanded image dims:{}".format(str(image.shape)))
37 | # pdb.set_trace()
38 | image = image.astype(np.float32)
39 | if self.transform is not None:
40 | image = torch.from_numpy(image)
41 | # image = self.transform(image)
42 |
43 | GT = self.GT[index]
44 | GT = np.transpose(GT, [2, 1, 0])
45 | if self.GT_transform is not None:
46 | GT = self.GT_transform(GT)
47 | return image, GT
48 | elif self.mode == "test":
49 | keys = list(self.images.keys())
50 | id = keys[index]
51 | image = self.images[id]
52 | image = np.transpose(image, [2, 1, 0]) # added by Chao
53 | image = np.expand_dims(image, axis=0)
54 | # print("expanded image dims:{}".format(str(image.shape)))
55 | # pdb.set_trace()
56 | image = image.astype(np.float32)
57 | if self.transform is not None:
58 | image = torch.from_numpy(image)
59 | # image = self.transform(image)
60 |
61 | GT = self.GT[id+'_segmentation'] # require customization
62 | GT = np.transpose(GT, [2, 1, 0])
63 | if self.GT_transform is not None:
64 | GT = self.GT_transform(GT)
65 | return image, GT, id
66 | elif self.mode == "infer":# added by Chao
67 | keys = list(self.images.keys())
68 | id = keys[index]
69 | image = self.images[id]
70 | # print("image shape from DataManager shown in PROMISE12:" + str(image.shape))
71 | image = np.transpose(image,[2,1,0]) # added by Chao
72 | image = np.expand_dims(image, axis=0)
73 | image = image.astype(np.float32)
74 | return image, id
75 |
76 | def __len__(self):
77 | return len(self.images)
78 |
--------------------------------------------------------------------------------
/dataset/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/huangmozhilv/promise12_vnet_pytorch/d5e5e09bb266d5acf004eefdc54373476641fbdc/dataset/.DS_Store
--------------------------------------------------------------------------------
/dataset/imagesTr/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/huangmozhilv/promise12_vnet_pytorch/d5e5e09bb266d5acf004eefdc54373476641fbdc/dataset/imagesTr/.DS_Store
--------------------------------------------------------------------------------
/dataset/imagesTs/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/huangmozhilv/promise12_vnet_pytorch/d5e5e09bb266d5acf004eefdc54373476641fbdc/dataset/imagesTs/.DS_Store
--------------------------------------------------------------------------------
/dataset/labelsTr/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/huangmozhilv/promise12_vnet_pytorch/d5e5e09bb266d5acf004eefdc54373476641fbdc/dataset/labelsTr/.DS_Store
--------------------------------------------------------------------------------
/lossFuncs.py:
--------------------------------------------------------------------------------
1 | import pdb
2 |
3 | import torch
4 | from torch.autograd import Function
5 | from itertools import repeat
6 | import numpy as np
7 |
8 | # Intersection = dot(A, B)
9 | # Union = dot(A, A) + dot(B, B)
10 | # The Dice loss function is defined as
11 | # 1/2 * intersection / union
12 | #
13 | # The derivative is 2[(union * target - 2 * intersect * input) / union^2]
14 |
15 | class DiceLoss(Function):
16 | '''
17 | Compute energy based on dice coefficient.
18 | Aims to maximize dice coefficient.
19 | '''
20 | def __init__(self, *args, **kwargs):
21 | pass
22 |
23 | def forward(self, input, target, save=True): # it seems official v-net sum up the dice coefficients over a minibatch. but not sure how it does with backward gradients? by Chao. In this case, mean is used both for forward and backward for each minibatch.
24 | # input shape: softmax output. shape is [batch_size, 2 (background and foreground), z*y*x]??? by Chao.
25 | # target shape: [batch_size, z, y, x]?? by Chao.
26 | # print("Loss forward:\ninput shape:{}; target shape:{}".format(input.shape, target.shape))
27 |
28 | eps = 0.00001
29 |
30 | # reshape target
31 | # pdb.set_trace()
32 | b, z, y, x = target.shape # b:batch_size, z:depth, y:height, w:width
33 | target_ = target.view(b, -1)
34 |
35 | # result_ = torch.zeros(input.shape[0], input.shape[2])
36 | # target_ = torch.zeros(input.shape[0], input.shape[2])
37 |
38 | # _, result_ = input.max(1)
39 | # for i in range(input.shape[0]):
40 | # result_[i, :] = input[i, :, :].argmax(0) # by Chao
41 | result_ = input.argmax(1) # dim 2 is of length 2. Reduce the length to 1 and label it with the class with highest probability. by Chao.
42 |
43 | # result_ = torch.squeeze(result_) # will do harm when batch_size=1
44 |
45 | # if input.is_cuda:
46 | # result = torch.cuda.FloatTensor(result_.size())
47 | # target = torch.cuda.FloatTensor(target_.size())
48 | # else:
49 | # result = torch.FloatTensor(result_.size())
50 | # target = torch.FloatTensor(target_.size())
51 | # result.copy_(result_)
52 | # self.target_.copy_(target)
53 | # target = self.target_
54 | if input.is_cuda: # by Chao.
55 | result = result_.type(torch.cuda.FloatTensor)
56 | target = target_.type(torch.cuda.FloatTensor)
57 | else:
58 | result = result_.type(torch.FloatTensor)
59 | target = target_.type(torch.FloatTensor)
60 |
61 | if save:
62 | self.save_for_backward(result, target)
63 |
64 | self.intersect = torch.zeros(input.shape[0])
65 | self.union = torch.zeros(input.shape[0])
66 | dice = torch.zeros(input.shape[0])
67 | if input.is_cuda:
68 | self.intersect = self.intersect.cuda()
69 | self.union = self.union.cuda()
70 | dice = dice.cuda()
71 | for i in range(input.shape[0]):
72 | self.intersect[i] = torch.dot(result[i, :], target[i, :])
73 | # binary values so sum the same as sum of squares
74 | result_sum = torch.sum(result[i, :])
75 | target_sum = torch.sum(target[i, :])
76 | self.union[i] = result_sum + target_sum
77 |
78 | # the target volume can be empty - so we still want to
79 | # end up with a score of 1 if the result is 0/0
80 | dice[i] = 2*self.intersect[i] / (self.union[i] + eps)
81 | print('union: {}\t intersect: {}\t dice_coefficient: {:.7f}'.format(str(self.union[i]), str(self.intersect[i]), dice[i])) # target_sum: {:.0f} pred_sum: {:.0f}; target_sum, result_sum,
82 |
83 | # intersect = torch.dot(result, target)
84 | # # binary values so sum the same as sum of squares
85 | # result_sum = torch.sum(result)
86 | # target_sum = torch.sum(target)
87 | # union = result_sum + target_sum
88 | #
89 | # # the target volume can be empty - so we still want to
90 | # # end up with a score of 1 if the result is 0/0
91 | # dice = 2*intersect / (union + eps)
92 |
93 | # batch mean dice
94 | sumDice = torch.sum(dice)
95 |
96 | out = torch.FloatTensor(1).fill_(sumDice)
97 | if input.is_cuda:
98 | out = out.cuda() # added by Chao.
99 | return out
100 |
101 | def backward(self, grad_output): # Update the weights of the network, typically using a simple update rule: weight = weight - learning_rate * gradient (refer: https://seba-1511.github.io/tutorials/beginner/blitz/neural_networks_tutorial.html )
102 | # print("grad_output:{}".format(grad_output))
103 | # why fix grad_output:tensor([1.])??? By Chao.
104 | input, target = self.saved_tensors
105 | intersect, union = self.intersect, self.union
106 |
107 | grad_input = torch.zeros(target.shape[0], 2, target.shape[1])
108 | if input.is_cuda:
109 | grad_input = grad_input.cuda()
110 | # pdb.set_trace()
111 | for i in range(input.shape[0]):
112 | part1 = torch.div(target[i,:], union[i])
113 | part2_2 = intersect[i] / (union[i] * union[i])
114 | part2 = torch.mul(input[i,:], part2_2)
115 | dDice = torch.add(torch.mul(part1, 2), torch.mul(part2, -4))
116 | if input.is_cuda:
117 | dDice = dDice.cuda()
118 | grad_input[i,0,:] = torch.mul(dDice, grad_output[0])
119 | grad_input[i,1,:] = torch.mul(dDice, -grad_output[0])
120 |
121 | return grad_input, None # Return None for the gradient of values that don’t actually need gradients
122 |
123 | def dice_loss(input, target):
124 | return DiceLoss()(input, target)
125 |
126 | def dice_error(input, target):
127 | eps = 0.00001
128 | _, result_ = input.max(1)
129 | result_ = torch.squeeze(result_)
130 | if input.is_cuda:
131 | result = torch.cuda.FloatTensor(result_.size())
132 | target_ = torch.cuda.FloatTensor(target.size())
133 | else:
134 | result = torch.FloatTensor(result_.size())
135 | target_ = torch.FloatTensor(target.size())
136 | result.copy_(result_.data)
137 | target_.copy_(target.data)
138 | target = target_
139 | intersect = torch.dot(result, target)
140 |
141 | result_sum = torch.sum(result)
142 | target_sum = torch.sum(target)
143 | union = result_sum + target_sum
144 | intersect = np.max([eps, intersect])
145 | # the target volume can be empty - so we still want to
146 | # end up with a score of 1 if the result is 0/0
147 | dice = 2*intersect / (union + eps)
148 | # print('union: {:.3f}\t intersect: {:.6f}\t target_sum: {:.0f} IoU: result_sum: {:.0f} IoU {:.7f}'.format(
149 | # union, intersect, target_sum, result_sum, 2*IoU))
150 | return dice
151 |
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 | #!/anaconda/envs/torvnet python3
2 |
3 | import sys
4 | import os
5 | import argparse
6 |
7 | import numpy as np
8 |
9 | import train
10 |
11 | basePath = os.getcwd()
12 |
13 | params = dict()
14 | params['DataManagerParams'] = dict()
15 | params['ModelParams'] = dict()
16 |
17 | # params of the algorithm
18 | # params['ModelParams']['numcontrolpoints'] = 2 # for B-spline free-form deformation??? what are the method details? By Chao.
19 | params['ModelParams']['sigma'] = 15 # used to produce randomly deformed images in data augmentation
20 | # params['ModelParams']['device'] = 0
21 | # params['ModelParams']['snapshot'] = 0
22 | params['ModelParams']['task'] = 'promise12'
23 | params['ModelParams']['dirTrainImage'] = os.path.join(basePath,'dataset/imagesTr') # if 'dirTest' is empty, denotes 'path to a dataset that will later be split into trainSet and testSet. Otherwise, denotes just trainSet.
24 | params['ModelParams']['dirTrainLabel'] = os.path.join(basePath,'dataset/labelsTr')
25 | params['ModelParams']['dirTestImage'] = '' # path to test images
26 | params['ModelParams']['dirTestLabel'] = '' # path to test labels
27 | # params['ModelParams']['testProp'] = 0.2 # if 'dirTestImage' or 'dirTestLabel' is empty, split 'dirTrainImage' and 'dirTrainLabel' into train and test
28 | params['ModelParams']['dirInferImage'] = os.path.join(basePath,'dataset/imagesTs') # used for inference, usually no labels provided.
29 | params['ModelParams']['dirResult'] = os.path.join(basePath,'results') # where we need to save the results (relative to the base path)
30 | # params['ModelParams']['dirSnapshots'] = os.path.join(basePath,'Models/MRI_cinque_snapshots/') # where to save the models while training
31 | params['ModelParams']['nProc'] = 4 # the number of threads to do data augmentation
32 |
33 |
34 | #params of the DataManager
35 | # params['DataManagerParams']['dstRes'] = np.asarray([1,1,1.5],dtype=float)
36 | # params['DataManagerParams']['VolSize'] = np.asarray([128, 128, 64],dtype=int)
37 | params['DataManagerParams']['normDir'] = False # if rotates the volume according to its transformation in the mhd file. Not reccommended.
38 |
39 | print('\n+preset parameters:\n' + str(params))
40 |
41 |
42 | # parse sys.argv
43 | parser = argparse.ArgumentParser()
44 | parser.add_argument('--numcontrolpoints', type=int, default=2) # for B-spline free-form deformation??? what are the method details? By Chao.
45 | parser.add_argument('--testProp', type=float, default=0.2) # if 'dirTestImage' or 'dirTestLabel' is empty, split 'dirTrainImage' and 'dirTrainLabel' into train and test
46 | parser.add_argument('--dstRes', type=str, default='[1, 1, 1.5]')
47 | parser.add_argument('--VolSize', type=str, default='[128, 128, 64]')
48 |
49 | parser.add_argument('--batchsize', type=int, default=2)
50 | parser.add_argument('--numIterations', type=int, default=1000) # the number of iterations, used by https://github.com/faustomilletari/VNet, as only one Epoch run.
51 | parser.add_argument('--baseLR', type=float, default=0.0001) # the learning rate, initial one
52 | parser.add_argument('--momentum', type=float, default=0.99)
53 | parser.add_argument('--weight_decay', '--wd', default=1e-8, type=float,
54 | metavar='W', help='weight decay (default: 1e-8)')
55 | parser.add_argument('--stepsize', type=int, default=20000)
56 | parser.add_argument('--gamma', type=float, default=0.1)
57 |
58 | parser.add_argument('--seed', type=int, default=1)
59 | parser.add_argument('--opt', type=str, default='adam',
60 | choices=('sgd', 'adam', 'rmsprop'))
61 |
62 | parser.add_argument('--dice', action='store_true', default=True)
63 | parser.add_argument('--gpu_ids', type=int, default=1) # what if multiple gpu ids? use list? by Chao.
64 | parser.add_argument('--nEpochs', type=int, default=1) # line "dataQueue_tmp = dataQueue" in train.py is not working for epoch=2 and so on. Why? By Chao.
65 | parser.add_argument('--xLabel', type=str, default='Iteration', help='x-axis label for training performance transition curve, accepts "Epoch" or "Iteration"')
66 | parser.add_argument('--start-epoch', default=0, type=int, metavar='N',
67 | help='manual epoch number (useful on restarts)')
68 | parser.add_argument('--resume', default='', type=str, metavar='PATH',
69 | help='path to latest checkpoint (default: none)')
70 | parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true',
71 | help='evaluate model on validation set')
72 |
73 | parser.add_argument('--no-cuda', action='store_true', default=False)
74 |
75 | args = parser.parse_args()
76 |
77 | print('\n+sys arguments:\n' + str(args))
78 |
79 | # load dataset, train, test(i.e. output predicted mask for test data in .mhd)
80 | train.main(params, args)
81 |
82 |
--------------------------------------------------------------------------------
/make_graph.py:
--------------------------------------------------------------------------------
1 | # derived from https://github.com/szagoruyko/pytorchviz/blob/master/torchviz/dot.py
2 | from collections import namedtuple
3 | from distutils.version import LooseVersion
4 | from graphviz import Digraph
5 | import torch
6 | from torch.autograd import Variable
7 | from subprocess import check_call
8 | from os.path import splitext
9 |
10 |
11 | Node = namedtuple('Node', ('name', 'inputs', 'attr', 'op'))
12 |
13 |
14 | def make_dot(fname, var, params=None):# by Chao
15 | """ Produces Graphviz representation of PyTorch autograd graph.
16 |
17 | Blue nodes are the Variables that require grad, orange are Tensors
18 | saved for backward in torch.autograd.Function
19 |
20 | Args:
21 | var: output Variable
22 | params: dict of (name, Variable) to add names to node that
23 | require grad (TODO: make optional)
24 | """
25 | if params is not None:
26 | assert all(isinstance(p, Variable) for p in params.values())
27 | param_map = {id(v): k for k, v in params.items()}
28 |
29 | node_attr = dict(style='filled',
30 | shape='box',
31 | align='left',
32 | fontsize='12',
33 | ranksep='0.1',
34 | height='0.2')
35 | dot = Digraph(node_attr=node_attr, graph_attr=dict(size="12,12"))
36 | seen = set()
37 |
38 | def size_to_str(size):
39 | return '(' + (', ').join(['%d' % v for v in size]) + ')'
40 |
41 | output_nodes = (var.grad_fn,) if not isinstance(var, tuple) else tuple(v.grad_fn for v in var)
42 |
43 | def add_nodes(var):
44 | if var not in seen:
45 | if torch.is_tensor(var):
46 | # note: this used to show .saved_tensors in pytorch0.2, but stopped
47 | # working as it was moved to ATen and Variable-Tensor merged
48 | dot.node(str(id(var)), size_to_str(var.size()), fillcolor='orange')
49 | elif hasattr(var, 'variable'):
50 | u = var.variable
51 | name = param_map[id(u)] if params is not None else ''
52 | node_name = '%s\n %s' % (name, size_to_str(u.size()))
53 | dot.node(str(id(var)), node_name, fillcolor='lightblue')
54 | elif var in output_nodes:
55 | dot.node(str(id(var)), str(type(var).__name__), fillcolor='darkolivegreen1')
56 | else:
57 | dot.node(str(id(var)), str(type(var).__name__))
58 | seen.add(var)
59 | if hasattr(var, 'next_functions'):
60 | for u in var.next_functions:
61 | if u[0] is not None:
62 | dot.edge(str(id(u[0])), str(id(var)))
63 | add_nodes(u[0])
64 | if hasattr(var, 'saved_tensors'):
65 | for t in var.saved_tensors:
66 | dot.edge(str(id(t)), str(id(var)))
67 | add_nodes(t)
68 |
69 | # handle multiple outputs
70 | if isinstance(var, tuple):
71 | for v in var:
72 | add_nodes(v.grad_fn)
73 | else:
74 | add_nodes(var.grad_fn)
75 |
76 | resize_graph(dot)
77 | dot.save(fname) # by Chao.
78 | filename, ext = splitext(fname)
79 | check_call(['dot','-Tpng', fname, '-o', filename+'.png'])
80 | # return dot # by Chao.
81 |
82 |
83 | # For traces
84 |
85 | def replace(name, scope):
86 | return '/'.join([scope[name], name])
87 |
88 |
89 | def parse(graph):
90 | scope = {}
91 | for n in graph.nodes():
92 | inputs = [i.uniqueName() for i in n.inputs()]
93 | for i in range(1, len(inputs)):
94 | scope[inputs[i]] = n.scopeName()
95 |
96 | uname = next(n.outputs()).uniqueName()
97 | assert n.scopeName() != '', '{} has empty scope name'.format(n)
98 | scope[uname] = n.scopeName()
99 | scope['0'] = 'input'
100 |
101 | nodes = []
102 | for n in graph.nodes():
103 | attrs = {k: n[k] for k in n.attributeNames()}
104 | attrs = str(attrs).replace("'", ' ')
105 | inputs = [replace(i.uniqueName(), scope) for i in n.inputs()]
106 | uname = next(n.outputs()).uniqueName()
107 | nodes.append(Node(**{'name': replace(uname, scope),
108 | 'op': n.kind(),
109 | 'inputs': inputs,
110 | 'attr': attrs}))
111 |
112 | for n in graph.inputs():
113 | uname = n.uniqueName()
114 | if uname not in scope.keys():
115 | scope[uname] = 'unused'
116 | nodes.append(Node(**{'name': replace(uname, scope),
117 | 'op': 'Parameter',
118 | 'inputs': [],
119 | 'attr': str(n.type())}))
120 |
121 | return nodes
122 |
123 |
124 | def make_dot_from_trace(trace):
125 | """ Produces graphs of torch.jit.trace outputs
126 |
127 | Example:
128 | >>> trace, = torch.jit.trace(model, args=(x,))
129 | >>> dot = make_dot_from_trace(trace)
130 | """
131 | # from tensorboardX
132 | if LooseVersion(torch.__version__) >= LooseVersion("0.4.1"):
133 | torch.onnx._optimize_trace(trace, torch._C._onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK)
134 | elif LooseVersion(torch.__version__) >= LooseVersion("0.4"):
135 | torch.onnx._optimize_trace(trace, False)
136 | else:
137 | torch.onnx._optimize_trace(trace)
138 | graph = trace.graph()
139 | list_of_nodes = parse(graph)
140 |
141 | node_attr = dict(style='filled',
142 | shape='box',
143 | align='left',
144 | fontsize='12',
145 | ranksep='0.1',
146 | height='0.2')
147 |
148 | dot = Digraph(node_attr=node_attr, graph_attr=dict(size="12,12"))
149 |
150 | for node in list_of_nodes:
151 | dot.node(node.name, label=node.name.replace('/', '\n'))
152 | if node.inputs:
153 | for inp in node.inputs:
154 | dot.edge(inp, node.name)
155 |
156 | resize_graph(dot)
157 |
158 | return dot
159 |
160 |
161 | def resize_graph(dot, size_per_element=0.15, min_size=12):
162 | """Resize the graph according to how much content it contains.
163 |
164 | Modify the graph in place.
165 | """
166 | # Get the approximate number of nodes and edges
167 | num_rows = len(dot.body)
168 | content_size = num_rows * size_per_element
169 | size = max(min_size, content_size)
170 | size_str = str(size) + "," + str(size)
171 | dot.graph_attr.update(size=size_str)
172 |
--------------------------------------------------------------------------------
/miscs_promise12.py:
--------------------------------------------------------------------------------
1 | import os
2 | from os import listdir
3 | from os.path import isfile, join, splitext, basename
4 | import csv
5 |
6 | import numpy as np
7 | import SimpleITK as sitk
8 | from glob import glob
9 |
10 | msdPath = '/Volumes/17816861965/msd'
11 |
12 | # taskF = 'Task02_Heart'
13 | meta = dict()
14 |
15 | meta['task'] = list()
16 | for taskF in listdir(msdPath):
17 | for f in glob(join(msdPath, taskF, 'imagesTr', '*.nii.gz')):
18 | meta['task'].append(taskF)
19 |
20 | for taskF in listdir(msdPath):
21 | for f in glob(join(msdPath, taskF, 'imagesTr', '*.nii.gz')):
22 | # f = '/Volumes/17816861965/msd/Task02_Heart/imagesTr/la_003.nii.gz'
23 | sitk_image = sitk.ReadImage(f)
24 | meta['task'] = taskF
25 | for key in ['bitpix', 'datatype', 'dim[0]', 'dim[1]', 'dim[2]', 'dim[3]', 'dim[4]', 'dim[5]', 'dim[6]', 'dim[7]', 'dim_info', 'pixdim[0]', 'pixdim[1]', 'pixdim[2]', 'pixdim[3]', 'pixdim[4]', 'pixdim[5]', 'pixdim[6]', 'pixdim[7]', 'scl_inter', 'scl_slope', 'srow_x', 'srow_y', 'srow_z']: # part of sitk_image.GetMetaDataKeys()
26 | meta[key] = sitk_image.GetMetaData(key)
27 | # stats = sitk.StatisticsImageFilter()
28 | # stats.Execute(sitk_image) # sitk::ERROR: Pixel type: 32-bit float is not supported in 4D by N3itk6simple21StatisticsImageFilterE or SimpleITK compiled with SimpleITK_4D_IMAGES set to OFF.
29 | # meta['intensity_max'] = stats.GetMaximum()
30 | # meta['intensity_min'] = stats.GetMinimum()
31 |
32 |
33 | def get_size_spacing(imageFolder, resultFolder, resultTag):
34 | imagesTrList = glob(imageFolder + '*.mhd')
35 | trainF = open(os.path.join(resultFolder, resultTag + '_size_spacing.csv'), 'w')
36 | trainF.write('{}, {}, {}, {}, {}, {}, {}\n'.format('id', 'width', 'height', 'depth', 'pixel0', 'pixel1', 'pixel2'))
37 | for i in imagesTrList:
38 | sitk_image = sitk.ReadImage(i)
39 | width, height, depth = sitk_image.GetSize()
40 | pixel0, pixel1, pixel2 = sitk_image.GetSpacing()
41 | trainF.write('{}, {}, {}, {}, {}, {}, {}\n'.format(basename(i).split('.')[0], width, height, depth, pixel0, pixel1, pixel2))
42 | trainF.close()
43 |
44 |
45 | get_size_spacing('/Users/messi/PycharmProjects/promise12/dataset/imagesTr/', '/Users/messi/PycharmProjects/promise12/dataset/', 'imagesTr')
46 | get_size_spacing('/Users/messi/PycharmProjects/promise12/dataset/imagesTs/', '/Users/messi/PycharmProjects/promise12/dataset/', 'imagesTs')
--------------------------------------------------------------------------------
/plot.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 |
3 | import argparse
4 | import os
5 | import numpy as np
6 | import pandas as pd
7 |
8 | import matplotlib as mpl
9 | mpl.use('Agg')
10 | import matplotlib.pyplot as plt
11 | plt.style.use('bmh')
12 | from matplotlib import rcParams
13 | rcParams.update({'figure.autolayout':True})
14 |
15 | def main():
16 | parser = argparse.ArgumentParser()
17 | parser.add_argument('xLabel', type=str)
18 | parser.add_argument('expDir', type=str)
19 | args = parser.parse_args()
20 |
21 | xLabel = args.xLabel
22 | expDir = args.expDir
23 |
24 | trainP = os.path.join(expDir, 'train.csv')
25 | trainData = pd.read_csv(trainP, header=None)
26 |
27 | trainI, trainDice, trainErr = trainData.iloc[:,[0]], trainData.iloc[:,[1]], trainData.iloc[:,[2]]
28 |
29 | fig, ax = plt.subplots(1, 1, figsize=(6, 5))
30 | plt.plot(trainI, trainDice, label='Train')
31 | # plt.plot(range(len(testI)), testDice, label='Test')
32 | plt.xlabel(xLabel)
33 | plt.ylabel('Dice coefficient')
34 | plt.legend()
35 | ax.set_yscale('linear')
36 | dice_fname = os.path.join(expDir, 'dice.png')
37 | plt.savefig(dice_fname)
38 | print('Created {}'.format(dice_fname))
39 |
40 | fig, ax = plt.subplots(1, 1, figsize=(6, 5))
41 | plt.plot(trainI, trainErr, label='Train')
42 | # plt.plot(range(len(testI)), testErr, label='Test')
43 | plt.xlabel(xLabel)
44 | plt.ylabel('Error')
45 | ax.set_yscale('linear')
46 | plt.legend()
47 | err_fname = os.path.join(expDir, 'error.png')
48 | plt.savefig(err_fname)
49 | print('Created {}'.format(err_fname))
50 |
51 | dice_err_fname = os.path.join(expDir, 'dice-error.png')
52 | os.system('convert +append {} {} {}'.format(dice_fname, err_fname, dice_err_fname))
53 | print('Created {}'.format(dice_err_fname))
54 |
55 | if __name__ == '__main__':
56 | main()
57 |
--------------------------------------------------------------------------------
/plot_contours.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 |
3 | import os
4 | from os.path import basename
5 |
6 | import SimpleITK as sitk
7 | import numpy as np
8 | import glob
9 | from skimage import measure, draw
10 | from PIL import Image
11 |
12 | import matplotlib.pyplot as plt
13 | plt.style.use('bmh') # plot style
14 |
15 | import matplotlib as mpl
16 | mpl.use('Agg') # non-interactive backend
17 |
18 | from matplotlib import rcParams
19 | rcParams.update({'figure.autolayout':True})
20 |
21 | def plot_contours(imagePath, predPath, expFileName, expDir):
22 |
23 | sitk_image = sitk.ReadImage(imagePath)
24 | image = np.transpose(sitk.GetArrayFromImage(sitk_image), [2, 1, 0]) # image.shape 320, 320, 130
25 | sitk_gt = sitk.ReadImage(gtPath)
26 | gt = np.transpose(sitk.GetArrayFromImage(sitk_gt), [2, 1, 0]) # gt.shape 320, 320, 130
27 | sitk_pred = sitk.ReadImage(predPath)
28 | pred = np.transpose(sitk.GetArrayFromImage(sitk_pred), [2, 1, 0]) # pred.shape ???
29 | print('\nimage shape:{}\ngt shape:{}\npred shape:{}'.format(image.shape, gt.shape, pred.shape))
30 |
31 | plt.figure(figsize=(20,40)) # need to custom for how many images to show
32 | gtPositions = [i for i in range(image.shape[2]) if np.max(gt[...,i])==1]
33 | maskPositions = [i for i in range(image.shape[2]) if np.max(pred[...,i])==1]
34 | if gtPositions[0] in maskPositions:
35 | start = gtPositions[0]
36 | elif maskPositions[0] in gtPositions:
37 | start = maskPositions[0]
38 |
39 | if start:
40 | for i in range(start, start + 2):
41 | plt.subplot(2, 1, i + 1 - start)
42 | gt_contour = measure.find_contours(gt[..., i], 0.5)[0]
43 | pred_contour = measure.find_contours(pred[..., i], 0.5)[0]
44 | plt.imshow(image[..., i], cmap="gray")
45 | plt.plot(gt_contour[:, 1], gt_contour[:, 0], c='r', linewidth=4)
46 | plt.plot(pred_contour[:, 1], pred_contour[:, 0], c='g', linewidth=4)
47 |
48 | plt.axis('off')
49 | fname = os.path.join(expDir, expFileName)
50 | plt.savefig(fname) # need to be above plt.show()
51 | plt.show()
52 |
53 | Image.open(fname).rotate(270, expand=True).save(
54 | fname) # rotate the final image with 90 degree to get a normal view
55 | print('Created contoured image for {}'.format(expFileName))
56 | else:
57 | print('OMG!!: Ground truth and predicted mask are not overlapped for {}'.format(expFileName))
58 |
59 |
60 | ####################################################################
61 | resultDir = '/Users/messi/Downloads/results/vnet.base.promise12.20180901_1623_RMSprop/'
62 | cases = [basename(i).split('_')[0] for i in glob.glob(resultDir+'*test*.mhd')]
63 |
64 | for case in cases:
65 | imagePath = '/Users/messi/PycharmProjects/promise12/dataset/imagesTr/{}.mhd'.format(case)
66 | gtPath = '/Users/messi/PycharmProjects/promise12/dataset/labelsTr/{}_segmentation.mhd'.format(case)
67 | predPath = '/Users/messi/Downloads/results/vnet.base.promise12.20180901_1623_RMSprop/{}_tested.mhd'.format(case)
68 | expDir = '/Users/messi/Downloads/'
69 | expFileName = '{}_adam.png'.format(case)
70 |
71 | plot_contours(imagePath, gtPath, predPath, expFileName, expDir)
72 |
73 |
74 | ##################################################################
75 | ##################################################################
76 | # for test data which has no gt
77 | def plot_infer_contours(imagePath, predPath, expFileName, expDir):
78 |
79 | sitk_image = sitk.ReadImage(imagePath)
80 | image = np.transpose(sitk.GetArrayFromImage(sitk_image), [2, 1, 0]) # image.shape 320, 320, 130
81 | sitk_pred = sitk.ReadImage(predPath)
82 | pred = np.transpose(sitk.GetArrayFromImage(sitk_pred), [2, 1, 0]) # pred.shape ???
83 | print('\nimage shape:{}\npred shape:{}'.format(image.shape, pred.shape))
84 |
85 | plt.figure(figsize=(20,40)) # need to custom for how many images to show
86 | maskPositions = [i for i in range(image.shape[2]) if np.max(pred[...,i])==1]
87 | start = maskPositions[0]
88 |
89 | if start:
90 | for i in range(start, start + 2):
91 | plt.subplot(2, 1, i + 1 - start)
92 | pred_contour = measure.find_contours(pred[..., i], 0.5)[0]
93 | plt.imshow(image[..., i], cmap="gray")
94 | plt.plot(pred_contour[:, 1], pred_contour[:, 0], c='g', linewidth=4)
95 |
96 | plt.axis('off')
97 | fname = os.path.join(expDir, expFileName)
98 | plt.savefig(fname) # need to be above plt.show()
99 | plt.show()
100 |
101 | Image.open(fname).rotate(270, expand=True).save(
102 | fname) # rotate the final image with 90 degree to get a normal view
103 | print('Created contoured image for {}'.format(expFileName))
104 | else:
105 | print('OMG!!: Ground truth and predicted mask are not overlapped for {}'.format(expFileName))
106 |
107 |
108 | ####################################################################
109 | resultDir = '/Users/messi/Downloads/results/vnet.base.promise12.20180906_1419/'
110 | cases = [basename(i).split('_')[0] for i in glob.glob(resultDir+'*infer*.mhd')]
111 |
112 | for case in cases:
113 | imagePath = os.path.join('/Users/messi/PycharmProjects/promise12/dataset/imagesTs/', '{}.mhd'.format(case))
114 | predPath = os.path.join(resultDir, '{}_inferred.mhd'.format(case))
115 | expDir = '/Users/messi/Downloads/results/'
116 | expFileName = '{}_adam.png'.format(case)
117 |
118 | plot_infer_contours(imagePath, predPath, expFileName, expDir)
119 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | import os
3 | from os.path import splitext
4 | from random import sample
5 | import time
6 | from multiprocessing import Process, Queue
7 | import pdb
8 |
9 | import shutil
10 | import setproctitle
11 | import numpy as np
12 |
13 | import torch
14 | import torch.nn as nn
15 | import torch.optim as optim
16 | from torch.autograd import Variable
17 | from torch.utils.data import DataLoader
18 | import torchvision.transforms as transforms
19 |
20 | import lossFuncs
21 | import utils as utils
22 |
23 | import vnet
24 | import DataManager as DM
25 | import customDataset
26 | import make_graph
27 |
28 | def weights_init(m):
29 | classname = m.__class__.__name__
30 | if classname.find('Conv3d') != -1:
31 | nn.init.kaiming_normal_(m.weight)
32 | m.bias.data.zero_()
33 |
34 | def datestr():
35 | now = time.gmtime()
36 | return '{}{:02}{:02}_{:02}{:02}'.format(now.tm_year, now.tm_mon, now.tm_mday, now.tm_hour, now.tm_min)
37 |
38 |
39 | def save_checkpoint(state, path, prefix, filename='checkpoint.pth.tar'):
40 | prefix_save = os.path.join(path, prefix)
41 | name = prefix_save + '_' + filename
42 | torch.save(state, name)
43 |
44 | def train_test_split(images, labels, test_proportion):
45 | # images and labels are both dict().
46 | # pdb.set_trace()
47 | keys = list(images.keys())
48 | size = len(keys)
49 | test_keys = sample(keys, int(test_proportion*size))
50 | test_images = {i: images[i] for i in keys if i in test_keys}
51 | test_labels = {i+'_segmentation': labels[i+'_segmentation'] for i in keys if i in test_keys} # require customization
52 | train_images = {i: images[i] for i in keys if i not in test_keys}
53 | train_labels = {i+'_segmentation': labels[i+'_segmentation'] for i in keys if i not in test_keys} # require customization
54 | return train_images, train_labels, test_images, test_labels
55 |
56 | def dataAugmentation(params, args, dataQueue, numpyImages, numpyGT):
57 |
58 | nr_iter = args.numIterations # params['ModelParams']['numIterations']
59 | batchsize = args.batchsize # params['ModelParams']['batchsize']
60 |
61 | # pdb.set_trace()
62 | keysIMG = list(numpyImages.keys())
63 |
64 | nr_iter_dataAug = nr_iter*batchsize
65 | np.random.seed(1)
66 | whichDataList = np.random.randint(len(keysIMG), size=int(nr_iter_dataAug/params['ModelParams']['nProc']))
67 | np.random.seed(11)
68 | whichDataForMatchingList = np.random.randint(len(keysIMG), size=int(nr_iter_dataAug/params['ModelParams']['nProc']))
69 |
70 | for whichData,whichDataForMatching in zip(whichDataList,whichDataForMatchingList):
71 |
72 | currImgKey = keysIMG[whichData]
73 | currGtKey = keysIMG[whichData] + '_segmentation' # require customization. This is for PROMISE12 data.
74 | # print("keysIMG type:{}\nkeysIMG:{}".format(type(keysIMG),str(keysIMG)))
75 | # print("whichData:{}".format(whichData))
76 | # pdb.set_trace()
77 | # currImgKey = keysIMG[whichData]
78 | # currGtKey = keysIMG[whichData] # for MSD data.
79 |
80 | # data agugumentation through hist matching across different examples...
81 | ImgKeyMatching = keysIMG[whichDataForMatching]
82 |
83 | defImg = numpyImages[currImgKey]
84 | defLab = numpyGT[currGtKey]
85 |
86 | defImg = utils.hist_match(defImg, numpyImages[ImgKeyMatching]) # why do histogram matching for all images? By Chao.
87 |
88 | if(np.random.rand(1)[0]>0.5): #do not apply deformations always, just sometimes
89 | defImg, defLab = utils.produceRandomlyDeformedImage(defImg, defLab, args.numcontrolpoints, params['ModelParams']['sigma'])
90 |
91 | dataQueue.put(tuple((defImg, defLab)))
92 |
93 | def adjust_opt(optAlg, optimizer, iteration):
94 | if optAlg == 'sgd':
95 | if epoch < 150:
96 | lr = 1e-1
97 | elif epoch == 150:
98 | lr = 1e-2
99 | elif epoch == 225:
100 | lr = 1e-3
101 | else:
102 | return
103 |
104 | for param_group in optimizer.param_groups:
105 | param_group['lr'] = lr
106 |
107 |
108 | def train_dice(args, epoch, iteration, model, trainLoader, optimizer, trainF):
109 | model.train()
110 | nProcessed = 0
111 | batch_size = len(trainLoader.dataset)
112 | for batch_idx, output in enumerate(trainLoader):
113 | data, target = output # data shape [batch_size, channels, z, y, x], output shape [batch_size, z, y, x]
114 | # pdb.set_trace()
115 | if args.cuda:
116 | data, target = data.cuda(), target.cuda()
117 |
118 | data = Variable(data)
119 | target = Variable(target)
120 | # data, target = Variable(data), Variable(target)
121 | optimizer.zero_grad()
122 | output = model(data) # output shape[batch_size, 2, z*y*x]
123 | # print("data shape:{}\noutput shape:{}\ntarget shape:{}".format(data.shape, output.shape, target.shape))
124 | loss = lossFuncs.dice_loss(output, target)
125 | # make_graph.make_dot(os.path.join(resultDir, 'promise_net_graph.dot'), loss)
126 | loss.backward()
127 | optimizer.step()
128 | nProcessed += len(data)
129 | diceOvBatch = loss.data[0]/batch_size # loss.data[0] is sum of dice coefficient over a mini-batch. By Chao.
130 | err = 100.*(1. - diceOvBatch)
131 |
132 | if np.mod(iteration, 10) == 0:
133 | print('\nFor trainning: epoch: {} iteration: {} \tdice_coefficient over batch: {:.4f}\tError: {:.4f}\n'.format(epoch, iteration, diceOvBatch, err))
134 |
135 | return diceOvBatch, err
136 |
137 |
138 | def test_dice(dataManager, args, epoch, model, testLoader, testF, resultDir):
139 | '''
140 | :param dataManager: contains self.sitkImages which is a dict of test sitk images or all sitk images including test sitk images.
141 | :param args:
142 | :param epoch:
143 | :param model:
144 | :param testLoader:
145 | :param testF: path to file recording test results.
146 | :return:
147 | '''
148 | model.eval()
149 | test_dice = 0
150 | incorrect = 0
151 | # assume single GPU/batch_size =1
152 | # pdb.set_trace()
153 | for batch_idx, data in enumerate(testLoader):
154 | data, target, id = data
155 | # print("testing with {}".format(id[0]))
156 | if args.cuda:
157 | data, target = data.cuda(), target.cuda()
158 | data = Variable(data)
159 | target = Variable(target)
160 | output = model(data)
161 | dice = lossFuncs.dice_loss(output, target).data[0]
162 | test_dice += dice
163 | incorrect += (1. - dice)
164 |
165 | # pdb.set_trace()
166 | _, _, z, y, x = data.shape # need to squeeze to shape of 3-d. by Chao.
167 | output = output[0,...] # assume batch_size = 1
168 | _, output = output.max(0)
169 | output = output.view(z, y, x)
170 | output = output.cpu()
171 | # In numpy, an array is indexed in the opposite order (z,y,x) while sitk will generate the sitk image in (x,y,z). (refer: http://insightsoftwareconsortium.github.io/SimpleITK-Notebooks/Python_html/01_Image_Basics.html)
172 | output = output.numpy()
173 | output = np.transpose(output, [2,1,0]) # change to simpleITK order (x, y, z)
174 | # pdb.set_trace()
175 | print("save predicted label for test{}".format(id[0]))
176 | dataManager.writeResultsFromNumpyLabel(output, id[0], '_tested_epoch{}'.format(epoch), '.mhd', resultDir) # require customization
177 | testF.write('{},{},{},{}\n'.format(epoch, id[0], dice, 1-dice))
178 |
179 | nTotal = len(testLoader)
180 | test_dice /= nTotal # loss function already averages over batch size
181 | err = 100.*incorrect/nTotal
182 | # if np.mod(iteration, 10) == 0:
183 | # print('\nFor testing: iteration:{}\tAverage Dice Coeff: {:.4f}\tError:{:.4f}\n'.format(iteration, test_dice, err))
184 |
185 | # testF.write('{},{},{}\n'.format('avarage', test_dice, err))
186 | testF.flush()
187 |
188 |
189 | def inference(dataManager, args, loader, model, resultDir):
190 | model.eval()
191 | # assume single GPU / batch size 1
192 | # pdb.set_trace()
193 | for batch_idx, data in enumerate(loader):
194 | data, id = data
195 | # pdb.set_trace()
196 | # convert names to batch tensor
197 | if args.cuda:
198 | data.pin_memory()
199 | data = data.cuda()
200 | with torch.no_grad():
201 | data = Variable(data)
202 | output = model(data)
203 |
204 | _, _, z, y, x = data.shape # need to subset shape of 3-d. by Chao.
205 | output = output[0,...] # assume batch_size=1
206 | _, output = output.max(0)
207 | output = output.view(z, y, x)
208 | output = output.cpu()
209 | # In numpy, an array is indexed in the opposite order (z,y,x) while sitk will generate the sitk image in (x,y,z). (refer: http://insightsoftwareconsortium.github.io/SimpleITK-Notebooks/Python_html/01_Image_Basics.html)
210 | output = output.numpy()
211 | output = np.transpose(output, [2,1,0]) # change to simpleITK order (x, y, z)
212 | # pdb.set_trace()
213 | print("save predicted label for inference {}".format(id[0]))
214 | dataManager.writeResultsFromNumpyLabel(output, id[0], '_inferred', '.mhd', resultDir) # require customization
215 |
216 |
217 | ## main method
218 | def main(params, args):
219 | best_prec1 = 100. # accuracy? by Chao
220 | epochs = args.nEpochs
221 | nr_iter = args.numIterations # params['ModelParams']['numIterations']
222 | batch_size = args.batchsize # params['ModelParams']['batchsize']
223 | resultDir = 'results/vnet.base.{}.{}'.format(params['ModelParams']['task'], datestr())
224 |
225 | weight_decay = args.weight_decay
226 | setproctitle.setproctitle(resultDir)
227 | if os.path.exists(resultDir):
228 | shutil.rmtree(resultDir)
229 | os.makedirs(resultDir, exist_ok=True)
230 |
231 | args.cuda = not args.no_cuda and torch.cuda.is_available()
232 |
233 | torch.manual_seed(args.seed)
234 | if args.cuda:
235 | torch.cuda.manual_seed(args.seed)
236 |
237 | print("build vnet")
238 | model = vnet.VNet(elu=False, nll=False)
239 | gpu_ids = args.gpu_ids
240 | # torch.cuda.set_device(gpu_ids) # why do I have to add this line? It seems the below line is useless to apply GPU devices. By Chao.
241 | # model = nn.parallel.DataParallel(model, device_ids=[gpu_ids])
242 | model = nn.parallel.DataParallel(model)
243 |
244 | if args.resume:
245 | if os.path.isfile(args.resume):
246 | print("=> loading checkpoint '{}'".format(args.resume))
247 | checkpoint = torch.load(args.resume)
248 | args.start_epoch = checkpoint['epoch']
249 | best_prec1 = checkpoint['best_prec1']
250 | model.load_state_dict(checkpoint['state_dict'])
251 | print("=> loaded checkpoint '{}' (epoch {})"
252 | .format(args.evaluate, checkpoint['epoch']))
253 | else:
254 | print("=> no checkpoint found at '{}'".format(args.resume))
255 | else:
256 | model.apply(weights_init)
257 |
258 |
259 | train = train_dice
260 | test = test_dice
261 |
262 | print(' + Number of params: {}'.format(
263 | sum([p.data.nelement() for p in model.parameters()])))
264 | if args.cuda:
265 | model = model.cuda()
266 |
267 | # transform
268 | trainTransform = transforms.Compose([
269 | transforms.ToTensor()
270 | ])
271 | testTransform = transforms.Compose([
272 | transforms.ToTensor()
273 | ])
274 |
275 | if args.opt == 'sgd':
276 | optimizer = optim.SGD(model.parameters(), lr=args.baseLR,
277 | momentum=args.momentum, weight_decay=weight_decay) # params['ModelParams']['baseLR']
278 | elif args.opt == 'adam':
279 | optimizer = optim.Adam(model.parameters(), weight_decay=weight_decay)
280 | elif args.opt == 'rmsprop':
281 | optimizer = optim.RMSprop(model.parameters(), weight_decay=weight_decay)
282 |
283 |
284 | kwargs = {'num_workers': 1, 'pin_memory': True} if args.cuda else {}
285 |
286 | # pdb.set_trace()
287 | DataManagerParams = {'dstRes':np.asarray(eval(args.dstRes), dtype=float), 'VolSize':np.asarray(eval(args.VolSize), dtype=int), 'normDir':params['DataManagerParams']['normDir']}
288 |
289 | if params['ModelParams']['dirTestImage']: # if exists, means test files are given.
290 | print("\nloading training set")
291 | dataManagerTrain = DM.DataManager(params['ModelParams']['dirTrainImage'],
292 | params['ModelParams']['dirTrainLabel'],
293 | params['ModelParams']['dirResult'],
294 | DataManagerParams)
295 | dataManagerTrain.loadTrainingData() # required
296 | train_images = dataManagerTrain.getNumpyImages()
297 | train_labels = dataManagerTrain.getNumpyGT()
298 |
299 | print("\nloading test set")
300 | dataManagerTest = DM.DataManager(params['ModelParams']['dirTestImage'], params['ModelParams']['dirTestLabel'],
301 | params['ModelParams']['dirResult'],
302 | DataManagerParams)
303 | dataManagerTest.loadTestingData() # required
304 | test_images = dataManagerTest.getNumpyImages()
305 | test_labels = dataManagerTest.getNumpyGT()
306 |
307 | testSet = customDataset.customDataset(mode='test', images=test_images, GT=test_labels, transform=testTransform)
308 | testLoader = DataLoader(testSet, batch_size=1, shuffle=True, **kwargs)
309 |
310 | elif args.testProp: # if 'dirTestImage' is not given but 'testProp' is given, means only one data set is given. need to perform train_test_split.
311 | print('\n loading dataset, will split into train and test')
312 | dataManager = DM.DataManager(params['ModelParams']['dirTrainImage'],
313 | params['ModelParams']['dirTrainLabel'],
314 | params['ModelParams']['dirResult'],
315 | DataManagerParams)
316 | dataManager.loadTrainingData() # required
317 | numpyImages = dataManager.getNumpyImages()
318 | numpyGT = dataManager.getNumpyGT()
319 | # pdb.set_trace()
320 |
321 | train_images, train_labels, test_images, test_labels = train_test_split(numpyImages, numpyGT, args.testProp)
322 | testSet = customDataset.customDataset(mode='test', images=test_images, GT=test_labels, transform=testTransform)
323 | testLoader = DataLoader(testSet, batch_size=1, shuffle=True, **kwargs)
324 |
325 | else: # if both 'dirTestImage' and 'testProp' are not given, means the only one dataset provided is used as train set.
326 | print('\n loading only train dataset')
327 | dataManager = DM.DataManager(params['ModelParams']['dirTrainImage'],
328 | params['ModelParams']['dirTrainLabel'],
329 | params['ModelParams']['dirResult'],
330 | DataManagerParams)
331 | dataManager.loadTrainingData() # required
332 | train_images = dataManager.getNumpyImages()
333 | train_labels = dataManager.getNumpyGT()
334 |
335 | test_images = None
336 | test_labels = None
337 | testSet = None
338 | testLoader = None
339 |
340 | if params['ModelParams']['dirTestImage']:
341 | dataManager_toTestFunc = dataManagerTest
342 | else:
343 | dataManager_toTestFunc = dataManager
344 |
345 | ### For train_images and train_labels, starting data augmentation and loading augmented data with multiprocessing
346 | dataQueue = Queue(30) # max 30 images in queue?
347 | dataPreparation = [None] * params['ModelParams']['nProc']
348 |
349 | # processes creation
350 | for proc in range(0, params['ModelParams']['nProc']):
351 | dataPreparation[proc] = Process(target=dataAugmentation,
352 | args=(params, args, dataQueue, train_images, train_labels))
353 | dataPreparation[proc].daemon = True
354 | dataPreparation[proc].start()
355 |
356 | batchData = np.zeros((batch_size, DataManagerParams['VolSize'][0],
357 | DataManagerParams['VolSize'][1],
358 | DataManagerParams['VolSize'][2]), dtype=float)
359 | batchLabel = np.zeros((batch_size, DataManagerParams['VolSize'][0],
360 | DataManagerParams['VolSize'][1],
361 | DataManagerParams['VolSize'][2]), dtype=float)
362 |
363 | trainF = open(os.path.join(resultDir, 'train.csv'), 'w')
364 | testF = open(os.path.join(resultDir, 'test.csv'), 'w')
365 |
366 | for epoch in range(1, epochs+1):
367 | dataQueue_tmp = dataQueue # not working from epoch = 2 and so on. why??? By Chao.
368 | diceOvBatch = 0
369 | err = 0
370 | for iteration in range(1, nr_iter + 1):
371 | # adjust_opt(args.opt, optimizer, iteration+)
372 | if args.opt == 'sgd':
373 | if np.mod(iteration, args.stepsize) == 0:
374 | for param_group in optimizer.param_groups:
375 | param_group['lr'] *= args.gamma
376 |
377 | for i in range(batch_size):
378 | [defImg, defLab] = dataQueue_tmp.get()
379 |
380 | batchData[i, :, :, :] = defImg.astype(dtype=np.float32)
381 | batchLabel[i, :, :, :] = (defLab > 0.5).astype(dtype=np.float32)
382 |
383 | trainSet = customDataset.customDataset(mode='train', images=batchData, GT=batchLabel,
384 | transform=trainTransform)
385 | trainLoader = DataLoader(trainSet, batch_size=batch_size, shuffle=True, **kwargs)
386 |
387 | diceOvBatch_tmp, err_tmp = train(args, epoch, iteration, model, trainLoader, optimizer, trainF)
388 |
389 | if args.xLabel == 'Iteration':
390 | trainF.write('{},{},{}\n'.format(iteration, diceOvBatch_tmp, err_tmp))
391 | trainF.flush()
392 | elif args.xLabel == 'Epoch':
393 | diceOvBatch += diceOvBatch_tmp
394 | err += err_tmp
395 | if args.xLabel == 'Epoch':
396 | trainF.write('{},{},{}\n'.format(epoch, diceOvBatch/nr_iter, err/nr_iter))
397 | trainF.flush()
398 |
399 | if np.mod(epoch, epochs) == 0: # default to set last epoch to save checkpoint
400 | save_checkpoint({'epoch': epoch,
401 | 'state_dict': model.state_dict(),
402 | 'best_prec1': best_prec1}, path=resultDir, prefix="vnet_epoch{}".format(epoch))
403 | if epoch == epochs and testLoader:
404 | test(dataManager_toTestFunc, args, epoch, model, testLoader, testF, resultDir) # by Chao.
405 |
406 | os.system('./plot.py {} {} &'.format(args.xLabel, resultDir))
407 |
408 | trainF.close()
409 | testF.close()
410 |
411 | # inference, i.e. output predicted mask for test data in .mhd
412 | if params['ModelParams']['dirInferImage'] != '':
413 | print("loading inference data")
414 | dataManagerInfer = DM.DataManager(params['ModelParams']['dirInferImage'], None,
415 | params['ModelParams']['dirResult'],
416 | DataManagerParams)
417 | dataManagerInfer.loadInferData() # required. Create .loadInferData??? by Chao.
418 | numpyImages = dataManagerInfer.getNumpyImages()
419 |
420 | inferSet = customDataset.customDataset(mode='infer', images=numpyImages, GT=None, transform=testTransform)
421 | inferLoader = DataLoader(inferSet, batch_size=1, shuffle=True, **kwargs)
422 | inference(dataManagerInfer, args, inferLoader, model, resultDir)
423 |
424 |
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | import os.path
2 |
3 | import matplotlib.pyplot as plt
4 | import SimpleITK as sitk
5 | from skimage import measure, morphology
6 | import scipy.ndimage
7 | from mpl_toolkits.mplot3d.art3d import Poly3DCollection
8 | import numpy as np
9 | import pydicom
10 |
11 |
12 | Z_MAX = None
13 | Y_MAX = None
14 | X_MAX = None
15 | vox_spacing = None
16 | shape_max = None
17 |
18 | def init_dims3D(z, y, x, spacing):
19 | global Z_MAX, Y_MAX, X_MAX, vox_spacing, shape_max
20 | vox_spacing = spacing
21 | Z_MAX, Y_MAX, X_MAX = z, y, x
22 | shape_max = (z, y, x)
23 |
24 | def debug_img(img):
25 | plt.hist(img.flatten(), bins=80, color='c')
26 | plt.xlabel("Hounsfield Units (HU)")
27 | plt.ylabel("Frequency")
28 | plt.show()
29 |
30 | def plot_3d(image, threshold=-300):
31 | # Position the scan upright,
32 | # so the head of the patient would be at the top facing the camera
33 | p = image.transpose(2,1,0)
34 |
35 | #p = image
36 |
37 | verts, faces = measure.marching_cubes(p, threshold)
38 |
39 | fig = plt.figure(figsize=(10, 10))
40 | ax = fig.add_subplot(111, projection='3d')
41 |
42 | # Fancy indexing: `verts[faces]` to generate a collection of triangles
43 | mesh = Poly3DCollection(verts[faces], alpha=0.70)
44 | face_color = [0.45, 0.45, 0.75]
45 | mesh.set_facecolor(face_color)
46 | ax.add_collection3d(mesh)
47 |
48 | ax.set_xlim(0, p.shape[0])
49 | ax.set_ylim(0, p.shape[1])
50 | ax.set_zlim(0, p.shape[2])
51 |
52 | plt.show()
53 |
54 | def npz_save(name, obj):
55 | keys = list(obj.keys())
56 | values = list(obj.values())
57 | np.savez(name+".npz", keys=keys, values=values)
58 |
59 | def npz_save_compressed(name, obj):
60 | keys = list(obj.keys())
61 | values = list(obj.values())
62 | np.savez_compressed(name+"_compressed.npz", keys=keys, values=values)
63 |
64 | def npz_load(filename):
65 | npzfile = np.load(filename+".npz")
66 | keys = npzfile["keys"]
67 | values = npzfile["values"]
68 | return dict(zip(keys, values))
69 |
70 | def npz_load_compressed(filename):
71 | npzfile = np.load(filename+"_compressed.npz")
72 | keys = npzfile["keys"]
73 | values = npzfile["values"]
74 | return dict(zip(keys, values))
75 |
76 | def copy_slice_centered(dst, src, dim):
77 | if dim <= Y_MAX:
78 | x_start = int((X_MAX - dim) / 2)
79 | y_start = int((Y_MAX - dim) / 2)
80 | for y in range(dim):
81 | for x in range(dim):
82 | dst[y_start+y][x_start+x] = src[y][x]
83 | elif dim <= X_MAX:
84 | x_start = int((X_MAX - dim) / 2)
85 | y_start = int((dim - Y_MAX) / 2)
86 | for y in range(Y_MAX):
87 | for x in range(dim):
88 | dst[y][x_start+x] = src[y_start+y][x]
89 | else:
90 | x_start = int((dim - X_MAX) / 2)
91 | y_start = int((dim - Y_MAX) / 2)
92 | for y in range(Y_MAX):
93 | for x in range(X_MAX):
94 | dst[y][x] = src[y_start+y][x_start+x]
95 |
96 | def copy_normalized(src, dtype=np.int16):
97 | src_shape = np.shape(src)
98 | if src_shape == shape_max:
99 | return src
100 |
101 | (z_axis, y_axis, x_axis) = src_shape
102 | print(src_shape)
103 | assert x_axis == y_axis
104 | new_img = np.full(shape_max, np.min(src), dtype=dtype)
105 | if z_axis < Z_MAX:
106 | start = int((Z_MAX - z_axis) / 2)
107 | for i in range(z_axis):
108 | copy_slice_centered(new_img[start + i], src[i], x_axis)
109 | else:
110 | start = int((z_axis - Z_MAX) / 2)
111 | for i in range(Z_MAX):
112 | copy_slice_centered(new_img[i], src[start+i], x_axis)
113 | return new_img
114 |
115 | def truncate(image, min_bound, max_bound):
116 | image[image < min_bound] = min_bound
117 | image[image > max_bound] = max_bound
118 | return image
119 |
120 | def hist_match(source, template):
121 | """
122 | Adjust the pixel values of a grayscale image such that its histogram
123 | matches that of a target image
124 |
125 | Arguments:
126 | -----------
127 | source: np.ndarray
128 | Image to transform; the histogram is computed over the flattened
129 | array
130 | template: np.ndarray
131 | Template image; can have different dimensions to source
132 | Returns:
133 | -----------
134 | matched: np.ndarray
135 | The transformed output image
136 | """
137 |
138 | oldshape = source.shape
139 | source = source.ravel()
140 | template = template.ravel()
141 |
142 | # get the set of unique pixel values and their corresponding indices and
143 | # counts
144 | s_values, bin_idx, s_counts = np.unique(source, return_inverse=True,
145 | return_counts=True)
146 | t_values, t_counts = np.unique(template, return_counts=True)
147 |
148 | # take the cumsum of the counts and normalize by the number of pixels to
149 | # get the empirical cumulative distribution functions for the source and
150 | # template images (maps pixel value --> quantile)
151 | s_quantiles = np.cumsum(s_counts).astype(np.float64)
152 | s_quantiles /= s_quantiles[-1]
153 | t_quantiles = np.cumsum(t_counts).astype(np.float64)
154 | t_quantiles /= t_quantiles[-1]
155 |
156 | # interpolate linearly to find the pixel values in the template image
157 | # that correspond most closely to the quantiles in the source image
158 | #interp_t_values = np.zeros_like(source,dtype=float)
159 | interp_t_values = np.interp(s_quantiles, t_quantiles, t_values)
160 |
161 | return interp_t_values[bin_idx].reshape(oldshape)
162 |
163 | def sitk_show(nda, title=None, margin=0.0, dpi=40):
164 | figsize = (1 + margin) * nda.shape[0] / dpi, (1 + margin) * nda.shape[1] / dpi
165 |
166 | extent = (0, nda.shape[1], nda.shape[0], 0)
167 | fig = plt.figure(figsize=figsize, dpi=dpi)
168 | ax = fig.add_axes([margin, margin, 1 - 2*margin, 1 - 2*margin])
169 |
170 | plt.set_cmap("gray")
171 | for k in range(0,nda.shape[2]):
172 | print("printing slice "+str(k))
173 | ax.imshow(np.squeeze(nda[:,:,k]),extent=extent,interpolation=None)
174 | plt.draw()
175 | plt.pause(0.1)
176 | #plt.waitforbuttonpress()
177 |
178 | def computeQualityMeasures(lP,lT):
179 | quality=dict()
180 | labelPred=sitk.GetImageFromArray(lP, isVector=False)
181 | labelTrue=sitk.GetImageFromArray(lT, isVector=False)
182 | hausdorffcomputer=sitk.HausdorffDistanceImageFilter()
183 | hausdorffcomputer.Execute(labelTrue>0.5,labelPred>0.5)
184 | quality["avgHausdorff"]=hausdorffcomputer.GetAverageHausdorffDistance()
185 | quality["Hausdorff"]=hausdorffcomputer.GetHausdorffDistance()
186 |
187 | dicecomputer=sitk.LabelOverlapMeasuresImageFilter()
188 | dicecomputer.Execute(labelTrue>0.5,labelPred>0.5)
189 | quality["dice"]=dicecomputer.GetDiceCoefficient()
190 |
191 | return quality
192 |
193 |
194 | def produceRandomlyDeformedImage(image, label, numcontrolpoints, stdDef):
195 | sitkImage=sitk.GetImageFromArray(image, isVector=False)
196 | sitklabel=sitk.GetImageFromArray(label, isVector=False)
197 |
198 | transfromDomainMeshSize=[numcontrolpoints]*sitkImage.GetDimension()
199 |
200 | tx = sitk.BSplineTransformInitializer(sitkImage,transfromDomainMeshSize)
201 |
202 |
203 | params = tx.GetParameters()
204 |
205 | paramsNp=np.asarray(params,dtype=float)
206 | paramsNp = paramsNp + np.random.randn(paramsNp.shape[0])*stdDef
207 |
208 | paramsNp[0:int(len(params)/3)]=0 #remove z deformations! The resolution in z is too bad
209 |
210 | params=tuple(paramsNp)
211 | tx.SetParameters(params)
212 |
213 | resampler = sitk.ResampleImageFilter()
214 | resampler.SetReferenceImage(sitkImage)
215 | resampler.SetInterpolator(sitk.sitkLinear)
216 | resampler.SetDefaultPixelValue(0)
217 | resampler.SetTransform(tx)
218 |
219 | resampler.SetDefaultPixelValue(0)
220 | outimgsitk = resampler.Execute(sitkImage)
221 | outlabsitk = resampler.Execute(sitklabel)
222 |
223 | outimg = sitk.GetArrayFromImage(outimgsitk)
224 | outimg = outimg.astype(dtype=np.float32)
225 |
226 | outlbl = sitk.GetArrayFromImage(outlabsitk)
227 | outlbl = (outlbl>0.5).astype(dtype=np.float32)
228 |
229 | return outimg,outlbl
230 |
231 |
232 | def resample_volume(img, spacing_old, spacing_new, bounds=None):
233 | (z_axis, y_axis, x_axis) = np.shape(img)
234 | print('img: {} old spacing: {} new spacing: {}'.format(np.shape(img), spacing_old, spacing_new))
235 | resize_factor = np.array(spacing_old) / spacing_new
236 | new_shape = np.round(np.shape(img) * resize_factor)
237 | real_resize_factor = new_shape / np.shape(img)
238 | img_rescaled = scipy.ndimage.interpolation.zoom(img, real_resize_factor, mode='nearest').astype(np.int16)
239 | img_array_normalized = copy_normalized(img_rescaled)
240 | img_tmp = img_array_normalized.copy()
241 | # determine what the mean will be on the anticipated value range
242 | mu, var = 0., 0.
243 | if bounds is not None:
244 | min_bound, max_bound = bounds
245 | img_tmp = truncate(img_tmp, min_bound, max_bound)
246 | mu = np.mean(img_tmp)
247 | var = np.var(img_tmp)
248 | return (img_array_normalized, mu, var)
249 |
250 |
251 | def save_image(img_arr, path):
252 | itk_img = sitk.GetImageFromArray(img_arr, isVector=False)
253 | sitk.WriteImage(itk_img, path)
254 |
255 |
256 | def get_subvolume(target, bounds):
257 | (zs, ze), (ys, ye), (xs, xe) = bounds
258 | return np.squeeze(target)[zs:ze, ys:ye, xs:xe]
259 |
260 |
261 | def partition_image(image, partition):
262 | z_p, y_p, x_p = partition
263 | z, y, x = np.shape(np.squeeze(image))
264 | z_incr, y_incr, x_incr = z // z_p, y // y_p, x // x_p
265 | assert z % z_p == 0
266 | assert y % y_p == 0
267 | assert x % x_p == 0
268 | image_list = []
269 | for zi in range(z_p):
270 | zstart = zi*z_incr
271 | zend = zstart + z_incr
272 | for yi in range(y_p):
273 | ystart = yi*y_incr
274 | yend = ystart + y_incr
275 | for xi in range(x_p):
276 | xstart = xi*x_incr
277 | xend = xstart + x_incr
278 | subvolume = get_subvolume(image, ((zstart, zend), (ystart, yend), (xstart, xend)))
279 | subvolume = subvolume.reshape((1, 1, z_incr, y_incr, x_incr))
280 | image_list.append(subvolume)
281 | return image_list
282 |
283 |
284 | def merge_image(image_list, partition):
285 | z_p, y_p, x_p = partition
286 | shape = np.array(np.shape(image_list[0]), dtype=np.int32)
287 | z, y, x = 0, 0, 0
288 | z, y, x = shape * partition
289 | i = 0
290 | z_list = []
291 | for zi in range(z_p):
292 | y_list = []
293 | for yi in range(y_p):
294 | x_list = []
295 | for xi in range(x_p):
296 | x_list.append(image_list[i])
297 | i += 1
298 | y_list.append(np.concatenate(x_list, axis=2))
299 | z_list.append(np.concatenate(y_list, axis=1))
300 | return np.concatenate(z_list)
301 |
--------------------------------------------------------------------------------
/vnet.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 | import pdb
6 |
7 | def passthrough(x, **kwargs):
8 | return x
9 |
10 | def ELUCons(elu, nchan):
11 | if elu:
12 | return nn.ELU(inplace=True)
13 | else:
14 | return nn.PReLU(nchan)
15 |
16 | # normalization between sub-volumes is necessary
17 | # for good performance
18 | class ContBatchNorm3d(nn.modules.batchnorm._BatchNorm):
19 | def _check_input_dim(self, input):
20 | if input.dim() != 5:
21 | raise ValueError('expected 5D input (got {}D input)'
22 | .format(input.dim()))
23 | super()._check_input_dim(input) # added 4 spaces to indent this line into "if". not sure if right by Chao??
24 | # super(ContBatchNorm3d, self)._check_input_dim(input)
25 |
26 | def forward(self, input):
27 | self._check_input_dim(input)
28 | return F.batch_norm(
29 | input, self.running_mean, self.running_var, self.weight, self.bias,
30 | True, self.momentum, self.eps)
31 |
32 |
33 | class LUConv(nn.Module):
34 | def __init__(self, nchan, elu):
35 | super(LUConv, self).__init__()
36 | self.relu1 = ELUCons(elu, nchan)
37 | self.conv1 = nn.Conv3d(nchan, nchan, kernel_size=5, padding=2)
38 | self.bn1 = ContBatchNorm3d(nchan)
39 |
40 | def forward(self, x):
41 | out = self.relu1(self.bn1(self.conv1(x)))
42 | return out
43 |
44 |
45 | def _make_nConv(nchan, depth, elu):
46 | layers = []
47 | for _ in range(depth):
48 | layers.append(LUConv(nchan, elu))
49 | return nn.Sequential(*layers)
50 |
51 |
52 | class InputTransition(nn.Module):
53 | def __init__(self, outChans, elu):
54 | super(InputTransition, self).__init__()
55 | self.conv1 = nn.Conv3d(1, 16, kernel_size=5, padding=2)
56 | self.bn1 = ContBatchNorm3d(16)
57 | self.relu1 = ELUCons(elu, 16)
58 |
59 | def forward(self, x):
60 | # do we want a PRELU here as well?
61 | # print("x before InputTransition shape:"+str(x.shape))
62 | out = self.bn1(self.conv1(x))
63 | # split input in to 16 channels (or is it right to say duplicate the input for 16 times to operate "torch.add()"?? By Chao)
64 | x16 = torch.cat((x, x, x, x, x, x, x, x,
65 | x, x, x, x, x, x, x, x), 1) # changed dim = 0 to 1 to operate on channels, and have "x16" the same size as "out" to operate "torch.add()". By Chao.
66 | # print("x16 shape:"+str(x16.shape))
67 | # pdb.set_trace()
68 | out = self.relu1(torch.add(out, x16))
69 | return out
70 |
71 |
72 | class DownTransition(nn.Module):
73 | def __init__(self, inChans, nConvs, elu, dropout=False):
74 | super(DownTransition, self).__init__()
75 | outChans = 2*inChans
76 | self.down_conv = nn.Conv3d(inChans, outChans, kernel_size=2, stride=2)
77 | self.bn1 = ContBatchNorm3d(outChans)
78 | self.do1 = passthrough
79 | self.relu1 = ELUCons(elu, outChans)
80 | self.relu2 = ELUCons(elu, outChans)
81 | if dropout:
82 | self.do1 = nn.Dropout3d()
83 | self.ops = _make_nConv(outChans, nConvs, elu)
84 |
85 | def forward(self, x):
86 | down = self.relu1(self.bn1(self.down_conv(x)))
87 | out = self.do1(down)
88 | out = self.ops(out)
89 | out = self.relu2(torch.add(out, down))
90 | return out
91 |
92 |
93 | class UpTransition(nn.Module):
94 | def __init__(self, inChans, outChans, nConvs, elu, dropout=False):
95 | super(UpTransition, self).__init__()
96 | self.up_conv = nn.ConvTranspose3d(inChans, outChans // 2, kernel_size=2, stride=2)
97 | self.bn1 = ContBatchNorm3d(outChans // 2)
98 | self.do1 = passthrough
99 | self.do2 = nn.Dropout3d()
100 | self.relu1 = ELUCons(elu, outChans // 2)
101 | self.relu2 = ELUCons(elu, outChans)
102 | if dropout:
103 | self.do1 = nn.Dropout3d()
104 | self.ops = _make_nConv(outChans, nConvs, elu)
105 |
106 | def forward(self, x, skipx):
107 | out = self.do1(x)
108 | skipxdo = self.do2(skipx)
109 | out = self.relu1(self.bn1(self.up_conv(out)))
110 | xcat = torch.cat((out, skipxdo), 1)
111 | out = self.ops(xcat)
112 | out = self.relu2(torch.add(out, xcat))
113 | return out
114 |
115 |
116 | class OutputTransition(nn.Module):
117 | def __init__(self, inChans, elu, nll):
118 | super(OutputTransition, self).__init__()
119 | self.conv1 = nn.Conv3d(inChans, 2, kernel_size=5, padding=2)
120 | self.bn1 = ContBatchNorm3d(2)
121 | self.conv2 = nn.Conv3d(2, 2, kernel_size=1)
122 | self.relu1 = ELUCons(elu, 2)
123 | if nll:
124 | self.softmax = F.log_softmax
125 | else:
126 | self.softmax = F.softmax
127 |
128 | def forward(self, x):
129 | # convolve 32 down to 2 channels
130 | out = self.relu1(self.bn1(self.conv1(x)))
131 | out = self.conv2(out)
132 | # print("out shape before softmax:"+str(out.shape))
133 |
134 | # flatten z, y, x
135 | b, c, z, y, x = out.shape # b:batch_size, c:channels, z:depth, y:height, w:width. channels is 2? as the output channels of the last conv layer?
136 | out = out.view(b, c, -1)
137 |
138 | # pdb.set_trace()
139 | res = self.softmax(out, dim = 1)
140 |
141 | # make channels the last axis
142 | # out = out.permute(0, 2, 3, 4, 1).contiguous()
143 | # out = out.view(out.numel() // 2, 2)
144 | # out = self.softmax(out,dim=1)
145 | return res
146 |
147 |
148 | class VNet(nn.Module):
149 | # the number of convolutions in each layer corresponds
150 | # to what is in the actual prototxt, not the intent
151 | def __init__(self, elu=True, nll=False):
152 | super(VNet, self).__init__()
153 | self.in_tr = InputTransition(16, elu)
154 | self.down_tr32 = DownTransition(16, 1, elu)
155 | self.down_tr64 = DownTransition(32, 2, elu)
156 | self.down_tr128 = DownTransition(64, 3, elu, dropout=True)
157 | self.down_tr256 = DownTransition(128, 2, elu, dropout=True)
158 | self.up_tr256 = UpTransition(256, 256, 2, elu, dropout=True)
159 | self.up_tr128 = UpTransition(256, 128, 2, elu, dropout=True)
160 | self.up_tr64 = UpTransition(128, 64, 1, elu)
161 | self.up_tr32 = UpTransition(64, 32, 1, elu)
162 | self.out_tr = OutputTransition(32, elu, nll)
163 |
164 | # The network topology as described in the diagram
165 | # in the VNet paper
166 | # def __init__(self):
167 | # super(VNet, self).__init__()
168 | # self.in_tr = InputTransition(16)
169 | # # the number of convolutions in each layer corresponds
170 | # # to what is in the actual prototxt, not the intent
171 | # self.down_tr32 = DownTransition(16, 2)
172 | # self.down_tr64 = DownTransition(32, 3)
173 | # self.down_tr128 = DownTransition(64, 3)
174 | # self.down_tr256 = DownTransition(128, 3)
175 | # self.up_tr256 = UpTransition(256, 3)
176 | # self.up_tr128 = UpTransition(128, 3)
177 | # self.up_tr64 = UpTransition(64, 2)
178 | # self.up_tr32 = UpTransition(32, 1)
179 | # self.out_tr = OutputTransition(16)
180 | def forward(self, x):
181 | # print("x before model shape:"+str(x.shape))
182 | out16 = self.in_tr(x)
183 | out32 = self.down_tr32(out16)
184 | out64 = self.down_tr64(out32)
185 | out128 = self.down_tr128(out64)
186 | # print("out16(in_tr) shape:"+str(out16.shape))
187 | # print("out32(down_tr32) shape:"+str(out32.shape))
188 | # print("out64(down_tr64) shape:"+str(out64.shape))
189 | # print("out128(down_tr128) shape:"+str(out128.shape))
190 | out256 = self.down_tr256(out128)
191 | # print("out256(down_tr256) shape:"+str(out256.shape))
192 | out = self.up_tr256(out256, out128)
193 | # print("up_tr256 shape:"+str(out.shape))
194 | out = self.up_tr128(out, out64)
195 | # print("up_tr128 shape:"+str(out.shape))
196 | out = self.up_tr64(out, out32)
197 | # print("up_tr64 shape:"+str(out.shape))
198 | out = self.up_tr32(out, out16)
199 | # print("up_tr32 shape:"+str(out.shape))
200 | out = self.out_tr(out)
201 | # print("out_tr shape:"+str(out.shape))
202 | return out
203 |
--------------------------------------------------------------------------------