├── .gitignore
├── .idea
├── CASED-Tensorflow.iml
├── inspectionProfiles
│ └── Project_Default.xml
├── modules.xml
├── vcs.xml
└── workspace.xml
├── CASED_test.py
├── CASED_train.py
├── LICENSE
├── README.md
├── assests
├── framework.JPG
├── lr.JPG
├── network.JPG
├── nodule.png
├── nodule_label.png
├── patch.png
├── result1.JPG
├── result2.JPG
└── stride.png
├── main_test.py
├── main_train.py
├── ops.py
├── preprocessing
├── README
│ ├── convert_luna_to_npy_README.md
│ └── h5py_patch_README.md
├── all_in_one.py
├── convert_luna_to_npy.py
├── h5py_patch.py
└── preprocess_utils.py
└── utils.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | env/
12 | build/
13 | develop-eggs/
14 | dist/
15 | downloads/
16 | eggs/
17 | .eggs/
18 | lib/
19 | lib64/
20 | parts/
21 | sdist/
22 | var/
23 | wheels/
24 | *.egg-info/
25 | .installed.cfg
26 | *.egg
27 |
28 | # PyInstaller
29 | # Usually these files are written by a python script from a template
30 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
31 | *.manifest
32 | *.spec
33 |
34 | # Installer logs
35 | pip-log.txt
36 | pip-delete-this-directory.txt
37 |
38 | # Unit test / coverage reports
39 | htmlcov/
40 | .tox/
41 | .coverage
42 | .coverage.*
43 | .cache
44 | nosetests.xml
45 | coverage.xml
46 | *.cover
47 | .hypothesis/
48 |
49 | # Translations
50 | *.mo
51 | *.pot
52 |
53 | # Django stuff:
54 | *.log
55 | local_settings.py
56 |
57 | # Flask stuff:
58 | instance/
59 | .webassets-cache
60 |
61 | # Scrapy stuff:
62 | .scrapy
63 |
64 | # Sphinx documentation
65 | docs/_build/
66 |
67 | # PyBuilder
68 | target/
69 |
70 | # Jupyter Notebook
71 | .ipynb_checkpoints
72 |
73 | # pyenv
74 | .python-version
75 |
76 | # celery beat schedule file
77 | celerybeat-schedule
78 |
79 | # SageMath parsed files
80 | *.sage.py
81 |
82 | # dotenv
83 | .env
84 |
85 | # virtualenv
86 | .venv
87 | venv/
88 | ENV/
89 |
90 | # Spyder project settings
91 | .spyderproject
92 | .spyproject
93 |
94 | # Rope project settings
95 | .ropeproject
96 |
97 | # mkdocs documentation
98 | /site
99 |
100 | # mypy
101 | .mypy_cache/
102 |
--------------------------------------------------------------------------------
/.idea/CASED-Tensorflow.iml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
--------------------------------------------------------------------------------
/.idea/inspectionProfiles/Project_Default.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
13 |
14 |
15 |
16 |
21 |
22 |
23 |
--------------------------------------------------------------------------------
/.idea/modules.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
--------------------------------------------------------------------------------
/.idea/vcs.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/.idea/workspace.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
34 |
35 |
36 |
37 |
38 |
39 |
40 |
41 |
42 |
43 |
44 |
45 |
46 |
47 |
48 |
49 |
50 |
51 |
52 |
53 |
54 |
55 |
56 |
57 |
58 |
59 |
60 |
61 |
62 |
63 |
64 |
65 |
66 |
67 |
68 |
69 |
70 |
71 |
72 |
73 |
74 |
75 |
76 |
77 |
78 |
79 |
80 |
81 |
82 |
83 |
84 |
85 |
86 |
87 |
88 |
89 |
90 |
91 |
92 |
93 |
94 |
95 |
96 |
97 |
98 |
99 |
100 |
101 |
102 |
103 |
104 |
105 |
106 |
107 |
108 |
109 |
110 |
111 |
112 |
113 |
114 |
115 |
116 |
117 |
118 |
119 |
120 |
121 |
122 |
123 |
124 |
125 |
126 |
127 |
128 |
129 |
130 | true
131 | DEFINITION_ORDER
132 |
133 |
134 |
135 |
136 |
137 |
138 |
139 |
140 |
141 |
142 |
143 |
144 |
145 |
146 |
147 | CSS
148 |
149 |
150 | Probable bugsCSS
151 |
152 |
153 | RELAX NG
154 |
155 |
156 |
157 |
158 | AngularJS
159 |
160 |
161 |
162 |
163 |
164 |
165 |
166 |
167 |
168 |
169 |
170 |
171 |
172 |
173 |
174 |
175 |
176 |
177 |
178 |
179 |
180 |
181 |
182 |
183 |
184 |
185 |
186 |
187 |
188 |
189 |
190 |
191 |
192 |
193 |
194 |
195 |
196 |
197 |
198 |
199 |
200 |
201 |
202 |
203 |
204 |
205 |
206 |
207 |
208 |
209 |
210 |
211 |
212 |
213 |
214 |
215 |
216 |
217 |
218 |
219 |
220 |
221 |
222 |
223 |
224 |
225 |
226 |
227 |
228 |
229 |
230 | 1513756564039
231 |
232 |
233 | 1513756564039
234 |
235 |
236 |
237 |
238 |
239 |
240 |
241 |
242 |
243 |
244 |
245 |
246 |
247 |
248 |
249 |
250 |
251 |
252 |
253 |
254 |
255 |
256 |
257 |
258 |
259 |
260 |
261 |
262 |
263 |
264 |
265 |
266 |
267 |
268 |
269 |
270 |
271 |
272 |
273 |
274 |
275 |
276 |
277 |
278 |
279 |
280 |
281 |
282 |
283 |
284 |
285 |
286 |
287 |
288 |
289 |
290 |
291 |
292 |
293 |
294 |
295 |
296 |
297 |
298 |
299 |
300 |
301 |
302 |
303 |
304 |
305 |
306 |
307 |
308 |
309 |
310 |
311 |
312 |
313 |
314 |
315 |
316 |
317 |
318 |
319 |
320 |
321 |
322 |
323 |
324 |
325 |
326 |
327 |
328 |
329 |
330 |
331 |
332 |
333 |
334 |
335 |
336 |
337 |
338 |
339 |
340 |
341 |
342 |
343 |
344 |
345 |
346 |
347 |
348 |
349 |
350 |
351 |
352 |
353 |
354 |
355 |
356 |
357 |
358 |
359 |
360 |
361 |
362 |
363 |
364 |
365 |
366 |
367 |
368 |
369 |
370 |
371 |
372 |
373 |
374 |
375 |
376 |
377 |
378 |
379 |
380 |
381 |
382 |
383 |
384 |
385 |
386 |
387 |
388 |
389 |
390 |
391 |
392 |
393 |
394 |
395 |
396 |
397 |
398 |
399 |
400 |
401 |
402 |
403 |
404 |
405 |
406 |
407 |
408 |
409 |
410 |
411 |
412 |
413 |
414 |
415 |
416 |
417 |
418 |
419 |
420 |
421 |
422 |
423 |
424 |
425 |
426 |
427 |
428 |
429 |
430 |
431 |
432 |
--------------------------------------------------------------------------------
/CASED_test.py:
--------------------------------------------------------------------------------
1 | import time, pickle
2 | from ops import *
3 | from utils import *
4 | from collections import defaultdict
5 | import matplotlib.pyplot as plt
6 | from skimage.util.shape import view_as_blocks as patch_blocks
7 | from math import ceil
8 | from skimage.measure import block_reduce
9 |
10 |
11 | class CASED(object):
12 | def __init__(self, sess, batch_size, checkpoint_dir, result_dir, log_dir):
13 | self.sess = sess
14 | self.dataset_name = 'LUNA16'
15 | self.checkpoint_dir = checkpoint_dir
16 | self.result_dir = result_dir
17 | self.log_dir = log_dir
18 | self.batch_size = batch_size
19 | self.model_name = "CASED" # name for checkpoint
20 |
21 | self.c_dim = 1
22 | self.y_dim = 2 # nodule ? or non_nodule ?
23 | self.block_size = 68
24 |
25 | def cased_network(self, x, reuse=False, scope='CASED_NETWORK'):
26 | with tf.variable_scope(scope, reuse=reuse):
27 | x = conv_layer(x, channels=32, kernel=3, stride=1, layer_name='conv1')
28 | up_conv1 = conv_layer(x, channels=32, kernel=3, stride=1, layer_name='up_conv1')
29 |
30 | x = max_pooling(up_conv1)
31 |
32 | x = conv_layer(x, channels=64, kernel=3, stride=1, layer_name='conv2')
33 | up_conv2 = conv_layer(x, channels=64, kernel=3, stride=1, layer_name='up_conv2')
34 |
35 | x = max_pooling(up_conv2)
36 |
37 | x = conv_layer(x, channels=128, kernel=3, stride=1, layer_name='conv3')
38 | up_conv3 = conv_layer(x, channels=128, kernel=3, stride=1, layer_name='up_conv3')
39 |
40 | x = max_pooling(up_conv3)
41 |
42 | x = conv_layer(x, channels=256, kernel=3, stride=1, layer_name='conv4')
43 | x = conv_layer(x, channels=128, kernel=3, stride=1, layer_name='conv5')
44 |
45 | x = deconv_layer(x, channels=128, kernel=4, stride=2, layer_name='deconv1')
46 | x = copy_crop(crop_layer=up_conv3, in_layer=x)
47 |
48 | x = conv_layer(x, channels=128, kernel=1, stride=1, layer_name='conv6')
49 | x = conv_layer(x, channels=64, kernel=1, stride=1, layer_name='conv7')
50 |
51 | x = deconv_layer(x, channels=64, kernel=4, stride=2, layer_name='deconv2')
52 | x = copy_crop(crop_layer=up_conv2, in_layer=x)
53 |
54 | x = conv_layer(x, channels=64, kernel=1, stride=1, layer_name='conv8')
55 | x = conv_layer(x, channels=32, kernel=1, stride=1, layer_name='conv9')
56 |
57 | x = deconv_layer(x, channels=32, kernel=4, stride=2, layer_name='deconv3')
58 | x = copy_crop(crop_layer=up_conv1, in_layer=x)
59 |
60 | x = conv_layer(x, channels=32, kernel=1, stride=1, layer_name='conv10')
61 | x = conv_layer(x, channels=32, kernel=1, stride=1, layer_name='conv11')
62 |
63 | logits = conv_layer(x, channels=2, kernel=1, stride=1, activation=None, layer_name='conv12')
64 |
65 | x = softmax(logits)
66 |
67 | return logits, x
68 |
69 | def build_model(self):
70 |
71 | bs = None
72 | scan_dims = [None, None, None, self.c_dim]
73 | scan_y_dims = [None, None, None, self.y_dim]
74 |
75 | """ Graph Input """
76 | # images
77 | self.inputs = tf.placeholder(tf.float32, [bs] + scan_dims, name='patch')
78 |
79 | # labels
80 | self.y = tf.placeholder(tf.float32, [bs] + scan_y_dims, name='y') # for loss
81 |
82 | self.logits, self.softmax_logits = self.cased_network(self.inputs)
83 |
84 | """ Loss function """
85 | self.correct_prediction = tf.equal(tf.argmax(self.softmax_logits, -1), tf.argmax(self.y, -1))
86 | self.accuracy = tf.reduce_mean(tf.cast(self.correct_prediction, tf.float32))
87 | self.sensitivity, self.fp_rate = sensitivity(labels=self.y, logits=self.softmax_logits)
88 |
89 | """ Summary """
90 |
91 | c_acc = tf.summary.scalar('acc', self.accuracy)
92 | c_recall = tf.summary.scalar('sensitivity', self.sensitivity)
93 | c_fp = tf.summary.scalar('false_positive', self.fp_rate)
94 | self.c_sum = tf.summary.merge([c_acc, c_recall, c_fp])
95 |
96 | def test(self):
97 | block_size = self.block_size
98 | # initialize all variables
99 | tf.global_variables_initializer().run()
100 |
101 | # saver to save model
102 | self.saver = tf.train.Saver()
103 |
104 | # summary writer
105 | self.writer = tf.summary.FileWriter(self.log_dir + '/' + self.model_name, self.sess.graph)
106 |
107 | # restore check-point if it exits
108 | could_load, checkpoint_counter = self.load(self.checkpoint_dir)
109 |
110 | if could_load:
111 | print(" [*] Load SUCCESS")
112 | else:
113 | print(" [!] Load failed...")
114 |
115 | validation_sub_n = 8
116 | subset_name = 'subset' + str(validation_sub_n)
117 | print(subset_name)
118 | image_paths = glob.glob("/data/jhkim/LUNA16/original/subset" + str(validation_sub_n) + '/*.mhd')
119 | all_scan_num = len(image_paths)
120 | sens_list = []
121 | fps_list = []
122 |
123 | nan_num = 0
124 | cnt = 1
125 |
126 |
127 | MIN_FROC = 0.125
128 | MAX_FROC = 8
129 |
130 | for scan in image_paths:
131 | print('{} / {}'.format(cnt, len(image_paths)))
132 | scan_name = os.path.split(scan)[1].replace('.mhd', '')
133 | scan_npy = '/data2/jhkim/npydata/' + subset_name + '/' + scan_name + '.npy'
134 | label_npy = '/data2/jhkim/npydata/' + subset_name + '/' + scan_name + '.label.npy'
135 |
136 | image = np.transpose(np.load(scan_npy))
137 | label = np.transpose(np.load(label_npy))
138 | print(np.shape(image))
139 |
140 | if np.count_nonzero(label) == 0:
141 | nan_num += 1
142 | cnt += 1
143 | continue
144 |
145 | pad_list = []
146 | for i in range(3):
147 | if np.shape(image)[i] % block_size == 0:
148 | pad_l = 0
149 | pad_r = pad_l
150 | else:
151 | q = (ceil(np.shape(image)[i] / block_size) * block_size) - np.shape(image)[i]
152 |
153 | if q % 2 == 0:
154 | pad_l = q // 2
155 | pad_r = pad_l
156 | else:
157 | pad_l = q // 2
158 | pad_r = pad_l + 1
159 |
160 | pad_list.append(pad_l)
161 | pad_list.append(pad_r)
162 |
163 | image = np.pad(image, pad_width=[[pad_list[0], pad_list[1]], [pad_list[2], pad_list[3]],
164 | [pad_list[4], pad_list[5]]],
165 | mode='constant', constant_values=np.min(image))
166 |
167 | label = np.pad(label, pad_width=[[pad_list[0], pad_list[1]], [pad_list[2], pad_list[3]],
168 | [pad_list[4], pad_list[5]]],
169 | mode='constant', constant_values=np.min(label))
170 |
171 | with open('jh_exclude.pkl', 'rb') as f:
172 | exclude_dict = pickle.load(f, encoding='bytes')
173 |
174 | exclude_coords = exclude_dict[scan_name]
175 | ex_mask = np.ones_like(image)
176 | for ex in exclude_coords:
177 | ex[0] = ex[0] + (pad_list[0] + pad_list[1]) // 2
178 | ex[1] = ex[1] + (pad_list[2] + pad_list[3]) // 2
179 | ex[2] = ex[2] + (pad_list[4] + pad_list[5]) // 2
180 | ex_diameter = ex[3]
181 | if ex_diameter < 0.0:
182 | ex_diameter = 10.0
183 | exclude_position = (ex[0], ex[1], ex[2])
184 | exclude_mask = create_exclude_mask(image.shape, exclude_position, ex_diameter)
185 | ex_mask = exclude_mask if ex_mask is None else np.logical_and(ex_mask, exclude_mask)
186 |
187 | image_blocks = patch_blocks(image, block_shape=(block_size, block_size, block_size))
188 | label_blocks = patch_blocks(label, block_shape=(block_size, block_size, block_size))
189 | ex_mask_blocks = patch_blocks(ex_mask, block_shape=(block_size, block_size, block_size))
190 |
191 | len_x = len(image_blocks)
192 | len_y = len(image_blocks[0])
193 | len_z = len(image_blocks[0, 0])
194 |
195 | result_scan = None
196 | label_scan = None
197 | ex_scan = None
198 | for x_i in range(len_x):
199 | x = None
200 | x_label = None
201 | x_ex = None
202 | for y_i in range(len_y):
203 | y = None
204 | y_label = None
205 | y_ex = None
206 | for z_i in range(len_z):
207 | scan = np.expand_dims(np.expand_dims(image_blocks[x_i, y_i, z_i], axis=-1), axis=0) # 1 68 68 68 1
208 | logit_label = block_reduce(label_blocks[x_i, y_i, z_i], (9, 9, 9), np.max)
209 | logit_ex = block_reduce(ex_mask_blocks[x_i, y_i, z_i], (9, 9, 9), np.min)
210 |
211 | test_feed_dict = {
212 | self.inputs: scan
213 | }
214 |
215 | logits = self.sess.run(
216 | self.softmax_logits, feed_dict=test_feed_dict
217 | ) # [1, 68, 68, 68, 2]
218 | logits_ = np.squeeze(logits, axis=0) # [68,68,68]
219 | logits = np.zeros(shape=(logits_.shape[0], logits_.shape[1], logits_.shape[2], 1))
220 | for x_i_, x_v in enumerate(logits_):
221 | for y_i_, y_v in enumerate(x_v):
222 | for z_i_, z_v in enumerate(y_v):
223 | logits[x_i_, y_i_, z_i_] = z_v[1]
224 | logits = np.squeeze(logits, axis=-1) # 68 68 68
225 |
226 | """
227 | [1, 72, 72, 72, 2] -> [1, 72, 72, 72] -> [72,72,72]
228 | """
229 |
230 | y = logits if y is None else np.concatenate((y, logits), axis=2) # z concat
231 | y_label = logit_label if y_label is None else np.concatenate((y_label, logit_label), axis=2)
232 | y_ex = logit_ex if y_ex is None else np.concatenate((y_ex, logit_ex), axis=2)
233 |
234 | x = y if x is None else np.concatenate((x, y), axis=1) # y concat
235 | x_label = y_label if x_label is None else np.concatenate((x_label, y_label), axis=1)
236 | x_ex = y_ex if x_ex is None else np.concatenate((x_ex, y_ex), axis=1)
237 |
238 | result_scan = x if result_scan is None else np.concatenate((result_scan, x), axis=0) # x concat
239 | label_scan = x_label if label_scan is None else np.concatenate((label_scan, x_label), axis=0)
240 | ex_scan = x_ex if ex_scan is None else np.concatenate((ex_scan, x_ex), axis=0)
241 | label = label_scan
242 | ex_mask = ex_scan
243 | # print(result) # 3d original size
244 |
245 | with open('jh.pkl', 'rb') as f:
246 | coords_dict = pickle.load(f, encoding='bytes')
247 |
248 | ex_mask = ex_mask.astype(np.float32)
249 | label = label.astype(np.float32)
250 | ex_mask = np.where(ex_mask == 0.0, -10, ex_mask)
251 |
252 | result_scan = result_scan + ex_mask
253 | label = label + ex_mask
254 |
255 | if np.count_nonzero(label == 2.0) == 0:
256 | nan_num += 1
257 | cnt += 1
258 | continue
259 | cnt += 1
260 |
261 | fps, tpr = fp_per_scan(result_scan, label)
262 |
263 | fps_list.append(fps)
264 | sens_list.append(tpr)
265 |
266 | fps_itp = np.linspace(MIN_FROC, MAX_FROC, num=10001)
267 | sens_itp = None
268 | for list_i in range(len(fps_list)):
269 | fps_list[list_i] /= (all_scan_num - nan_num)
270 | sens_list[list_i] /= (all_scan_num - nan_num)
271 | if sens_itp is None:
272 | sens_itp = np.interp(fps_itp, fps_list[list_i], sens_list[list_i])
273 | else:
274 | sens_itp += np.interp(fps_itp, fps_list[list_i], sens_list[list_i])
275 | print(fps_itp)
276 | print(sens_itp)
277 |
278 | @property
279 | def model_dir(self):
280 | return "{}_{}".format(
281 | self.model_name, self.dataset_name)
282 |
283 | def save(self, checkpoint_dir, step):
284 | checkpoint_dir = os.path.join(checkpoint_dir, self.model_dir, self.model_name)
285 |
286 | if not os.path.exists(checkpoint_dir):
287 | os.makedirs(checkpoint_dir)
288 |
289 | self.saver.save(self.sess, os.path.join(checkpoint_dir, self.model_name + '.model'), global_step=step)
290 |
291 | def load(self, checkpoint_dir):
292 | import re
293 | print(" [*] Reading checkpoints...")
294 | checkpoint_dir = os.path.join(checkpoint_dir, self.model_dir, self.model_name)
295 |
296 | ckpt = tf.train.get_checkpoint_state(checkpoint_dir)
297 | if ckpt and ckpt.model_checkpoint_path:
298 | ckpt_name = os.path.basename(ckpt.model_checkpoint_path)
299 | self.saver.restore(self.sess, os.path.join(checkpoint_dir, ckpt_name))
300 | counter = int(next(re.finditer("(\d+)(?!.*\d)", ckpt_name)).group(0))
301 | print(" [*] Success to read {}".format(ckpt_name))
302 | return True, counter
303 | else:
304 | print(" [*] Failed to find a checkpoint")
305 | return False, 0
306 |
--------------------------------------------------------------------------------
/CASED_train.py:
--------------------------------------------------------------------------------
1 | import time
2 | from heapq import nlargest
3 | from random import uniform
4 |
5 | from ops import *
6 | from utils import *
7 | class CASED(object) :
8 | def __init__(self, sess, epoch, batch_size, test_batch_size, num_gpu, checkpoint_dir, result_dir, log_dir):
9 | self.sess = sess
10 | self.dataset_name = 'LUNA16'
11 | self.checkpoint_dir = checkpoint_dir
12 | self.result_dir = result_dir
13 | self.log_dir = log_dir
14 | self.epoch = epoch
15 | self.batch_size = batch_size
16 | self.test_batch_size = test_batch_size
17 | self.predictor_batch_size = batch_size * 2
18 |
19 | self.model_name = "CASED" # name for checkpoint
20 | self.num_gpu = num_gpu
21 | self.total_subset = 10
22 |
23 | self.x = 68
24 | self.y = 68
25 | self.z = 68
26 | self.c_dim = 1
27 | self.y_dim = 2 # nodule ? or non_nodule ?
28 | self.out_dim = 8
29 |
30 | self.weight_decay = 1e-4
31 | self.lr_decay = 1.0 # not decay
32 | self.learning_rate = 0.01
33 | self.momentum = 0.9
34 | self.M = 10
35 | self.K_fold = True
36 |
37 | def cased_network(self, x, reuse=False, scope='CASED_NETWORK'):
38 | with tf.variable_scope(scope, reuse=reuse) :
39 | # print(np.shape(x))
40 | x = conv_layer(x, channels=32, kernel=3, stride=1, layer_name='conv1')
41 | up_conv1 = conv_layer(x, channels=32, kernel=3, stride=1, layer_name='up_conv1')
42 |
43 | x = max_pooling(up_conv1)
44 |
45 | x = conv_layer(x, channels=64, kernel=3, stride=1, layer_name='conv2')
46 | up_conv2 = conv_layer(x, channels=64, kernel=3, stride=1, layer_name='up_conv2')
47 |
48 | x = max_pooling(up_conv2)
49 |
50 | x = conv_layer(x, channels=128, kernel=3, stride=1, layer_name='conv3')
51 | up_conv3 = conv_layer(x, channels=128, kernel=3, stride=1, layer_name='up_conv3')
52 |
53 | x = max_pooling(up_conv3)
54 |
55 | x = conv_layer(x, channels=256, kernel=3, stride=1, layer_name='conv4')
56 | x = conv_layer(x, channels=128, kernel=3, stride=1, layer_name='conv5')
57 |
58 | x = deconv_layer(x, channels=128, kernel=4, stride=2, layer_name='deconv1')
59 | x = copy_crop(crop_layer=up_conv3, in_layer=x)
60 |
61 | x = conv_layer(x, channels=128, kernel=1, stride=1, layer_name='conv6')
62 | x = conv_layer(x, channels=64, kernel=1, stride=1, layer_name='conv7')
63 |
64 | x = deconv_layer(x, channels=64, kernel=4, stride=2, layer_name='deconv2')
65 | x = copy_crop(crop_layer=up_conv2, in_layer=x)
66 |
67 | x = conv_layer(x, channels=64, kernel=1, stride=1, layer_name='conv8')
68 | x = conv_layer(x, channels=32, kernel=1, stride=1, layer_name='conv9')
69 |
70 | x = deconv_layer(x, channels=32, kernel=4, stride=2, layer_name='deconv3')
71 | x = copy_crop(crop_layer=up_conv1, in_layer=x)
72 |
73 | x = conv_layer(x, channels=32, kernel=1, stride=1, layer_name='conv10')
74 | x = conv_layer(x, channels=32, kernel=1, stride=1, layer_name='conv11')
75 |
76 | logits = conv_layer(x, channels=2, kernel=1, stride=1, activation=None, layer_name='conv12')
77 | x = softmax(logits)
78 |
79 | return logits, x
80 |
81 | def build_model(self):
82 | patch_dims = [None, None, None, self.c_dim]
83 | bs = self.batch_size
84 | p_bs = self.predictor_batch_size * self.num_gpu
85 | y_dims = [None, None, None, self.y_dim]
86 | self.decay_lr = tf.placeholder(tf.float32, name='learning_rate')
87 |
88 | """ Graph Input """
89 | # images
90 | self.inputs = tf.placeholder(tf.float32, [bs] + patch_dims, name='patch')
91 | self.p_inputs = tf.placeholder(tf.float32, [p_bs] + patch_dims, name='p_patch')
92 |
93 | # labels
94 | self.y = tf.placeholder(tf.float32, [bs] + y_dims, name='y') # for loss
95 | self.p_y = tf.placeholder(tf.float32, [p_bs] + y_dims, name='p_y')
96 |
97 | self.logits, self.softmax_logits = self.cased_network(self.inputs)
98 | self.P_logits, self.P_softmax_logits = self.cased_network(self.inputs, reuse=True)
99 |
100 | """ Predictor Loss """
101 | self.x_dict = dict()
102 | for gpu_i in range(self.num_gpu) :
103 | with tf.device('/gpu:%d' % gpu_i) :
104 | index_start = self.predictor_batch_size * gpu_i
105 | index_end = self.predictor_batch_size * (gpu_i + 1)
106 | self.predictor_logits, _ = self.cased_network(self.p_inputs[index_start : index_end], reuse=True)
107 | self.P_loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=self.p_y[index_start : index_end],
108 | logits=self.predictor_logits), axis=[1, 2, 3])
109 | self.x_dict.update(
110 | {index_start + i: self.P_loss[i] for i in range(self.predictor_batch_size)}
111 | )
112 |
113 | """ Loss function """
114 | self.loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=self.y, logits=self.logits))
115 | self.l2_loss = tf.add_n([tf.nn.l2_loss(var) for var in tf.trainable_variables()])
116 | self.loss = self.loss + self.l2_loss*self.weight_decay
117 |
118 | self.correct_prediction = tf.equal(tf.argmax(self.softmax_logits, -1), tf.argmax(self.y, -1))
119 | self.P_correct_prediction = tf.equal(tf.argmax(self.P_softmax_logits, -1), tf.argmax(self.y, -1))
120 |
121 | self.accuracy = tf.reduce_mean(tf.cast(self.correct_prediction, tf.float32))
122 | self.P_accuracy = tf.reduce_mean(tf.cast(self.P_correct_prediction, tf.float32))
123 |
124 | self.sensitivity, self.fp_rate = sensitivity(labels=self.y, logits=self.softmax_logits)
125 | self.P_sensitivity, self.P_fp_rate = sensitivity(labels=self.y, logits=self.P_softmax_logits)
126 |
127 | self.optim = tf.train.MomentumOptimizer(learning_rate=self.decay_lr, momentum=self.momentum, use_nesterov=True).minimize(self.loss)
128 |
129 | """ Summary """
130 | c_lr = tf.summary.scalar('cosine_lr', self.decay_lr)
131 | c_loss = tf.summary.scalar('loss', self.loss)
132 | c_acc = tf.summary.scalar('acc', self.accuracy)
133 | c_recall = tf.summary.scalar('sensitivity', self.sensitivity)
134 | c_fp = tf.summary.scalar('false_positive', self.fp_rate)
135 | self.c_sum = tf.summary.merge([c_lr, c_loss, c_acc, c_recall, c_fp])
136 |
137 | def train(self):
138 |
139 | # initialize all variables
140 | tf.global_variables_initializer().run()
141 |
142 | # saver to save model
143 | self.saver = tf.train.Saver()
144 |
145 | # summary writer
146 | self.writer = tf.summary.FileWriter(self.log_dir + '/' + self.model_name, self.sess.graph)
147 |
148 | # restore check-point if it exits
149 | could_load, checkpoint_counter = self.load(self.checkpoint_dir)
150 | if could_load:
151 | if self.K_fold :
152 | start_epoch = (int)(checkpoint_counter / (self.total_subset - 1))
153 | start_sub_n = checkpoint_counter - start_epoch * (self.total_subset - 1) + 1
154 |
155 | else :
156 | start_epoch = (int)(checkpoint_counter / (self.total_subset))
157 | start_sub_n = checkpoint_counter - start_epoch * (self.total_subset)
158 |
159 | counter = checkpoint_counter
160 | print(" [*] Load SUCCESS")
161 | else:
162 | start_epoch = 0
163 | start_sub_n = 0
164 | counter = 0
165 | print(" [!] Load failed...")
166 |
167 | # loop for epoch
168 | start_time = time.time()
169 | count_temp = 0
170 | # 10 counter = 1 epoch
171 | for epoch in range(start_epoch, self.epoch):
172 | validation_sub_n = (epoch % self.total_subset)
173 | print('validation sub_n : {}'.format(validation_sub_n))
174 | for sub_n in range(start_sub_n, self.total_subset) :
175 | # K fold cross validation ...
176 | if sub_n == validation_sub_n :
177 | continue
178 | train_acc = 0.0
179 | train_recall = 0.0
180 | train_fp = 0.0
181 | nan_num = 0
182 | prob = 1.0
183 | train_lr = self.learning_rate
184 | nodule_patch, all_patch, nodule_y, all_y = prepare_data(sub_n)
185 | M = len(all_patch)
186 | num_batches = M // self.batch_size
187 | print('finish prepare data : ', M)
188 |
189 | predict_batch = self.predictor_batch_size * self.num_gpu
190 | total_predict_index = M // predict_batch
191 |
192 | for idx in range(num_batches):
193 | if idx == int((num_batches * 0.5)) :
194 | train_lr = train_lr * self.lr_decay
195 | print("*** now learning rate : {} ***\n".format(train_lr))
196 |
197 | # train_lr = Snapshot(t=idx, T=num_batches, M=self.M, alpha_zero=self.learning_rate)
198 | each_time = time.time()
199 | p = uniform(0,1)
200 | print('probability M : ', prob)
201 | print('condition P : ', p)
202 | if p <= prob :
203 | random_index = np.random.choice(len(nodule_patch), size=self.batch_size, replace=False)
204 | batch_patch = nodule_patch[random_index]
205 | batch_y = nodule_y[random_index]
206 |
207 | else :
208 |
209 | predict_dict = dict()
210 | for p_idx in range(total_predict_index) :
211 | result = dict()
212 | batch_patch = all_patch[predict_batch*p_idx : predict_batch*(p_idx+1)]
213 | batch_y = all_y[predict_batch*p_idx : predict_batch*(p_idx+1)]
214 | predictor_feed_dict = {
215 | self.p_inputs: batch_patch, self.p_y : batch_y
216 | }
217 | temp_x = self.sess.run(self.x_dict, feed_dict=predictor_feed_dict)
218 |
219 | if p_idx != 0 :
220 | for k,v in temp_x.items() :
221 | new_k = k + predict_batch * p_idx
222 | result[new_k] = v
223 | else :
224 | for k,v in temp_x.items() :
225 | result[k] = v
226 |
227 | predict_dict.update(result)
228 | index = nlargest(self.batch_size, predict_dict, key=predict_dict.get)
229 | predict_dict = {s_idx: predict_dict[s_idx] for s_idx in index}
230 |
231 | g_r_index = list(predict_dict.keys())
232 | batch_patch = all_patch[g_r_index]
233 | batch_y = all_y[g_r_index]
234 |
235 | prob *= pow(1/M, 1/num_batches)
236 | train_feed_dict = {
237 | self.inputs: batch_patch, self.y : batch_y,
238 | self.decay_lr : train_lr
239 | }
240 |
241 | _, summary_str_c, c_loss, c_acc, c_recall, c_fp = self.sess.run(
242 | [self.optim, self.c_sum, self.loss, self.accuracy, self.sensitivity, self.fp_rate],
243 | feed_dict=train_feed_dict)
244 | self.writer.add_summary(summary_str_c, count_temp)
245 | count_temp += 1
246 |
247 | if np.isnan(c_recall) :
248 | train_acc += c_acc
249 | train_fp += c_fp
250 | # train_recall += 0
251 | nan_num += 1
252 | else :
253 | train_acc += c_acc
254 | train_fp += c_fp
255 | train_recall += c_recall
256 | # display training status
257 | print("Epoch: [%2d], Sub_n: [%2d], [%4d/%4d] time: %4.4f, each_time: %4.4f, c_loss: %.8f, c_acc: %.4f, c_recall: %.4f, c_fp: %.4f" \
258 | % (epoch, sub_n, idx, num_batches, time.time() - start_time, time.time() - each_time, c_loss, c_acc, c_recall, c_fp))
259 |
260 | train_acc /= num_batches
261 | train_recall /= (num_batches - nan_num)
262 | train_fp /= num_batches
263 |
264 | summary_train = tf.Summary(value=[tf.Summary.Value(tag='train_accuracy', simple_value=train_acc),
265 | tf.Summary.Value(tag='train_recall', simple_value=train_recall),
266 | tf.Summary.Value(tag='train_fp', simple_value=train_fp)])
267 | self.writer.add_summary(summary_train, counter)
268 |
269 | line = "Epoch: [%2d], Sub_n: [%2d], train_acc: %.4f, train_recall: %.4f, train_fp: %.4f\n" % (epoch, sub_n, train_acc, train_recall, train_fp)
270 | print(line)
271 | with open(os.path.join(self.result_dir, 'train_logs.txt'), 'a') as f:
272 | f.write(line)
273 | # save model
274 | counter += 1
275 | self.save(self.checkpoint_dir, counter)
276 | del nodule_patch
277 | del nodule_y
278 | del all_patch
279 | del all_y
280 | start_sub_n = 0
281 |
282 |
283 |
284 |
285 | @property
286 | def model_dir(self):
287 | return "{}_{}_{}".format(
288 | self.model_name, self.dataset_name,
289 | self.batch_size,)
290 |
291 | def save(self, checkpoint_dir, step):
292 | checkpoint_dir = os.path.join(checkpoint_dir, self.model_dir, self.model_name)
293 |
294 | if not os.path.exists(checkpoint_dir):
295 | os.makedirs(checkpoint_dir)
296 |
297 | self.saver.save(self.sess, os.path.join(checkpoint_dir, self.model_name + '.model'), global_step=step)
298 |
299 | def load(self, checkpoint_dir):
300 | import re
301 | print(" [*] Reading checkpoints...")
302 | checkpoint_dir = os.path.join(checkpoint_dir, self.model_dir, self.model_name)
303 |
304 | ckpt = tf.train.get_checkpoint_state(checkpoint_dir)
305 | if ckpt and ckpt.model_checkpoint_path:
306 | ckpt_name = os.path.basename(ckpt.model_checkpoint_path)
307 | self.saver.restore(self.sess, os.path.join(checkpoint_dir, ckpt_name))
308 | counter = int(next(re.finditer("(\d+)(?!.*\d)", ckpt_name)).group(0))
309 | print(" [*] Success to read {}".format(ckpt_name))
310 | return True, counter
311 | else:
312 | print(" [*] Failed to find a checkpoint")
313 | return False, 0
314 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2017 Junho Kim (1993.01.12)
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # CASED-Tensorflow
2 | Tensorflow implementation of [Curriculum Adaptive Sampling for Extreme Data Imbalance](https://www.researchgate.net/publication/319461093_CASED_Curriculum_Adaptive_Sampling_for_Extreme_Data_Imbalance) with **multi GPU** using [*LUNA16*](https://luna16.grand-challenge.org/)
3 |
4 | ## Preprocessing Tutorial
5 | * [convert_luna_to_npy](/preprocessing/README/convert_luna_to_npy_README.md)
6 | * [create patch](preprocessing/README/h5py_patch_README.md)
7 | ```python
8 | > all_in_one.py = convert_luna_to_npy + create_patch
9 | ```
10 |
11 |
12 | ## Usage for preprocessing
13 | ```python
14 | > python all_in_one.py
15 | ```
16 | * Check `src_root` and `save_path`
17 |
18 | ## Usage for train
19 | ```python
20 | > python main_train.py
21 | ```
22 | * See `main_train.py` for other arguments.
23 |
24 | ## Usage for test
25 | ```python
26 | > python main_test.py
27 | ```
28 |
29 | ## Issue
30 | * *The hyper-parameter information is not listed in the paper, so I'm still testing it.*
31 | * Use ***[Snapshot Ensemble](https://arxiv.org/pdf/1704.00109.pdf)*** (M=10, init_lr=0.1)
32 | * Or Fix learning rate **0.01**
33 |
34 | 
35 | ```python
36 | def Snapshot(t, T, M, alpha_zero) :
37 | """
38 | t = # of current iteration
39 | T = # of total iteration
40 | M = # of snapshot
41 | alpha_zero = init learning rate
42 | """
43 |
44 | x = (np.pi * (t % (T // M))) / (T // M)
45 | x = np.cos(x) + 1
46 |
47 | lr = (alpha_zero / 2) * x
48 |
49 | return lr
50 | ```
51 |
52 | ## Summary
53 | ### Preprocessing
54 | * Resample
55 | ```bash
56 | > 1.25mm
57 | ```
58 |
59 | * Hounsfield
60 | ```python
61 | > minHU = -1000
62 | > maxHU = 400
63 | ```
64 |
65 | * Zero centering
66 | ```python
67 | > Pixel Mean = 0.25
68 | ```
69 |
70 | ### Data augmentation
71 | If you want to do augmentation, see this [link](https://github.com/aleju/imgaug)
72 |
73 | * Affine rotate
74 | ```python
75 | -2 to 2 degree
76 | ```
77 |
78 | * Scale
79 | ```python
80 | 0.9 to 1.1
81 | ```
82 |
83 | ### Network Architecture
84 | 
85 |
86 | ### Algorithm
87 | 
88 | ```python
89 | p_x = 1.0
90 |
91 | for i in iteration :
92 | p = uniform(0,1)
93 |
94 | if p <= p_x :
95 | g_n_index = np.random.choice(N, size=batch_size, replace=False)
96 | batch_patch = nodule_patch[g_n_index]
97 | batch_y = nodule_patch_y[g_n_index]
98 |
99 | else :
100 | predictor_dict = Predictor(all_patch) # key = index, value = loss
101 | g_r_index = nlargest(batch_size, predictor_dict, key=predictor_dict.get)
102 |
103 | batch_patch = all_patch[g_r_index]
104 | batch_y = all_patch_y[g_r_index]
105 |
106 | p_x *= pow(1/M, 1/iteration)
107 | ```
108 |
109 | ## Result
110 | 
111 |
112 |
113 | ## Author
114 | Junho Kim / [@Lunit](http://lunit.io/)
115 |
--------------------------------------------------------------------------------
/assests/framework.JPG:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/taki0112/CASED-Tensorflow/7ac533d2dcecc9da05f7f69610a0acdb307df4c5/assests/framework.JPG
--------------------------------------------------------------------------------
/assests/lr.JPG:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/taki0112/CASED-Tensorflow/7ac533d2dcecc9da05f7f69610a0acdb307df4c5/assests/lr.JPG
--------------------------------------------------------------------------------
/assests/network.JPG:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/taki0112/CASED-Tensorflow/7ac533d2dcecc9da05f7f69610a0acdb307df4c5/assests/network.JPG
--------------------------------------------------------------------------------
/assests/nodule.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/taki0112/CASED-Tensorflow/7ac533d2dcecc9da05f7f69610a0acdb307df4c5/assests/nodule.png
--------------------------------------------------------------------------------
/assests/nodule_label.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/taki0112/CASED-Tensorflow/7ac533d2dcecc9da05f7f69610a0acdb307df4c5/assests/nodule_label.png
--------------------------------------------------------------------------------
/assests/patch.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/taki0112/CASED-Tensorflow/7ac533d2dcecc9da05f7f69610a0acdb307df4c5/assests/patch.png
--------------------------------------------------------------------------------
/assests/result1.JPG:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/taki0112/CASED-Tensorflow/7ac533d2dcecc9da05f7f69610a0acdb307df4c5/assests/result1.JPG
--------------------------------------------------------------------------------
/assests/result2.JPG:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/taki0112/CASED-Tensorflow/7ac533d2dcecc9da05f7f69610a0acdb307df4c5/assests/result2.JPG
--------------------------------------------------------------------------------
/assests/stride.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/taki0112/CASED-Tensorflow/7ac533d2dcecc9da05f7f69610a0acdb307df4c5/assests/stride.png
--------------------------------------------------------------------------------
/main_test.py:
--------------------------------------------------------------------------------
1 | import argparse
2 |
3 | import tensorflow as tf
4 | from CASED_test import CASED
5 | from utils import check_folder
6 | from utils import show_all_variables
7 |
8 | """parsing and configuration"""
9 | def parse_args():
10 | desc = "Tensorflow implementation of CASED"
11 | parser = argparse.ArgumentParser(description=desc)
12 | parser.add_argument('--batch_size', type=int, default=1, help='The size of batch')
13 | parser.add_argument('--checkpoint_dir', type=str, default='checkpoint',
14 | help='Directory name to save the checkpoints')
15 | parser.add_argument('--result_dir', type=str, default='results',
16 | help='Directory name to save the generated images')
17 | parser.add_argument('--log_dir', type=str, default='logs_test',
18 | help='Directory name to save training logs')
19 |
20 | return check_args(parser.parse_args())
21 |
22 | """checking arguments"""
23 | def check_args(args):
24 | # --checkpoint_dir
25 | check_folder(args.checkpoint_dir)
26 |
27 | # --result_dir
28 | check_folder(args.result_dir)
29 |
30 | # --result_dir
31 | check_folder(args.log_dir)
32 |
33 |
34 | # --batch_size
35 | try:
36 | assert args.batch_size >= 0
37 | assert args.test_batch_size >= 0
38 | except:
39 | print('batch size must be larger than or equal to one')
40 |
41 | return args
42 |
43 | """main"""
44 | def main():
45 | # parse arguments
46 | args = parse_args()
47 | if args is None:
48 | exit()
49 |
50 | # open session
51 | with tf.Session(config=tf.ConfigProto(allow_soft_placement=True)) as sess:
52 | model = CASED(sess, batch_size=args.batch_size,
53 | checkpoint_dir=args.checkpoint_dir, result_dir=args.result_dir, log_dir=args.log_dir)
54 |
55 | # build graph
56 | model.build_model()
57 |
58 | # show network architecture
59 | show_all_variables()
60 |
61 | # launch the graph in a session
62 | model.test()
63 | print(" [*] Testing finished!")
64 |
65 |
66 |
67 | if __name__ == '__main__':
68 | main()
--------------------------------------------------------------------------------
/main_train.py:
--------------------------------------------------------------------------------
1 | import argparse
2 |
3 | import tensorflow as tf
4 | from CASED_train import CASED
5 | from utils import check_folder
6 | from utils import show_all_variables
7 |
8 | """parsing and configuration"""
9 | def parse_args():
10 | desc = "Tensorflow implementation of CASED"
11 | parser = argparse.ArgumentParser(description=desc)
12 | parser.add_argument('--epoch', type=int, default=3, help='The number of epochs to run')
13 | parser.add_argument('--batch_size', type=int, default=16, help='The size of batch')
14 | parser.add_argument('--test_batch_size', type=int, default=16, help='The size of test batch')
15 | parser.add_argument('--num_gpu', type=int, default=8, help='# of gpu')
16 | parser.add_argument('--checkpoint_dir', type=str, default='checkpoint',
17 | help='Directory name to save the checkpoints')
18 | parser.add_argument('--result_dir', type=str, default='results',
19 | help='Directory name to save the generated images')
20 | parser.add_argument('--log_dir', type=str, default='logs',
21 | help='Directory name to save training logs')
22 |
23 | return check_args(parser.parse_args())
24 |
25 | """checking arguments"""
26 | def check_args(args):
27 | # --checkpoint_dir
28 | check_folder(args.checkpoint_dir)
29 |
30 | # --result_dir
31 | check_folder(args.result_dir)
32 |
33 | # --result_dir
34 | check_folder(args.log_dir)
35 |
36 | # --epoch
37 | try:
38 | assert args.epoch >= 1
39 | except:
40 | print('number of epochs must be larger than or equal to one')
41 |
42 | # --batch_size
43 | try:
44 | assert args.batch_size >= 1
45 | assert args.test_batch_size >= 1
46 | except:
47 | print('batch size must be larger than or equal to one')
48 |
49 | return args
50 |
51 | """main"""
52 | def main():
53 | # parse arguments
54 | args = parse_args()
55 | if args is None:
56 | exit()
57 |
58 | # open session
59 | with tf.Session(config=tf.ConfigProto(allow_soft_placement=True)) as sess:
60 | model = CASED(sess, epoch=args.epoch, batch_size=args.batch_size, test_batch_size=args.test_batch_size, num_gpu=args.num_gpu,
61 | checkpoint_dir=args.checkpoint_dir, result_dir=args.result_dir, log_dir=args.log_dir)
62 |
63 | # build graph
64 | model.build_model()
65 |
66 | # show network architecture
67 | show_all_variables()
68 |
69 | # launch the graph in a session
70 | model.train()
71 | print(" [*] Training finished!")
72 |
73 |
74 |
75 | if __name__ == '__main__':
76 | main()
--------------------------------------------------------------------------------
/ops.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 | from tensorflow.contrib.layers import variance_scaling_initializer as he_init
3 | from tensorflow.contrib.layers import l2_regularizer
4 |
5 | def conv_layer(x, channels, kernel=3, stride=1, activation='relu', padding='VALID', layer_name='_conv3d') :
6 | with tf.name_scope(layer_name) :
7 | if activation == None :
8 | return tf.layers.conv3d(inputs=x, filters=channels, kernel_size=kernel, kernel_initializer=he_init(),
9 | strides=stride, padding=padding)
10 | else :
11 | return tf.layers.conv3d(inputs=x, filters=channels, kernel_size=kernel, kernel_initializer=he_init(),
12 | strides=stride, padding=padding, activation=relu)
13 |
14 | def deconv_layer(x, channels, kernel=4, stride=2, padding='VALID', layer_name='_deconv3d') :
15 | with tf.name_scope(layer_name) :
16 | crop = 1
17 | x = tf.layers.conv3d_transpose(inputs=x, filters=channels, kernel_size=kernel, kernel_initializer=he_init(),
18 | strides=stride, padding=padding, use_bias=False)
19 | x = x[:, crop:-crop, crop:-crop, crop:-crop, :]
20 | return x
21 |
22 | def flatten(x) :
23 | return tf.layers.flatten(x)
24 |
25 | def fully_connected(x, unit=2, layer_name='fully') :
26 | with tf.name_scope(layer_name) :
27 | return tf.layers.dense(x, units=unit)
28 |
29 | def max_pooling(x, kernel=2, stride=2, padding='VALID') :
30 | x = tf.layers.max_pooling3d(inputs=x, pool_size=kernel, strides=stride, padding=padding)
31 | x = tf.concat([x, x], axis=-1)
32 | return x
33 |
34 |
35 | def copy_crop(crop_layer, in_layer):
36 | crop = []
37 | for i in range(1, 4):
38 | crop_left = (tf.shape(crop_layer)[i] - tf.shape(in_layer)[i]) // 2
39 | crop.append(crop_left)
40 |
41 | crop_right = tf.cond(tf.equal(tf.shape(crop_layer)[i] - crop_left*2, tf.shape(in_layer)[i]) ,
42 | lambda : crop_left,
43 | lambda : crop_left + 1)
44 | crop.append(crop_right)
45 |
46 | crop_layer = crop_layer[:, crop[0]: -crop[1], crop[2]: -crop[3], crop[4]: -crop[5], :]
47 |
48 | return tf.concat([crop_layer, in_layer], axis=-1)
49 |
50 | def relu(x) :
51 | return tf.nn.relu(x)
52 |
53 | def sigmoid(x) :
54 | return tf.sigmoid(x)
55 |
56 | def softmax(x) :
57 | return tf.nn.softmax(x)
58 |
--------------------------------------------------------------------------------
/preprocessing/README/convert_luna_to_npy_README.md:
--------------------------------------------------------------------------------
1 | # Preprocessing
2 | ```python
3 | > all_in_one.py = convert_luna_to_npy.py + h5py_patch_py
4 | ```
5 |
6 | ## convert_luna_to_npy.py
7 | ### 1. read_csv
8 | * This is the code that reads `annotations.csv`
9 | * **key** is the `series_uid`
10 | * **value** is the `coordinate value (x, y, z order) and diameter`
11 | ```python
12 | def read_csv(filename):
13 | lines = []
14 | with open(filename, 'r') as f:
15 | csvreader = csv.reader(f)
16 | for line in csvreader:
17 | lines.append(line)
18 |
19 | lines = lines[1:] # remove csv headers
20 | annotations_dict = {}
21 | for i in lines:
22 | series_uid, x, y, z, diameter = i
23 | value = {'position':[float(x),float(y),float(z)],
24 | 'diameter':float(diameter)}
25 | if series_uid in annotations_dict.keys():
26 | annotations_dict[series_uid].append(value)
27 | else:
28 | annotations_dict[series_uid] = [value]
29 |
30 | return annotations_dict
31 | ```
32 |
33 | ### 2. load_itk_image
34 | * This code converts `mhd file` to a `numpy array`
35 | * The direction axis is set to `[z, y, x]`
36 | ```python
37 | def load_itk_image(filename):
38 | itkimage = sitk.ReadImage(filename)
39 | numpyImage = sitk.GetArrayFromImage(itkimage)
40 |
41 | numpyOrigin = np.array(list(reversed(itkimage.GetOrigin())))
42 | numpySpacing = np.array(list(reversed(itkimage.GetSpacing())))
43 |
44 | return numpyImage, numpyOrigin, numpySpacing
45 | ```
46 |
47 | * If you want the direction axis to be `[x, y, z]`, use the this code
48 | ```python
49 | def load_itk(filename):
50 | itkimage = sitk.ReadImage(filename)
51 | image = np.transpose(sitk.GetArrayFromImage(itkimage))
52 | origin = np.array(itkimage.GetOrigin())
53 | spacing = np.array(itkimage.GetSpacing())
54 | return image, origin, spacing
55 | ```
56 |
57 | ### 3. resampling
58 | * If you load all the luna16 data and make it into a numpy array... You do `resample`, `normalize` and `zero_centering`
59 | * Each mhd file has different distances between the x, y, and z axis. (You might think this is because the machines that took the pictures are different)
60 | * `resample` is to match the distances between the x, y, and z axis in all mhd files.
61 | * `OUTPUT_SPACING` is the distance mentioned above. (In now, `1.25mm`)
62 | ```python
63 | def resample(image, org_spacing, new_spacing=OUTPUT_SPACING):
64 |
65 | resize_factor = org_spacing / new_spacing
66 | new_real_shape = image.shape * resize_factor
67 | new_shape = np.round(new_real_shape)
68 | real_resize_factor = new_shape / image.shape
69 | new_spacing = org_spacing / real_resize_factor
70 |
71 | image = scipy.ndimage.interpolation.zoom(image, real_resize_factor, mode='nearest')
72 |
73 | return image, new_spacing
74 | ```
75 |
76 | ### 4. normalize
77 | * In `normalize`, Please check the `Hounsfield_Unit` in the table below.
78 | * In LUNA16 nodule detection, uses `-1000 ~ 400`
79 |
80 | 
81 |
82 | ```python
83 | def normalize_planes(npzarray):
84 |
85 | maxHU = 400.
86 | minHU = -1000.
87 | npzarray = (npzarray - minHU) / (maxHU - minHU)
88 | npzarray[npzarray > 1] = 1.
89 | npzarray[npzarray < 0] = 0.
90 | return npzarray
91 | ```
92 |
93 | ### 5. zero centering
94 | * `zero center` makes the average of the images zero. If you do this, your training will be better. (But sometimes it is better not to do it.)
95 | * In LUNA16, use `0.25`
96 | ```python
97 | def zero_center(image):
98 | PIXEL_MEAN = 0.25
99 | image = image - PIXEL_MEAN
100 |
101 | return image
102 | ```
103 |
104 | ### 6. create label
105 | * This is the code that masks the `nodule`
106 | * *example*
107 |
108 |

109 |

110 |
111 |
112 | ```python
113 | def create_label(arr_shape, nodules, new_spacing, coord=False):
114 | """
115 | nodules = list of dict {'position', 'diameter'}
116 | """
117 | def _create_mask(arr_shape, position, diameter):
118 |
119 | z_dim, y_dim, x_dim = arr_shape
120 | z_pos, y_pos, x_pos = position
121 |
122 | z,y,x = np.ogrid[-z_pos:z_dim-z_pos, -y_pos:y_dim-y_pos, -x_pos:x_dim-x_pos]
123 | mask = z**2 + y**2 + x**2 <= int(diameter//2)**2
124 |
125 | return mask
126 |
127 | if coord:
128 | label = []
129 | else:
130 | label = np.zeros(arr_shape, dtype='bool')
131 |
132 | for nodule in nodules:
133 | worldCoord = nodule['position']
134 | worldCoord = np.asarray([worldCoord[2],worldCoord[1],worldCoord[0]])
135 |
136 | # new_spacing came from resample
137 | voxelCoord = compute_coord(worldCoord, origin, new_spacing)
138 | voxelCoord = [int(i) for i in voxelCoord]
139 |
140 | diameter = nodule['diameter']
141 | diameter = diameter / new_spacing[1]
142 |
143 | if coord:
144 | label.append(voxelCoord + [diameter])
145 | else:
146 | mask = _create_mask(arr_shape, voxelCoord, diameter)
147 | label = np.logical_or(label, mask)
148 |
149 | return label
150 | ```
151 |
152 | ### Next Step
153 | * [h5py_patch](https://github.com/taki0112/CASED-Tensorflow/blob/master/preprocessing/README/h5py_patch_README.md)
154 |
155 | ## Author
156 | Junho Kim / [@Lunit](http://lunit.io/)
157 |
--------------------------------------------------------------------------------
/preprocessing/README/h5py_patch_README.md:
--------------------------------------------------------------------------------
1 | # Preprocessing
2 | ```python
3 | > all_in_one.py = convert_luna_to_npy.py + h5py_patch_py
4 | ```
5 |
6 | ## h5py_patch
7 | * This code is used to train the image by extracting it by patch unit.
8 |
9 | ### 1. read npy file
10 | * The reason for transpose is to make the coordinate directions `[x, y, z]`
11 | ```python
12 | image = np.transpose(np.load(mhd_to_npy))
13 | label = np.transpose(np.load(create_label_npy))
14 | ```
15 |
16 | ### 2. padding the image
17 | ```python
18 | offset = patch_size // 2
19 | stride = 8
20 | move = offset // stride
21 |
22 | non_pad = image
23 | non_label_pad = label
24 |
25 | non_pad = np.pad(non_pad, offset, 'constant', constant_values=np.min(non_pad))
26 | non_label_pad = np.pad(non_label_pad, offset, 'constant', constant_values=np.min(non_label_pad))
27 |
28 | image = np.pad(image, offset + (stride * move), 'constant', constant_values=np.min(image))
29 | label = np.pad(label, offset + (stride * move), 'constant', constant_values=np.min(label))
30 |
31 | ```
32 |
33 | * The patch is ***centered on the nodule*** and then extracted.
34 | * If the coordinates of the nodule are on the edge of the image, it is hard to extract the patch... so do padding
35 | * `patch size` = 68, ie the size of the patch is `68 * 68 * 68`
36 | * `offset` = padding size
37 |
38 |
39 |
40 | * `stride` = It means how many spaces to move around the nodule
41 | * It is possible to stride `8 times(4 * 2)` in the ***x, y, z, xy, xz, yz, and xyz*** directions respectively. (*Then, get 56 patch*)
42 | * Therefore, It will be make the patch per single nodule point is 56 + 1 = `57`
43 | * `stride` = 8, `move` = 4
44 |
45 | 
46 |
47 | ### 3. get the coordinates of the nodule
48 | ```python
49 | def world_2_voxel(world_coord, origin, spacing):
50 | stretched_voxel_coord = np.absolute(world_coord - origin)
51 | voxel_coord = stretched_voxel_coord / spacing
52 | return voxel_coord
53 | ```
54 |
55 | ### 4. extract the patch that contains the nodule and any patches that are not included
56 | * the number of non-nodule patch = 3 * the number of nodule patch
57 | * For the label patch, resize the patch using `max-pooling` in `skimage`, since the output size of the network (U-net) is 8 * 8 * 8
58 | * Don't use `scipy.resize`... Because, if you use it, there is a phenomenon that the pixel value is 1 (True) disappears
59 |
60 | ```python
61 | def get_patch(image, coords, offset, patch_list, patch_flag=True):
62 | xyz = image[int(coords[0] - offset): int(coords[0] + offset),
63 | int(coords[1] - offset): int(coords[1] + offset),
64 | int(coords[2] - offset): int(coords[2] + offset)]
65 |
66 | if patch_flag:
67 | output = np.expand_dims(xyz, axis=-1)
68 | else: # label
69 | # resize xyz
70 | xyz = skimage.measure.block_reduce(xyz, (9, 9, 9), np.max)
71 | output = np.expand_dims(xyz, axis=-1)
72 |
73 | output = indices_to_one_hot(output.astype(np.int32), 2)
74 | output = np.reshape(output, (label_size, label_size, label_size, 2))
75 | output = output.astype(np.float32)
76 |
77 | patch_list.append(output)
78 | ```
79 |
80 | ### 5. create the patch with h5py
81 | * See [link](https://www.safaribooksonline.com/library/view/python-and-hdf5/9781491944981/ch04.html) for reasons why I chose `lzf type`
82 | ```python
83 | with h5py.File(save_path + 'subset' + str(i) + '.h5', 'w') as hf:
84 | hf.create_dataset('nodule', data=nodule[:], compression='lzf')
85 | hf.create_dataset('label_nodule', data=nodule_label[:], compression='lzf')
86 |
87 | hf.create_dataset('non_nodule', data=non_nodule[:], compression='lzf')
88 | hf.create_dataset('label_non_nodule', data=non_nodule_label[:], compression='lzf')
89 | ```
90 |
91 | ### 6. data load
92 | * Use multiprocessing.Pool
93 | * `process_num` is good for exponentiation of 2.
94 | * If you want to know how many cpu you have available in Linux... `grep -c processor /proc/cpuinfo`
95 | ```python
96 | from multiprocessing import Pool
97 |
98 | def nodule_hf(idx):
99 | with h5py.File(image_patch, 'r') as hf:
100 | nodule = hf['nodule'][idx:idx + get_data_num]
101 | return nodule
102 |
103 | process_num = 32
104 | get_data_num = 64
105 |
106 | with h5py.File(image_patch, 'r') as fin:
107 | nodule_range = range(0, len(fin['nodule']), get_data_num)
108 |
109 | pool = Pool(processes = process_num)
110 | pool_nodule = pool.map(nodule_hf, nodule_range)
111 | pool.close()
112 |
113 | nodule = []
114 |
115 | for p in pool_nodule :
116 | nodule.extend(p)
117 | ```
118 |
119 | ## Author
120 | Junho Kim / [@Lunit](http://lunit.io/)
121 |
--------------------------------------------------------------------------------
/preprocessing/all_in_one.py:
--------------------------------------------------------------------------------
1 | from preprocessing.preprocess_utils import *
2 | from random import randint
3 | import h5py
4 | import numpy as np
5 | OUTPUT_SPACING = [1.25, 1.25, 1.25]
6 | patch_size = 68
7 | label_size = 8
8 | stride = 8
9 | move = (patch_size // 2) // stride # 4
10 |
11 |
12 | if __name__=="__main__":
13 |
14 | from matplotlib.patches import Circle
15 |
16 | coord = False
17 |
18 | dst_spacing = OUTPUT_SPACING
19 | src_root = '/data/jhkim/LUNA16/original'
20 | #dst_root = '/lunit/data/LUNA16/npydata'
21 |
22 | save_path = '/data2/jhkim/LUNA16/patch/SH/'
23 |
24 | # if not os.path.exists(dst_root):
25 | # os.makedirs(dst_root)
26 |
27 | annotation_csv = os.path.join(src_root,'CSVFILES/annotations.csv')
28 | annotations = read_csv(annotation_csv)
29 |
30 | src_mhd = []
31 | coord_dict = {}
32 | for sub_n in range(10) :
33 | idx = 1
34 | sub_n_str = 'subset' + str(sub_n)
35 | image_paths = glob.glob(os.path.join(src_root,sub_n_str,'*.mhd'))
36 |
37 | nodule = []
38 | non_nodule = []
39 | nodule_label = []
40 | non_nodule_label = []
41 |
42 | for i in image_paths:
43 | filename = os.path.split(i)[-1]
44 | series_uid = os.path.splitext(filename)[0]
45 |
46 | subset_num = i.split('/')[-2]
47 | # dst_subset_path = os.path.join(dst_root,subset_num)
48 | #
49 | # if not os.path.exists(dst_subset_path):
50 | # os.makedirs(dst_subset_path)
51 |
52 | np_img, origin, spacing = load_itk_image(i)
53 |
54 | resampled_img, new_spacing = resample(np_img, spacing, dst_spacing)
55 | resampled_img_shape = resampled_img.shape
56 |
57 | norm_img = normalize_planes(resampled_img)
58 | norm_img = zero_center(norm_img)
59 | norm_img_shape = norm_img.shape
60 | nodule_coords = []
61 | try:
62 | nodule_coords = annotations[series_uid]
63 | label = create_label(resampled_img_shape, nodule_coords, origin, new_spacing, coord=coord)
64 | except:
65 | if coord:
66 | label = []
67 | else:
68 | label = np.zeros(resampled_img_shape, dtype='bool')
69 |
70 | image = np.transpose(norm_img)
71 | label = np.transpose(label)
72 | #origin = np.transpose(origin)
73 | #new_spacing = np.transpose(new_spacing)
74 |
75 | #np.save(os.path.join(dst_subset_path,series_uid+'.npy'), norm_img)
76 | #np.save(os.path.join(dst_subset_path,series_uid+'.label.npy'), label)
77 | #coord_dict[os.path.join(dst_subset_path,series_uid+'.npy')] = label
78 |
79 | # padding
80 | offset = patch_size // 2
81 |
82 | non_pad = image
83 | non_label_pad = label
84 |
85 | non_pad = np.pad(non_pad, offset, 'constant', constant_values=np.min(non_pad))
86 | non_label_pad = np.pad(non_label_pad, offset, 'constant', constant_values=np.min(non_label_pad))
87 |
88 | image = np.pad(image, offset + (stride * move), 'constant', constant_values=np.min(image))
89 | label = np.pad(label, offset + (stride * move), 'constant', constant_values=np.min(label))
90 |
91 | nodule_list = []
92 | nodule_label_list = []
93 |
94 | for nodule_coord in nodule_coords :
95 | worldCoord = nodule_coord['position']
96 | worldCoord = np.asarray([worldCoord[2], worldCoord[1], worldCoord[0]])
97 |
98 | # new_spacing came from resample
99 | voxelCoord = compute_coord(worldCoord, origin, new_spacing)
100 | voxelCoord = np.asarray([int(i) + offset + (stride * move) for i in voxelCoord])
101 | voxelCoord = np.transpose(voxelCoord)
102 |
103 | patch_stride(image, voxelCoord, offset, nodule_list) # x,y,z, xy,xz,yz, xyz ... get stride patch
104 | patch_stride(label, voxelCoord, offset, nodule_label_list, patch_flag=False)
105 |
106 | #print(series_uid)
107 |
108 | nodule_num = len(nodule_list)
109 |
110 | non_nodule_list = []
111 | non_nodule_label_list = []
112 | x_coords = non_pad.shape[0] - offset - 1
113 | y_coords = non_pad.shape[1] - offset - 1
114 | z_coords = non_pad.shape[2] - offset - 1
115 |
116 | while len(non_nodule_list) < 3 * nodule_num:
117 | rand_x = randint(offset, x_coords)
118 | rand_y = randint(offset, y_coords)
119 | rand_z = randint(offset, z_coords)
120 |
121 | coords = np.array([rand_x, rand_y, rand_z])
122 |
123 | get_patch(non_pad, coords, offset, non_nodule_list)
124 | get_patch(non_label_pad, coords, offset, non_nodule_label_list, patch_flag=False)
125 |
126 | nodule.extend(nodule_list)
127 | non_nodule.extend(non_nodule_list)
128 |
129 | nodule_label.extend(nodule_label_list)
130 | non_nodule_label.extend(non_nodule_label_list)
131 |
132 | print(sub_n_str + ' / ' + str(idx) + ' / ' + str(len(image_paths)))
133 | print('nodule : ', np.shape(nodule))
134 | print('nodule_label : ', np.shape(nodule_label))
135 | print('non-nodule : ', np.shape(non_nodule))
136 | print('non-nodule_label : ', np.shape(non_nodule_label))
137 | idx += 1
138 |
139 | np.random.seed(0)
140 | np.random.shuffle(nodule)
141 | np.random.seed(0)
142 | np.random.shuffle(nodule_label)
143 |
144 | np.random.seed(0)
145 | np.random.shuffle(non_nodule)
146 | np.random.seed(0)
147 | np.random.shuffle(non_nodule_label)
148 |
149 | train_nodule_len = int(len(nodule) * 0.7)
150 | train_non_nodule_len = int(len(non_nodule) * 0.7)
151 |
152 | with h5py.File(save_path + 'subset' + str(sub_n) + '.h5', 'w') as hf:
153 | hf.create_dataset('nodule', data=nodule[:], compression='lzf')
154 | hf.create_dataset('label_nodule', data=nodule_label[:], compression='lzf')
155 |
156 | hf.create_dataset('non_nodule', data=non_nodule[:], compression='lzf')
157 | hf.create_dataset('label_non_nodule', data=non_nodule_label[:], compression='lzf')
158 |
159 |
160 |
161 | # with h5py.File(save_path + 'subset' + str(sub_n) + '.h5', 'w') as hf:
162 | # hf.create_dataset('nodule', data=nodule[:train_nodule_len], compression='lzf')
163 | # hf.create_dataset('label_nodule', data=nodule_label[:train_nodule_len], compression='lzf')
164 | #
165 | # hf.create_dataset('non_nodule', data=non_nodule[:train_non_nodule_len], compression='lzf')
166 | # hf.create_dataset('label_non_nodule', data=non_nodule_label[:train_non_nodule_len], compression='lzf')
167 | #
168 | # with h5py.File(save_path + 't_subset' + str(sub_n) + '.h5', 'w') as hf:
169 | # hf.create_dataset('nodule', data=nodule[train_nodule_len:], compression='lzf')
170 | # hf.create_dataset('label_nodule', data=nodule_label[train_nodule_len:], compression='lzf')
171 | #
172 | # hf.create_dataset('non_nodule', data=non_nodule[train_non_nodule_len:], compression='lzf')
173 | # hf.create_dataset('label_non_nodule', data=non_nodule_label[train_non_nodule_len:], compression='lzf')
174 |
175 |
176 |
--------------------------------------------------------------------------------
/preprocessing/convert_luna_to_npy.py:
--------------------------------------------------------------------------------
1 | import SimpleITK as sitk
2 | import numpy as np
3 | import csv
4 | import os, glob
5 | from PIL import Image
6 | import matplotlib.pyplot as plt
7 | import pickle
8 |
9 | import scipy.ndimage
10 | from skimage import measure, morphology
11 | from mpl_toolkits.mplot3d.art3d import Poly3DCollection
12 | from time import time
13 |
14 | OUTPUT_SPACING = [1.25, 1.25, 1.25]
15 | def load_itk_image(filename):
16 | itkimage = sitk.ReadImage(filename)
17 | numpyImage = sitk.GetArrayFromImage(itkimage)
18 |
19 | numpyOrigin = np.array(list(reversed(itkimage.GetOrigin())))
20 | numpySpacing = np.array(list(reversed(itkimage.GetSpacing())))
21 |
22 | return numpyImage, numpyOrigin, numpySpacing
23 |
24 | def read_csv(filename):
25 | lines = []
26 | with open(filename, 'r') as f:
27 | csvreader = csv.reader(f)
28 | for line in csvreader:
29 | lines.append(line)
30 |
31 | lines = lines[1:] # remove csv headers
32 | annotations_dict = {}
33 | for i in lines:
34 | series_uid, x, y, z, diameter = i
35 | value = {'position':[float(x),float(y),float(z)],
36 | 'diameter':float(diameter)}
37 | if series_uid in annotations_dict.keys():
38 | annotations_dict[series_uid].append(value)
39 | else:
40 | annotations_dict[series_uid] = [value]
41 |
42 | return annotations_dict
43 |
44 | def compute_coord(worldCoord, origin, spacing):
45 | stretchedVoxelCoord = np.absolute(worldCoord - origin)
46 | voxelCoord = stretchedVoxelCoord / spacing
47 | return voxelCoord
48 |
49 | def normalize_planes(npzarray):
50 | #maxHU = 600.
51 | #minHU = -1200.
52 | maxHU = 400.
53 | minHU = -1000.
54 | npzarray = (npzarray - minHU) / (maxHU - minHU)
55 | npzarray[npzarray>1] = 1.
56 | npzarray[npzarray<0] = 0.
57 | return npzarray
58 |
59 | def zero_center(image):
60 | PIXEL_MEAN = 0.25
61 | image = image - PIXEL_MEAN
62 |
63 | return image
64 |
65 | def resample(image, org_spacing, new_spacing=OUTPUT_SPACING):
66 |
67 | resize_factor = org_spacing / new_spacing
68 | new_real_shape = image.shape * resize_factor
69 | new_shape = np.round(new_real_shape)
70 | real_resize_factor = new_shape / image.shape
71 | new_spacing = org_spacing / real_resize_factor
72 |
73 | image = scipy.ndimage.interpolation.zoom(image, real_resize_factor, mode='nearest')
74 |
75 | return image, new_spacing
76 |
77 | def create_label(arr_shape, nodules, new_spacing, coord=False):
78 | """
79 | nodules = list of dict {'position', 'diameter'}
80 | """
81 | def _create_mask(arr_shape, position, diameter):
82 |
83 | z_dim, y_dim, x_dim = arr_shape
84 | z_pos, y_pos, x_pos = position
85 |
86 | z,y,x = np.ogrid[-z_pos:z_dim-z_pos, -y_pos:y_dim-y_pos, -x_pos:x_dim-x_pos]
87 | mask = z**2 + y**2 + x**2 <= int(diameter//2)**2
88 |
89 | return mask
90 |
91 | if coord:
92 | label = []
93 | else:
94 | label = np.zeros(arr_shape, dtype='bool')
95 |
96 | for nodule in nodules:
97 | worldCoord = nodule['position']
98 | worldCoord = np.asarray([worldCoord[2],worldCoord[1],worldCoord[0]])
99 |
100 | # new_spacing came from resample
101 | voxelCoord = compute_coord(worldCoord, origin, new_spacing)
102 | voxelCoord = [int(i) for i in voxelCoord]
103 |
104 | diameter = nodule['diameter']
105 | diameter = diameter / new_spacing[1]
106 |
107 | if coord:
108 | label.append(voxelCoord + [diameter])
109 | else:
110 | mask = _create_mask(arr_shape, voxelCoord, diameter)
111 | label = np.logical_or(label, mask)
112 |
113 | return label
114 |
115 |
116 | def plot(image, label, z_idx):
117 |
118 | fig = plt.figure()
119 | ax = fig.add_subplot(1,2,1)
120 | ax.imshow(image[z_idx,:,:],cmap='gray')
121 |
122 | ax = fig.add_subplot(1,2,2)
123 | ax.imshow(label[z_idx,:,:],cmap='gray')
124 |
125 | fig.show()
126 |
127 | if __name__=="__main__":
128 |
129 | from matplotlib.patches import Circle
130 |
131 | coord = True
132 | exclude_flag = True
133 |
134 | dst_spacing = OUTPUT_SPACING
135 | #src_root = '/lunit/data/LUNA16/rawdata'
136 | dst_root = '/data2/jhkim/npydata'
137 |
138 | src_root = '/data/jhkim/LUNA16/original'
139 |
140 |
141 | if not os.path.exists(dst_root):
142 | os.makedirs(dst_root)
143 |
144 | if exclude_flag :
145 | annotation_csv = os.path.join(src_root,'CSVFILES/annotations_excluded.csv')
146 | else :
147 | annotation_csv = os.path.join(src_root,'CSVFILES/annotations.csv')
148 |
149 | annotations = read_csv(annotation_csv)
150 |
151 | src_mhd = []
152 | coord_dict = {}
153 | all_file = len(glob.glob(os.path.join(src_root,'subset[0-9]','*.mhd')))
154 | cnt = 1
155 |
156 | for i in glob.glob(os.path.join(src_root,'subset[0-9]','*.mhd')):
157 | st = time()
158 | filename = os.path.split(i)[-1]
159 | series_uid = os.path.splitext(filename)[0]
160 |
161 | subset_num = i.split('/')[-2]
162 | dst_subset_path = os.path.join(dst_root,subset_num)
163 |
164 | if not os.path.exists(dst_subset_path):
165 | os.makedirs(dst_subset_path)
166 |
167 | np_img, origin, spacing = load_itk_image(i)
168 |
169 | resampled_img, new_spacing = resample(np_img, spacing, dst_spacing)
170 | resampled_img_shape = resampled_img.shape
171 |
172 | norm_img = normalize_planes(resampled_img)
173 | norm_img = zero_center(norm_img)
174 |
175 | try:
176 | nodules = annotations[series_uid]
177 | label = create_label(resampled_img_shape, nodules, new_spacing, coord=coord)
178 | except:
179 | if coord:
180 | label = []
181 | else:
182 | label = np.zeros(resampled_img_shape, dtype='bool')
183 |
184 | # np.save(os.path.join(dst_subset_path,series_uid+'.npy'), norm_img)
185 | # np.save(os.path.join(dst_subset_path,series_uid+'.label.npy'), label)
186 |
187 | coord_dict[series_uid] = label
188 | with open('exclude.pkl', 'wb') as f:
189 | pickle.dump(coord_dict, f, protocol=pickle.HIGHEST_PROTOCOL)
190 | print('{} / {} / {}'.format(cnt, all_file, time() - st))
191 | cnt += 1
192 |
193 |
194 |
195 |
--------------------------------------------------------------------------------
/preprocessing/h5py_patch.py:
--------------------------------------------------------------------------------
1 | import glob
2 | import os
3 | from random import randint
4 | import h5py
5 | import numpy as np
6 | import pandas as pd
7 | from utils import load_itk, world_2_voxel
8 | import scipy.ndimage
9 | from skimage.measure import block_reduce
10 |
11 | OUTPUT_SPACING = [1.25, 1.25, 1.25]
12 | patch_size = 68
13 | label_size = 8
14 | stride = 8
15 |
16 | move = (patch_size // 2) // stride # 4
17 | annotations = pd.read_csv('/data/jhkim/LUNA16/CSVFILES/annotations.csv') # save dict format may be..
18 |
19 |
20 | def indices_to_one_hot(data, nb_classes):
21 | """Convert an iterable of indices to one-hot encoded labels."""
22 | targets = np.array(data).reshape(-1)
23 | return np.eye(nb_classes)[targets]
24 |
25 |
26 | def normalize(image):
27 | maxHU = 400.
28 | minHU = -1000.
29 |
30 | image = (image - minHU) / (maxHU - minHU)
31 | image[image > 1] = 1.
32 | image[image < 0] = 0.
33 | return image
34 |
35 |
36 | def zero_center(image):
37 | PIXEL_MEAN = 0.25
38 | image = image - PIXEL_MEAN
39 |
40 | return image
41 |
42 |
43 | def get_patch(image, coords, offset, nodule_list, patch_flag=True):
44 | xyz = image[int(coords[0] - offset): int(coords[0] + offset), int(coords[1] - offset): int(coords[1] + offset),
45 | int(coords[2] - offset): int(coords[2] + offset)]
46 |
47 | if patch_flag:
48 | output = np.expand_dims(xyz, axis=-1)
49 | else:
50 | # resize xyz
51 | """
52 | xyz = scipy.ndimage.zoom(input=xyz, zoom=1/8, order=1) # nearest
53 | xyz = np.where(xyz > 0, 1.0, 0.0)
54 | """
55 | xyz = block_reduce(xyz, (9, 9, 9), np.max)
56 | output = np.expand_dims(xyz, axis=-1)
57 |
58 | output = indices_to_one_hot(output.astype(np.int32), 2)
59 | output = np.reshape(output, (label_size, label_size, label_size, 2))
60 | output = output.astype(np.float32)
61 |
62 | # print('------------------')
63 | # print(output)
64 |
65 | # print(output)
66 | # print(np.shape(output))
67 |
68 | nodule_list.append(output)
69 |
70 |
71 | def patch_stride(image, coords, offset, nodule_list, patch_flag=True):
72 | # get stride * 7
73 | # In this case, get 8*7 = 56
74 | original_coords = coords
75 | move_stride = stride
76 |
77 | # center
78 | get_patch(image, coords, offset, nodule_list, patch_flag)
79 |
80 | # x
81 | for idx in range(1, move + 1):
82 | move_stride = stride * idx
83 | coords[0] = coords[0] + move_stride
84 | get_patch(image, coords, offset, nodule_list, patch_flag)
85 | coords = original_coords
86 |
87 | coords[0] = coords[0] - move_stride
88 | get_patch(image, coords, offset, nodule_list, patch_flag)
89 | coords = original_coords
90 |
91 | # y
92 | for idx in range(1, move + 1):
93 | move_stride = stride * idx
94 | coords[1] = coords[1] + move_stride
95 | get_patch(image, coords, offset, nodule_list, patch_flag)
96 | coords = original_coords
97 |
98 | coords[1] = coords[1] - move_stride
99 | get_patch(image, coords, offset, nodule_list, patch_flag)
100 | coords = original_coords
101 |
102 | # z
103 | for idx in range(1, move + 1):
104 | move_stride = stride * idx
105 | coords[2] = coords[2] + move_stride
106 | get_patch(image, coords, offset, nodule_list, patch_flag)
107 | coords = original_coords
108 |
109 | coords[2] = coords[2] - move_stride
110 | get_patch(image, coords, offset, nodule_list, patch_flag)
111 | coords = original_coords
112 |
113 | # xy
114 | for idx in range(1, move + 1):
115 | move_stride = stride * idx
116 | coords[0] = coords[0] + move_stride
117 | coords[1] = coords[1] + move_stride
118 | get_patch(image, coords, offset, nodule_list, patch_flag)
119 | coords = original_coords
120 |
121 | coords[0] = coords[0] - move_stride
122 | coords[1] = coords[1] - move_stride
123 | get_patch(image, coords, offset, nodule_list, patch_flag)
124 | coords = original_coords
125 |
126 | # xz
127 | for idx in range(1, move + 1):
128 | move_stride = stride * idx
129 | coords[0] = coords[0] + move_stride
130 | coords[2] = coords[2] + move_stride
131 | get_patch(image, coords, offset, nodule_list, patch_flag)
132 | coords = original_coords
133 |
134 | coords[0] = coords[0] - move_stride
135 | coords[2] = coords[2] - move_stride
136 | get_patch(image, coords, offset, nodule_list, patch_flag)
137 | coords = original_coords
138 |
139 | # yz
140 | for idx in range(1, move + 1):
141 | move_stride = stride * idx
142 | coords[1] = coords[1] + move_stride
143 | coords[2] = coords[2] + move_stride
144 | get_patch(image, coords, offset, nodule_list, patch_flag)
145 | coords = original_coords
146 |
147 | coords[1] = coords[1] - move_stride
148 | coords[2] = coords[2] - move_stride
149 | get_patch(image, coords, offset, nodule_list, patch_flag)
150 | coords = original_coords
151 |
152 | # xyz
153 | for idx in range(1, move + 1):
154 | move_stride = stride * idx
155 | coords += move_stride
156 | get_patch(image, coords, offset, nodule_list, patch_flag)
157 | coords = original_coords
158 |
159 | coords -= move_stride
160 | get_patch(image, coords, offset, nodule_list, patch_flag)
161 | coords = original_coords
162 |
163 |
164 | def process_image(image_path, annotations, nodule, non_nodule, nodule_label, non_nodule_label):
165 | image, origin, spacing = load_itk(image_path) # 512 512 119
166 | image_name = os.path.split(image_path)[1].replace('.mhd', '')
167 |
168 | subset_name = image_path.split('/')[-2]
169 | SH_path = '/data2/jhkim/npydata/' + subset_name + '/' + image_name + '.npy'
170 | label_name = '/data2/jhkim/npydata/' + subset_name + '/' + image_name + '.label.npy'
171 |
172 | # calculate resize factor
173 | resize_factor = spacing / OUTPUT_SPACING
174 | new_real_shape = image.shape * resize_factor
175 | new_shape = np.round(new_real_shape)
176 | real_resize = new_shape / image.shape
177 | new_spacing = spacing / real_resize
178 |
179 | image = np.transpose(np.load(SH_path))
180 | label = np.transpose(np.load(label_name))
181 |
182 | # image = normalize(image)
183 | # image = zero_center(image)
184 |
185 | # padding
186 | offset = patch_size // 2
187 |
188 | non_pad = image
189 | non_label_pad = label
190 |
191 | non_pad = np.pad(non_pad, offset, 'constant', constant_values=np.min(non_pad))
192 | non_label_pad = np.pad(non_label_pad, offset, 'constant', constant_values=np.min(non_label_pad))
193 |
194 | image = np.pad(image, offset + (stride * move), 'constant', constant_values=np.min(image))
195 | label = np.pad(label, offset + (stride * move), 'constant', constant_values=np.min(label))
196 |
197 |
198 | indices = annotations[annotations['seriesuid'] == image_name].index
199 |
200 | nodule_list = []
201 | nodule_label_list = []
202 | for i in indices:
203 | row = annotations.iloc[i]
204 | world_coords = np.array([row.coordX, row.coordY, row.coordZ])
205 |
206 | coords = np.floor(world_2_voxel(world_coords, origin, new_spacing)) + offset + (stride * move) # center
207 | patch_stride(image, coords, offset, nodule_list) # x,y,z, xy,xz,yz, xyz ... get stride patch
208 | patch_stride(label, coords, offset, nodule_label_list, patch_flag=False)
209 |
210 | nodule_num = len(nodule_list)
211 |
212 | non_nodule_list = []
213 | non_nodule_label_list = []
214 | x_coords = non_pad.shape[0] - offset - 1
215 | y_coords = non_pad.shape[1] - offset - 1
216 | z_coords = non_pad.shape[2] - offset - 1
217 |
218 | while len(non_nodule_list) < 3 * nodule_num:
219 | rand_x = randint(offset, x_coords)
220 | rand_y = randint(offset, y_coords)
221 | rand_z = randint(offset, z_coords)
222 |
223 | coords = np.array([rand_x, rand_y, rand_z])
224 |
225 | get_patch(non_pad, coords, offset, non_nodule_list)
226 | get_patch(non_label_pad, coords, offset, non_nodule_label_list, patch_flag=False)
227 |
228 | nodule.extend(nodule_list)
229 | non_nodule.extend(non_nodule_list)
230 | nodule_label.extend(nodule_label_list)
231 | non_nodule_label.extend(non_nodule_label_list)
232 |
233 | print('nodule : ', np.shape(nodule))
234 | print('nodule_label : ', np.shape(nodule_label))
235 | print('non-nodule : ', np.shape(non_nodule))
236 | print('non-nodule_label : ', np.shape(non_nodule_label))
237 |
238 |
239 | for i in range(10):
240 | image_paths = glob.glob("/data/jhkim/LUNA16/original/subset" + str(i) + '/*.mhd')
241 | nodule = []
242 | non_nodule = []
243 | nodule_label = []
244 | non_nodule_label = []
245 |
246 | idx = 1
247 | flag = 1
248 | save_path = '/data2/jhkim/LUNA16/patch/SH/'
249 |
250 | for image_path in image_paths:
251 | print('subset' + str(i) + ' / ' + str(idx) + ' / ' + str(len(image_paths)))
252 | process_image(image_path, annotations, nodule, non_nodule, nodule_label, non_nodule_label)
253 | idx += 1
254 |
255 | np.random.seed(0)
256 | np.random.shuffle(nodule)
257 | np.random.seed(0)
258 | np.random.shuffle(nodule_label)
259 |
260 | np.random.seed(0)
261 | np.random.shuffle(non_nodule)
262 | np.random.seed(0)
263 | np.random.shuffle(non_nodule_label)
264 |
265 | train_nodule_len = int(len(nodule) * 0.7)
266 | train_non_nodule_len = int(len(non_nodule) * 0.7)
267 |
268 | # print(np.shape(nodule))
269 | # print(np.shape(non_nodule))
270 |
271 | # print(np.shape(nodule[:train_nodule_len]))
272 |
273 | with h5py.File(save_path + 'subset' + str(i) + '.h5', 'w') as hf:
274 | hf.create_dataset('nodule', data=nodule[:], compression='lzf')
275 | hf.create_dataset('label_nodule', data=nodule_label[:], compression='lzf')
276 |
277 | hf.create_dataset('non_nodule', data=non_nodule[:], compression='lzf')
278 | hf.create_dataset('label_non_nodule', data=non_nodule_label[:], compression='lzf')
279 |
280 | # with h5py.File(save_path + 't_subset' + str(i) + '.h5', 'w') as hf:
281 | # hf.create_dataset('nodule', data=nodule[train_nodule_len:], compression='lzf')
282 | # hf.create_dataset('label_nodule', data=nodule_label[train_nodule_len:], compression='lzf')
283 | #
284 | # hf.create_dataset('non_nodule', data=non_nodule[train_non_nodule_len:], compression='lzf')
285 | # hf.create_dataset('label_non_nodule', data=non_nodule_label[train_non_nodule_len:], compression='lzf')
286 |
--------------------------------------------------------------------------------
/preprocessing/preprocess_utils.py:
--------------------------------------------------------------------------------
1 | import SimpleITK as sitk
2 | import numpy as np
3 | import csv
4 | import os, glob
5 | from PIL import Image
6 | import matplotlib.pyplot as plt
7 |
8 | import scipy.ndimage
9 | from skimage import measure, morphology
10 | from mpl_toolkits.mplot3d.art3d import Poly3DCollection
11 | from skimage.measure import block_reduce
12 |
13 | OUTPUT_SPACING = [1.25, 1.25, 1.25]
14 | patch_size = 68
15 | label_size = 8
16 | stride = 8
17 |
18 | move = (patch_size // 2) // stride # 4
19 |
20 | def load_itk_image(filename):
21 | itkimage = sitk.ReadImage(filename)
22 | numpyImage = sitk.GetArrayFromImage(itkimage)
23 |
24 | numpyOrigin = np.array(list(reversed(itkimage.GetOrigin())))
25 | numpySpacing = np.array(list(reversed(itkimage.GetSpacing())))
26 |
27 | return numpyImage, numpyOrigin, numpySpacing
28 |
29 | def read_csv(filename):
30 | lines = []
31 | with open(filename, 'r') as f:
32 | csvreader = csv.reader(f)
33 | for line in csvreader:
34 | lines.append(line)
35 |
36 | lines = lines[1:] # remove csv headers
37 | annotations_dict = {}
38 | for i in lines:
39 | series_uid, x, y, z, diameter = i
40 | value = {'position':[float(x),float(y),float(z)],
41 | 'diameter':float(diameter)}
42 | if series_uid in annotations_dict.keys():
43 | annotations_dict[series_uid].append(value)
44 | else:
45 | annotations_dict[series_uid] = [value]
46 |
47 | return annotations_dict
48 |
49 | def compute_coord(worldCoord, origin, spacing):
50 | stretchedVoxelCoord = np.absolute(worldCoord - origin)
51 | voxelCoord = stretchedVoxelCoord / spacing
52 | return voxelCoord
53 |
54 | def normalize_planes(npzarray):
55 | maxHU = 400. # 600
56 | minHU = -1000. # -1200
57 |
58 | npzarray = (npzarray - minHU) / (maxHU - minHU)
59 | npzarray[npzarray>1] = 1.
60 | npzarray[npzarray<0] = 0.
61 | return npzarray
62 |
63 | def zero_center(image):
64 | PIXEL_MEAN = 0.25
65 | image = image - PIXEL_MEAN
66 |
67 | return image
68 |
69 | def indices_to_one_hot(data, nb_classes):
70 | """Convert an iterable of indices to one-hot encoded labels."""
71 | targets = np.array(data).reshape(-1)
72 | return np.eye(nb_classes)[targets]
73 |
74 | def get_patch(image, coords, offset, nodule_list, patch_flag=True):
75 | xyz = image[int(coords[0] - offset): int(coords[0] + offset), int(coords[1] - offset): int(coords[1] + offset),
76 | int(coords[2] - offset): int(coords[2] + offset)]
77 |
78 | if patch_flag:
79 | output = np.expand_dims(xyz, axis=-1)
80 | print(coords, np.shape(output))
81 | else:
82 | # resize xyz
83 | """
84 | xyz = scipy.ndimage.zoom(input=xyz, zoom=1/8, order=1) # nearest
85 | xyz = np.where(xyz > 0, 1.0, 0.0)
86 | """
87 | xyz = block_reduce(xyz, (9, 9, 9), np.max)
88 | output = np.expand_dims(xyz, axis=-1)
89 |
90 | output = indices_to_one_hot(output.astype(np.int32), 2)
91 | output = np.reshape(output, (label_size, label_size, label_size, 2))
92 | output = output.astype(np.float32)
93 |
94 | # print('------------------')
95 | # print(output)
96 |
97 | # print(output)
98 | # print(np.shape(output))
99 |
100 | nodule_list.append(output)
101 |
102 | def patch_stride(image, coords, offset, nodule_list, patch_flag=True):
103 | # get stride * 7
104 | # In this case, get 8*7 = 56
105 | original_coords = coords
106 | move_stride = stride
107 |
108 | # center
109 | get_patch(image, coords, offset, nodule_list, patch_flag)
110 |
111 | # x
112 | for idx in range(1, move + 1):
113 | move_stride = stride * idx
114 | coords[0] = coords[0] + move_stride
115 | get_patch(image, coords, offset, nodule_list, patch_flag)
116 | coords = original_coords
117 |
118 | coords[0] = coords[0] - move_stride
119 | get_patch(image, coords, offset, nodule_list, patch_flag)
120 | coords = original_coords
121 |
122 | # y
123 | for idx in range(1, move + 1):
124 | move_stride = stride * idx
125 | coords[1] = coords[1] + move_stride
126 | get_patch(image, coords, offset, nodule_list, patch_flag)
127 | coords = original_coords
128 |
129 | coords[1] = coords[1] - move_stride
130 | get_patch(image, coords, offset, nodule_list, patch_flag)
131 | coords = original_coords
132 |
133 | # z
134 | for idx in range(1, move + 1):
135 | move_stride = stride * idx
136 | coords[2] = coords[2] + move_stride
137 | get_patch(image, coords, offset, nodule_list, patch_flag)
138 | coords = original_coords
139 |
140 | coords[2] = coords[2] - move_stride
141 | get_patch(image, coords, offset, nodule_list, patch_flag)
142 | coords = original_coords
143 |
144 | # xy
145 | for idx in range(1, move + 1):
146 | move_stride = stride * idx
147 | coords[0] = coords[0] + move_stride
148 | coords[1] = coords[1] + move_stride
149 | get_patch(image, coords, offset, nodule_list, patch_flag)
150 | coords = original_coords
151 |
152 | coords[0] = coords[0] - move_stride
153 | coords[1] = coords[1] - move_stride
154 | get_patch(image, coords, offset, nodule_list, patch_flag)
155 | coords = original_coords
156 |
157 | # xz
158 | for idx in range(1, move + 1):
159 | move_stride = stride * idx
160 | coords[0] = coords[0] + move_stride
161 | coords[2] = coords[2] + move_stride
162 | get_patch(image, coords, offset, nodule_list, patch_flag)
163 | coords = original_coords
164 |
165 | coords[0] = coords[0] - move_stride
166 | coords[2] = coords[2] - move_stride
167 | get_patch(image, coords, offset, nodule_list, patch_flag)
168 | coords = original_coords
169 |
170 | # yz
171 | for idx in range(1, move + 1):
172 | move_stride = stride * idx
173 | coords[1] = coords[1] + move_stride
174 | coords[2] = coords[2] + move_stride
175 | get_patch(image, coords, offset, nodule_list, patch_flag)
176 | coords = original_coords
177 |
178 | coords[1] = coords[1] - move_stride
179 | coords[2] = coords[2] - move_stride
180 | get_patch(image, coords, offset, nodule_list, patch_flag)
181 | coords = original_coords
182 |
183 | # xyz
184 | for idx in range(1, move + 1):
185 | move_stride = stride * idx
186 | coords += move_stride
187 | get_patch(image, coords, offset, nodule_list, patch_flag)
188 | coords = original_coords
189 |
190 | coords -= move_stride
191 | get_patch(image, coords, offset, nodule_list, patch_flag)
192 | coords = original_coords
193 |
194 | def resample(image, org_spacing, new_spacing):
195 |
196 | resize_factor = org_spacing / new_spacing
197 | new_real_shape = image.shape * resize_factor
198 | new_shape = np.round(new_real_shape)
199 | real_resize_factor = new_shape / image.shape
200 | new_spacing = org_spacing / real_resize_factor
201 |
202 | image = scipy.ndimage.interpolation.zoom(image, real_resize_factor, mode='nearest')
203 |
204 | return image, new_spacing
205 |
206 | def create_label(arr_shape, nodules, origin, new_spacing, coord=False):
207 | """
208 | nodules = list of dict {'position', 'diameter'}
209 | """
210 |
211 | def _create_mask(arr_shape, position, diameter):
212 |
213 | z_dim, y_dim, x_dim = arr_shape
214 | z_pos, y_pos, x_pos = position
215 |
216 | z,y,x = np.ogrid[-z_pos:z_dim-z_pos, -y_pos:y_dim-y_pos, -x_pos:x_dim-x_pos]
217 | mask = z**2 + y**2 + x**2 <= int(diameter//2)**2
218 |
219 | return mask
220 |
221 | if coord:
222 | label = []
223 | else:
224 | label = np.zeros(arr_shape, dtype='bool')
225 |
226 | for nodule in nodules:
227 | worldCoord = nodule['position']
228 | worldCoord = np.asarray([worldCoord[2],worldCoord[1],worldCoord[0]])
229 |
230 | # new_spacing came from resample
231 | voxelCoord = compute_coord(worldCoord, origin, new_spacing)
232 | voxelCoord = [int(i) for i in voxelCoord]
233 |
234 | diameter = nodule['diameter']
235 | diameter = diameter / new_spacing[1]
236 |
237 | if coord:
238 | label.append(voxelCoord + [diameter])
239 | else:
240 | mask = _create_mask(arr_shape, voxelCoord, diameter)
241 | label = np.logical_or(label, mask)
242 |
243 | return label
244 |
245 |
246 | def plot(image, label, z_idx):
247 |
248 | fig = plt.figure()
249 | ax = fig.add_subplot(1,2,1)
250 | ax.imshow(image[z_idx,:,:],cmap='gray')
251 |
252 | ax = fig.add_subplot(1,2,2)
253 | ax.imshow(label[z_idx,:,:],cmap='gray')
254 |
255 | fig.show()
256 |
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import SimpleITK as sitk
3 | import tensorflow as tf
4 | import os, h5py, glob
5 | import tensorflow.contrib.slim as slim
6 | from multiprocessing import Pool
7 | from sklearn.metrics import roc_curve, auc
8 | import matplotlib.pyplot as plt
9 | from matplotlib.ticker import FixedFormatter
10 | import matplotlib
11 |
12 |
13 | seed = 0
14 | file = ""
15 |
16 | process_num = 32
17 | get_data_num = 64
18 |
19 |
20 | def load_itk(filename):
21 | itkimage = sitk.ReadImage(filename)
22 | image = np.transpose(sitk.GetArrayFromImage(itkimage))
23 | origin = np.array(itkimage.GetOrigin())
24 | spacing = np.array(itkimage.GetSpacing())
25 | return image, origin, spacing
26 |
27 |
28 | def world_2_voxel(world_coord, origin, spacing):
29 | stretched_voxel_coord = np.absolute(world_coord - origin)
30 | voxel_coord = stretched_voxel_coord / spacing
31 | return voxel_coord
32 |
33 |
34 | def voxel_2_world(voxel_coord, origin, spacing):
35 | stretched_voxel_coord = voxel_coord * spacing
36 | world_coord = stretched_voxel_coord + origin
37 | return world_coord
38 |
39 |
40 | def show_all_variables():
41 | model_vars = tf.trainable_variables()
42 | slim.model_analyzer.analyze_vars(model_vars, print_info=True)
43 |
44 |
45 | def check_folder(log_dir):
46 | if not os.path.exists(log_dir):
47 | os.makedirs(log_dir)
48 | return log_dir
49 |
50 |
51 | def nodule_hf(idx):
52 | with h5py.File(file, 'r') as hf:
53 | nodule = hf['nodule'][idx:idx + get_data_num]
54 | # print(np.shape(nodule))
55 | return nodule
56 |
57 |
58 | def non_nodule_hf(idx):
59 | with h5py.File(file, 'r') as hf:
60 | non_nodule = hf['non_nodule'][idx:idx + get_data_num]
61 | return non_nodule
62 |
63 | def label_nodule_hf(idx):
64 | with h5py.File(file, 'r') as hf:
65 | nodule = hf['label_nodule'][idx:idx + get_data_num]
66 | # print(np.shape(nodule))
67 | return nodule
68 |
69 |
70 | def label_non_nodule_hf(idx):
71 | with h5py.File(file, 'r') as hf:
72 | non_nodule = hf['label_non_nodule'][idx:idx + get_data_num]
73 | return non_nodule
74 |
75 |
76 | def prepare_data(sub_n):
77 | global file
78 | dir = '/data2/jhkim/LUNA16/patch/SH/subset'
79 | file = ''.join([dir, str(sub_n), '.h5'])
80 |
81 | with h5py.File(file, 'r') as fin:
82 | nodule_range = range(0, len(fin['nodule']), get_data_num)
83 | non_nodule_range = range(0, len(fin['non_nodule']), get_data_num)
84 | label_nodule_range = range(0, len(fin['label_nodule']), get_data_num)
85 | label_non_nodule_range = range(0, len(fin['label_non_nodule']), get_data_num)
86 |
87 |
88 | pool = Pool(processes=process_num)
89 | pool_nodule = pool.map(nodule_hf, nodule_range)
90 | pool_non_nodule = pool.map(non_nodule_hf, non_nodule_range)
91 | pool_label_nodule = pool.map(label_nodule_hf, label_nodule_range)
92 | pool_label_non_nodule = pool.map(label_non_nodule_hf, label_non_nodule_range)
93 |
94 | pool.close()
95 |
96 | # print(np.shape(nodule[0]))
97 | nodule = []
98 | non_nodule = []
99 |
100 | nodule_y = []
101 | non_nodule_y = []
102 |
103 | for p in pool_nodule:
104 | nodule.extend(p)
105 | for p in pool_non_nodule:
106 | non_nodule.extend(p)
107 |
108 | for p in pool_label_nodule:
109 | nodule_y.extend(p)
110 | for p in pool_label_non_nodule:
111 | non_nodule_y.extend(p)
112 |
113 | nodule = np.asarray(nodule)
114 | non_nodule = np.asarray(non_nodule)
115 | nodule_y = np.asarray(nodule_y)
116 | non_nodule_y = np.asarray(non_nodule_y)
117 |
118 | #print(np.shape(nodule_y))
119 | #print(np.shape(non_nodule_y))
120 |
121 | all_y = np.concatenate([nodule_y, non_nodule_y], axis=0)
122 | all_patch = np.concatenate([nodule, non_nodule], axis=0)
123 |
124 | # seed = 777
125 | # np.random.seed(seed)
126 | # np.random.shuffle(nodule)
127 | # np.random.seed(seed)
128 | # np.random.shuffle(nodule_y)
129 | #
130 | # np.random.seed(seed)
131 | # np.random.shuffle(all_patch)
132 | # np.random.seed(seed)
133 | # np.random.shuffle(all_y)
134 |
135 |
136 | return nodule, all_patch, nodule_y, all_y
137 |
138 | def validation_data(sub_n) :
139 | global file
140 | dir = '/data2/jhkim/LUNA16/patch/SH/subset'
141 | file = ''.join([dir, str(sub_n), '.h5'])
142 |
143 | with h5py.File(file, 'r') as fin:
144 | nodule_range = range(0, len(fin['nodule']), get_data_num)
145 | non_nodule_range = range(0, len(fin['non_nodule']), get_data_num)
146 | label_nodule_range = range(0, len(fin['label_nodule']), get_data_num)
147 | label_non_nodule_range = range(0, len(fin['label_non_nodule']), get_data_num)
148 |
149 |
150 | pool = Pool(processes=process_num)
151 | pool_nodule = pool.map(nodule_hf, nodule_range)
152 | pool_non_nodule = pool.map(non_nodule_hf, non_nodule_range)
153 | pool_label_nodule = pool.map(label_nodule_hf, label_nodule_range)
154 | pool_label_non_nodule = pool.map(label_non_nodule_hf, label_non_nodule_range)
155 |
156 | pool.close()
157 |
158 | # print(np.shape(nodule[0]))
159 | nodule = []
160 | non_nodule = []
161 |
162 | nodule_y = []
163 | non_nodule_y = []
164 |
165 | for p in pool_nodule:
166 | nodule.extend(p)
167 | for p in pool_non_nodule:
168 | non_nodule.extend(p)
169 |
170 | for p in pool_label_nodule:
171 | nodule_y.extend(p)
172 | for p in pool_label_non_nodule:
173 | non_nodule_y.extend(p)
174 |
175 | nodule = np.asarray(nodule)
176 | non_nodule = np.asarray(non_nodule)
177 | nodule_y = np.asarray(nodule_y)
178 | non_nodule_y = np.asarray(non_nodule_y)
179 |
180 | all_y = np.concatenate([nodule_y, non_nodule_y], axis=0)
181 | all_patch = np.concatenate([nodule, non_nodule], axis=0)
182 |
183 | # seed = 777
184 | # np.random.seed(seed)
185 | # np.random.shuffle(all_patch)
186 | # np.random.seed(seed)
187 | # np.random.shuffle(all_y)
188 |
189 |
190 | return all_patch, all_y
191 |
192 | def test_data(sub_n):
193 | global file
194 | dir = '/data2/jhkim/LUNA16/patch/SH/t_subset'
195 | file = ''.join([dir, str(sub_n), '.h5'])
196 |
197 | with h5py.File(file, 'r') as fin:
198 | nodule_range = range(0, len(fin['nodule']), get_data_num)
199 | non_nodule_range = range(0, len(fin['non_nodule']), get_data_num)
200 | label_nodule_range = range(0, len(fin['label_nodule']), get_data_num)
201 | label_non_nodule_range = range(0, len(fin['label_non_nodule']), get_data_num)
202 |
203 |
204 | pool = Pool(processes=process_num)
205 | pool_nodule = pool.map(nodule_hf, nodule_range)
206 | pool_non_nodule = pool.map(non_nodule_hf, non_nodule_range)
207 | pool_label_nodule = pool.map(label_nodule_hf, label_nodule_range)
208 | pool_label_non_nodule = pool.map(label_non_nodule_hf, label_non_nodule_range)
209 |
210 | pool.close()
211 |
212 | # print(np.shape(nodule[0]))
213 | nodule = []
214 | non_nodule = []
215 |
216 | nodule_y = []
217 | non_nodule_y = []
218 |
219 | for p in pool_nodule:
220 | nodule.extend(p)
221 | for p in pool_non_nodule:
222 | non_nodule.extend(p)
223 |
224 | for p in pool_label_nodule:
225 | nodule_y.extend(p)
226 | for p in pool_label_non_nodule:
227 | non_nodule_y.extend(p)
228 |
229 | nodule = np.asarray(nodule)
230 | non_nodule = np.asarray(non_nodule)
231 | nodule_y = np.asarray(nodule_y)
232 | non_nodule_y = np.asarray(non_nodule_y)
233 |
234 | all_y = np.concatenate([nodule_y, non_nodule_y], axis=0)
235 | all_patch = np.concatenate([nodule, non_nodule], axis=0)
236 |
237 | # seed = 777
238 | # np.random.seed(seed)
239 | # np.random.shuffle(all_patch)
240 | # np.random.seed(seed)
241 | # np.random.shuffle(all_y)
242 |
243 |
244 | return all_patch, all_y
245 |
246 |
247 | def sensitivity(logits, labels):
248 | predictions = tf.argmax(logits, axis=-1)
249 | actuals = tf.argmax(labels, axis=-1)
250 |
251 |
252 | nodule_actuals = tf.ones_like(actuals)
253 | non_nodule_actuals = tf.zeros_like(actuals)
254 | nodule_predictions = tf.ones_like(predictions)
255 | non_nodule_predictions = tf.zeros_like(predictions)
256 |
257 | tp_op = tf.reduce_sum(
258 | tf.cast(
259 | tf.logical_and(
260 | tf.equal(actuals, nodule_actuals),
261 | tf.equal(predictions, nodule_predictions)
262 | ),
263 | tf.float32
264 | )
265 | )
266 |
267 | tn_op = tf.reduce_sum(
268 | tf.cast(
269 | tf.logical_and(
270 | tf.equal(actuals, non_nodule_actuals),
271 | tf.equal(predictions, non_nodule_predictions)
272 | ),
273 | tf.float32
274 | )
275 | )
276 |
277 | fp_op = tf.reduce_sum(
278 | tf.cast(
279 | tf.logical_and(
280 | tf.equal(actuals, non_nodule_actuals),
281 | tf.equal(predictions, nodule_predictions)
282 | ),
283 | tf.float32
284 | )
285 | )
286 |
287 | fn_op = tf.reduce_sum(
288 | tf.cast(
289 | tf.logical_and(
290 | tf.equal(actuals, nodule_actuals),
291 | tf.equal(predictions, non_nodule_predictions)
292 | ),
293 | tf.float32
294 | )
295 | )
296 |
297 | false_positive_rate = fp_op / (fp_op + tn_op)
298 |
299 | recall = tp_op / (tp_op + fn_op)
300 |
301 | return recall, false_positive_rate
302 |
303 | def Snapshot(t, T, M, alpha_zero) :
304 | """
305 |
306 | t = # of current iteration
307 | T = # of total iteration
308 | M = # of snapshot
309 | alpha_zero = init learning rate
310 |
311 | """
312 |
313 | x = (np.pi * (t % (T // M))) / (T // M)
314 | x = np.cos(x) + 1
315 |
316 | lr = (alpha_zero / 2) * x
317 |
318 | return lr
319 |
320 | def indices_to_one_hot(data, nb_classes):
321 | """Convert an iterable of indices to one-hot encoded labels."""
322 | targets = np.array(data).reshape(-1)
323 | return np.eye(nb_classes)[targets]
324 |
325 | def create_exclude_mask(arr_shape, position, diameter):
326 | x_dim, y_dim, z_dim = arr_shape
327 | x_pos, y_pos, z_pos = position
328 |
329 | x, y, z = np.ogrid[-x_pos:x_dim - x_pos, -y_pos:y_dim - y_pos, -z_pos:z_dim - z_pos]
330 | mask = x ** 2 + y ** 2 + z ** 2 > int(diameter // 2) ** 2
331 |
332 | return mask
333 |
334 | def fp_per_scan(logit, label) :
335 | logit = np.reshape(logit, -1)
336 | label = np.reshape(label, -1)
337 |
338 | logit = logit[logit >= 0]
339 | label = label[label >= 0]
340 |
341 | logit = np.where(logit >= 1.0, logit-1, logit)
342 | label = np.where(label >= 1.0, label-1, label)
343 |
344 | fpr, tpr, th = roc_curve(label, logit, pos_label=1.0)
345 | negative_samples = np.count_nonzero(label == 0.0)
346 | fps = fpr * negative_samples
347 |
348 |
349 | """
350 | mean_sens = np.mean(sens_list)
351 | matplotlib.use('Agg')
352 |
353 | ax = plt.gca()
354 | plt.plot(fps_itp, sens_itp)
355 | # https://matplotlib.org/devdocs/api/_as_gen/matplotlib.pyplot.grid.html
356 | plt.xlim(MIN_FROC, MAX_FROC)
357 | plt.ylim(0, 1.1)
358 | plt.xlabel('Average number of false positives per scan')
359 | plt.ylabel('Sensitivity')
360 | # plt.legend(loc='lower right')
361 | # plt.legend(loc=9)
362 | plt.title('Average sensitivity = %.4f' % (mean_sens))
363 |
364 | plt.xscale('log', basex=2)
365 | ax.xaxis.set_major_formatter(FixedFormatter(fp_list))
366 |
367 | ax.xaxis.set_ticks(fp_list)
368 | ax.yaxis.set_ticks(np.arange(0, 1.1, 0.1))
369 | plt.grid(b=True, linestyle='dotted', which='both')
370 | plt.tight_layout()
371 |
372 | # plt.show()
373 | plt.savefig('result.png', bbox_inches=0, dpi=300)
374 | """
375 |
376 | return np.asarray(fps), np.asarray(tpr)
--------------------------------------------------------------------------------