├── .gitattributes
├── .idea
├── deployment.xml
├── lits_torch.iml
├── misc.xml
├── modules.xml
├── remote-mappings.xml
├── webServers.xml
└── workspace.xml
├── LITS_DataSet.py
├── LITS_reader.py
├── Unet.py
├── init_util.py
├── logger.py
├── metrics.py
├── train_val.py
└── util.py
/.gitattributes:
--------------------------------------------------------------------------------
1 | # Auto detect text files and perform LF normalization
2 | * text=auto
3 |
--------------------------------------------------------------------------------
/.idea/deployment.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
--------------------------------------------------------------------------------
/.idea/lits_torch.iml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
--------------------------------------------------------------------------------
/.idea/misc.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
--------------------------------------------------------------------------------
/.idea/modules.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
--------------------------------------------------------------------------------
/.idea/remote-mappings.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
--------------------------------------------------------------------------------
/.idea/webServers.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
11 |
12 |
--------------------------------------------------------------------------------
/.idea/workspace.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
34 |
35 |
36 |
37 |
38 |
39 |
40 |
41 |
42 |
43 |
44 |
45 |
46 |
47 |
48 |
49 |
50 |
51 |
52 |
53 |
54 |
55 |
56 |
57 |
58 |
59 |
60 |
61 |
62 |
63 |
64 |
65 |
66 |
67 |
68 |
69 |
70 |
71 |
72 |
73 |
74 |
75 |
76 |
77 |
78 |
79 |
80 |
81 |
82 |
83 |
84 |
85 |
86 |
87 |
88 |
89 |
90 |
91 |
92 |
93 |
94 |
95 |
96 |
97 |
98 |
99 |
100 |
101 |
102 |
103 |
104 |
105 |
106 |
107 |
108 |
109 |
110 |
111 |
112 |
113 |
114 | true
115 | DEFINITION_ORDER
116 |
117 |
118 |
119 |
120 |
121 |
122 |
123 |
124 |
125 |
126 |
127 |
128 |
129 |
130 |
131 |
132 |
133 |
134 |
135 |
136 |
137 |
138 |
139 |
140 |
141 |
142 |
143 |
144 |
145 |
146 |
147 |
148 |
149 |
150 |
151 |
152 |
153 |
154 |
155 |
156 |
157 |
158 |
159 |
160 |
161 |
162 |
163 |
164 |
165 |
166 |
167 |
168 |
169 |
170 |
171 |
172 |
173 |
174 |
175 |
176 |
177 |
178 |
179 |
180 |
181 |
182 |
183 |
184 |
185 |
186 |
187 |
188 |
189 |
190 |
191 |
192 |
193 |
194 |
195 |
196 |
197 |
198 |
199 |
200 |
201 |
202 |
203 |
204 |
205 |
206 |
207 |
208 |
209 |
210 |
211 |
212 |
213 |
214 |
215 |
216 |
217 |
218 |
219 |
220 |
221 |
222 |
223 |
224 |
225 |
226 |
227 |
228 |
229 |
230 |
231 |
232 |
233 |
234 |
235 |
236 |
237 |
238 |
239 |
240 |
241 |
242 |
243 |
244 |
245 |
246 |
247 |
248 |
249 |
250 |
251 |
252 |
253 |
254 |
255 |
256 |
257 |
258 |
259 |
260 |
261 |
262 |
263 |
264 |
265 |
266 |
267 |
268 |
269 |
270 |
271 |
272 |
273 |
274 |
275 |
276 |
277 |
278 |
279 |
280 |
281 |
282 |
283 |
284 |
285 |
286 |
287 |
288 |
289 | 1562729157972
290 |
291 |
292 | 1562729157972
293 |
294 |
295 |
296 |
297 |
298 |
299 |
300 |
301 |
302 |
303 |
304 |
305 |
306 |
307 |
308 |
309 |
310 |
311 |
312 |
313 |
314 |
315 |
316 |
317 |
318 |
319 |
320 |
321 |
322 |
323 |
324 |
325 |
326 |
327 |
328 |
329 |
330 |
331 |
332 |
333 |
334 |
335 |
336 |
337 |
338 |
339 |
340 |
341 |
342 |
343 |
344 |
345 |
346 |
347 |
348 |
349 |
350 |
351 |
352 |
353 |
354 |
355 |
356 |
357 |
358 |
359 |
360 |
361 |
362 |
363 |
364 |
365 |
366 |
367 |
368 |
369 |
370 |
371 |
372 |
373 |
374 |
375 |
376 |
377 |
378 |
379 |
380 |
381 |
382 |
383 |
384 |
385 |
386 |
387 |
388 |
389 |
390 |
391 |
392 |
393 |
394 |
395 |
396 |
397 |
398 |
399 |
400 |
401 |
402 |
403 |
404 |
405 |
406 |
407 |
408 |
409 |
410 |
411 |
412 |
413 |
414 |
415 |
416 |
417 |
418 |
419 |
420 |
421 |
422 |
423 |
424 |
425 |
426 |
427 |
428 |
429 |
430 |
431 |
432 |
433 |
434 |
435 |
436 |
437 |
438 |
439 |
440 |
441 |
442 |
443 |
444 |
445 |
446 |
447 |
448 |
449 |
450 |
451 |
452 |
453 |
454 |
455 |
456 |
457 |
458 |
459 |
460 |
461 |
462 |
463 |
464 |
465 |
466 |
467 |
468 |
469 |
470 |
471 |
472 |
473 |
474 |
475 |
476 |
477 |
478 |
479 |
480 |
481 |
482 |
483 |
484 |
485 |
486 |
487 |
488 |
489 |
490 |
491 |
492 |
493 |
494 |
495 |
496 |
497 |
498 |
499 |
500 |
501 |
502 |
503 |
504 |
505 |
506 |
507 |
508 |
509 |
510 |
511 |
512 |
513 |
514 |
515 |
516 |
517 |
518 |
519 |
520 |
521 |
522 |
523 |
524 |
525 |
526 |
527 |
528 |
529 |
530 |
531 |
532 |
533 |
534 |
535 |
536 |
537 |
538 |
539 |
540 |
541 |
542 |
543 |
544 |
545 |
546 |
547 |
548 |
549 |
550 |
551 |
552 |
553 |
554 |
555 |
556 |
557 |
558 |
559 |
560 |
561 |
562 |
563 |
564 |
565 |
566 |
567 |
568 |
569 |
570 |
571 |
572 |
573 |
574 |
575 |
576 |
577 |
578 |
579 |
580 |
581 |
582 |
583 |
584 |
585 |
586 |
587 |
588 |
589 |
590 |
591 |
592 |
593 |
594 |
595 |
596 |
597 |
598 |
599 |
600 |
601 |
602 |
603 |
604 |
605 |
606 |
607 |
608 |
609 |
610 |
611 |
612 |
613 |
614 |
615 |
616 |
617 |
618 |
619 |
620 |
621 |
622 |
623 |
624 |
625 |
626 |
627 |
628 |
629 |
630 |
631 |
632 |
633 |
634 |
635 |
--------------------------------------------------------------------------------
/LITS_DataSet.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from torchvision import transforms as T
3 | import torch
4 | from torch.utils.data import Dataset, DataLoader
5 | import LITS_reader
6 |
7 |
8 | class Lits_DataSet(Dataset):
9 | def __init__(self, crop_size, batch_size, lits_reader,resize_scale):
10 | self.crop_size = crop_size
11 | self.batch_size = batch_size
12 | self.lits_reader = lits_reader
13 | self.resize_scale=resize_scale
14 |
15 | def __getitem__(self, index):
16 | data, target = self.lits_reader.next_train_batch_3d_sub_by_index(train_batch_size=self.batch_size,
17 | crop_size=self.crop_size, index=index,
18 | resize_scale=self.resize_scale)
19 | data = data.transpose(0, 4, 1, 2, 3)
20 | target = target.transpose(0, 4, 1, 2, 3)
21 | return torch.from_numpy(data), torch.from_numpy(target)
22 |
23 | def __len__(self):
24 | return 104
25 |
26 |
27 | class Lits_DataSet_val(Dataset):
28 | def __init__(self, crop_size, batch_size, lits_reader,resize_scale):
29 | self.crop_size = crop_size
30 | self.batch_size = batch_size
31 | self.lits_reader = lits_reader
32 | self.resize_scale=resize_scale
33 |
34 | def __getitem__(self, index):
35 | data, target = self.lits_reader.next_val_batch_3d_sub_by_index(val_batch_size=self.batch_size,
36 | crop_size=self.crop_size, index=index,
37 | resize_scale=self.resize_scale)
38 | data = data.transpose(0, 4, 1, 2, 3)
39 | target = target.transpose(0, 4, 1, 2, 3)
40 | return torch.from_numpy(data), torch.from_numpy(target)
41 |
42 | def __len__(self):
43 | return 13
44 |
45 |
46 | def main():
47 | reader = LITS_reader.LITS_reader(data_fix=False)
48 | dataset = Lits_DataSet([32, 64, 64], 4, reader,resize_scale=0.5)
49 | data_loader=DataLoader(dataset=dataset,shuffle=True,num_workers=2)
50 | for data, mask in data_loader:
51 | data=torch.squeeze(data,dim=0)
52 | mask=torch.squeeze(mask,dim=0)
53 | print(data.shape, mask.shape)
54 |
55 |
56 | if __name__ == '__main__':
57 | main()
58 |
--------------------------------------------------------------------------------
/LITS_reader.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import os
3 | import glob
4 | import SimpleITK as sitk
5 | import random
6 | import util
7 | from scipy import ndimage
8 |
9 | '''
10 | def sitk_read(img_path):
11 | nda = sitk.ReadImage(img_path)
12 | nda = sitk.GetArrayFromImage(nda) #(155,240,240)
13 | zero = np.zeros([5, 240, 240])
14 | nda = np.concatenate([zero, nda], axis=0) #(160,240,240)
15 | nda = nda.transpose(1, 2, 0) #(240,240,160)
16 | return nda
17 | '''
18 | MIN_BOUND = -1000.0
19 | MAX_BOUND = 400.0
20 |
21 | def norm_img(image):
22 | image = (image - MIN_BOUND) / (MAX_BOUND - MIN_BOUND)
23 | image[image > 1] = 1.
24 | image[image < 0] = 0.
25 | return image
26 |
27 |
28 | def sitk_read_row(img_path, resize_scale=1):
29 | nda = sitk.ReadImage(img_path)
30 | nda = sitk.GetArrayFromImage(nda) # channel first
31 | nda=ndimage.zoom(nda,[resize_scale,resize_scale,resize_scale],order=0)
32 |
33 | return nda
34 |
35 |
36 | def make_one_hot_3d(x, n):
37 | one_hot = np.zeros([x.shape[0], x.shape[1], x.shape[2], n])
38 | for i in range(x.shape[0]):
39 | for j in range(x.shape[1]):
40 | for v in range(x.shape[2]):
41 | one_hot[i, j, v, int(x[i, j, v])] = 1
42 |
43 | return one_hot
44 |
45 |
46 | class LITS_reader:
47 | def __init__(self, data_init=False, data_fix=False):
48 | self.row_root_path = '/root/userfolder/PY/data/LITS/'
49 | self.data_root_path = '/root/userfolder/PY/data/LITS/fixed/'
50 | if data_fix:
51 | self.fix_data()
52 | if data_init:
53 | self.init_data()
54 |
55 | self.train_name_list = self.load_file_name_list(self.data_root_path + "train_name_list.txt")
56 | self.val_name_list = self.load_file_name_list(self.data_root_path + "val_name_list.txt")
57 | self.test_name_list = self.load_file_name_list(self.data_root_path + "test_name_list.txt")
58 |
59 | self.n_train_file = len(self.train_name_list)
60 | self.n_val_file = len(self.val_name_list)
61 | self.n_test_file = len(self.test_name_list)
62 |
63 | self.train_batch_index = 0
64 | self.val_batch_index = 0
65 | self.test_batch_index = 0
66 |
67 | self.n_labels = 3
68 |
69 | def write_train_val_test_name_list(self):
70 | data_name_list = np.zeros([131], dtype='int32')
71 | for i in range(131):
72 | data_name_list[i] = i
73 | # data_name_list = os.listdir(self.data_root_path + "/")
74 | random.shuffle(data_name_list)
75 | length = len(data_name_list)
76 | n_train_file = int(length / 10 * 8)
77 | n_val_file = int(length / 10 * 1)
78 | train_name_list = data_name_list[0:n_train_file]
79 | val_name_list = data_name_list[n_train_file:(n_train_file + n_val_file)]
80 | test_name_list = data_name_list[(n_train_file + n_val_file):len(data_name_list)]
81 | self.write_name_list(train_name_list, "train_name_list.txt")
82 | self.write_name_list(val_name_list, "val_name_list.txt")
83 | self.write_name_list(test_name_list, "test_name_list.txt")
84 |
85 | def write_name_list(self, name_list, file_name):
86 | f = open(self.data_root_path + file_name, 'w')
87 | for i in range(len(name_list)):
88 | f.write(str(name_list[i]) + "\n")
89 | f.close()
90 |
91 | def init_data(self):
92 | self.write_train_val_test_name_list()
93 |
94 | def fix_data(self):
95 | upper = 200
96 | lower = -200
97 | expand_slice = 20 # 轴向上向外扩张的slice数量
98 | size = 48 # 取样的slice数量
99 | stride = 3 # 取样的步长
100 | down_scale = 0.5
101 | slice_thickness = 2
102 |
103 | for ct_file in os.listdir(self.row_root_path + 'data/'):
104 | print(ct_file)
105 | # 将CT和金标准入读内存
106 | ct = sitk.ReadImage(os.path.join(self.row_root_path + 'data/', ct_file), sitk.sitkInt16)
107 | ct_array = sitk.GetArrayFromImage(ct)
108 |
109 | seg = sitk.ReadImage(os.path.join(self.row_root_path + 'label/', ct_file.replace('volume', 'segmentation')),
110 | sitk.sitkInt8)
111 | seg_array = sitk.GetArrayFromImage(seg)
112 |
113 | print(ct_array.shape, seg_array.shape)
114 |
115 | # 将金标准中肝脏和肝肿瘤的标签融合为一个
116 | seg_array[seg_array > 0] = 1
117 |
118 | # 将灰度值在阈值之外的截断掉
119 | ct_array[ct_array > upper] = upper
120 | ct_array[ct_array < lower] = lower
121 |
122 | # 找到肝脏区域开始和结束的slice,并各向外扩张
123 | z = np.any(seg_array, axis=(1, 2))
124 | start_slice, end_slice = np.where(z)[0][[0, -1]]
125 |
126 | # 两个方向上各扩张个slice
127 | if start_slice - expand_slice < 0:
128 | start_slice = 0
129 | else:
130 | start_slice -= expand_slice
131 |
132 | if end_slice + expand_slice >= seg_array.shape[0]:
133 | end_slice = seg_array.shape[0] - 1
134 | else:
135 | end_slice += expand_slice
136 |
137 | print(str(start_slice) + '--' + str(end_slice))
138 | # 如果这时候剩下的slice数量不足size,直接放弃,这样的数据很少
139 | if end_slice - start_slice + 1 < size:
140 | print('!!!!!!!!!!!!!!!!')
141 | print(ct_file, 'too little slice')
142 | print('!!!!!!!!!!!!!!!!')
143 | continue
144 |
145 | ct_array = ct_array[start_slice:end_slice + 1, :, :]
146 | seg_array = sitk.GetArrayFromImage(seg)
147 | seg_array = seg_array[start_slice:end_slice + 1, :, :]
148 |
149 | new_ct = sitk.GetImageFromArray(ct_array)
150 | new_seg = sitk.GetImageFromArray(seg_array)
151 |
152 | sitk.WriteImage(new_ct, os.path.join(self.data_root_path + 'data/', ct_file))
153 | sitk.WriteImage(new_seg,
154 | os.path.join(self.data_root_path + 'label/', ct_file.replace('volume', 'segmentation')))
155 |
156 | def load_file_name_list(self, file_path):
157 | file_name_list = []
158 | with open(file_path, 'r') as file_to_read:
159 | while True:
160 | lines = file_to_read.readline().strip() # 整行读取数据
161 | if not lines:
162 | break
163 | pass
164 | file_name_list.append(lines)
165 | pass
166 | return file_name_list
167 |
168 | def get_np_data_3d(self, data_name, resize_scale=1):
169 | data_np = sitk_read_row(self.data_root_path + 'data/' + 'volume-' + data_name + '.nii',
170 | resize_scale=resize_scale)
171 | data_np=norm_img(data_np)
172 | label_np = sitk_read_row(self.data_root_path + 'label/' + 'segmentation-' + data_name + '.nii',
173 | resize_scale=resize_scale)
174 |
175 | return data_np, label_np
176 |
177 | def next_train_batch_3d_sub_by_index(self, train_batch_size, crop_size, index,resize_scale=1):
178 | train_imgs = np.zeros([train_batch_size, crop_size[0], crop_size[1], crop_size[2], 1])
179 | train_labels = np.zeros([train_batch_size, crop_size[0], crop_size[1], crop_size[2], self.n_labels])
180 | img, label = self.get_np_data_3d(self.train_name_list[index],resize_scale=resize_scale)
181 | for i in range(train_batch_size):
182 | sub_img, sub_label = util.random_crop_3d(img, label, crop_size)
183 |
184 | sub_img = sub_img[:, :, :, np.newaxis]
185 | sub_label_onehot = make_one_hot_3d(sub_label, self.n_labels)
186 |
187 | train_imgs[i] = sub_img
188 | train_labels[i] = sub_label_onehot
189 |
190 | return train_imgs, train_labels
191 |
192 | def next_train_batch_3d_sub(self, train_batch_size, crop_size):
193 | self.n_train_steps_per_epoch = self.n_train_file // 1
194 | train_imgs = np.zeros([train_batch_size, crop_size[0], crop_size[1], crop_size[2], 1])
195 | train_labels = np.zeros([train_batch_size, crop_size[0], crop_size[1], crop_size[2], self.n_labels])
196 | if self.train_batch_index >= self.n_train_steps_per_epoch:
197 | self.train_batch_index = 0
198 | img, label = self.get_np_data_3d(self.train_name_list[self.train_batch_index])
199 | for i in range(train_batch_size):
200 | sub_img, sub_label = util.random_crop_3d(img, label, crop_size)
201 | '''
202 | num=0
203 | num_0=0
204 | num_1=0
205 | num_2=0
206 | for z in range(sub_label.shape[0]):
207 | for x in range(sub_label.shape[1]):
208 | for c in range(sub_label.shape[2]):
209 | if sub_label[z][x][c]!=0:
210 | num+=1
211 | if sub_label[z][x][c]==0:
212 | num_0+=1
213 | if sub_label[z][x][c]==1:
214 | num_1+=1
215 | if sub_label[z][x][c]==2:
216 | num_2+=1
217 | print('-----')
218 | print(num)
219 | print(num_0)
220 | print(num_1)
221 | print(num_2)
222 | print('-----')
223 | '''
224 | sub_img = sub_img[:, :, :, np.newaxis]
225 | sub_label_onehot = make_one_hot_3d(sub_label, self.n_labels)
226 | '''
227 | num = 0
228 | num_0 = 0
229 | num_1 = 0
230 | num_2 = 0
231 | for z in range(sub_label.shape[0]):
232 | for x in range(sub_label.shape[1]):
233 | for c in range(sub_label.shape[2]):
234 | if sub_label_onehot[z][x][c][0] == 1:
235 | num_0 += 1
236 | if sub_label_onehot[z][x][c][1] == 1:
237 | num_1 += 1
238 | if sub_label_onehot[z][x][c][2] == 1:
239 | num_2 += 1
240 | print('-----')
241 | print(num)
242 | print(num_0)
243 | print(num_1)
244 | print(num_2)
245 | print('-----')
246 | '''
247 | train_imgs[i] = sub_img
248 | train_labels[i] = sub_label_onehot
249 |
250 | self.train_batch_index += 1
251 | return train_imgs, train_labels
252 |
253 | def next_val_batch_3d_sub_by_index(self, val_batch_size, crop_size, index,resize_scale=1):
254 | val_imgs = np.zeros([val_batch_size, crop_size[0], crop_size[1], crop_size[2], 1])
255 | val_labels = np.zeros([val_batch_size, crop_size[0], crop_size[1], crop_size[2], self.n_labels])
256 | img, label = self.get_np_data_3d(self.val_name_list[index],resize_scale=resize_scale)
257 | for i in range(val_batch_size):
258 | sub_img, sub_label = util.random_crop_3d(img, label, crop_size)
259 |
260 | sub_img = sub_img[:, :, :, np.newaxis]
261 | sub_label_onehot = make_one_hot_3d(sub_label, self.n_labels)
262 |
263 | val_imgs[i] = sub_img
264 | val_labels[i] = sub_label_onehot
265 |
266 | return val_imgs, val_labels
267 |
268 | def next_val_batch_3d_sub(self, val_batch_size, crop_size):
269 | self.n_val_steps_per_epoch = self.n_val_file // 1
270 | val_imgs = np.zeros([val_batch_size, crop_size[0], crop_size[1], crop_size[2], 1])
271 | val_labels = np.zeros([val_batch_size, crop_size[0], crop_size[1], crop_size[2], self.n_labels])
272 | if self.val_batch_index >= self.n_val_steps_per_epoch:
273 | self.val_batch_index = 0
274 | img, label = self.get_np_data_3d(self.val_name_list[self.val_batch_index])
275 | for i in range(val_batch_size):
276 | sub_img, sub_label = util.random_crop_3d(img, label, crop_size)
277 |
278 | sub_img = sub_img[:, :, :, np.newaxis]
279 | sub_label_onehot = make_one_hot_3d(sub_label, self.n_labels)
280 |
281 | val_imgs[i] = sub_img
282 | val_labels[i] = sub_label_onehot
283 |
284 | self.val_batch_index += 1
285 | return val_imgs, val_labels
286 |
287 | def next_test_img(self):
288 | self.n_test_steps_per_epoch = self.n_test_file // 1
289 | if self.test_batch_index >= self.n_test_steps_per_epoch:
290 | self.test_batch_index = 0
291 | img, label = self.get_np_data_3d(self.test_name_list[self.test_batch_index], resize_scale=0.5)
292 |
293 | img = img[np.newaxis, :, :, :, np.newaxis]
294 | label = make_one_hot_3d(label, self.n_labels)
295 | label = label[np.newaxis, :]
296 |
297 | self.test_batch_index += 1
298 |
299 | return img, label
300 |
301 | def next_train_batch_3d(self, train_batch_size):
302 | return None
303 |
304 |
305 | def main():
306 | reader = LITS_reader(data_fix=False)
307 | img, label = reader.next_val_batch_3d_sub(8, [32, 64, 64])
308 | print(img.shape)
309 | print(label.shape)
310 |
311 |
312 | if __name__ == '__main__':
313 | main()
314 |
--------------------------------------------------------------------------------
/Unet.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function
2 | import argparse
3 | import torch
4 | import torch.nn as nn
5 | import torch.nn.functional as F
6 | import torch.optim as optim
7 | from torchvision import datasets, transforms
8 |
9 |
10 | # 未测试
11 | class SEBlock(nn.Module):
12 | def __init__(self,in_channels,out_channels,net_mode='2d'):
13 | super(SEBlock,self).__init__()
14 |
15 | if net_mode == '2d':
16 | self.gap=nn.AdaptiveAvgPool2d(1)
17 | conv=nn.Conv2d
18 | elif net_mode == '3d':
19 | self.gap=nn.AdaptiveAvgPool3d(1)
20 | conv=nn.Conv3d
21 | else:
22 | self.gap=None
23 | conv=None
24 |
25 | self.conv1=conv(in_channels,out_channels,1)
26 | self.conv2=conv(in_channels,out_channels,1)
27 |
28 | self.relu = nn.ReLU(inplace=True)
29 | self.sigmoid=nn.Sigmoid()
30 |
31 | def forward(self,x):
32 | inpu=x
33 | x=self.gap(x)
34 | x=self.conv1(x)
35 | x=self.relu(x)
36 | x=self.conv2(x)
37 | x=self.sigmoid(x)
38 |
39 | return inpu*x
40 |
41 | # 未测试
42 | class DenseBlock(nn.Module):
43 | def __init__(self,channels,conv_num,net_mode='2d'):
44 | super(DenseBlock,self).__init__()
45 | self.conv_num=conv_num
46 | if net_mode == '2d':
47 | conv = nn.Conv2d
48 | elif net_mode == '3d':
49 | conv = nn.Conv3d
50 | else:
51 | conv = None
52 |
53 | self.relu=nn.ReLU()
54 | self.conv_list=[]
55 | self.bottle_conv_list=[]
56 | for i in conv_num:
57 | self.bottle_conv_list.append(conv(channels*(i+1),channels*4,1))
58 | self.conv_list.append(conv(channels*4,channels,3,padding=1))
59 |
60 |
61 | def forward(self,x):
62 |
63 | res_x=[]
64 | res_x.append(x)
65 |
66 | for i in self.conv_num:
67 | inputs=torch.cat(res_x,dim=1)
68 | x=self.bottle_conv_list[i](inputs)
69 | x=self.relu(x)
70 | x=self.conv_list[i](x)
71 | x=self.relu(x)
72 | res_x.append(x)
73 |
74 | return x
75 |
76 |
77 |
78 | class ResBlock(nn.Module):
79 | def __init__(self, in_channels, out_channels, stride=1, net_mode='2d'):
80 | super(ResBlock, self).__init__()
81 | self.in_channels=in_channels
82 | self.out_channels=out_channels
83 | if net_mode == '2d':
84 | conv = nn.Conv2d
85 | bn = nn.BatchNorm2d
86 | elif net_mode == '3d':
87 | conv = nn.Conv3d
88 | bn = nn.BatchNorm3d
89 | else:
90 | conv = None
91 | bn = None
92 |
93 | self.conv1 = conv(in_channels, out_channels, 3, stride=stride, padding=1)
94 | self.bn1 = bn(out_channels)
95 | self.relu = nn.ReLU(inplace=True)
96 | self.conv2 = conv(out_channels, out_channels, 3, stride=stride, padding=1)
97 | self.bn2 = bn(out_channels)
98 |
99 | if in_channels!=out_channels:
100 | self.res_conv=conv(in_channels,out_channels,1,stride=stride)
101 |
102 | def forward(self, x):
103 | if self.in_channels != self.out_channels:
104 | res=self.res_conv(x)
105 | else:
106 | res=x
107 | x = self.conv1(x)
108 | x = self.bn1(x)
109 | x = self.relu(x)
110 | x = self.conv2(x)
111 | x = self.bn2(x)
112 |
113 | out = x + res
114 | out = self.relu(out)
115 |
116 | return out
117 |
118 |
119 | class Up(nn.Module):
120 | def __init__(self, down_in_channels, in_channels, out_channels, conv_block, interpolation=True, net_mode='2d'):
121 | super(Up, self).__init__()
122 |
123 | if net_mode == '2d':
124 | inter_mode = 'bilinear'
125 | trans_conv = nn.ConvTranspose2d
126 | elif net_mode == '3d':
127 | inter_mode = 'trilinear'
128 | trans_conv = nn.ConvTranspose3d
129 | else:
130 | inter_mode = None
131 | trans_conv = None
132 |
133 | if interpolation == True:
134 | self.up = nn.Upsample(scale_factor=2, mode=inter_mode, align_corners=True)
135 | else:
136 | self.up = trans_conv(down_in_channels, down_in_channels, 2, stride=2)
137 |
138 | self.conv = RecombinationBlock(in_channels + down_in_channels, out_channels, net_mode=net_mode)
139 |
140 | def forward(self, down_x, x):
141 | up_x = self.up(down_x)
142 |
143 | x = torch.cat((up_x, x), dim=1)
144 |
145 | x = self.conv(x)
146 |
147 | return x
148 |
149 |
150 | class Down(nn.Module):
151 | def __init__(self, in_channels, out_channels, conv_block, net_mode='2d'):
152 | super(Down, self).__init__()
153 | if net_mode == '2d':
154 | maxpool = nn.MaxPool2d
155 | elif net_mode == '3d':
156 | maxpool = nn.MaxPool3d
157 | else:
158 | maxpool = None
159 |
160 | self.conv = RecombinationBlock(in_channels, out_channels, net_mode=net_mode)
161 |
162 | self.down = maxpool(2, stride=2)
163 |
164 | def forward(self, x):
165 | x = self.conv(x)
166 | out = self.down(x)
167 |
168 | return x, out
169 |
170 |
171 | class SegSEBlock(nn.Module):
172 | def __init__(self, in_channels, rate=2, net_mode='2d'):
173 | super(SegSEBlock, self).__init__()
174 |
175 | if net_mode == '2d':
176 | conv = nn.Conv2d
177 | elif net_mode == '3d':
178 | conv = nn.Conv3d
179 | else:
180 | conv = None
181 |
182 | self.in_channels = in_channels
183 | self.rate = rate
184 | self.dila_conv = conv(self.in_channels, self.in_channels // self.rate, 3, padding=2, dilation=self.rate)
185 | self.conv1 = conv(self.in_channels // self.rate, self.in_channels, 1)
186 |
187 | def forward(self, input):
188 | x = self.dila_conv(input)
189 | x = self.conv1(x)
190 | x = nn.Sigmoid()(x)
191 |
192 | return x
193 |
194 |
195 | class RecombinationBlock(nn.Module):
196 | def __init__(self, in_channels, out_channels, batch_normalization=True, kernel_size=3, net_mode='2d'):
197 | super(RecombinationBlock, self).__init__()
198 |
199 | if net_mode == '2d':
200 | conv = nn.Conv2d
201 | bn = nn.BatchNorm2d
202 | elif net_mode == '3d':
203 | conv = nn.Conv3d
204 | bn = nn.BatchNorm3d
205 | else:
206 | conv = None
207 | bn = None
208 |
209 | self.in_channels = in_channels
210 | self.out_channels = out_channels
211 | self.bach_normalization = batch_normalization
212 | self.kerenl_size = kernel_size
213 | self.rate = 2
214 | self.expan_channels = self.out_channels * self.rate
215 |
216 | self.expansion_conv = conv(self.in_channels, self.expan_channels, 1)
217 | self.skip_conv = conv(self.in_channels, self.out_channels, 1)
218 | self.zoom_conv = conv(self.out_channels * self.rate, self.out_channels, 1)
219 |
220 | self.bn = bn(self.expan_channels)
221 | self.norm_conv = conv(self.expan_channels, self.expan_channels, self.kerenl_size, padding=1)
222 |
223 | self.segse_block = SegSEBlock(self.expan_channels, net_mode=net_mode)
224 |
225 | def forward(self, input):
226 | x = self.expansion_conv(input)
227 |
228 | for i in range(1):
229 | if self.bach_normalization:
230 | x = self.bn(x)
231 | x = nn.ReLU6()(x)
232 | x = self.norm_conv(x)
233 |
234 | se_x = self.segse_block(x)
235 |
236 | x = x * se_x
237 |
238 | x = self.zoom_conv(x)
239 |
240 | skip_x = self.skip_conv(input)
241 | out = x + skip_x
242 |
243 | return out
244 |
245 |
246 | class UNet(nn.Module):
247 | def __init__(self, in_channels, filter_num_list, class_num, conv_block=RecombinationBlock, net_mode='2d'):
248 | super(UNet, self).__init__()
249 |
250 | if net_mode == '2d':
251 | conv = nn.Conv2d
252 | elif net_mode == '3d':
253 | conv = nn.Conv3d
254 | else:
255 | conv = None
256 |
257 | self.inc = conv(in_channels, 16, 1)
258 |
259 | # down
260 | self.down1 = Down(16, filter_num_list[0], conv_block=conv_block, net_mode=net_mode)
261 | self.down2 = Down(filter_num_list[0], filter_num_list[1], conv_block=conv_block, net_mode=net_mode)
262 | self.down3 = Down(filter_num_list[1], filter_num_list[2], conv_block=conv_block, net_mode=net_mode)
263 | self.down4 = Down(filter_num_list[2], filter_num_list[3], conv_block=conv_block, net_mode=net_mode)
264 |
265 | self.bridge = conv_block(filter_num_list[3], filter_num_list[4], net_mode=net_mode)
266 |
267 | # up
268 | self.up1 = Up(filter_num_list[4], filter_num_list[3], filter_num_list[3], conv_block=conv_block,
269 | net_mode=net_mode)
270 | self.up2 = Up(filter_num_list[3], filter_num_list[2], filter_num_list[2], conv_block=conv_block,
271 | net_mode=net_mode)
272 | self.up3 = Up(filter_num_list[2], filter_num_list[1], filter_num_list[1], conv_block=conv_block,
273 | net_mode=net_mode)
274 | self.up4 = Up(filter_num_list[1], filter_num_list[0], filter_num_list[0], conv_block=conv_block,
275 | net_mode=net_mode)
276 |
277 | self.class_conv = conv(filter_num_list[0], class_num, 1)
278 |
279 | def forward(self, input):
280 |
281 | x = input
282 |
283 | x = self.inc(x)
284 |
285 | conv1, x = self.down1(x)
286 |
287 | conv2, x = self.down2(x)
288 |
289 | conv3, x = self.down3(x)
290 |
291 | conv4, x = self.down4(x)
292 |
293 | x = self.bridge(x)
294 |
295 | x = self.up1(x, conv4)
296 |
297 | x = self.up2(x, conv3)
298 |
299 | x = self.up3(x, conv2)
300 |
301 | x = self.up4(x, conv1)
302 |
303 | x = self.class_conv(x)
304 |
305 | x = nn.Softmax(1)(x)
306 |
307 | return x
308 |
309 |
310 | '''
311 | def main():
312 | torch.cuda.set_device(1)
313 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # PyTorch v0.4.0
314 | model = UNet(1,[32,48,64,96,128],3,net_mode='3d').to(device)
315 | x=torch.rand(4,1,64,96,96)
316 | x=x.to(device)
317 | model.forward(x)
318 |
319 | if __name__ == '__main__':
320 | main()
321 | '''
322 |
--------------------------------------------------------------------------------
/init_util.py:
--------------------------------------------------------------------------------
1 | from torch.nn import init
2 |
3 |
4 | def weights_init_normal(m):
5 | classname = m.__class__.__name__
6 | #print(classname)
7 | if classname.find('Conv') != -1:
8 | init.normal(m.weight.data, 0.0, 0.02)
9 | elif classname.find('Linear') != -1:
10 | init.normal(m.weight.data, 0.0, 0.02)
11 | elif classname.find('BatchNorm') != -1:
12 | init.normal(m.weight.data, 1.0, 0.02)
13 | init.constant(m.bias.data, 0.0)
14 |
15 |
16 | def weights_init_xavier(m):
17 | classname = m.__class__.__name__
18 | #print(classname)
19 | if classname.find('Conv') != -1:
20 | init.xavier_normal(m.weight.data, gain=1)
21 | elif classname.find('Linear') != -1:
22 | init.xavier_normal(m.weight.data, gain=1)
23 | elif classname.find('BatchNorm') != -1:
24 | init.normal(m.weight.data, 1.0, 0.02)
25 | init.constant(m.bias.data, 0.0)
26 |
27 |
28 | def weights_init_kaiming(m):
29 | classname = m.__class__.__name__
30 | #print(classname)
31 | if classname.find('Conv') != -1:
32 | init.kaiming_normal(m.weight.data, a=0, mode='fan_in')
33 | elif classname.find('Linear') != -1:
34 | init.kaiming_normal(m.weight.data, a=0, mode='fan_in')
35 | elif classname.find('BatchNorm') != -1:
36 | init.normal(m.weight.data, 1.0, 0.02)
37 | init.constant(m.bias.data, 0.0)
38 |
39 |
40 | def weights_init_orthogonal(m):
41 | classname = m.__class__.__name__
42 | #print(classname)
43 | if classname.find('Conv') != -1:
44 | init.orthogonal(m.weight.data, gain=1)
45 | elif classname.find('Linear') != -1:
46 | init.orthogonal(m.weight.data, gain=1)
47 | elif classname.find('BatchNorm') != -1:
48 | init.normal(m.weight.data, 1.0, 0.02)
49 | init.constant(m.bias.data, 0.0)
50 |
51 |
52 | def init_weights(net, init_type='normal'):
53 | #print('initialization method [%s]' % init_type)
54 | if init_type == 'normal':
55 | net.apply(weights_init_normal)
56 | elif init_type == 'xavier':
57 | net.apply(weights_init_xavier)
58 | elif init_type == 'kaiming':
59 | net.apply(weights_init_kaiming)
60 | elif init_type == 'orthogonal':
61 | net.apply(weights_init_orthogonal)
62 | else:
63 | raise NotImplementedError('initialization method [%s] is not implemented' % init_type)
64 |
65 |
66 | def adjust_learning_rate(optimizer, lr):
67 | """Sets the learning rate to a fixed number"""
68 | for param_group in optimizer.param_groups:
69 | param_group['lr'] = lr
70 |
71 | def print_network(net):
72 | num_params = 0
73 | for param in net.parameters():
74 | num_params += param.numel()
75 | print(net)
76 | print('Total number of parameters: %d' % num_params)
--------------------------------------------------------------------------------
/logger.py:
--------------------------------------------------------------------------------
1 | # Code referenced from https://gist.github.com/gyglim/1f8dfb1b5c82627ae3efcfbbadb9f514
2 | import tensorflow as tf
3 | import numpy as np
4 | import scipy.misc
5 |
6 | try:
7 | from StringIO import StringIO # Python 2.7
8 | except ImportError:
9 | from io import BytesIO # Python 3.x
10 |
11 |
12 | class Logger(object):
13 | def __init__(self, log_dir):
14 | """Create a summary writer logging to log_dir."""
15 | self.writer = tf.summary.FileWriter(log_dir)
16 |
17 | def scalar_summary(self, tag, value, step):
18 | """Log a scalar variable."""
19 | summary = tf.Summary(value=[tf.Summary.Value(tag=tag, simple_value=value)])
20 | self.writer.add_summary(summary, step)
21 |
22 | def image_summary(self, tag, images, step):
23 | """Log a list of images."""
24 |
25 | img_summaries = []
26 | for i, img in enumerate(images):
27 | # Write the image to a string
28 | try:
29 | s = StringIO()
30 | except:
31 | s = BytesIO()
32 | scipy.misc.toimage(img).save(s, format="png")
33 |
34 | # Create an Image object
35 | img_sum = tf.Summary.Image(encoded_image_string=s.getvalue(),
36 | height=img.shape[0],
37 | width=img.shape[1])
38 | # Create a Summary value
39 | img_summaries.append(tf.Summary.Value(tag='%s/%d' % (tag, i), image=img_sum))
40 |
41 | # Create and write Summary
42 | summary = tf.Summary(value=img_summaries)
43 | self.writer.add_summary(summary, step)
44 |
45 | def histo_summary(self, tag, values, step, bins=1000):
46 | """Log a histogram of the tensor of values."""
47 |
48 | # Create a histogram using numpy
49 | counts, bin_edges = np.histogram(values, bins=bins)
50 |
51 | # Fill the fields of the histogram proto
52 | hist = tf.HistogramProto()
53 | hist.min = float(np.min(values))
54 | hist.max = float(np.max(values))
55 | hist.num = int(np.prod(values.shape))
56 | hist.sum = float(np.sum(values))
57 | hist.sum_squares = float(np.sum(values ** 2))
58 |
59 | # Drop the start of the first bin
60 | bin_edges = bin_edges[1:]
61 |
62 | # Add bin edges and counts
63 | for edge in bin_edges:
64 | hist.bucket_limit.append(edge)
65 | for c in counts:
66 | hist.bucket.append(c)
67 |
68 | # Create and write Summary
69 | summary = tf.Summary(value=[tf.Summary.Value(tag=tag, histo=hist)])
70 | self.writer.add_summary(summary, step)
71 | self.writer.flush()
72 |
--------------------------------------------------------------------------------
/metrics.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | import torch.nn.functional as F
3 | import torch
4 |
5 |
6 | def cross_entropy_2D(input, target, weight=None, size_average=True):
7 | n, c, h, w = input.size()
8 | log_p = F.log_softmax(input, dim=1)
9 | log_p = log_p.transpose(1, 2).transpose(2, 3).contiguous().view(-1, c)
10 | target = target.view(target.numel())
11 | loss = F.nll_loss(log_p, target, weight=weight, size_average=False)
12 | if size_average:
13 | loss /= float(target.numel())
14 | return loss
15 |
16 |
17 | def cross_entropy_3D(input, target, weight=None, size_average=True):
18 | n, c, h, w, s = input.size()
19 | log_p = F.log_softmax(input, dim=1)
20 | log_p = log_p.transpose(1, 2).transpose(2, 3).transpose(3, 4).contiguous().view(-1, c)
21 | target = target.view(target.numel())
22 | loss = F.nll_loss(log_p, target, weight=weight, size_average=False)
23 | if size_average:
24 | loss /= float(target.numel())
25 | return loss
26 |
27 |
28 | class SoftDiceLoss(nn.Module):
29 | def __init__(self, weight=None, size_average=True):
30 | super(SoftDiceLoss, self).__init__()
31 |
32 | def forward(self, logits, targets):
33 | num = targets.size(0)
34 | smooth = 1
35 |
36 | probs = F.sigmoid(logits)
37 | m1 = probs.view(num, -1)
38 | m2 = targets.view(num, -1)
39 | intersection = (m1 * m2)
40 |
41 | score = 2. * (intersection.sum(1) + smooth) / (m1.sum(1) + m2.sum(1) + smooth)
42 | score = 1 - score.sum() / num
43 | return score
44 |
45 |
46 | class DiceMean(nn.Module):
47 | def __init__(self):
48 | super(DiceMean, self).__init__()
49 |
50 | def forward(self, logits, targets):
51 | class_num = logits.size(1)
52 |
53 | dice_sum = 0
54 | for i in range(class_num):
55 | inter = torch.sum(logits[:, i, :, :, :] * targets[:, i, :, :, :])
56 | union = torch.sum(logits[:, i, :, :, :]) + torch.sum(targets[:, i, :, :, :])
57 | dice = (2. * inter + 1) / (union + 1)
58 | dice_sum += dice
59 | return dice_sum / class_num
60 |
61 |
62 | class DiceMeanLoss(nn.Module):
63 | def __init__(self):
64 | super(DiceMeanLoss, self).__init__()
65 |
66 | def forward(self, logits, targets):
67 | class_num = logits.size(1)
68 |
69 | dice_sum = 0
70 | for i in range(class_num):
71 | inter = torch.sum(logits[:, i, :, :, :] * targets[:, i, :, :, :])
72 | union = torch.sum(logits[:, i, :, :, :]) + torch.sum(targets[:, i, :, :, :])
73 | dice = (2. * inter + 1) / (union + 1)
74 | dice_sum += dice
75 | return 1 - dice_sum / class_num
76 |
77 |
78 | class WeightDiceLoss(nn.Module):
79 | def __init__(self):
80 | super(WeightDiceLoss, self).__init__()
81 |
82 | def forward(self, logits, targets):
83 |
84 | num_sum = torch.sum(targets, dim=(0, 2, 3, 4))
85 | w = torch.Tensor([0, 0, 0]).cuda()
86 | for i in range(targets.size(1)):
87 | if (num_sum[i] < 1):
88 | w[i] = 0
89 | else:
90 | w[i] = (0.1 * num_sum[i] + num_sum[i - 1] + num_sum[i - 2] + 1) / (torch.sum(num_sum) + 1)
91 | print(w)
92 | inter = w * torch.sum(targets * logits, dim=(0, 2, 3, 4))
93 | inter = torch.sum(inter)
94 |
95 | union = w * torch.sum(targets + logits, dim=(0, 2, 3, 4))
96 | union = torch.sum(union)
97 |
98 | return 1 - 2. * inter / union
99 |
100 |
101 | def dice(logits, targets, class_index):
102 | inter = torch.sum(logits[:, class_index, :, :, :] * targets[:, class_index, :, :, :])
103 | union = torch.sum(logits[:, class_index, :, :, :]) + torch.sum(targets[:, class_index, :, :, :])
104 | dice = (2. * inter + 1) / (union + 1)
105 | return dice
106 |
107 |
108 | def T(logits, targets):
109 | return torch.sum(targets[:, 2, :, :, :])
110 |
111 |
112 | def P(logits, targets):
113 | return torch.sum(logits[:, 2, :, :, :])
114 |
115 |
116 | def TP(logits, targets):
117 | return torch.sum(targets[:, 2, :, :, :] * logits[:, 2, :, :, :])
118 |
--------------------------------------------------------------------------------
/train_val.py:
--------------------------------------------------------------------------------
1 | import LITS_DataSet
2 | import torch
3 | import argparse
4 | import torch.nn.functional as F
5 | import torch.nn as nn
6 | import torch.optim as optim
7 | import LITS_reader
8 | import metrics
9 | from torch.utils.data import Dataset, DataLoader
10 | from Unet import UNet,ResBlock,RecombinationBlock
11 | import logger
12 | import init_util
13 |
14 |
15 | def val(model, val_loader, device, epoch, val_dict, logger):
16 | model.eval()
17 | val_loss = 0
18 | val_dice0 = 0
19 | val_dice1 = 0
20 | val_dice2 = 0
21 | with torch.no_grad():
22 | for data, target in val_loader:
23 | data = torch.squeeze(data, dim=0)
24 | target = torch.squeeze(target, dim=0)
25 | data, target = data.float(), target.float()
26 | data, target = data.to(device), target.to(device)
27 | output = model(data)
28 |
29 | loss = metrics.DiceMeanLoss()(output, target)
30 | dice0 = metrics.dice(output, target, 0)
31 | dice1 = metrics.dice(output, target, 1)
32 | dice2 = metrics.dice(output, target, 2)
33 |
34 | val_loss += float(loss)
35 | val_dice0 += float(dice0)
36 | val_dice1 += float(dice1)
37 | val_dice2 += float(dice2)
38 |
39 | val_loss /= len(val_loader)
40 | val_dice0 /= len(val_loader)
41 | val_dice1 /= len(val_loader)
42 | val_dice2 /= len(val_loader)
43 |
44 | val_dict['loss'].append(float(val_loss))
45 | val_dict['dice0'].append(float(val_dice0))
46 | val_dict['dice1'].append(float(val_dice1))
47 | val_dict['dice2'].append(float(val_dice2))
48 | logger.scalar_summary('val_loss', val_loss, epoch)
49 | logger.scalar_summary('val_dice0', val_dice0, epoch)
50 | logger.scalar_summary('val_dice1', val_dice1, epoch)
51 | logger.scalar_summary('val_dice2', val_dice2, epoch)
52 | print('\nVal set: Average loss: {:.6f}, dice0: {:.6f}\tdice1: {:.6f}\tdice2: {:.6f}\t\n'.format(
53 | val_loss, val_dice0, val_dice1, val_dice2))
54 |
55 |
56 | def train(model, train_loader, device, optimizer, epoch, train_dict, logger):
57 | model.train()
58 | train_loss = 0
59 | train_dice0 = 0
60 | train_dice1 = 0
61 | train_dice2 = 0
62 | for batch_idx, (data, target) in enumerate(train_loader):
63 | data = torch.squeeze(data, dim=0)
64 | target = torch.squeeze(target, dim=0)
65 | data, target = data.float(), target.float()
66 | data, target = data.to(device), target.to(device)
67 | output = model(data)
68 |
69 | optimizer.zero_grad()
70 |
71 | # loss = nn.CrossEntropyLoss()(output,target)
72 | # loss=metrics.SoftDiceLoss()(output,target)
73 | # loss=nn.MSELoss()(output,target)
74 | loss = metrics.DiceMeanLoss()(output, target)
75 | # loss=metrics.WeightDiceLoss()(output,target)
76 | # loss=metrics.CrossEntropy()(output,target)
77 | loss.backward()
78 | optimizer.step()
79 |
80 | train_loss = loss
81 | train_dice0 = metrics.dice(output, target, 0)
82 | train_dice1 = metrics.dice(output, target, 1)
83 | train_dice2 = metrics.dice(output, target, 2)
84 | print(
85 | 'Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}\tdice0: {:.6f}\tdice1: {:.6f}\tdice2: {:.6f}\tT: {:.6f}\tP: {:.6f}\tTP: {:.6f}'.format(
86 | epoch, batch_idx, len(train_loader),
87 | 100. * batch_idx / len(train_loader), loss.item(),
88 | train_dice0, train_dice1, train_dice2,
89 | metrics.T(output, target), metrics.P(output, target), metrics.TP(output, target)))
90 |
91 | train_dict['loss'].append(float(train_loss))
92 | train_dict['dice0'].append(float(train_dice0))
93 | train_dict['dice1'].append(float(train_dice1))
94 | train_dict['dice2'].append(float(train_dice2))
95 |
96 | logger.scalar_summary('train_loss', train_loss, epoch)
97 | logger.scalar_summary('train_dice0', train_dice0, epoch)
98 | logger.scalar_summary('train_dice1', train_dice1, epoch)
99 | logger.scalar_summary('train_dice2', train_dice2, epoch)
100 |
101 |
102 | def adjust_learning_rate(optimizer, epoch, args):
103 | """Sets the learning rate to the initial LR decayed by 10 every 30 epochs"""
104 | lr = args.lr * (0.1 ** (epoch // 30))
105 | for param_group in optimizer.param_groups:
106 | param_group['lr'] = lr
107 |
108 |
109 | if __name__ == '__main__':
110 | # torch.cuda.set_device(3)
111 | parser = argparse.ArgumentParser(description='PyTorch LIST')
112 | parser.add_argument('--epochs', type=int, default=50, metavar='N',
113 | help='number of epochs to train (default: 10)')
114 | parser.add_argument('--lr', type=float, default=0.01, metavar='LR',
115 | help='learning rate (default: 0.01)')
116 | parser.add_argument('--momentum', type=float, default=0.5, metavar='M',
117 | help='SGD momentum (default: 0.5)')
118 | parser.add_argument('--save-model', action='store_true', default=False,
119 | help='For Saving the current Model')
120 | args = parser.parse_args()
121 |
122 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # PyTorch v0.4.0
123 | model = UNet(1, [32, 48, 64, 96, 128], 3, net_mode='3d',conv_block=RecombinationBlock).to(device)
124 | init_util.print_network(model)
125 |
126 | model = nn.DataParallel(model, device_ids=[0, 1]) # multi-GPU
127 |
128 | reader = LITS_reader.LITS_reader(data_fix=False)
129 | train_set = LITS_DataSet.Lits_DataSet([16, 96, 96], 12, reader,0.5)
130 | val_set = LITS_DataSet.Lits_DataSet_val([16, 96, 96], 12, reader,0.5)
131 | train_loader=DataLoader(dataset=train_set,shuffle=True)
132 | val_loader=DataLoader(dataset=val_set,shuffle=True)
133 |
134 | optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum)
135 |
136 | train_dict = {'loss': [], 'dice0': [], 'dice1': [], 'dice2': []}
137 | val_dict = {'loss': [], 'dice0': [], 'dice1': [], 'dice2': []}
138 |
139 | logger = logger.Logger('./log')
140 | for epoch in range(1, args.epochs + 1):
141 | adjust_learning_rate(optimizer, epoch, args)
142 | train(model, train_loader, device, optimizer, epoch, train_dict, logger)
143 | val(model, val_loader, device, epoch, val_dict, logger)
144 | torch.save(model, 'model')
145 |
146 |
--------------------------------------------------------------------------------
/util.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import random
3 |
4 |
5 | def random_crop_2d(img, label, crop_size):
6 | random_x_max = img.shape[0] - crop_size[0]
7 | random_y_max = img.shape[1] - crop_size[1]
8 |
9 | if random_x_max < 0 or random_y_max < 0:
10 | return None
11 |
12 | x_random = random.randint(0, random_x_max)
13 | y_random = random.randint(0, random_y_max)
14 |
15 | crop_img = img[x_random:x_random + crop_size[0], y_random:y_random + crop_size[1]]
16 | crop_label = label[x_random:x_random + crop_size[0], y_random:y_random + crop_size[1]]
17 |
18 | return crop_img, crop_label
19 |
20 |
21 | def random_crop_3d(img, label, crop_size):
22 | random_x_max = img.shape[0] - crop_size[0]
23 | random_y_max = img.shape[1] - crop_size[1]
24 | random_z_max = img.shape[2] - crop_size[2]
25 |
26 | if random_x_max < 0 or random_y_max < 0 or random_z_max < 0:
27 | return None
28 |
29 | x_random = random.randint(0, random_x_max)
30 | y_random = random.randint(0, random_y_max)
31 | z_random = random.randint(0, random_z_max)
32 |
33 | crop_img = img[x_random:x_random + crop_size[0], y_random:y_random + crop_size[1], z_random:z_random + crop_size[2]]
34 | crop_label = label[x_random:x_random + crop_size[0], y_random:y_random + crop_size[1],
35 | z_random:z_random + crop_size[2]]
36 |
37 | return crop_img, crop_label
38 |
39 |
--------------------------------------------------------------------------------