├── .gitignore
├── .idea
├── Pytorch-IterativeFCN.iml
├── inspectionProfiles
│ └── profiles_settings.xml
├── misc.xml
├── modules.xml
├── vcs.xml
└── workspace.xml
├── README.md
├── __pycache__
├── iterativeFCN.cpython-36.pyc
└── model.cpython-36.pyc
├── data
├── __pycache__
│ ├── data_augmentation.cpython-36.pyc
│ ├── dataset.cpython-36.pyc
│ └── preprocessing.cpython-36.pyc
├── data_augmentation.py
├── dataset.py
└── preprocessing.py
├── eval.py
├── imgs
├── example_empty.png
├── example_normal.png
├── model.png
└── result.png
├── instance_segmentation.py
├── iterativeFCN.py
├── requirements.txt
├── test
├── test_dataset.py
└── test_iterativeFCN_summary.py
├── train.py
├── utils
├── __pycache__
│ ├── metrics.cpython-36.pyc
│ └── utils.cpython-36.pyc
├── metrics.py
└── utils.py
└── weights
└── IterativeFCN_pretrained.pth
/.gitignore:
--------------------------------------------------------------------------------
1 | /crop_isotropic_dataset
2 | /CSI_dataset
3 | /isotropic_dataset
4 | /test/samples
5 | /checkpoints
--------------------------------------------------------------------------------
/.idea/Pytorch-IterativeFCN.iml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
--------------------------------------------------------------------------------
/.idea/inspectionProfiles/profiles_settings.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/.idea/misc.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
--------------------------------------------------------------------------------
/.idea/modules.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
--------------------------------------------------------------------------------
/.idea/vcs.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/.idea/workspace.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
34 |
35 |
36 |
37 |
38 |
39 |
40 |
41 |
42 |
43 |
44 |
45 |
46 |
47 |
48 |
49 |
50 |
51 |
52 |
53 |
54 |
55 |
56 |
57 |
58 |
59 |
60 |
61 |
62 |
63 |
64 |
65 |
66 |
67 |
68 |
69 |
70 |
71 |
72 |
73 |
74 |
75 |
76 |
77 |
78 |
79 |
80 |
81 |
82 |
83 |
84 |
85 |
86 |
87 |
88 |
89 |
90 |
91 |
92 |
93 |
94 |
95 |
96 |
97 |
98 |
99 |
100 |
101 |
102 |
103 |
104 |
105 |
106 |
107 |
108 |
109 |
110 |
111 |
112 |
113 |
114 |
115 |
116 |
117 |
118 |
119 |
120 |
121 |
122 |
123 |
124 |
125 |
126 |
127 |
128 |
129 |
130 |
131 |
132 |
133 |
134 |
135 |
136 |
137 |
138 |
139 |
140 |
141 |
142 |
143 |
144 |
145 |
146 |
147 |
148 |
149 |
150 |
151 |
152 |
153 |
154 |
155 |
156 |
157 |
158 |
159 |
160 |
161 |
162 |
163 |
164 |
165 |
166 |
167 |
168 |
169 |
170 |
171 |
172 |
173 |
174 |
175 |
176 |
177 |
178 |
179 | 1579633550852
180 |
181 |
182 | 1579633550852
183 |
184 |
185 |
186 |
187 |
188 |
189 |
190 |
191 |
192 |
193 |
194 |
195 |
196 |
197 |
198 |
199 |
200 |
201 |
202 |
203 |
204 |
205 |
206 |
207 |
208 |
209 |
210 |
211 |
212 |
213 |
214 |
215 |
216 |
217 |
218 |
219 |
220 |
221 |
222 |
223 |
224 |
225 |
226 |
227 |
228 |
229 |
230 |
231 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Iterative fully convolutional neural networks for automatic vertebra segmentation
2 | This is a Pytorch implementation of the paper [Iterative fully convolutional neural networks for automatic vertebra segmentation](https://openreview.net/forum?id=S1NnlZnjG) accepted in MIDL2018. This paper provides and automatic mechanism for precise vertebrae segmentation on CT images. I create this project for polishing up my knowledge of deep learning in medical image. There is an updated version [Iterative fully convolutional neural networks for automatic vertebra segmentation and identification](https://arxiv.org/abs/1804.04383)in 2019 with similiar structure. For the reason of computational cost, I decided to implement the early version in 2018.
3 |
4 | ## Model
5 | This is the model illustration from the paper. This model contains a similiar shape like [3D U-Net](https://arxiv.org/abs/1606.06650) but with constant channel in every layer and a extend branch for classification propose. There are 2 inputs for this model, inclusing image patch and correspond instanace memory patches. Instance Memory is used to remind the model to segment the first 'unsegmented vertebrae' so as to make sure the vertebrae are segmented one by one.
6 |
7 | 
8 |
9 | ## Dataset and Pre-processsing
10 |
11 | ### 1. Dataset
12 | I choose one of the dataset used in the paper, The spine segmentation challenge in CSI2014. The dataset can be obtain in the Dataset 2 posted on [SpineWeb](http://spineweb.digitalimaginggroup.ca/spineweb/index.php?n=Main.Datasets#Dataset_2.3A_Spine_and_Vertebrae_Segmentation)
13 |
14 | ### 2. Data preprocessing
15 | The preprocessing steps of each CT images and corresponded masks(both train and test set) includes:
16 | * **Resample the images and masks to isotropic (1mm * 1mm * 1mm)**
17 | * **Calculate the weight penalty coefficient for each images via distance transform.**
18 | * **Crop the images and masks to remove the vertebrae that not have labels in masks.**
19 | * **Prepare the training patches, including "image patches", "instance memory patches", "mask patches" and "weight patches".**
20 |
21 | ### 3. Illustration of training patches.
22 | A normal set of a training patches is showned as follows:
23 |
24 | 
25 |
26 | Since our model using slide window to segment the vertebrae, we need to teach it to produce empty prediction when their is no vertebrae in the image or all vertebrae are segmented and recorded in instnace memory:
27 |
28 | 
29 |
30 | ## Training Detail
31 | I apply the same setting as suggested in papers:
32 | * **Batch-size = 1 due to GPU memory limitation.**
33 | * **Adam with learning rate = 1e-3**
34 | * **Apply data augmentation via elastic deformation, gaussain blur, gaussian noise, random crop along z-axis**
35 | * **Produce empty mask training example every 5th iteratons.**
36 |
37 | I trained this model on Google Colab, which has similiar CUDA Memory(12GB) with NVIDIA TITANX. The provided [pretrained weight](https://github.com/leohsuofnthu/Pytorch-IterativeFCN/tree/master/weights) here is trained only with around 25000 iterations. The initial learning rate at 1e-3 from 1 to 10000 iterations, 1e-4 for 10001 to 20000 and 1e-5 for the rest of iterations, which is different from paper that using 1e-3 for whole training.
38 |
39 | ## Segmentation Result
40 | The following are some segmentation result from both train and test data.
41 |
42 | ### (1)Visual Result
43 | 
44 |
45 | ### (2)Averge Dice Coefficient
46 | | Result | Paper |
47 | | ------------- | ------------- |
48 | | 0.918 | 0.958 |
49 |
50 | P.S. None of refine technique for preprocessing and postprocessing are used in this repo.
51 |
52 | ## Usage
53 | ### Setup the Environment
54 | The requirment.txt are provided in the repo
55 | ```bash
56 | pip install -r requirements.txt
57 | ```
58 |
59 | ### Preprocessing the CSI dataset
60 | ```bash
61 | python -m data.preprocessing --dataset 'the root path of CSI dataset'
62 | ```
63 |
64 | ### Start Training
65 | ```bash
66 | python train.py --dataset 'the directory of preprocessed CSI dataset'
67 | ```
68 |
69 | ### Instance Segmentation
70 | ```bash
71 | python instance_segmentation.py --test_dir 'the directory of test images' --weights 'pretrained weights'
72 | ```
73 |
74 | ### Evaluation the Dice Coefficient with labels
75 | ```bash
76 | python eval.py --label_dir 'directory of test labels' --pred_dir 'the directory of prediction segmetnation'
77 | ```
78 |
79 | ## Authors
80 |
81 | * **HSU, CHIH-CHAO** - *Professional Machine Learning Master Student at [Mila](https://mila.quebec/)*
82 |
83 | ## Reference
84 | Thanks to the information from following sources and kind answer from the paper authors:
85 |
86 | * https://www.youtube.com/watch?v=0we-WooGqxw
87 | * https://gist.github.com/erniejunior/601cdf56d2b424757de5
88 | * https://github.com/SimpleITK/SimpleITK/issues/561
89 |
--------------------------------------------------------------------------------
/__pycache__/iterativeFCN.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/leohsuofnthu/Pytorch-IterativeFCN/c9a1094bb4cb26ff23b3a11fdda3abbd53cd7ad7/__pycache__/iterativeFCN.cpython-36.pyc
--------------------------------------------------------------------------------
/__pycache__/model.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/leohsuofnthu/Pytorch-IterativeFCN/c9a1094bb4cb26ff23b3a11fdda3abbd53cd7ad7/__pycache__/model.cpython-36.pyc
--------------------------------------------------------------------------------
/data/__pycache__/data_augmentation.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/leohsuofnthu/Pytorch-IterativeFCN/c9a1094bb4cb26ff23b3a11fdda3abbd53cd7ad7/data/__pycache__/data_augmentation.cpython-36.pyc
--------------------------------------------------------------------------------
/data/__pycache__/dataset.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/leohsuofnthu/Pytorch-IterativeFCN/c9a1094bb4cb26ff23b3a11fdda3abbd53cd7ad7/data/__pycache__/dataset.cpython-36.pyc
--------------------------------------------------------------------------------
/data/__pycache__/preprocessing.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/leohsuofnthu/Pytorch-IterativeFCN/c9a1094bb4cb26ff23b3a11fdda3abbd53cd7ad7/data/__pycache__/preprocessing.cpython-36.pyc
--------------------------------------------------------------------------------
/data/data_augmentation.py:
--------------------------------------------------------------------------------
1 | import random
2 | import numpy as np
3 | from scipy import ndimage
4 | from scipy.ndimage.filters import gaussian_filter
5 | from scipy.ndimage.interpolation import map_coordinates
6 | from skimage.transform import resize
7 |
8 |
9 | def elastic_transform(image, mask, ins, weight, alpha, sigma, random_state=None):
10 | """Elastic deformation of images as described in [Simard2003]_.
11 | .. [Simard2003] Simard, Steinkraus and Platt, "Best Practices for
12 | Convolutional Neural Networks applied to Visual Document Analysis", in
13 | Proc. of the International Conference on Document Analysis and
14 | Recognition, 2003.
15 |
16 | Modified from: https://gist.github.com/erniejunior/601cdf56d2b424757de5
17 | """
18 | if random_state is None:
19 | random_state = np.random.RandomState(None)
20 |
21 | shape = image.shape
22 | dx = gaussian_filter((random_state.rand(*shape) * 2 - 1), sigma, mode="constant", cval=0) * alpha
23 | dy = gaussian_filter((random_state.rand(*shape) * 2 - 1), sigma, mode="constant", cval=0) * alpha
24 | dz = gaussian_filter((random_state.rand(*shape) * 2 - 1), sigma, mode="constant", cval=0) * alpha
25 |
26 | x, y, z = np.meshgrid(np.arange(shape[0]), np.arange(shape[1]), np.arange(shape[2]))
27 | indices = np.reshape(y + dy, (-1, 1)), np.reshape(x + dx, (-1, 1)), np.reshape(z + dz, (-1, 1))
28 |
29 | distored_image = map_coordinates(image, indices, order=1, mode='reflect')
30 | distored_mask = map_coordinates(mask, indices, order=1, mode='reflect')
31 | distorted_ins = map_coordinates(ins, indices, order=1, mode='reflect')
32 | distorted_weight = map_coordinates(weight, indices, order=1, mode='reflect')
33 |
34 | return distored_image.reshape(image.shape), distored_mask.reshape(mask.shape), distorted_ins.reshape(
35 | ins.shape), distorted_weight.reshape(weight.shape)
36 |
37 |
38 | def gaussian_blur(image):
39 | return gaussian_filter(image, sigma=1)
40 |
41 |
42 | def gaussian_noise(image):
43 | mean = 0
44 | var = 0.1
45 | sigma = var ** 0.5
46 | gauss = 50 * np.random.normal(mean, sigma, image.shape)
47 | gauss = gauss.reshape(image.shape)
48 | return image + gauss
49 |
50 |
51 | def rotate(image, ins, gt, weight):
52 | degree = [90, 180, 270]
53 | d = degree[random.randint(0, len(degree) - 1)]
54 | rotate_img = ndimage.rotate(image, d, (1, 2), reshape=False)
55 | rotate_ins = ndimage.rotate(ins, d, (1, 2), reshape=False)
56 | rotate_gt = ndimage.rotate(gt, d, (1, 2), reshape=False)
57 | rotate_weight = ndimage.rotate(weight, d, (1, 2), reshape=False)
58 | return rotate_img, rotate_ins, rotate_gt, rotate_weight
59 |
60 |
61 | def random_crop(image, ins, gt, weight, depth=80):
62 | out_shape = (128, 128, 128)
63 | start = random.randint(0, image.shape[0] - depth - 1)
64 |
65 | image = crop_z(image, start, start + depth)
66 | ins = crop_z(ins, start, start + depth)
67 | gt = crop_z(gt, start, start + depth)
68 | weight = crop_z(weight, start, start + depth)
69 |
70 | crop_img = resize(image, out_shape, order=1, preserve_range=True)
71 | crop_ins = resize(ins, out_shape, order=0, preserve_range=True)
72 | crop_gt = resize(gt, out_shape, order=0, preserve_range=True)
73 | crop_weight = resize(weight, out_shape, order=1, preserve_range=True)
74 |
75 | return crop_img, crop_ins, crop_gt, crop_weight
76 |
77 |
78 | def crop_z(arr, start, end):
79 | return arr[start:end]
80 |
--------------------------------------------------------------------------------
/data/dataset.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | from torch.utils.data import Dataset
4 | import SimpleITK as sitk
5 |
6 | from utils.utils import extract_random_patch
7 |
8 |
9 | class CSIDataset(Dataset):
10 | """MICCAI 2014 Spine Challange Dataset"""
11 |
12 | def __init__(self,
13 | dataset_path,
14 | subset='train',
15 | empty_interval=5,
16 | flag_patch_norm=False,
17 | flag_linear=False,
18 | linear_att=1.0,
19 | offset=1000.0):
20 |
21 | self.idx = 1
22 | self.empty_interval = empty_interval
23 | self.flag_linear = flag_linear
24 | self.flag_patch_norm = flag_patch_norm
25 |
26 | self.dataset_path = dataset_path
27 | self.subset = subset
28 | self.linear_att = linear_att
29 | self.offset = offset
30 |
31 | self.img_path = os.path.join(dataset_path, subset, 'img')
32 | self.mask_path = os.path.join(dataset_path, subset, 'seg')
33 | self.weight_path = os.path.join(dataset_path, subset, 'weight')
34 |
35 | self.img_names = [f for f in os.listdir(self.img_path) if f.endswith('.mhd')]
36 |
37 | def __len__(self):
38 | return len(self.img_names)
39 |
40 | def __getitem__(self, idx):
41 | img_name = self.img_names[idx]
42 | mask_name = self.img_names[idx].split('.')[0] + '_label.mhd'
43 | weight_name = self.img_names[idx].split('.')[0] + '_weight.nrrd'
44 |
45 | img_file = os.path.join(self.img_path, img_name)
46 | mask_file = os.path.join(self.mask_path, mask_name)
47 | weight_file = os.path.join(self.weight_path, weight_name)
48 |
49 | img = sitk.GetArrayFromImage(sitk.ReadImage(img_file))
50 | mask = sitk.GetArrayFromImage(sitk.ReadImage(mask_file))
51 | weight = sitk.GetArrayFromImage(sitk.ReadImage(weight_file))
52 |
53 | """
54 | linear transformation from 12bit reconstruction img to HU unit
55 | depend on the original data (CSI data value is from 0 ~ 4095)
56 | """
57 | if self.flag_linear:
58 | img = img * self.linear_att - self.offset
59 |
60 | # extract a traning patche
61 | img_patch, ins_patch, gt_patch, weight_patch, c_label = extract_random_patch(img,
62 | mask,
63 | weight,
64 | self.idx,
65 | self.subset,
66 | self.empty_interval)
67 |
68 | if self.flag_patch_norm:
69 | img_patch = (img_patch - img_patch.mean()) / img_patch.std()
70 |
71 | self.idx += 1
72 |
73 | return img_patch, ins_patch, gt_patch, weight_patch, c_label
74 |
--------------------------------------------------------------------------------
/data/preprocessing.py:
--------------------------------------------------------------------------------
1 | import os
2 | import re
3 | import argparse
4 | import logging
5 | from pathlib import Path
6 |
7 | import numpy as np
8 | from scipy import ndimage
9 | import SimpleITK as sitk
10 |
11 | logging.basicConfig(level=logging.info())
12 |
13 |
14 | # resample the CT images to isotropic
15 | def isotropic_resampler(input_path, output_path):
16 | raw_img = sitk.ReadImage(input_path)
17 | new_spacing = [1, 1, 1]
18 |
19 | resampler = sitk.ResampleImageFilter()
20 | resampler.SetInterpolator(sitk.sitkNearestNeighbor)
21 | resampler.SetOutputDirection(raw_img.GetDirection())
22 | resampler.SetOutputOrigin(raw_img.GetOrigin())
23 | resampler.SetOutputSpacing(new_spacing)
24 |
25 | orig_size = np.array(raw_img.GetSize(), dtype=np.int)
26 | orig_spacing = raw_img.GetSpacing()
27 | new_size = np.array([x * (y / z) for x, y, z in zip(orig_size, orig_spacing, new_spacing)])
28 | new_size = np.ceil(new_size).astype(np.int) # Image dimensions are in integers
29 | new_size = [int(s) for s in new_size]
30 | resampler.SetSize(new_size)
31 |
32 | isotropic_img = resampler.Execute(raw_img)
33 | sitk.WriteImage(isotropic_img, output_path, True)
34 |
35 |
36 | # Function for cropping
37 | def z_mid(mask, chosen_vert):
38 | indices = np.nonzero(mask == chosen_vert)
39 | lower = [np.min(i) for i in indices]
40 | upper = [np.max(i) for i in indices]
41 |
42 | return int((lower[0] + upper[0]) / 2)
43 |
44 |
45 | def findZRange(img, mask):
46 | # list available vertebrae
47 | verts = np.unique(mask)
48 |
49 | vert_low = verts[1]
50 | vert_up = verts[-1]
51 |
52 | z_range = [z_mid(mask, vert_low), z_mid(mask, vert_up)]
53 | logging.info('Range of Z axis %s' % z_range)
54 | return z_range
55 |
56 |
57 | def crop_unref_vert(path, out_path, subset):
58 | img_path = os.path.join(path, subset, 'img')
59 | mask_path = os.path.join(path, subset, 'seg')
60 | weight_path = os.path.join(path, subset, 'weight')
61 | img_names = [f for f in os.listdir(img_path) if f.endswith('.mhd')]
62 |
63 | for img_name in img_names:
64 | logging.info('Cropping non-reference vertebrae of %s' % img_name)
65 | img_name = img_name
66 | mask_name = img_name.split('.')[0] + '_label.mhd'
67 | weight_name = img_name.split('.')[0] + '_weight.nrrd'
68 |
69 | img_file = os.path.join(img_path, img_name)
70 | mask_file = os.path.join(mask_path, mask_name)
71 | weight_file = os.path.join(weight_path, weight_name)
72 |
73 | img = sitk.GetArrayFromImage(sitk.ReadImage(img_file))
74 | mask = sitk.GetArrayFromImage(sitk.ReadImage(mask_file))
75 | weight = sitk.GetArrayFromImage(sitk.ReadImage(weight_file))
76 |
77 | z_range = findZRange(img, mask)
78 |
79 | sitk.WriteImage(sitk.GetImageFromArray(img[z_range[0]:z_range[1], :, :]),
80 | os.path.join(out_path, subset, 'img', img_name), True)
81 | sitk.WriteImage(sitk.GetImageFromArray(mask[z_range[0]:z_range[1], :, :]),
82 | os.path.join(out_path, subset, 'seg', mask_name), True)
83 | sitk.WriteImage(sitk.GetImageFromArray(weight[z_range[0]:z_range[1], :, :]),
84 | os.path.join(out_path, subset, 'weight', weight_name), True)
85 |
86 |
87 | # calculate the weight via distance transform
88 | def compute_distance_weight_matrix(mask, alpha=1, beta=8, omega=6):
89 | """
90 | Code from author : Dr.Lessman (nikolas.lessmann@radboudumc.nl)
91 | """
92 | mask = np.asarray(mask)
93 | distance_to_border = ndimage.distance_transform_edt(mask > 0) + ndimage.distance_transform_edt(mask == 0)
94 | weights = alpha + beta * np.exp(-(distance_to_border ** 2 / omega ** 2))
95 | return np.asarray(weights, dtype='float32')
96 |
97 |
98 | def calculate_weight(isotropic_path, subset):
99 | mask_path = os.path.join(isotropic_path, subset, 'seg')
100 | weight_path = os.path.join(isotropic_path, subset, 'weight')
101 |
102 | Path(mask_path).mkdir(parents=True, exist_ok=True)
103 | Path(weight_path).mkdir(parents=True, exist_ok=True)
104 |
105 | for f in [f for f in os.listdir(mask_path) if f.endswith('.mhd')]:
106 | seg_mask = sitk.GetArrayFromImage(sitk.ReadImage(os.path.join(mask_path, f)))
107 | weight = compute_distance_weight_matrix(seg_mask)
108 | sitk.WriteImage(sitk.GetImageFromArray(weight), os.path.join(weight_path, f.split('_')[0] + '_weight.nrrd'),
109 | True)
110 | logging.info("Calculating weight of %s" % f)
111 |
112 |
113 | def create_folders(root, subsets, folders):
114 | for subset in subsets:
115 | for f in folders:
116 | Path(os.path.join(root, subset, f)).mkdir(parents=True, exist_ok=True)
117 |
118 |
119 | def main():
120 | parser = argparse.ArgumentParser(description='iterativeFCN')
121 | parser.add_argument('--dataset', type=str, default='./CSI_dataset', help='root path of CSI dataset ')
122 | parser.add_argument('--output_isotropic', type=str, default='./isotropic_dataset',
123 | help='output path for isotropic images')
124 | parser.add_argument('--output_crop', type=str, default='./crop_isotropic_dataset',
125 | help='output path for crop samples')
126 | parser.add_argument('--split_ratio', type=float, default=0.8, help='ratio of train/test')
127 | args = parser.parse_args()
128 |
129 | # split data into train test folder
130 | folders = ['img', 'seg', 'weight']
131 | subsets = ['train', 'test']
132 | create_folders(args.output_isotropic, subsets, folders)
133 | create_folders(args.output_crop, subsets, folders)
134 |
135 | # resample the CSI dataset to isotropic dataset
136 | files = [x for x in os.listdir(os.path.join(args.dataset)) if 'raw' not in x]
137 | for f in files:
138 | case_id = re.findall(r'\d+', f)[0]
139 | logging.info('Resampling ' + f + '...')
140 | if int(case_id) < int(len(files)/2 * args.split_ratio):
141 | if '_label' in f:
142 | file_output = os.path.join(args.output_isotropic, 'train/seg', f)
143 | else:
144 | file_output = os.path.join(args.output_isotropic, 'train/img', f)
145 | else:
146 | if '_label' in f:
147 | file_output = os.path.join(args.output_isotropic, 'test/seg', f)
148 | else:
149 | file_output = os.path.join(args.output_isotropic, 'test/img', f)
150 |
151 | isotropic_resampler(os.path.join(args.dataset, f), file_output)
152 |
153 | # Pre Calculate the weight
154 | calculate_weight(args.output_isotropic, 'train')
155 | calculate_weight(args.output_isotropic, 'test')
156 |
157 | # Crop the image to remove the vertebrae that are not labeled in ground truth
158 | crop_unref_vert(args.output_isotropic, args.output_crop, 'train')
159 | crop_unref_vert(args.output_isotropic, args.output_crop, 'test')
160 |
161 |
162 | if __name__ == '__main__':
163 | main()
164 |
--------------------------------------------------------------------------------
/eval.py:
--------------------------------------------------------------------------------
1 | import os
2 | import argparse
3 | import logging
4 | import numpy as np
5 | import SimpleITK as sitk
6 | from medpy.metric.binary import dc
7 |
8 | logging.basicConfig(level=logging.INFO)
9 |
10 |
11 | def main():
12 | parser = argparse.ArgumentParser(description='Iterative Fully Convolutional Network')
13 | parser.add_argument('--label_dir', type=str, default='./crop_isotropic_dataset/test/seg',
14 | help='folder of test label')
15 | parser.add_argument('--pred_dir', type=str, default='./pred',
16 | help='folder of pred masks')
17 | args = parser.parse_args()
18 |
19 | labels = [os.path.join(args.label_dir, x) for x in os.listdir(os.path.join(args.label_dir)) if 'raw' not in x]
20 | preds = [os.path.join(args.pred_dir, x) for x in os.listdir(os.path.join(args.pred_dir)) if 'raw' not in x]
21 |
22 | n = 0
23 | avg_dc = 0.
24 | for l, p in zip(labels, preds):
25 | logging.info("Process %s and %s" % (p, l))
26 | label = sitk.GetArrayFromImage(sitk.ReadImage(l))
27 | pred = sitk.GetArrayFromImage(sitk.ReadImage(p))
28 | for i in np.unique(label):
29 | l = label[label == i]
30 | p = pred[label == i]
31 | l[l > 0] = 1
32 | p[p > 0] = 1
33 | avg_dc += dc(p, l)
34 | n += 1
35 |
36 | logging.info("Average Dice Coefficient for %s individual vertebrae test : %s" % (n, avg_dc / n))
37 |
38 |
39 | if __name__ == '__main__':
40 | main()
41 |
--------------------------------------------------------------------------------
/imgs/example_empty.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/leohsuofnthu/Pytorch-IterativeFCN/c9a1094bb4cb26ff23b3a11fdda3abbd53cd7ad7/imgs/example_empty.png
--------------------------------------------------------------------------------
/imgs/example_normal.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/leohsuofnthu/Pytorch-IterativeFCN/c9a1094bb4cb26ff23b3a11fdda3abbd53cd7ad7/imgs/example_normal.png
--------------------------------------------------------------------------------
/imgs/model.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/leohsuofnthu/Pytorch-IterativeFCN/c9a1094bb4cb26ff23b3a11fdda3abbd53cd7ad7/imgs/model.png
--------------------------------------------------------------------------------
/imgs/result.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/leohsuofnthu/Pytorch-IterativeFCN/c9a1094bb4cb26ff23b3a11fdda3abbd53cd7ad7/imgs/result.png
--------------------------------------------------------------------------------
/instance_segmentation.py:
--------------------------------------------------------------------------------
1 | import os
2 | import logging
3 | import argparse
4 | from pathlib import Path
5 |
6 | import torch
7 | import numpy as np
8 | from scipy import ndimage
9 |
10 | import SimpleITK as sitk
11 | from iterativeFCN import IterativeFCN
12 |
13 | logging.basicConfig(level=logging.INFO)
14 |
15 |
16 | def extract(img, x, y, z, patch_size):
17 | offset = int(patch_size / 2)
18 | return img[z - offset:z + offset, y - offset:y + offset, x - offset:x + offset]
19 |
20 |
21 | def instance_segmentation(model, img_name, patch_size, sigma_x, lim_alternate_times, n_min, output_path):
22 | step = int(patch_size / 2)
23 | img = sitk.GetArrayFromImage(sitk.ReadImage(img_name))
24 | ins = np.zeros_like(img)
25 | mask = np.zeros_like(img)
26 | img_shape = img.shape
27 |
28 | # slide window with initial center coord
29 | patch_size = 128
30 | z = int(img.shape[0] - (patch_size / 2))
31 | y = int(patch_size / 2)
32 | x = int(patch_size / 2)
33 |
34 | # x_t-1, x_t
35 | c_now = [z, y, x]
36 | c_prev = [0, 0, 0]
37 |
38 | label = 100
39 | iters = 0
40 | ii = 0
41 | # slide window check
42 | logging.info('Start Instance Segmentation')
43 | while True:
44 |
45 | logging.info('(Z, Y, X) Now: (%s, %s, %s)' % (z, y, x))
46 | if abs(x - patch_size / 2) < sigma_x and abs(y - patch_size / 2) < sigma_x and abs(
47 | z - patch_size / 2) < sigma_x:
48 | break
49 |
50 | # extract patch and instance memory
51 | img_patch = torch.tensor(np.expand_dims(extract(img, x, y, z, 128), axis=0))
52 | ins_patch = torch.tensor(np.expand_dims(extract(ins, x, y, z, 128), axis=0))
53 |
54 | input_patch = torch.cat((img_patch, ins_patch))
55 | input_patch = torch.unsqueeze(input_patch, dim=0)
56 |
57 | with torch.no_grad():
58 | S, C = model(input_patch.float().to('cuda'))
59 |
60 | S = torch.squeeze(S.round().to('cpu')).numpy()
61 | vol = np.count_nonzero(S)
62 |
63 | ii += 1
64 | # check if vol > min_threshold
65 | if vol > n_min:
66 | c_prev[0] = c_now[0]
67 | c_prev[1] = c_now[1]
68 | c_prev[2] = c_now[2]
69 |
70 | center = ndimage.measurements.center_of_mass(S)
71 | center = [int(center[0]), int(center[1]), int(center[2])]
72 | logging.info('Center relative to patch:%s' % center)
73 |
74 | c_now[0] = z + (patch_size / 2) - (patch_size - center[0])
75 | c_now[1] = y - (patch_size / 2) + center[1]
76 | c_now[2] = x - (patch_size / 2) + center[2]
77 | logging.info('Global Center:%s' % c_now)
78 |
79 | # correction to be in-frame
80 | if (c_now[0] + patch_size / 2) > img.shape[0]:
81 | c_now[0] = img.shape[0] - (patch_size / 2)
82 |
83 | elif (c_now[0] - patch_size / 2) < 0:
84 | c_now[0] = (patch_size / 2)
85 |
86 | if (c_now[1] + patch_size / 2) > img.shape[1]:
87 | c_now[1] = img.shape[1] - (patch_size / 2)
88 |
89 | elif (c_now[1] - patch_size / 2) < 0:
90 | c_now[1] = (patch_size / 2)
91 |
92 | if (c_now[2] + patch_size / 2) > img.shape[2]:
93 | c_now[2] = img.shape[2] - (patch_size / 2)
94 |
95 | elif (c_now[2] - patch_size / 2) < 0:
96 | c_now[2] = (patch_size / 2)
97 |
98 | c_now[0] = int(c_now[0])
99 | c_now[1] = int(c_now[1])
100 | c_now[2] = int(c_now[2])
101 | logging.info('Modified center:%s' % c_now)
102 | logging.info('Prev center %s' % c_prev)
103 |
104 | if abs(c_now[0] - c_prev[0]) > sigma_x or abs(c_now[1] - c_prev[1]) > sigma_x or abs(
105 | c_now[2] - c_prev[2]) > sigma_x:
106 | iters += 1
107 | logging.info('Not converge iterations %s' % iters)
108 |
109 | if iters == lim_alternate_times:
110 | logging.info('iteration:%s' % lim_alternate_times)
111 | # pick avg and dim as converge
112 | c_now[0] = int((c_now[0] + c_prev[0]) / 2)
113 | c_now[1] = int((c_now[1] + c_prev[0]) / 2)
114 | c_now[2] = int((c_now[2] + c_prev[0]) / 2)
115 |
116 | logging.info('converge and seg')
117 | iters = 0
118 | # converge, update ins and mask
119 | z_low = int(c_now[0] - (patch_size / 2))
120 | z_up = int(c_now[0] + (patch_size / 2))
121 | y_low = int(c_now[1] - (patch_size / 2))
122 | y_up = int(c_now[1] + (patch_size / 2))
123 | x_low = int(c_now[2] - (patch_size / 2))
124 | x_up = int(c_now[2] + (patch_size / 2))
125 |
126 | r = S > 0
127 | ins[z_low:z_up, y_low:y_up, x_low:x_up][r] = 1
128 | mask[z_low:z_up, y_low:y_up, x_low:x_up][r] = label
129 |
130 | label += 100
131 | logging.info("seg {}th verts complete!!".format(label))
132 | else:
133 | logging.info('converge and seg')
134 | iters = 0
135 |
136 | # converge, update ins and mask
137 | z_low = int(c_now[0] - (patch_size / 2))
138 | z_up = int(c_now[0] + (patch_size / 2))
139 | y_low = int(c_now[1] - (patch_size / 2))
140 | y_up = int(c_now[1] + (patch_size / 2))
141 | x_low = int(c_now[2] - (patch_size / 2))
142 | x_up = int(c_now[2] + (patch_size / 2))
143 |
144 | r = S > 0
145 | ins[z_low:z_up, y_low:y_up, x_low:x_up][r] = 1
146 | mask[z_low:z_up, y_low:y_up, x_low:x_up][r] = label
147 |
148 | label += 100
149 | logging.info("seg {}th verts complete!!".format(label))
150 |
151 | # same patch analyze again, center remain
152 | z = c_now[0]
153 | y = c_now[1]
154 | x = c_now[2]
155 | else:
156 | logging.info('slide window')
157 | # continue slide windows
158 | if x + step > img_shape[2]:
159 | x = int(patch_size / 2)
160 | if y + step > img_shape[1]:
161 | y = int(patch_size / 2)
162 | z = z - step
163 | else:
164 | y = y + step
165 | else:
166 | x = x + step
167 |
168 | logging.info('Finish Segmentation!')
169 | sitk.WriteImage(sitk.GetImageFromArray(mask), output_path, True)
170 |
171 |
172 | def main():
173 | parser = argparse.ArgumentParser(description='Iterative Fully Convolutional Network')
174 | parser.add_argument('--test_dir', type=str, default='./crop_isotropic_dataset/test/img',
175 | help='folder of test images')
176 | parser.add_argument('--output_dir', type=str, default='./pred',
177 | help='folder of pred masks')
178 | parser.add_argument('--weights', type=str, default='./weights/IterativeFCN_best_train.pth',
179 | help='trained weights of model')
180 | parser.add_argument('--patch_size', type=int, default=128,
181 | help='patch_size of the model')
182 | parser.add_argument('--sigma', type=int, default=2,
183 | help='patch_size of the model')
184 | parser.add_argument('--min_vol', type=int, default=1000,
185 | help='min volume threshold')
186 | parser.add_argument('--max_alter', type=int, default=20,
187 | help='max alternation of 2 centers')
188 | args = parser.parse_args()
189 |
190 | # Create FCN
191 | logging.info('Create Model and Loading Pretrained Weights')
192 | model = IterativeFCN().to('cuda')
193 | model.load_state_dict(torch.load(args.weights))
194 |
195 | # list the test images
196 | Path(args.output_dir).mkdir(parents=True, exist_ok=True)
197 | test_imgs = [x for x in os.listdir(os.path.join(args.test_dir)) if 'raw' not in x]
198 | for img in test_imgs:
199 | logging.info("Processing image: %s", img)
200 | output_path = os.path.join(args.output_dir, img.split('.')[0]+'_pred.nrrd')
201 | instance_segmentation(model, os.path.join(args.test_dir, img), args.patch_size, args.sigma, args.max_alter, args.min_vol, output_path)
202 |
203 |
204 | if __name__ == '__main__':
205 | main()
206 |
--------------------------------------------------------------------------------
/iterativeFCN.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 |
5 | class IterativeFCN(nn.Module):
6 | """
7 | Structure of Iterative FCN Model
8 |
9 | Still need to convert to enable parallel training
10 | """
11 |
12 | def consecutive_conv(self, in_channels, out_channels):
13 | return nn.Sequential(
14 | nn.Conv3d(in_channels, out_channels, 3, padding=1),
15 | nn.ReLU(inplace=True),
16 | nn.BatchNorm3d(out_channels),
17 | nn.Conv3d(out_channels, out_channels, 3, padding=1),
18 | nn.ReLU(inplace=True),
19 | nn.BatchNorm3d(out_channels))
20 |
21 | def __init__(self, num_channels=64):
22 | super(IterativeFCN, self).__init__()
23 |
24 | self.conv_initial = self.consecutive_conv(2, num_channels)
25 |
26 | self.conv_final = self.consecutive_conv(num_channels, 1)
27 |
28 | self.conv_rest = self.consecutive_conv(num_channels, num_channels)
29 |
30 | self.conv_up = self.consecutive_conv(num_channels * 2, num_channels)
31 |
32 | self.contract = nn.MaxPool3d(2, stride=2)
33 |
34 | self.expand = nn.Upsample(scale_factor=2)
35 |
36 | self.dense = nn.Linear(num_channels, 1)
37 |
38 | def forward(self, x):
39 | # 2*128*128*128 to 64*128*128*128
40 | x_128 = self.conv_initial(x)
41 |
42 | # 64*128*128*128 to 64*64*64*64
43 | x_128 = self.conv_rest(x_128)
44 | x_64 = self.contract(x_128)
45 |
46 | # 64*64*64*64 to 64*32*32*32
47 | x_64 = self.conv_rest(x_64)
48 | x_32 = self.contract(x_64)
49 |
50 | # 64*32*32*32 to 64*16*16*16
51 | x_32 = self.conv_rest(x_32)
52 | x_16 = self.contract(x_32)
53 |
54 | # 64*16*16*16 to 64*8*8*8
55 | x_16 = self.conv_rest(x_16)
56 |
57 | # upsmapling path
58 | u_32 = self.expand(x_16)
59 | u_32 = self.conv_up(torch.cat((x_32, u_32), 1))
60 |
61 | u_64 = self.expand(u_32)
62 | u_64 = self.conv_up(torch.cat((x_64, u_64), 1))
63 |
64 | u_128 = self.expand(u_64)
65 | u_128 = self.conv_up(torch.cat((x_128, u_128), 1))
66 |
67 | u_128 = self.conv_final(u_128)
68 |
69 | # classification path
70 | x_8 = self.conv_rest(self.contract(x_16))
71 |
72 | x_4 = self.conv_rest((self.contract(x_8)))
73 |
74 | x_2 = self.conv_rest(self.contract(x_4))
75 |
76 | x_1 = self.contract(x_2)
77 |
78 | seg = torch.sigmoid(u_128)
79 | cls = torch.sigmoid(self.dense(torch.flatten(x_1)))
80 |
81 | return seg, cls
82 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | matplotlib==3.1.1
2 | numpy==1.16.5
3 | scipy==1.3.1
4 | torchsummary==1.5.1
5 | MedPy==0.4.0
6 | torch==1.2.0
7 | scikit_image==0.15.0
8 | SimpleITK==1.2.4
9 | skimage==0.0
10 |
--------------------------------------------------------------------------------
/test/test_dataset.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import SimpleITK as sitk
3 | from pathlib import Path
4 | from data.dataset import CSIDataset
5 | from torch.utils.data import Dataset, DataLoader
6 |
7 | crop_img = '../crop_isotropic_dataset'
8 | batch_size = 1
9 |
10 | train_dataset = CSIDataset(crop_img)
11 | train_dataloader = DataLoader(train_dataset, batch_size=1, shuffle=True)
12 |
13 | img_patch, ins_patch, gt_patch, weight, c_label = next(iter(train_dataloader))
14 |
15 | img_patch = torch.squeeze(img_patch)
16 | ins_patch = torch.squeeze(ins_patch)
17 | gt_patch = torch.squeeze(gt_patch)
18 | weight = torch.squeeze(weight)
19 |
20 | assert img_patch.shape == (128, 128, 128)
21 | assert ins_patch.shape == (128, 128, 128)
22 | assert gt_patch.shape == (128, 128, 128)
23 | assert weight.shape == (128, 128, 128)
24 |
25 | # store patches for visualization
26 | Path('./samples/').mkdir(parents=True, exist_ok=True)
27 | sitk.WriteImage(sitk.GetImageFromArray(img_patch.numpy()), './samples/img.nrrd', True)
28 | sitk.WriteImage(sitk.GetImageFromArray(gt_patch.numpy()), './samples/gt.nrrd', True)
29 | sitk.WriteImage(sitk.GetImageFromArray(ins_patch.numpy()), './samples/ins.nrrd', True)
30 | sitk.WriteImage(sitk.GetImageFromArray(weight.numpy()), './samples/wei.nrrd', True)
--------------------------------------------------------------------------------
/test/test_iterativeFCN_summary.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torchsummary import summary
3 | from iterativeFCN import IterativeFCN
4 |
5 | # Test Purpose
6 | model = IterativeFCN(num_channels=11)
7 | if torch.cuda.is_available():
8 | model.cuda()
9 | summary(model, (2, 128, 128, 128))
10 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | import os
2 | import time
3 | import argparse
4 | import numpy as np
5 | import logging
6 |
7 | import matplotlib.pyplot as plt
8 | import torch
9 | import torch.nn.functional as F
10 | import torch.optim as optim
11 | from torch.utils.data import DataLoader
12 |
13 | from data.dataset import CSIDataset
14 | from utils.metrics import DiceCoeff, Segloss
15 | from iterativeFCN import IterativeFCN
16 |
17 | logging.basicConfig(
18 | format='%(asctime)s : %(levelname)s : %(message)s',
19 | level=logging.INFO,
20 | datefmt="%Y-%m-%d %H:%M:%S"
21 | )
22 |
23 |
24 | def train_single(model, device, img_patch, ins_patch, gt_patch, weight, c_label, optimizer):
25 | torch.cuda.empty_cache()
26 |
27 | model.train()
28 | correct = 0
29 |
30 | # convert data to float, just in case
31 | img_patch = img_patch.float()
32 | ins_patch = ins_patch.float()
33 | gt_patch = gt_patch.float()
34 | weight = weight.float()
35 | c_label = c_label.float()
36 |
37 | # pick a random scan
38 | optimizer.zero_grad()
39 |
40 | # concatenate the img_patch and ins_patch
41 | input_patch = torch.cat((img_patch, ins_patch), dim=1)
42 | input_patch, gt_patch, weight, c_label = input_patch.to(device), gt_patch.to(device), weight.to(device), c_label.to(
43 | device)
44 |
45 | S, C = model(input_patch.float())
46 |
47 | # calculate dice coefficient
48 | pred = torch.round(S).detach()
49 | train_dice_coef = DiceCoeff(pred, gt_patch.detach())
50 |
51 | # calculate total loss
52 | lamda = 0.1
53 | FP, FN = Segloss(S, gt_patch, weight)
54 | s_loss = lamda * FP + FN
55 | c_loss = F.binary_cross_entropy(torch.unsqueeze(C, dim=0), c_label)
56 | train_loss = s_loss + c_loss
57 |
58 | logging.info("train_dice_coef: %s, S Loss: %s, C Loss: %s" % (train_dice_coef, s_loss.item(), c_loss.item()))
59 |
60 | if C.round() == c_label:
61 | correct = 1
62 |
63 | # optimize the parameters
64 | train_loss.backward()
65 | optimizer.step()
66 |
67 | return train_loss.item(), correct, train_dice_coef
68 |
69 |
70 | def test_single(model, device, img_patch, ins_patch, gt_patch, weight, c_label):
71 | torch.cuda.empty_cache()
72 |
73 | model.eval()
74 | correct = 0
75 |
76 | img_patch = img_patch.float()
77 | ins_patch = ins_patch.float()
78 | gt_patch = gt_patch.float()
79 | weight = weight.float()
80 | c_label = c_label.float()
81 |
82 | input_patch = torch.cat((img_patch, ins_patch), dim=1)
83 | input_patch, gt_patch, weight, c_label = input_patch.to(device), gt_patch.to(device), weight.to(device), c_label.to(
84 | device)
85 |
86 | with torch.no_grad():
87 | S, C = model(input_patch.float())
88 |
89 | # calculate dice coefficient
90 | pred = torch.round(S).detach()
91 | test_dice_coef = DiceCoeff(pred, gt_patch.detach())
92 |
93 | # calculate total loss
94 | lamda = 0.1
95 | FP, FN = Segloss(S, gt_patch, weight)
96 | s_loss = lamda * FP + FN
97 | c_loss = F.binary_cross_entropy(torch.unsqueeze(C, dim=0), c_label)
98 |
99 | logging.info("train_dice_coef: %s, S Loss: %s, C Loss: %s" % (test_dice_coef, s_loss.item(), c_loss.item()))
100 |
101 | if C.round() == c_label:
102 | correct = 1
103 |
104 | test_loss = s_loss + c_loss
105 |
106 | return test_loss.item(), correct, test_dice_coef
107 |
108 |
109 | if __name__ == "__main__":
110 | # Version of Pytorch
111 | logging.info("Pytorch Version:%s" % torch.__version__)
112 |
113 | # Training args
114 | parser = argparse.ArgumentParser(description='Iterative Fully Convolutional Network')
115 | parser.add_argument('--dataset', type=str, default='./crop_isotropic_dataset',
116 | help='path of processed dataset')
117 | parser.add_argument('--weight', type=str, default='./weights',
118 | help='path of processed dataset')
119 | parser.add_argument('--checkpoints', type=str, default='./checkpoints',
120 | help='path of training snapshot')
121 | parser.add_argument('--resume', type=bool, default=False,
122 | help='resume training by loading last snapshot')
123 | parser.add_argument('--batch-size', type=int, default=1, metavar='N',
124 | help='input batch size for training (default: 64)')
125 | parser.add_argument('--test-batch-size', type=int, default=1, metavar='N',
126 | help='input batch size for testing (default: 1)')
127 | parser.add_argument('--iterations', type=int, default=20, metavar='N',
128 | help='number of iterations to train (default: 80000)')
129 | parser.add_argument('--log_interval', type=int, default=2, metavar='N',
130 | help='number of iterations to log (default: 1000)')
131 | parser.add_argument('--eval_iters', type=int, default=1, metavar='N',
132 | help='number of iterations to train (default: 20)')
133 | parser.add_argument('--lr', type=float, default=1e-3, metavar='LR',
134 | help='learning rate (default: 0.01)')
135 | parser.add_argument('--seed', type=int, default=1, metavar='S',
136 | help='random seed (default: 1)')
137 | parser.add_argument('--save-model', action='store_true', default=True,
138 | help='For Saving the current Model')
139 | args = parser.parse_args()
140 |
141 | # set random seed for reproducibility
142 | torch.manual_seed(args.seed)
143 | np.random.seed(args.seed)
144 |
145 | # Use GPU if it is available
146 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
147 |
148 | # Create model and check if we want to resume training
149 | model = IterativeFCN(num_channels=4).to('cuda')
150 |
151 | batch_size = args.batch_size
152 | batch_size_valid = batch_size
153 |
154 | train_dataset = CSIDataset(args.dataset, subset='train')
155 | test_dataset = CSIDataset(args.dataset, subset='test')
156 |
157 | train_loader = DataLoader(train_dataset, batch_size=1, shuffle=True)
158 | test_loader = DataLoader(test_dataset, batch_size=1, shuffle=True)
159 |
160 | # optimizer
161 | optimizer = optim.Adam(model.parameters(), lr=args.lr)
162 |
163 | train_loss, test_loss = [], []
164 | train_acc, test_acc = [], []
165 | train_dice, test_dice = [], []
166 | best_train_loss, best_test_dice = 0., 0.
167 |
168 | total_iteration = args.iterations
169 | train_interval = args.log_interval
170 | eval_interval = args.eval_iters
171 |
172 | iteration = 1
173 |
174 | if args.resume:
175 | logging.info("Resume Training: Load states from latest checkpoint.")
176 | checkpoint = torch.load(os.path.join(args.checkpoints, 'latest_checkpoints.pth'))
177 | model.load_state_dict(checkpoint['model_state_dict'])
178 | optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
179 | iteration = checkpoint['iteration']
180 | train_loss = checkpoint['train_loss']
181 | test_loss = checkpoint['test_loss']
182 | train_acc = checkpoint['train_acc']
183 | test_acc = checkpoint['test_acc']
184 |
185 | # Start Training
186 | while iteration < args.iterations + 1:
187 | start_time = time.time()
188 | epoch_train_dice = []
189 | epoch_test_dice = []
190 | epoch_train_loss = []
191 | epoch_test_loss = []
192 | epoch_train_accuracy = 0.
193 | epoch_test_accuracy = 0.
194 | correct_train_count = 0
195 | correct_test_count = 0
196 |
197 | img_patch, ins_patch, gt_patch, weight, c_label = next(iter(train_loader))
198 | t_loss, t_c, t_dice = train_single(model, device, img_patch, ins_patch, gt_patch, weight, c_label, optimizer)
199 | epoch_train_loss.append(t_loss)
200 | epoch_train_dice.append(t_dice)
201 | correct_train_count += t_c
202 |
203 | if iteration > 1 and iteration % args.log_interval:
204 | avg_train_loss = sum(epoch_train_loss) / len(epoch_train_loss)
205 | avg_train_dice = (sum(epoch_train_dice) / len(epoch_train_dice)) * 100
206 | epoch_train_accuracy = (correct_train_count / train_interval) * 100
207 |
208 | logging.info('Iter {}-{}: \t Loss: {:.6f}\t acc: {:.6f}%\t dice: {:.6f}%'.format(
209 | iteration - args.log_interval,
210 | iteration,
211 | avg_train_loss,
212 | epoch_train_accuracy,
213 | avg_train_dice))
214 |
215 | if avg_train_loss < best_train_loss:
216 | best_train_loss = avg_train_loss
217 | logging.info('--- Saving model at Avg Train Dice:{:.2f}% ---'.format(avg_train_dice))
218 | torch.save(model.state_dict(), os.path.join(args.weight, '.IterativeFCN_best_train.pth'))
219 |
220 | # validation process
221 | for i in range(args.eval_iters):
222 | img_patch, ins_patch, gt_patch, weight, c_label = next(iter(test_loader))
223 | v_loss, v_c, v_dice = test_single(model, device, img_patch, ins_patch, gt_patch, weight, c_label)
224 | epoch_test_loss.append(v_loss)
225 | epoch_test_dice.append(v_dice)
226 | correct_test_count += v_c
227 |
228 | avg_test_loss = sum(epoch_test_loss) / len(epoch_test_loss)
229 | avg_test_dice = (sum(epoch_test_dice) / len(epoch_test_dice)) * 100
230 | epoch_test_accuracy = (correct_test_count / eval_interval) * 100
231 |
232 | logging.info('Iter {}-{} eval: \t Loss: {:.6f}\t acc: {:.6f}%\t dice: {:.6f}%'.format(
233 | iteration - args.log_interval,
234 | iteration,
235 | avg_test_loss,
236 | epoch_test_accuracy,
237 | avg_test_dice))
238 |
239 | if avg_test_dice > best_test_dice:
240 | best_test_dice = avg_test_dice
241 | logging.info('--- Saving model at Avg Train Dice:{:.2f}% ---'.format(avg_test_dice))
242 | torch.save(model.state_dict(), os.path.join(args.weight, './IterativeFCN_best_valid.pth'))
243 |
244 | train_loss.append(epoch_train_loss)
245 | test_loss.append(epoch_test_loss)
246 | train_acc.append(epoch_train_accuracy)
247 | test_acc.append(epoch_test_accuracy)
248 |
249 | # save snapshot for resume training
250 | logging.info('--- Saving snapshot ---')
251 | torch.save({
252 | 'iteration': iteration,
253 | 'model_state_dict': model.state_dict(),
254 | 'optimizer_state_dict': optimizer.state_dict(),
255 | 'train_loss': train_loss,
256 | 'test_loss': test_loss,
257 | 'train_acc': train_acc,
258 | 'test_acc': test_acc,
259 | 'train_dice': train_dice,
260 | 'test_dice': test_dice,
261 | 'best_train_loss': best_train_loss,
262 | 'best_test_dice': best_test_dice}, os.path.join(args.checkpoints, 'latest_checkpoints.pth'))
263 |
264 | logging.info("--- %s seconds ---" % (time.time() - start_time))
265 |
266 | iteration += 1
267 |
--------------------------------------------------------------------------------
/utils/__pycache__/metrics.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/leohsuofnthu/Pytorch-IterativeFCN/c9a1094bb4cb26ff23b3a11fdda3abbd53cd7ad7/utils/__pycache__/metrics.cpython-36.pyc
--------------------------------------------------------------------------------
/utils/__pycache__/utils.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/leohsuofnthu/Pytorch-IterativeFCN/c9a1094bb4cb26ff23b3a11fdda3abbd53cd7ad7/utils/__pycache__/utils.cpython-36.pyc
--------------------------------------------------------------------------------
/utils/metrics.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | from medpy.metric.binary import assd, dc
4 |
5 |
6 | def Segloss(pred, target, weight):
7 | FP = torch.sum(weight * (1 - target) * pred)
8 | FN = torch.sum(weight * (1 - pred) * target)
9 | return FP, FN
10 |
11 |
12 | def DiceCoeff(pred, gt):
13 | pred = pred.to('cpu').numpy()
14 | gt = gt.to('cpu').numpy()
15 |
16 | # if gt is all zero (use inverse to count)
17 | if np.count_nonzero(gt) == 0:
18 | gt = gt + 1
19 | pred = 1 - pred
20 |
21 | return dc(pred, gt)
--------------------------------------------------------------------------------
/utils/utils.py:
--------------------------------------------------------------------------------
1 | import random
2 | import numpy as np
3 | from data.data_augmentation import elastic_transform, gaussian_blur, gaussian_noise, random_crop
4 |
5 |
6 | def force_inside_img(x, patch_size, img_shape):
7 | x_low = int(x - patch_size / 2)
8 | x_up = int(x + patch_size / 2)
9 | if x_low < 0:
10 | x_up -= x_low
11 | x_low = 0
12 | elif x_up > img_shape[2]:
13 | x_low -= (x_up - img_shape[2])
14 | x_up = img_shape[2]
15 | return x_low, x_up
16 |
17 |
18 | def extract_random_patch(img, mask, weight, i, subset, empty_interval=5, patch_size=128):
19 | flag_empty = False
20 |
21 | # list available vertebrae
22 | verts = np.unique(mask)
23 | chosen_vert = verts[random.randint(1, len(verts) - 1)]
24 |
25 | # create corresponde instance memory and ground truth
26 | ins_memory = np.copy(mask)
27 | ins_memory[ins_memory <= chosen_vert] = 0
28 | ins_memory[ins_memory > 0] = 1
29 |
30 | gt = np.copy(mask)
31 | gt[gt != chosen_vert] = 0
32 | gt[gt > 0] = 1
33 |
34 | # send empty mask sample in certain frequency
35 | if i % empty_interval == 0:
36 | patch_center = [np.random.randint(0, s) for s in img.shape]
37 | x = patch_center[2]
38 | y = patch_center[1]
39 | z = patch_center[0]
40 |
41 | # for instance memory
42 | gt = np.copy(mask)
43 | flag_empty = True
44 | else:
45 | indices = np.nonzero(mask == chosen_vert)
46 | lower = [np.min(i) for i in indices]
47 | upper = [np.max(i) for i in indices]
48 | # random center of patch
49 | x = random.randint(lower[2], upper[2])
50 | y = random.randint(lower[1], upper[1])
51 | z = random.randint(lower[0], upper[0])
52 |
53 | # force random patches' range within the image
54 | x_low, x_up = force_inside_img(x, patch_size, img.shape)
55 | y_low, y_up = force_inside_img(y, patch_size, img.shape)
56 | z_low, z_up = force_inside_img(z, patch_size, img.shape)
57 |
58 | # crop the patch
59 | img_patch = img[z_low:z_up, y_low:y_up, x_low:x_up]
60 | ins_patch = ins_memory[z_low:z_up, y_low:y_up, x_low:x_up]
61 | gt_patch = gt[z_low:z_up, y_low:y_up, x_low:x_up]
62 | weight_patch = weight[z_low:z_up, y_low:y_up, x_low:x_up]
63 |
64 | # if the label is empty mask
65 | if flag_empty:
66 | ins_patch = np.copy(gt_patch)
67 | ins_patch[ins_patch > 0] = 1
68 | gt_patch = np.zeros_like(ins_patch)
69 | weight_patch = np.ones_like(ins_patch)
70 |
71 | # Randomly on-the-fly Data Augmentation
72 | # 50% chance elastic deformation
73 | if subset == 'train':
74 | if np.random.rand() > 0.5:
75 | img_patch, gt_patch, ins_patch, weight_patch = elastic_transform(img_patch, gt_patch, ins_patch,
76 | weight_patch, alpha=20, sigma=5)
77 | # 50% chance gaussian blur
78 | if np.random.rand() > 0.5:
79 | img_patch = gaussian_blur(img_patch)
80 | # 50% chance gaussian noise
81 | if np.random.rand() > 0.5:
82 | img_patch = gaussian_noise(img_patch)
83 |
84 | # 50% random crop along z-axis
85 | if np.random.rand() > 0.5:
86 | img_patch, ins_patch, gt_patch, weight_patch = random_crop(img_patch, ins_patch, gt_patch
87 | , weight_patch)
88 |
89 | # decide label of completeness(partial or complete)
90 | vol = np.count_nonzero(gt == 1)
91 | sample_vol = np.count_nonzero(gt_patch == 1)
92 | c_label = 0 if float(sample_vol / (vol + 0.0001)) < 0.98 else 1
93 |
94 | img_patch = np.expand_dims(img_patch, axis=0)
95 | ins_patch = np.expand_dims(ins_patch, axis=0)
96 | gt_patch = np.expand_dims(gt_patch, axis=0)
97 | weight_patch = np.expand_dims(weight_patch, axis=0)
98 | c_label = np.expand_dims(c_label, axis=0)
99 |
100 | return img_patch, ins_patch, gt_patch, weight_patch, c_label
101 |
--------------------------------------------------------------------------------
/weights/IterativeFCN_pretrained.pth:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/leohsuofnthu/Pytorch-IterativeFCN/c9a1094bb4cb26ff23b3a11fdda3abbd53cd7ad7/weights/IterativeFCN_pretrained.pth
--------------------------------------------------------------------------------