├── LICENSE
├── README.md
├── brats_split_list
├── test_list_cv0.txt
├── train_list_cv0.txt
├── train_list_cv0_12samples.txt
├── train_list_cv0_24samples.txt
├── train_list_cv0_3samples1.txt
├── train_list_cv0_6samples1.txt
├── train_list_cv0_96samples.txt
└── validation_list_cv0.txt
├── configs
├── brats_boundseg.yml
├── brats_cpcseg.yml
├── brats_encdec.yml
├── brats_semi_cpcseg.yml
├── brats_semi_vaeseg.yml
└── brats_vaeseg.yml
├── examples
├── normalize.py
├── split_kfold.py
└── train.py
├── img
└── architecture.png
├── requirements.txt
└── src
├── __init__.py
├── datasets
├── __init__.py
├── generate_msd_edge.py
├── msd_bound.py
└── preprocess.py
├── functions
├── array
│ ├── __init__.py
│ ├── resize_images_3d.py
│ └── resize_images_3d_nearest_neighbor.py
├── evaluation
│ ├── __init__.py
│ └── dice_coefficient.py
└── loss
│ ├── __init__.py
│ ├── boundary_bce.py
│ ├── cpc_loss.py
│ ├── dice_loss.py
│ ├── focal_loss.py
│ ├── generalized_dice_loss.py
│ └── mixed_dice_loss.py
├── links
├── __init__.py
└── model
│ ├── __init__.py
│ ├── resize_images_3d.py
│ └── vaeseg.py
├── training
├── __init__.py
├── extensions
│ ├── __init__.py
│ ├── boundseg_evaluator.py
│ ├── cpcseg_evaluator.py
│ ├── encdec_seg_evaluator.py
│ └── vaeseg_evaluator.py
└── updaters
│ ├── __init__.py
│ ├── boundseg_updater.py
│ ├── cpcseg_updater.py
│ ├── encdec_seg_updater.py
│ └── vaeseg_updater.py
└── utils
├── config.py
├── encode_one_hot_vector.py
└── setup_helpers.py
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2020 Preferred Networks, Inc.
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Label-Efficient Multi-Task Segmentation using Contrastive Learning
2 |
3 | This repository contains the Chainer implemented code used for the following paper:
4 | > [1] **Label-Efficient Multi-Task Segmentation using Contrastive Learning**
5 | >Junichiro Iwasawa, Yuichiro Hirano, Yohei Sugawara
6 | > Accepted to MICCAI BrainLes 2020 workshop, [preprint on arXiv](https://arxiv.org/abs/2009.11160).
7 | >
8 | >
9 | > **Abstract:** *Obtaining annotations for 3D medical images is expensive and time-consuming, despite its importance for automating segmentation tasks. Although multi-task learning is considered an effective method for training segmentation models using small amounts of annotated data, a systematic understanding of various subtasks is still lacking. In this study, we propose a multi-task segmentation model with a contrastive learning based subtask and compare its performance with other multi-task models, varying the number of labeled data for training. We further extend our model so that it can utilize unlabeled data through the regularization branch in a semi-supervised manner. We experimentally show that our proposed method outperforms other multi-task methods including the state-of-the-art fully supervised model when the amount of annotated data is limited.*
10 |
11 | Disclaimer: PFN provides no warranty or support for this implementation. Use it at your own risk.
12 |
13 | ## Dependencies
14 |
15 | We have tested this code using:
16 |
17 | - Ubuntu 16.04.6
18 | - Python 3.7
19 | - CUDA 9.0
20 | - cuDNN 7.0
21 |
22 | The full list of Python packages for the code is given in `requirements.txt`. These can be installed using:
23 |
24 | ```bash
25 | pip install -r requirements.txt
26 | ```
27 |
28 | ## Usage
29 |
30 |
31 |
32 |
33 |
34 | The schematic image for the model is given above. To compare the effect of different regularization branches, we implemented four different models: the encoder-decoder alone (EncDec), EncDec with a VAE branch (VAEseg), EncDec with a boundary attention branch (Boundseg), and EncDec with a CPC branch (CPCseg). We have also implemented the semi-supervised VAEseg (ssVAEseg) and semi-supervised CPCseg (ssCPCseg) to investigate the effect of utilizing unlabeled data. The brain tumor dataset from the 2016 and 2017 Brain Tumor Image Segmentation ([BraTS](http://dx.doi.org/10.1109/TMI.2014.2377694)) Challenges was used for evaluation.
35 |
36 | ### Preparation
37 |
38 | #### Normalization
39 |
40 | The models were trained using images normalized to have zero mean and unit standard deviation. To normalize the original image data in the BraTS dataset (`imagesTr`), run the command below which will create a directory with normalized files (`imagesTr_normalized`).
41 |
42 | ```bash
43 | python examples/normalize.py --root_path PATH_TO_BRATS_DATA/imagesTr
44 | ```
45 |
46 | #### Generating boundary labels for boundseg
47 |
48 | Boundseg model takes as input the boundary labels of the dataset. In order to generate boundary labels, run
49 |
50 | ```bash
51 | python src/datasets/generate_msd_edge.py
52 | ```
53 |
54 | after modifying `label_path` in the file.
55 |
56 | #### Fixing training configs
57 |
58 | - Before running the model, please specify the paths for `image_path`, `label_path`, (and `edge_path` for Boundseg) in the config files within `configs/`.
59 |
60 | - When you want to change the number of labeled samples for training, rewrite `train_list_path` in the config file. For example, to use only 6 labeled data, rewrite it as `train_list_path: ./brats_split_list/train_list_cv0_6samples`.
61 |
62 | ### Training
63 |
64 | #### EncDec
65 |
66 | ```bash
67 | mpiexec -n 8 python3 examples/train.py \
68 | --config configs/brats_encdec.yml \
69 | --out results
70 | ```
71 |
72 | #### VAEseg
73 |
74 | ```bash
75 | mpiexec -n 8 python3 examples/train.py \
76 | --config configs/brats_vaeseg.yml \
77 | --out results
78 | ```
79 |
80 | #### Boundseg
81 |
82 | ```bash
83 | mpiexec -n 8 python3 examples/train.py \
84 | --config configs/brats_boundseg.yml \
85 | --out results
86 | ```
87 |
88 | #### CPCseg
89 |
90 | ```bash
91 | mpiexec -n 8 python3 examples/train.py \
92 | --config configs/brats_cpcseg.yml \
93 | --out results
94 | ```
95 |
96 | #### ssVAEseg
97 |
98 | ```bash
99 | mpiexec -n 8 python3 examples/train.py \
100 | --config configs/brats_semi_vaeseg.yml \
101 | --out results
102 | ```
103 |
104 | #### ssCPCseg
105 |
106 | ```bash
107 | mpiexec -n 8 python3 examples/train.py \
108 | --config configs/brats_semi_cpcseg.yml \
109 | --out results
110 | ```
111 |
112 | ## Results
113 |
114 | The performances of the models, trained with 6 labeled samples (`train_list_cv0_6samples1.txt`), were evaluated by the mean dice score and 95th percentile Hausdorff distance using the test dataset (`brats_split_list/test_list_cv0.txt`) ([[1](#Label-Efficient-Multi-Task-Segmentation-using-Contrastive-Learning)], Table 1). For more detailed results, please see the paper [[1](#Label-Efficient-Multi-Task-Segmentation-using-Contrastive-Learning)].
115 |
116 | | | Dice (ET) | Dice (TC) | Dice (WT) | Hausdorff (ET) | Hausdorff (TC) | Hausdorff (WT) |
117 | |:-|:---:|:--------:|:-----------------------:|:-----------------:|:-----------------------:|:-----------------:|
118 | |EncDec |0.8412 |0.8383 |0.9144 | 11.4697 |20.12 |24.3726
119 | |VAEseg |0.8234 |0.8036 |0.8998 |14.3467 |22.4926 |17.9775
120 | |Boundseg |0.8356 |0.8378 |0.9041 |17.0323 |27.2128 |25.8112
121 | |CPCseg |0.8374 |0.8386 |0.9057 |10.2839 |14.9661 |15.0633
122 | |ssVAEseg |0.8626 |0.8425 |0.9131 |9.1966 | **12.5302** |14.8056
123 | |ssCPCseg | **0.8873** | **0.8761** | **0.9151** | **8.7092** |16.0947 | **12.3962**
124 |
125 | ## License
126 |
127 | [MIT License](LICENSE)
128 |
--------------------------------------------------------------------------------
/brats_split_list/test_list_cv0.txt:
--------------------------------------------------------------------------------
1 | BRATS_072
2 | BRATS_207
3 | BRATS_102
4 | BRATS_294
5 | BRATS_265
6 | BRATS_152
7 | BRATS_427
8 | BRATS_209
9 | BRATS_418
10 | BRATS_377
11 | BRATS_344
12 | BRATS_234
13 | BRATS_095
14 | BRATS_261
15 | BRATS_348
16 | BRATS_448
17 | BRATS_035
18 | BRATS_148
19 | BRATS_069
20 | BRATS_191
21 | BRATS_482
22 | BRATS_256
23 | BRATS_381
24 | BRATS_400
25 | BRATS_055
26 | BRATS_065
27 | BRATS_408
28 | BRATS_449
29 | BRATS_404
30 | BRATS_161
31 | BRATS_166
32 | BRATS_469
33 | BRATS_124
34 | BRATS_280
35 | BRATS_401
36 | BRATS_421
37 | BRATS_406
38 | BRATS_335
39 | BRATS_248
40 | BRATS_129
41 | BRATS_199
42 | BRATS_125
43 | BRATS_099
44 | BRATS_063
45 | BRATS_410
46 | BRATS_442
47 | BRATS_366
48 | BRATS_048
49 | BRATS_300
50 |
--------------------------------------------------------------------------------
/brats_split_list/train_list_cv0.txt:
--------------------------------------------------------------------------------
1 | BRATS_446
2 | BRATS_415
3 | BRATS_165
4 | BRATS_293
5 | BRATS_121
6 | BRATS_435
7 | BRATS_038
8 | BRATS_140
9 | BRATS_305
10 | BRATS_197
11 | BRATS_452
12 | BRATS_375
13 | BRATS_177
14 | BRATS_093
15 | BRATS_439
16 | BRATS_432
17 | BRATS_281
18 | BRATS_451
19 | BRATS_126
20 | BRATS_247
21 | BRATS_385
22 | BRATS_361
23 | BRATS_275
24 | BRATS_412
25 | BRATS_198
26 | BRATS_282
27 | BRATS_262
28 | BRATS_307
29 | BRATS_468
30 | BRATS_082
31 | BRATS_016
32 | BRATS_374
33 | BRATS_172
34 | BRATS_341
35 | BRATS_220
36 | BRATS_271
37 | BRATS_117
38 | BRATS_419
39 | BRATS_425
40 | BRATS_338
41 | BRATS_008
42 | BRATS_122
43 | BRATS_360
44 | BRATS_382
45 | BRATS_229
46 | BRATS_359
47 | BRATS_470
48 | BRATS_163
49 | BRATS_445
50 | BRATS_150
51 | BRATS_295
52 | BRATS_264
53 | BRATS_353
54 | BRATS_340
55 | BRATS_175
56 | BRATS_428
57 | BRATS_235
58 | BRATS_014
59 | BRATS_213
60 | BRATS_403
61 | BRATS_376
62 | BRATS_326
63 | BRATS_079
64 | BRATS_484
65 | BRATS_164
66 | BRATS_086
67 | BRATS_096
68 | BRATS_393
69 | BRATS_136
70 | BRATS_073
71 | BRATS_001
72 | BRATS_013
73 | BRATS_106
74 | BRATS_456
75 | BRATS_066
76 | BRATS_169
77 | BRATS_402
78 | BRATS_196
79 | BRATS_257
80 | BRATS_329
81 | BRATS_236
82 | BRATS_395
83 | BRATS_173
84 | BRATS_396
85 | BRATS_123
86 | BRATS_186
87 | BRATS_044
88 | BRATS_313
89 | BRATS_194
90 | BRATS_171
91 | BRATS_223
92 | BRATS_023
93 | BRATS_279
94 | BRATS_347
95 | BRATS_330
96 | BRATS_333
97 | BRATS_182
98 | BRATS_002
99 | BRATS_215
100 | BRATS_268
101 | BRATS_352
102 | BRATS_159
103 | BRATS_355
104 | BRATS_143
105 | BRATS_369
106 | BRATS_046
107 | BRATS_481
108 | BRATS_398
109 | BRATS_091
110 | BRATS_454
111 | BRATS_417
112 | BRATS_443
113 | BRATS_349
114 | BRATS_059
115 | BRATS_414
116 | BRATS_115
117 | BRATS_224
118 | BRATS_342
119 | BRATS_137
120 | BRATS_253
121 | BRATS_110
122 | BRATS_283
123 | BRATS_378
124 | BRATS_321
125 | BRATS_409
126 | BRATS_087
127 | BRATS_273
128 | BRATS_478
129 | BRATS_025
130 | BRATS_088
131 | BRATS_158
132 | BRATS_424
133 | BRATS_222
134 | BRATS_206
135 | BRATS_118
136 | BRATS_433
137 | BRATS_092
138 | BRATS_050
139 | BRATS_230
140 | BRATS_397
141 | BRATS_193
142 | BRATS_277
143 | BRATS_350
144 | BRATS_101
145 | BRATS_430
146 | BRATS_331
147 | BRATS_301
148 | BRATS_440
149 | BRATS_068
150 | BRATS_081
151 | BRATS_237
152 | BRATS_058
153 | BRATS_285
154 | BRATS_363
155 | BRATS_018
156 | BRATS_303
157 | BRATS_183
158 | BRATS_076
159 | BRATS_336
160 | BRATS_390
161 | BRATS_108
162 | BRATS_184
163 | BRATS_019
164 | BRATS_394
165 | BRATS_231
166 | BRATS_320
167 | BRATS_179
168 | BRATS_030
169 | BRATS_138
170 | BRATS_324
171 | BRATS_216
172 | BRATS_011
173 | BRATS_447
174 | BRATS_258
175 | BRATS_100
176 | BRATS_109
177 | BRATS_370
178 | BRATS_372
179 | BRATS_131
180 | BRATS_168
181 | BRATS_105
182 | BRATS_422
183 | BRATS_187
184 | BRATS_200
185 | BRATS_318
186 | BRATS_241
187 | BRATS_202
188 | BRATS_132
189 | BRATS_367
190 | BRATS_189
191 | BRATS_471
192 | BRATS_479
193 | BRATS_483
194 | BRATS_103
195 | BRATS_356
196 | BRATS_120
197 | BRATS_328
198 | BRATS_036
199 | BRATS_085
200 | BRATS_388
201 | BRATS_284
202 | BRATS_411
203 | BRATS_005
204 | BRATS_392
205 | BRATS_021
206 | BRATS_015
207 | BRATS_228
208 | BRATS_245
209 | BRATS_480
210 | BRATS_365
211 | BRATS_060
212 | BRATS_210
213 | BRATS_208
214 | BRATS_437
215 | BRATS_343
216 | BRATS_116
217 | BRATS_135
218 | BRATS_274
219 | BRATS_154
220 | BRATS_371
221 | BRATS_444
222 | BRATS_431
223 | BRATS_296
224 | BRATS_064
225 | BRATS_139
226 | BRATS_345
227 | BRATS_251
228 | BRATS_017
229 | BRATS_047
230 | BRATS_218
231 | BRATS_051
232 | BRATS_204
233 | BRATS_386
234 | BRATS_057
235 | BRATS_225
236 | BRATS_097
237 | BRATS_312
238 | BRATS_195
239 | BRATS_034
240 | BRATS_111
241 | BRATS_233
242 | BRATS_354
243 | BRATS_062
244 | BRATS_465
245 | BRATS_467
246 | BRATS_029
247 | BRATS_053
248 | BRATS_286
249 | BRATS_083
250 | BRATS_311
251 | BRATS_089
252 | BRATS_461
253 | BRATS_314
254 | BRATS_040
255 | BRATS_387
256 | BRATS_380
257 | BRATS_337
258 | BRATS_212
259 | BRATS_056
260 | BRATS_327
261 | BRATS_289
262 | BRATS_473
263 | BRATS_246
264 | BRATS_466
265 | BRATS_219
266 | BRATS_332
267 | BRATS_334
268 | BRATS_155
269 | BRATS_104
270 | BRATS_438
271 | BRATS_266
272 | BRATS_453
273 | BRATS_181
274 | BRATS_045
275 | BRATS_291
276 | BRATS_211
277 | BRATS_477
278 | BRATS_144
279 | BRATS_141
280 | BRATS_290
281 | BRATS_190
282 | BRATS_269
283 | BRATS_098
284 | BRATS_156
285 | BRATS_170
286 | BRATS_071
287 | BRATS_310
288 | BRATS_319
289 | BRATS_297
290 | BRATS_420
291 | BRATS_263
292 | BRATS_379
293 | BRATS_039
294 | BRATS_239
295 | BRATS_007
296 | BRATS_145
297 | BRATS_240
298 | BRATS_407
299 | BRATS_436
300 | BRATS_476
301 | BRATS_304
302 | BRATS_373
303 | BRATS_003
304 | BRATS_112
305 | BRATS_317
306 | BRATS_176
307 | BRATS_153
308 | BRATS_351
309 | BRATS_078
310 | BRATS_260
311 | BRATS_037
312 | BRATS_383
313 | BRATS_288
314 | BRATS_475
315 | BRATS_180
316 | BRATS_042
317 | BRATS_205
318 | BRATS_302
319 | BRATS_157
320 | BRATS_107
321 | BRATS_149
322 | BRATS_458
323 | BRATS_043
324 | BRATS_322
325 | BRATS_278
326 | BRATS_315
327 | BRATS_214
328 | BRATS_457
329 | BRATS_339
330 | BRATS_255
331 | BRATS_028
332 | BRATS_167
333 | BRATS_250
334 | BRATS_004
335 | BRATS_272
336 | BRATS_054
337 | BRATS_134
338 | BRATS_306
339 | BRATS_033
340 | BRATS_128
341 | BRATS_270
342 | BRATS_460
343 | BRATS_192
344 | BRATS_244
345 | BRATS_242
346 | BRATS_299
347 | BRATS_188
348 | BRATS_462
349 | BRATS_292
350 | BRATS_308
351 | BRATS_022
352 | BRATS_459
353 | BRATS_472
354 | BRATS_323
355 | BRATS_031
356 | BRATS_287
357 | BRATS_113
358 | BRATS_464
359 | BRATS_090
360 | BRATS_203
361 | BRATS_061
362 | BRATS_405
363 | BRATS_389
364 | BRATS_384
365 | BRATS_074
366 | BRATS_114
367 | BRATS_298
368 | BRATS_067
369 | BRATS_346
370 | BRATS_049
371 | BRATS_226
372 | BRATS_127
373 | BRATS_426
374 | BRATS_249
375 | BRATS_130
376 | BRATS_010
377 | BRATS_357
378 | BRATS_077
379 | BRATS_358
380 | BRATS_399
381 | BRATS_276
382 | BRATS_119
383 | BRATS_146
384 | BRATS_474
385 | BRATS_254
386 | BRATS_080
387 | BRATS_232
--------------------------------------------------------------------------------
/brats_split_list/train_list_cv0_12samples.txt:
--------------------------------------------------------------------------------
1 | BRATS_446
2 | BRATS_415
3 | BRATS_165
4 | BRATS_293
5 | BRATS_121
6 | BRATS_435
7 | BRATS_038
8 | BRATS_140
9 | BRATS_305
10 | BRATS_197
11 | BRATS_452
12 | BRATS_375
--------------------------------------------------------------------------------
/brats_split_list/train_list_cv0_24samples.txt:
--------------------------------------------------------------------------------
1 | BRATS_446
2 | BRATS_415
3 | BRATS_165
4 | BRATS_293
5 | BRATS_121
6 | BRATS_435
7 | BRATS_038
8 | BRATS_140
9 | BRATS_305
10 | BRATS_197
11 | BRATS_452
12 | BRATS_375
13 | BRATS_177
14 | BRATS_093
15 | BRATS_439
16 | BRATS_432
17 | BRATS_281
18 | BRATS_451
19 | BRATS_126
20 | BRATS_247
21 | BRATS_385
22 | BRATS_361
23 | BRATS_275
24 | BRATS_412
--------------------------------------------------------------------------------
/brats_split_list/train_list_cv0_3samples1.txt:
--------------------------------------------------------------------------------
1 | BRATS_446
2 | BRATS_415
3 | BRATS_165
--------------------------------------------------------------------------------
/brats_split_list/train_list_cv0_6samples1.txt:
--------------------------------------------------------------------------------
1 | BRATS_446
2 | BRATS_415
3 | BRATS_165
4 | BRATS_293
5 | BRATS_121
6 | BRATS_435
--------------------------------------------------------------------------------
/brats_split_list/train_list_cv0_96samples.txt:
--------------------------------------------------------------------------------
1 | BRATS_446
2 | BRATS_415
3 | BRATS_165
4 | BRATS_293
5 | BRATS_121
6 | BRATS_435
7 | BRATS_038
8 | BRATS_140
9 | BRATS_305
10 | BRATS_197
11 | BRATS_452
12 | BRATS_375
13 | BRATS_177
14 | BRATS_093
15 | BRATS_439
16 | BRATS_432
17 | BRATS_281
18 | BRATS_451
19 | BRATS_126
20 | BRATS_247
21 | BRATS_385
22 | BRATS_361
23 | BRATS_275
24 | BRATS_412
25 | BRATS_198
26 | BRATS_282
27 | BRATS_262
28 | BRATS_307
29 | BRATS_468
30 | BRATS_082
31 | BRATS_016
32 | BRATS_374
33 | BRATS_172
34 | BRATS_341
35 | BRATS_220
36 | BRATS_271
37 | BRATS_117
38 | BRATS_419
39 | BRATS_425
40 | BRATS_338
41 | BRATS_008
42 | BRATS_122
43 | BRATS_360
44 | BRATS_382
45 | BRATS_229
46 | BRATS_359
47 | BRATS_470
48 | BRATS_163
49 | BRATS_445
50 | BRATS_150
51 | BRATS_295
52 | BRATS_264
53 | BRATS_353
54 | BRATS_340
55 | BRATS_175
56 | BRATS_428
57 | BRATS_235
58 | BRATS_014
59 | BRATS_213
60 | BRATS_403
61 | BRATS_376
62 | BRATS_326
63 | BRATS_079
64 | BRATS_484
65 | BRATS_164
66 | BRATS_086
67 | BRATS_096
68 | BRATS_393
69 | BRATS_136
70 | BRATS_073
71 | BRATS_001
72 | BRATS_013
73 | BRATS_106
74 | BRATS_456
75 | BRATS_066
76 | BRATS_169
77 | BRATS_402
78 | BRATS_196
79 | BRATS_257
80 | BRATS_329
81 | BRATS_236
82 | BRATS_395
83 | BRATS_173
84 | BRATS_396
85 | BRATS_123
86 | BRATS_186
87 | BRATS_044
88 | BRATS_313
89 | BRATS_194
90 | BRATS_171
91 | BRATS_223
92 | BRATS_023
93 | BRATS_279
94 | BRATS_347
95 | BRATS_330
96 | BRATS_333
--------------------------------------------------------------------------------
/brats_split_list/validation_list_cv0.txt:
--------------------------------------------------------------------------------
1 | BRATS_325
2 | BRATS_364
3 | BRATS_413
4 | BRATS_041
5 | BRATS_094
6 | BRATS_174
7 | BRATS_217
8 | BRATS_012
9 | BRATS_147
10 | BRATS_026
11 | BRATS_185
12 | BRATS_450
13 | BRATS_441
14 | BRATS_368
15 | BRATS_151
16 | BRATS_052
17 | BRATS_259
18 | BRATS_133
19 | BRATS_227
20 | BRATS_070
21 | BRATS_423
22 | BRATS_084
23 | BRATS_178
24 | BRATS_316
25 | BRATS_142
26 | BRATS_201
27 | BRATS_309
28 | BRATS_267
29 | BRATS_020
30 | BRATS_463
31 | BRATS_006
32 | BRATS_416
33 | BRATS_362
34 | BRATS_238
35 | BRATS_162
36 | BRATS_391
37 | BRATS_455
38 | BRATS_160
39 | BRATS_252
40 | BRATS_434
41 | BRATS_009
42 | BRATS_024
43 | BRATS_221
44 | BRATS_075
45 | BRATS_243
46 | BRATS_032
47 | BRATS_429
48 | BRATS_027
49 |
--------------------------------------------------------------------------------
/configs/brats_boundseg.yml:
--------------------------------------------------------------------------------
1 | dataset_name: msd_bound
2 | image_path: /PATH_TO_NORMALIZED_IMAGES/imagesTr_normalized
3 | label_path: /PATH_TO_LABELS/labelsTr
4 | edge_path: /PATH_TO_EDGE_LABELS/brats_edges
5 | edge_label: True
6 | image_file_format: npz
7 | label_file_format: nii.gz
8 | train_list_path: ./brats_split_list/train_list_cv0.txt
9 | validation_list_path: ./brats_split_list/validation_list_cv0.txt
10 | test_list_path: ./brats_split_list/test_list_cv0.txt
11 | crop_size: (160,192,128)
12 | random_flip: True
13 | in_channels: 4
14 | nb_labels: 4
15 | print_each_dc: True
16 | val_batchsize: 1
17 | structseg_nb_copies: 1
18 | shift_intensity: 0.1
19 | segmentor_name: boundseg
20 | vaeseg_norm: 'GroupNormalization'
21 | seg_lossfun: dice_loss_plus_cross_entropy
22 | init_lr: 0.0001
23 | epoch: 300
24 | snapshot_interval: 100
25 | random_scale: True
26 | eval_interval: 1
27 | is_brats: True
28 | nested_label: True
29 |
--------------------------------------------------------------------------------
/configs/brats_cpcseg.yml:
--------------------------------------------------------------------------------
1 | dataset_name: msd_bound
2 | image_path: /PATH_TO_NORMALIZED_IMAGES/imagesTr_normalized
3 | label_path: /PATH_TO_LABELS/labelsTr
4 | edge_path: /PATH_TO_EDGE_LABELS/brats_edges
5 | edge_label: False
6 | image_file_format: npz
7 | label_file_format: nii.gz
8 | train_list_path: ./brats_split_list/train_list_cv0.txt
9 | validation_list_path: ./brats_split_list/validation_list_cv0.txt
10 | test_list_path: ./brats_split_list/test_list_cv0.txt
11 | crop_size: (144,144,128)
12 | random_flip: True
13 | in_channels: 4
14 | nb_labels: 4
15 | print_each_dc: True
16 | val_batchsize: 1
17 | structseg_nb_copies: 1
18 | shift_intensity: 0.1
19 | segmentor_name: cpcseg
20 | vaeseg_norm: 'GroupNormalization'
21 | seg_lossfun: dice_loss_plus_cross_entropy
22 | cpc_vaeseg_cpc_loss_weight: 0.001 # 0.0005
23 | init_lr: 0.0001
24 | epoch: 300
25 | cpc_pattern: 'updown'
26 | snapshot_interval: 100
27 | random_scale: True
28 | eval_interval: 5
29 | is_brats: True
30 | nested_label: True
31 |
--------------------------------------------------------------------------------
/configs/brats_encdec.yml:
--------------------------------------------------------------------------------
1 | dataset_name: msd_bound
2 | image_path: /PATH_TO_NORMALIZED_IMAGES/imagesTr_normalized
3 | label_path: /PATH_TO_LABELS/labelsTr
4 | edge_path: /PATH_TO_EDGE_LABELS/brats_edges
5 | edge_label: False
6 | image_file_format: npz
7 | label_file_format: nii.gz
8 | train_list_path: ./brats_split_list/train_list_cv0.txt
9 | validation_list_path: ./brats_split_list/validation_list_cv0.txt
10 | test_list_path: ./brats_split_list/test_list_cv0.txt
11 | crop_size: (160,192,128)
12 | random_flip: True
13 | in_channels: 4
14 | nb_labels: 4
15 | print_each_dc: True
16 | val_batchsize: 1
17 | structseg_nb_copies: 1
18 | shift_intensity: 0.1
19 | segmentor_name: encdec_seg
20 | vaeseg_norm: 'GroupNormalization'
21 | seg_lossfun: dice_loss_plus_cross_entropy
22 | init_lr: 0.0001
23 | epoch: 300
24 | snapshot_interval: 100
25 | random_scale: True
26 | eval_interval: 5
27 | is_brats: True
28 | nested_label: True
29 |
--------------------------------------------------------------------------------
/configs/brats_semi_cpcseg.yml:
--------------------------------------------------------------------------------
1 | dataset_name: msd_bound
2 | image_path: /PATH_TO_NORMALIZED_IMAGES/imagesTr_normalized
3 | label_path: /PATH_TO_LABELS/labelsTr
4 | edge_path: /PATH_TO_EDGE_LABELS/brats_edges
5 | edge_label: False
6 | image_file_format: npz
7 | label_file_format: nii.gz
8 | train_list_path: ./brats_split_list/train_list_cv0.txt
9 | ignore_path: ./brats_split_list/train_list_cv0_3samples1.txt
10 | validation_list_path: ./brats_split_list/validation_list_cv0.txt
11 | test_list_path: ./brats_split_list/test_list_cv0.txt
12 | crop_size: (144,144,128)
13 | random_flip: True
14 | in_channels: 4
15 | nb_labels: 4
16 | print_each_dc: True
17 | val_batchsize: 1
18 | structseg_nb_copies: 1
19 | shift_intensity: 0.1
20 | segmentor_name: cpcseg
21 | vaeseg_norm: 'GroupNormalization'
22 | seg_lossfun: dice_loss_plus_cross_entropy
23 | cpc_vaeseg_cpc_loss_weight: 0.001 # 0.0005
24 | init_lr: 0.0001
25 | epoch: 300
26 | cpc_pattern: 'updown'
27 | snapshot_interval: 100
28 | random_scale: True
29 | eval_interval: 2
30 | is_brats: True
31 | report_interval: 10
32 | vae_idle_weight: 1
33 | nested_label: True
34 |
--------------------------------------------------------------------------------
/configs/brats_semi_vaeseg.yml:
--------------------------------------------------------------------------------
1 | dataset_name: msd_bound
2 | image_path: /PATH_TO_NORMALIZED_IMAGES/imagesTr_normalized
3 | label_path: /PATH_TO_LABELS/labelsTr
4 | edge_path: /PATH_TO_EDGE_LABELS/brats_edges
5 | edge_label: False
6 | image_file_format: npz
7 | label_file_format: nii.gz
8 | train_list_path: ./brats_split_list/train_list_cv0.txt
9 | ignore_path: ./brats_split_list/train_list_cv0_3samples1.txt
10 | validation_list_path: ./brats_split_list/validation_list_cv0.txt
11 | test_list_path: ./brats_split_list/test_list_cv0.txt
12 | crop_size: (160,192,128)
13 | random_flip: True
14 | in_channels: 4
15 | nb_labels: 4
16 | print_each_dc: True
17 | val_batchsize: 1
18 | structseg_nb_copies: 1
19 | shift_intensity: 0.1
20 | segmentor_name: vaeseg
21 | vaeseg_norm: 'GroupNormalization'
22 | seg_lossfun: dice_loss_plus_cross_entropy
23 | vaeseg_rec_loss_weight: 0.1
24 | vaeseg_kl_loss_weight: 0.1
25 | init_lr: 0.0001
26 | epoch: 300
27 | snapshot_interval: 100
28 | random_scale: True
29 | eval_interval: 5
30 | is_brats: True
31 | report_interval: 10
32 | vae_idle_weight: 1
33 | nested_label: True
34 |
--------------------------------------------------------------------------------
/configs/brats_vaeseg.yml:
--------------------------------------------------------------------------------
1 | dataset_name: msd_bound
2 | image_path: /PATH_TO_NORMALIZED_IMAGES/imagesTr_normalized
3 | label_path: /PATH_TO_LABELS/labelsTr
4 | edge_path: /PATH_TO_EDGE_LABELS/brats_edges
5 | edge_label: False
6 | image_file_format: npz
7 | label_file_format: nii.gz
8 | train_list_path: ./brats_split_list/train_list_cv0.txt
9 | validation_list_path: ./brats_split_list/validation_list_cv0.txt
10 | test_list_path: ./brats_split_list/test_list_cv0.txt
11 | crop_size: (160,192,128)
12 | random_flip: True
13 | in_channels: 4
14 | nb_labels: 4
15 | print_each_dc: True
16 | val_batchsize: 1
17 | structseg_nb_copies: 1
18 | shift_intensity: 0.1
19 | segmentor_name: vaeseg
20 | vaeseg_norm: 'GroupNormalization'
21 | seg_lossfun: dice_loss_plus_cross_entropy
22 | vaeseg_rec_loss_weight: 0.1
23 | vaeseg_kl_loss_weight: 0.1
24 | init_lr: 0.0001
25 | epoch: 300
26 | snapshot_interval: 100
27 | random_scale: False
28 | eval_interval: 1
29 | is_brats: True
30 | nested_label: True
31 |
--------------------------------------------------------------------------------
/examples/normalize.py:
--------------------------------------------------------------------------------
1 | import sys
2 | sys.path.append('.') # NOQA
3 | from src.datasets.preprocess import normalize
4 |
5 |
6 | def main(root_path=None, arr_type='nii.gz', modality='mri'):
7 | # save normalized npz arrays in root_path/normalized/
8 | normalize(root_path, arr_type, modality)
9 |
10 |
11 | if __name__ == '__main__':
12 | from fire import Fire
13 | Fire(main)
14 |
--------------------------------------------------------------------------------
/examples/split_kfold.py:
--------------------------------------------------------------------------------
1 | import sys
2 | sys.path.append('.') # NOQA
3 | from src.datasets.preprocess import split_kfold
4 |
5 |
6 | def main(root_path, arr_type='nii.gz', n_splits=5, random_state=42):
7 | # create split file
8 | split_kfold(root_path, arr_type, n_splits, random_state)
9 |
10 |
11 | if __name__ == '__main__':
12 | from fire import Fire
13 | Fire(main)
14 |
--------------------------------------------------------------------------------
/examples/train.py:
--------------------------------------------------------------------------------
1 | import sys
2 | sys.path.append('.') # NOQA
3 | import argparse
4 | import yaml
5 | import warnings
6 | import numpy as np
7 | import random
8 | import chainer
9 | from chainer import global_config
10 |
11 | from src.utils.config import overwrite_config
12 | from src.utils.setup_helpers import setup_trainer
13 |
14 | warnings.simplefilter(action='ignore', category=FutureWarning) # NOQA
15 |
16 |
17 | def reset_seed(seed=42):
18 | random.seed(seed)
19 | np.random.seed(seed)
20 | if chainer.cuda.available:
21 | chainer.cuda.cupy.random.seed(seed)
22 |
23 |
24 | def main():
25 | reset_seed(42)
26 | parser = argparse.ArgumentParser(
27 | description='Medical image segmentation'
28 | )
29 | parser.add_argument('--config', '-c')
30 | parser.add_argument('--out', default='results',
31 | help='Output directory')
32 | parser.add_argument('--batch_size', '-b', type=int, default=1,
33 | help="Batch size")
34 | parser.add_argument('--epoch', '-e', type=int, default=500,
35 | help="Number of epochs")
36 | parser.add_argument('--gpu_start_id', '-g', type=int, default=0,
37 | help="Start ID of gpu. (negative value indicates cpu)")
38 | args = parser.parse_args()
39 | config = overwrite_config(
40 | yaml.load(open(args.config)), dump_yaml_dir=args.out
41 | )
42 |
43 | if config['mn']:
44 | global_config.autotune = True
45 | import multiprocessing
46 | multiprocessing.set_start_method('forkserver')
47 | p = multiprocessing.Process(target=print, args=('Initialize forkserver',))
48 | p.start()
49 | p.join()
50 |
51 | trainer = setup_trainer(config, args.out, args.batch_size, args.epoch, args.gpu_start_id)
52 | trainer.run()
53 |
54 |
55 | if __name__ == '__main__':
56 | main()
57 |
--------------------------------------------------------------------------------
/img/architecture.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pfnet-research/label-efficient-brain-tumor-segmentation/aad80ed7acb510a3147bb11c3910d2e17fb355d1/img/architecture.png
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | nibabel==2.3.1
2 | six==1.12.0
3 | numpy==1.16.2
4 | chainer==6.2.0
5 | scikit_image==0.14.2
6 | imageio==2.4.1
7 | cupy_cuda90==6.2.0
8 | fire==0.2.1
9 | pyyaml==3.13
10 | scikit_learn==0.22.1
11 | mpi4py==3.0.0
12 | SimpleITK==1.2.0
13 | pillow==4.3.0
14 | opencv-python==4.1.0.25
--------------------------------------------------------------------------------
/src/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pfnet-research/label-efficient-brain-tumor-segmentation/aad80ed7acb510a3147bb11c3910d2e17fb355d1/src/__init__.py
--------------------------------------------------------------------------------
/src/datasets/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pfnet-research/label-efficient-brain-tumor-segmentation/aad80ed7acb510a3147bb11c3910d2e17fb355d1/src/datasets/__init__.py
--------------------------------------------------------------------------------
/src/datasets/generate_msd_edge.py:
--------------------------------------------------------------------------------
1 | import sys
2 | sys.path.append('.') # NOQA
3 | import copy
4 | import numpy as np
5 | import nibabel as nib
6 | import os
7 | import chainer
8 | from chainer.backends.cuda import get_device_from_id
9 | import chainer.links as L
10 | import cupy
11 |
12 | from src.utils.encode_one_hot_vector import encode_one_hot_vector
13 | from src.utils.setup_helpers import _setup_communicator
14 |
15 | chainer.backends.cuda.cudnn_enabled = True
16 | chainer.config.use_cudnn = 'always'
17 | chainer.cudnn_deterministic = True
18 |
19 | # generates edge information for structseg2019
20 | # dataset using 3D convolution with a Laplacian kernel
21 | dataset = 'nested_brats'
22 | nested = True
23 |
24 | if dataset == 'nested_brats':
25 | print('brats with nested labels!')
26 | label_path = '/PATH_TO_LABELS/labelsTr/'
27 | edge_path = './brats_nested_edges/'
28 | num_range = 484
29 | nb_class = 4
30 | else:
31 | print('Could not identify dataset')
32 |
33 | if not os.path.isdir(edge_path):
34 | raise FileNotFoundError('No such directory: {}'.format(edge_path))
35 |
36 | # kernel for 3d Laplacian
37 | k = np.array([
38 | [[0, 0, 0],
39 | [0, 1, 0],
40 | [0, 0, 0]],
41 | [[0, 1, 0],
42 | [1, -6, 1],
43 | [0, 1, 0]],
44 | [[0, 0, 0],
45 | [0, 1, 0],
46 | [0, 0, 0]]], dtype=np.float32)
47 |
48 | k = k.reshape((1, 1, 3, 3, 3))
49 | k = k.transpose((0, 1, 3, 4, 2))
50 |
51 | default_config = {
52 | 'mn': True,
53 | 'gpu_start_id': 0,
54 | }
55 |
56 | config = copy.copy(default_config)
57 | comm, is_master, device = _setup_communicator(config, gpu_start_id=0)
58 |
59 | get_device_from_id(device).use()
60 | conv3d = L.Convolution3D(in_channels=1, out_channels=1,
61 | ksize=3, stride=1, pad=1, initialW=k)
62 | conv3d.to_gpu()
63 |
64 | with chainer.no_backprop_mode(), chainer.using_config('train', False):
65 | for num in range(1, num_range+1):
66 | if dataset == 'nested_brats':
67 | num = "{0:0=3d}".format(num)
68 | path = label_path+'BRATS_'+str(num)+'.nii.gz'
69 | try:
70 | nii_img = nib.load(path)
71 | affine = nii_img.affine
72 | img = nii_img.get_data()
73 | shape = img.shape
74 |
75 | edge = np.zeros((1, nb_class) + shape)
76 | img = img.reshape((1,)+shape) # (h,w,d) --> (batch,h,w,d)
77 | img = encode_one_hot_vector(img, nb_class=nb_class) # (batch,channel,h,w,d)
78 | img = cupy.asarray(img, dtype=np.float32)
79 | for c in range(nb_class):
80 | if nested:
81 | if c == 1:
82 | wt = (img[0, 1, :, :, :]).astype(bool)+(img[0, 2, :, :, :]).astype(bool)\
83 | + (img[0, 3, :, :, :]).astype(bool)
84 | crop_img = wt.reshape((1, 1)+shape)
85 | elif c == 2:
86 | tc = (img[0, 2, :, :, :]).astype(bool)+(img[0, 3, :, :, :]).astype(bool)
87 | crop_img = tc.reshape((1, 1)+shape)
88 | elif c == 3:
89 | et = (img[0, 3, :, :, :]).astype(bool)
90 | crop_img = et.reshape((1, 1)+shape)
91 | else:
92 | crop_img = img[0, c, :, :, :].reshape((1, 1)+shape)
93 | else:
94 | crop_img = img[0, c, :, :, :].reshape((1, 1)+shape)
95 | temp_img = conv3d(crop_img.astype(np.float32)).data[0, 0, :, :, :]
96 | edge[0, c, :, :, :] = (-(cupy.asnumpy(temp_img)) > 0.).astype(int)
97 |
98 | edge_labels_nii = nib.Nifti1Image(edge, affine)
99 | nib.save(edge_labels_nii, edge_path+'edge_'+str(num)+'.nii.gz')
100 | print(num)
101 | except:
102 | print('No data for '+str(num)+'.nii.gz')
103 |
--------------------------------------------------------------------------------
/src/datasets/msd_bound.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | from chainer.dataset import DatasetMixin
4 | from src.datasets.preprocess import read_image, random_crop_pair
5 | from src.functions.array.resize_images_3d import resize_images_3d
6 | from src.functions.array.resize_images_3d_nearest_neighbor import resize_images_3d_nearest_neighbor
7 | from numpy.random import rand
8 | import chainer
9 | import chainer.functions as F
10 |
11 |
12 | class MSDBoundDataset(DatasetMixin):
13 | """Edited MSD dataset for 4-channel labels (for boundary aware networks)"""
14 | def __init__(self, config, split_file):
15 | self.crop_size = eval(config['crop_size'])
16 | self.shift_intensity = config['shift_intensity']
17 | self.random_flip = config['random_flip']
18 | self.target_label = config['target_label']
19 | self.nb_class = config['nb_labels']
20 | self.unet_cropping = config['unet_cropping']
21 | self.random_scale = config['random_scale']
22 | self.training = True
23 | with open(split_file) as f:
24 | self.split_list = [line.rstrip() for line in f]
25 | image_list = [os.path.join(config['image_path'], name) for name in self.split_list]
26 | label_list = [os.path.join(config['label_path'], name) for name in self.split_list]
27 | edge_list = [os.path.join(config['edge_path'], 'edge_'+str(name)[-3:])
28 | for name in self.split_list]
29 | self.edge_label = config['edge_label']
30 | assert len(image_list) == len(label_list)
31 | self.pair_list = list(zip(image_list, label_list, edge_list))
32 | self.nb_copies = config['structseg_nb_copies']
33 | self.image_file_format = config['image_file_format']
34 | self.ignore_path = config['ignore_path']
35 | self.nested_label = config['nested_label']
36 |
37 | def __len__(self):
38 | return len(self.pair_list) * self.nb_copies
39 |
40 | def _get_image(self, i):
41 | i = i % len(self.pair_list)
42 | image = read_image(self.pair_list[i][0], self.image_file_format)
43 | image = image.transpose((3, 0, 1, 2))
44 | return image
45 |
46 | def _get_label(self, i):
47 | i = i % len(self.pair_list)
48 | label = read_image(self.pair_list[i][1], 'nii.gz')
49 | if self.training and (self.ignore_path is not None):
50 | # use samples without labels and use subtask branch only
51 | ignore_list = open(self.ignore_path).read().splitlines()
52 | if self.pair_list[i][1][-9:] not in ignore_list:
53 | # giving NaN label when the sample is not in the ignore_list
54 | # ignore_list: list of samples which we DO NOT ignore labels
55 | label = np.empty(label.shape)
56 | label[:] = np.nan
57 | if self.unet_cropping:
58 | label = label[20:-20, 20:-20, 20:-20]
59 | return label
60 |
61 | def _get_edge(self, i):
62 | i = i % len(self.pair_list)
63 | edge = read_image(self.pair_list[i][2], 'nii.gz')
64 | edge = edge[0, :, :, :, :]
65 | if self.unet_cropping:
66 | edge = edge[20:-20, 20:-20, 20:-20]
67 | return edge
68 |
69 | def _crop(self, x, y):
70 | return random_crop_pair(x, y, self.crop_size)
71 |
72 | def _change_intensity(self, x):
73 | for i in range(len(x)):
74 | diff = 2 * self.shift_intensity * rand() - self.shift_intensity
75 | x[i] += diff
76 | return x
77 |
78 | @staticmethod
79 | def _flip(x, y):
80 | if np.random.random() < 0.5:
81 | x = x[:, ::-1, :, :]
82 | y = y[:, ::-1, :, :]
83 | if np.random.random() < 0.5:
84 | x = x[:, :, ::-1, :]
85 | y = y[:, :, ::-1, :]
86 | if np.random.random() < 0.5:
87 | x = x[:, :, :, ::-1]
88 | y = y[:, :, :, ::-1]
89 | return x, y
90 |
91 | @staticmethod
92 | def _binarize_label(y, t_idx):
93 | return (y >= t_idx).astype(np.int32)
94 |
95 | @staticmethod
96 | def _binarize_brats(y, dc=1):
97 | if (np.sum(y) < 0) or (np.isnan(y).any()):
98 | # when using ignore list (semi-supervised VAEseg)
99 | return y
100 | else:
101 | if dc == 1:
102 | whole_tumor = (y >= 1).astype(np.int32)
103 | return whole_tumor
104 | elif dc == 2:
105 | tumor_core = ((y == 3) + (y == 2)).astype(np.int32)
106 | return tumor_core
107 | elif dc == 3:
108 | enhancing_tumor = (y == 3).astype(np.int32)
109 | return enhancing_tumor
110 | else:
111 | print('Error for binarization')
112 | return y
113 |
114 | @staticmethod
115 | def _rand_scale(x, y):
116 | # random scale the image [0.9,1.1) for augmentation
117 | xx = x.reshape((1,)+x.shape)
118 | y = y.astype(np.float32)
119 |
120 | rand_scale = 0.9+rand()*0.2
121 | orig_shape = np.array(xx.shape)
122 | scaled_shape = orig_shape[2:]*rand_scale
123 | scaled_shape = np.round(scaled_shape).astype(int)
124 |
125 | resized_x = resize_images_3d(xx, scaled_shape)
126 | resized_y = resize_images_3d_nearest_neighbor(y.reshape((1,)+y.shape),
127 | scaled_shape)[0, :, :, :, :]
128 | resized_y = chainer.Variable(resized_y)
129 | if rand_scale < 1:
130 | # zero-padding if the image is scaled down
131 | pad_w = orig_shape[2] - scaled_shape[0]
132 | pad_h = orig_shape[3] - scaled_shape[1]
133 | pad_d = orig_shape[4] - scaled_shape[2]
134 |
135 | w_half = int(round(pad_w/2))
136 | h_half = int(round(pad_h/2))
137 | d_half = int(round(pad_d/2))
138 |
139 | pad_width_x = ((0, 0), (0, 0), (w_half, pad_w-w_half),
140 | (h_half, pad_h-h_half), (d_half, pad_d-d_half))
141 | pad_width_y = ((0, 0), (w_half, pad_w-w_half),
142 | (h_half, pad_h-h_half), (d_half, pad_d-d_half))
143 | resized_x = F.pad(resized_x, pad_width=pad_width_x, mode='constant', constant_values=0.)
144 | resized_y = F.pad(resized_y, pad_width=pad_width_y, mode='constant', constant_values=0.)
145 | assert resized_x.shape[2:] == resized_y.shape[1:]
146 | resized_x = resized_x[0, :, :, :, :]
147 |
148 | return resized_x.data, resized_y.data.astype(np.int32)
149 |
150 | def get_example(self, i):
151 | x = self._get_image(i)
152 | y = self._get_label(i)
153 | if self.edge_label:
154 | y = y.reshape((1,)+y.shape)
155 | ye = self._get_edge(i)
156 | y = np.append(y, ye, axis=0)
157 | else:
158 | y = y.reshape((1,)+y.shape)
159 |
160 | if self.random_scale:
161 | x, y = self._rand_scale(x, y)
162 | if self.training:
163 | x, y = self._crop(x, y)
164 | if self.shift_intensity > 0.:
165 | x = self._change_intensity(x)
166 | if self.random_flip:
167 | x, y = self._flip(x, y)
168 | if self.target_label:
169 | assert self.nb_class == 2
170 | y = self._binarize_label(y, self.target_label)
171 | elif self.nested_label:
172 | ys = []
173 | for i in range(self.nb_class-1):
174 | ys.append(self._binarize_brats(y[0, :, :, :], i+1))
175 | if self.edge_label:
176 | return x.astype(np.float32),\
177 | np.array(ys).astype(np.int32), y[1:, :, :, :].astype(np.int32)
178 | else:
179 | return x.astype(np.float32), np.array(ys).astype(np.int32)
180 |
181 | if self.edge_label:
182 | return x.astype(np.float32), y[0, :, :, :].astype(np.int32),\
183 | y[1:, :, :, :].astype(np.int32)
184 | else:
185 | return x.astype(np.float32), y[0, :, :, :].astype(np.int32)
186 |
--------------------------------------------------------------------------------
/src/datasets/preprocess.py:
--------------------------------------------------------------------------------
1 | import os
2 | import random
3 | import numpy as np
4 | import nibabel as nib
5 | import imageio
6 | from skimage.transform import resize
7 | from sklearn.model_selection import KFold, train_test_split
8 |
9 |
10 | def read_image(path, file_format='nii.gz'):
11 | """ read image array from path
12 | Args:
13 | path (str) : path to directory which images are stored.
14 | file_format (str) : type of reading file {'npy','npz','jpg','png','nii'(3d)}
15 | Returns:
16 | image (np.ndarray) : image array
17 | """
18 | path = path + '.' + file_format
19 | if file_format == 'npy':
20 | image = np.load(path)
21 | elif file_format == 'npz':
22 | image = np.load(path)['arr_0']
23 | elif file_format in ('png', 'jpg'):
24 | image = np.array(imageio.imread(path))
25 | elif file_format == 'dcm':
26 | image = np.array(imageio.volread(path, 'DICOM'))
27 | elif file_format in ('nii', 'nii.gz'):
28 | image = nib.load(path).get_data()
29 | else:
30 | raise ValueError('invalid --input_type : {}'.format(file_format))
31 |
32 | return image
33 |
34 |
35 | def random_crop_pair(image, label, crop_size=(160, 192, 128)):
36 | H, W, D = crop_size
37 |
38 | if image.shape[1] >= H:
39 | top = random.randint(0, image.shape[1] - H)
40 | else:
41 | raise ValueError('shape of image needs to be larger than output shape')
42 | h_slice = slice(top, top + H)
43 |
44 | if image.shape[2] >= W:
45 | left = random.randint(0, image.shape[2] - W)
46 | else:
47 | raise ValueError('shape of image needs to be larger than output shape')
48 | w_slice = slice(left, left + W)
49 |
50 | if image.shape[3] >= D:
51 | rear = random.randint(0, image.shape[3] - D)
52 | else:
53 | raise ValueError('shape of image needs to be larger than output shape')
54 | d_slice = slice(rear, rear + D)
55 |
56 | image = image[:, h_slice, w_slice, d_slice]
57 | if label.ndim == 4:
58 | label = label[:, h_slice, w_slice, d_slice]
59 | else:
60 | label = label[h_slice, w_slice, d_slice]
61 | return image, label
62 |
63 |
64 | def compute_stats(root_path, arr_type='nii.gz', modality='mri'):
65 | files = os.listdir(root_path)
66 | means = []
67 | stds = []
68 | for file_i in files:
69 | img = read_image(os.path.join(root_path, file_i), arr_type)
70 | if len(img.shape) == 3:
71 | img = np.expand_dims(img, axis=-1)
72 | img = img.transpose((3, 0, 1, 2))
73 | if modality == 'ct':
74 | np.clip(img, -1000., 1000., out=img)
75 | img += 1000.
76 | mean = []
77 | std = []
78 | for i in range(len(img)):
79 | mean.append(np.mean(img[i][img[i].nonzero()]))
80 | if modality == 'ct':
81 | mean[i] -= 1000.
82 | std.append(np.std(img[i][img[i].nonzero()]))
83 | means.append(mean)
84 | stds.append(std)
85 | return np.mean(np.array(means), axis=0), np.mean(np.array(stds), axis=0)
86 |
87 |
88 | def normalize(root_path, arr_type='nii.gz', modality='mri'):
89 | files = os.listdir(root_path)
90 | if arr_type == 'nii.gz':
91 | files = ['.'.join(file_i.split('.')[:-2]) for file_i in files]
92 | else:
93 | files = [os.path.splitext(file_i)[0] for file_i in files]
94 | base_name = os.path.basename(root_path)
95 | normalized_path = root_path.replace(base_name, base_name+"_normalized")
96 | if not os.path.exists(normalized_path):
97 | os.makedirs(normalized_path, exist_ok=True)
98 | for file_i in files:
99 | img = read_image(os.path.join(root_path, file_i), arr_type)
100 | if len(img.shape) == 3:
101 | img = np.expand_dims(img, axis=-1)
102 | img = img.transpose((3, 0, 1, 2)).astype(np.float32)
103 | if modality == 'ct':
104 | np.clip(img, -1000., 1000., out=img)
105 | img += 1000.
106 | for i in range(len(img)):
107 | mean = np.mean(img[i][img[i].nonzero()])
108 | std = np.std(img[i][img[i].nonzero()])
109 | img[i] = (img[i] - mean) / std
110 | img = img.transpose((1, 2, 3, 0))
111 | np.savez_compressed(os.path.join(normalized_path, file_i.replace(arr_type, 'npz')), img)
112 |
113 |
114 | def resample_isotropic_nifti(root_path, arr_type='nii.gz', is_label=False):
115 | files = os.listdir(root_path)
116 | if arr_type == 'nii.gz':
117 | files = ['.'.join(file_i.split('.')[:-2]) for file_i in files]
118 | else:
119 | files = [os.path.splitext(file_i)[0] for file_i in files]
120 | base_name = os.path.basename(root_path)
121 | resampled_path = root_path.replace(base_name, base_name+"_resampled")
122 | order = 0 if is_label else 1
123 | if not os.path.exists(resampled_path):
124 | os.makedirs(resampled_path, exist_ok=True)
125 | for file_i in files:
126 | img = read_image(os.path.join(root_path, file_i), 'nii.gz')
127 | spacing = nib.load(os.path.join(root_path, file_i)).header['pixdim'][:4]
128 | img = img.transpose((3, 0, 1, 2)).astype(np.float32)
129 | img_shape = img.shape
130 | target_shape = tuple([int(img_shape[i] / spacing[i]) for i in range(1, 4)])
131 | resampled_img = []
132 | for i in range(len(img)):
133 | resampled_img.append(resize(img[i], target_shape, order, mode='constant'))
134 | resampled_img = np.array(resampled_img).transpose((1, 2, 3, 0))
135 | np.savez_compressed(
136 | os.path.join(resampled_path, file_i.replace('nii.gz', 'npz')),
137 | resampled_img)
138 |
139 |
140 | def split_kfold(root_path, arr_type='nii.gz', n_splits=5, random_state=42):
141 | data_list = os.listdir(root_path)
142 | if arr_type == 'nii.gz':
143 | data_list = ['.'.join(data.split('.')[:-2]) for data in data_list]
144 | else:
145 | data_list = [os.path.splitext(data)[0] for data in data_list]
146 | save_path = root_path.replace(os.path.basename(root_path), 'split_list')
147 | if not os.path.exists(save_path):
148 | os.makedirs(save_path, exist_ok=True)
149 | # k-fold
150 | kf = KFold(
151 | n_splits=n_splits, shuffle=True, random_state=random_state)
152 | for i, (train_index, val_test_index) in enumerate(kf.split(data_list)):
153 | val_index, test_index = \
154 | train_test_split(val_test_index, test_size=0.5, shuffle=True, random_state=random_state)
155 | val_index = sorted(val_index)
156 | test_index = sorted(test_index)
157 | f = open(os.path.join(save_path, 'train_list_cv{}.txt'.format(i)), "w")
158 | for j in train_index:
159 | f.write(str(data_list[j]) + "\n")
160 | f.close()
161 | f = open(os.path.join(save_path, 'validation_list_cv{}.txt'.format(i)), "w")
162 | for j in val_index:
163 | f.write(str(data_list[j]) + "\n")
164 | f.close()
165 | f = open(os.path.join(save_path, 'test_list_cv{}.txt'.format(i)), "w")
166 | for j in test_index:
167 | f.write(str(data_list[j]) + "\n")
168 | f.close()
169 |
--------------------------------------------------------------------------------
/src/functions/array/__init__.py:
--------------------------------------------------------------------------------
1 | from src.functions.array.resize_images_3d import resize_images_3d # NOQA
2 |
--------------------------------------------------------------------------------
/src/functions/array/resize_images_3d.py:
--------------------------------------------------------------------------------
1 | import numpy
2 |
3 | from chainer.backends import cuda
4 | from chainer import function_node
5 | from chainer.utils import type_check
6 |
7 |
8 | class ResizeImages3D(function_node.FunctionNode):
9 |
10 | def __init__(self, output_shape):
11 | self.out_H = output_shape[0]
12 | self.out_W = output_shape[1]
13 | self.out_D = output_shape[2]
14 |
15 | def check_type_forward(self, in_types):
16 | n_in = in_types.size()
17 | type_check.expect(n_in == 1)
18 |
19 | x_type = in_types[0]
20 | type_check.expect(
21 | x_type.dtype.char == 'f',
22 | x_type.ndim == 5
23 | )
24 |
25 | def forward(self, inputs):
26 | x, = inputs
27 | xp = cuda.get_array_module(x)
28 |
29 | B, C, H, W, D = x.shape
30 |
31 | v_1d = xp.linspace(0, H - 1, num=self.out_H)
32 | u_1d = xp.linspace(0, W - 1, num=self.out_W)
33 | t_1d = xp.linspace(0, D - 1, num=self.out_D)
34 | grid = xp.meshgrid(v_1d, u_1d, t_1d, indexing='ij')
35 | v = grid[0].ravel()
36 | u = grid[1].ravel()
37 | t = grid[2].ravel()
38 |
39 | v0 = xp.floor(v).astype(numpy.int32)
40 | v0 = v0.clip(0, H - 2)
41 | u0 = xp.floor(u).astype(numpy.int32)
42 | u0 = u0.clip(0, W - 2)
43 | t0 = xp.floor(t).astype(numpy.int32)
44 | t0 = t0.clip(0, D - 2)
45 |
46 | y = 0
47 | for i in range(2):
48 | for j in range(2):
49 | for k in range(2):
50 | wv = xp.abs(v0 + (1 - i) - v)
51 | wu = xp.abs(u0 + (1 - j) - u)
52 | wt = xp.abs(t0 + (1 - k) - t)
53 | w = (wv * wu * wt).astype(x.dtype, copy=False)
54 | y += w[None, None, :] * x[:, :, v0 + i, u0 + j, t0 + k]
55 | y = y.reshape(B, C, self.out_H, self.out_W, self.out_D)
56 | return y,
57 |
58 | def backward(self, indexes, grad_outputs):
59 | return ResizeImagesGrad3D(
60 | self.inputs[0].shape,
61 | (self.out_H, self.out_W, self.out_D)).apply(grad_outputs)
62 |
63 |
64 | class ResizeImagesGrad3D(function_node.FunctionNode):
65 |
66 | def __init__(self, input_shape, output_shape):
67 | self.out_H = output_shape[0]
68 | self.out_W = output_shape[1]
69 | self.out_D = output_shape[2]
70 | self.input_shape = input_shape
71 |
72 | def check_type_forward(self, in_types):
73 | n_in = in_types.size()
74 | type_check.expect(n_in == 1)
75 |
76 | x_type = in_types[0]
77 | type_check.expect(
78 | x_type.dtype.char == 'f',
79 | x_type.ndim == 5
80 | )
81 |
82 | def forward(self, inputs):
83 | xp = cuda.get_array_module(*inputs)
84 | gy, = inputs
85 |
86 | B, C, H, W, D = self.input_shape
87 |
88 | v_1d = xp.linspace(0, H - 1, num=self.out_H)
89 | u_1d = xp.linspace(0, W - 1, num=self.out_W)
90 | t_1d = xp.linspace(0, D - 1, num=self.out_D)
91 | grid = xp.meshgrid(v_1d, u_1d, t_1d, indexing='ij')
92 | v = grid[0].ravel()
93 | u = grid[1].ravel()
94 | t = grid[2].ravel()
95 |
96 | v0 = xp.floor(v).astype(numpy.int32)
97 | v0 = v0.clip(0, H - 2)
98 | u0 = xp.floor(u).astype(numpy.int32)
99 | u0 = u0.clip(0, W - 2)
100 | t0 = xp.floor(t).astype(numpy.int32)
101 | t0 = t0.clip(0, D - 2)
102 |
103 | if xp is numpy:
104 | scatter_add = numpy.add.at
105 | else:
106 | scatter_add = cuda.cupyx.scatter_add
107 |
108 | gx = xp.zeros(self.input_shape, dtype=gy.dtype)
109 | gy = gy.reshape(B, C, -1)
110 | for i in range(2):
111 | for j in range(2):
112 | for k in range(2):
113 | wv = xp.abs(v0 + (1 - i) - v)
114 | wu = xp.abs(u0 + (1 - j) - u)
115 | wt = xp.abs(t0 + (1 - k) - t)
116 | w = (wv * wu * wt).astype(gy.dtype, copy=False)
117 | scatter_add(
118 | gx,
119 | (slice(None), slice(None), v0 + i, u0 + j, t0 + k),
120 | gy * w)
121 | return gx,
122 |
123 | def backward(self, indexes, grad_outputs):
124 | return ResizeImages3D(
125 | (self.out_H, self.out_W, self.out_D)).apply(grad_outputs)
126 |
127 |
128 | def resize_images_3d(x, output_shape):
129 | """Resize images to the given shape.
130 | This function resizes 3D data to :obj:`output_shape`.
131 | Currently, only bilinear interpolation is supported as the sampling method.
132 | Notation: here is a notation for dimensionalities.
133 | - :math:`n` is the batch size.
134 | - :math:`c_I` is the number of the input channels.
135 | - :math:`h`, :math:`w` and :math:`d` are the height, width and depth of the
136 | input image, respectively.
137 | - :math:`h_O`, :math:`w_O` and :math:`d_O` are the height, width and depth
138 | of the output image.
139 | Args:
140 | x (~chainer.Variable):
141 | Input variable of shape :math:`(n, c_I, h, w, d)`.
142 | output_shape (tuple):
143 | This is a tuple of length 3 whose values are :obj:`(h_O, w_O, d_O)`
144 | Returns:
145 | ~chainer.Variable: Resized image whose shape is \
146 | :math:`(n, c_I, h_O, w_O, d_O)`.
147 | """
148 | return ResizeImages3D(output_shape).apply((x,))[0]
149 |
--------------------------------------------------------------------------------
/src/functions/array/resize_images_3d_nearest_neighbor.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from typing import Tuple
3 |
4 | from chainer.backends import cuda
5 | from chainer.types import NdArray
6 |
7 |
8 | def resize_images_3d_nearest_neighbor(x: NdArray, output_shape: Tuple[int, int, int]) -> NdArray:
9 | # Note: this function is not differentiable
10 | # Resize from (H, W, D) to (H', W', D')
11 | out_H, out_W, out_D = output_shape
12 | xp = cuda.get_array_module(x)
13 |
14 | B, C, H, W, D = x.shape
15 | v_1d = xp.linspace(0, H - 1, num=out_H)
16 | u_1d = xp.linspace(0, W - 1, num=out_W)
17 | t_1d = xp.linspace(0, D - 1, num=out_D)
18 | grid = xp.meshgrid(v_1d, u_1d, t_1d, indexing='ij')
19 | v = grid[0].ravel() # (H'W'D',)
20 | u = grid[1].ravel()
21 | t = grid[2].ravel()
22 |
23 | v0 = xp.floor(v).astype(np.int32)
24 | v0 = v0.clip(0, H - 2)
25 | u0 = xp.floor(u).astype(np.int32)
26 | u0 = u0.clip(0, W - 2)
27 | t0 = xp.floor(t).astype(np.int32)
28 | t0 = t0.clip(0, D - 2)
29 |
30 | ws, xs = [], []
31 | for i in range(2):
32 | for j in range(2):
33 | for k in range(2):
34 | wv = xp.abs(v0 + (1 - i) - v)
35 | wu = xp.abs(u0 + (1 - j) - u)
36 | wt = xp.abs(t0 + (1 - k) - t)
37 | w = wv * wu * wt
38 | ws.append(w)
39 | xs.append(x[:, :, v0 + i, u0 + j, t0 + k])
40 | ws = xp.stack(ws) # (8, H'W'D')
41 | xs = xp.stack(xs) # (8, B, C, H'W'D')
42 | xs = xs.transpose((0, 3, 1, 2)) # (8, H'W'D', B, C)
43 |
44 | target_indices = xp.argmax(ws, axis=0) # (H'W'D',)
45 | y = xs[target_indices, np.arange(out_H * out_W * out_D)] # (H'W'D', B, C)
46 | y = y.transpose((1, 2, 0)).reshape((B, C, out_H, out_W, out_D)) # (B, C, H', W', D')
47 | return y
48 |
--------------------------------------------------------------------------------
/src/functions/evaluation/__init__.py:
--------------------------------------------------------------------------------
1 | from src.functions.evaluation.dice_coefficient import dice_coefficient # NOQA
2 | from src.functions.evaluation.dice_coefficient import mean_dice_coefficient # NOQA
3 |
--------------------------------------------------------------------------------
/src/functions/evaluation/dice_coefficient.py:
--------------------------------------------------------------------------------
1 | from six import moves
2 | import numpy as np
3 |
4 | from chainer import function
5 | from chainer import functions as F
6 | from chainer.backends import cuda
7 | from chainer.utils import type_check
8 |
9 |
10 | class DiceCoefficient(function.Function):
11 |
12 | def __init__(self, ret_nan=True, dataset='task8hepatic', eps=1e-7, is_brats=False):
13 | self.ret_nan = ret_nan # return NaN if union==0
14 | self.dataset = dataset
15 | self.eps = eps
16 | self.is_brats = is_brats
17 |
18 | def check_type_forward(self, in_types):
19 | type_check.expect(in_types.size() == 2)
20 | x_type, t_type = in_types
21 |
22 | type_check.expect(
23 | x_type.dtype.kind == 'f',
24 | t_type.dtype == np.int32
25 | )
26 |
27 | t_ndim = type_check.eval(t_type.ndim)
28 | type_check.expect(
29 | x_type.ndim >= t_type.ndim,
30 | x_type.shape[0] == t_type.shape[0],
31 | x_type.shape[2: t_ndim + 1] == t_type.shape[1:]
32 | )
33 | for i in moves.range(t_ndim + 1, type_check.eval(x_type.ndim)):
34 | type_check.expect(x_type.shape[i] == 1)
35 |
36 | def forward(self, inputs):
37 | """
38 | compute average Dice coefficient between two label images
39 | Math:
40 | DC = \frac{2|A\cap B|}{|A|+|B|}
41 | Args:
42 | inputs ((array_like, array_like)):
43 | Input pair (prediction, ground_truth)
44 | Returns:
45 | dice (float)
46 | """
47 | xp = cuda.get_array_module(*inputs)
48 | y, t = inputs
49 | number_class = y.shape[1]
50 | y = y.argmax(axis=1).reshape(t.shape)
51 | axis = tuple(range(t.ndim))
52 |
53 | dice = xp.zeros(number_class, dtype=xp.float32)
54 | for i in range(number_class):
55 | if self.is_brats:
56 | if i == 1:
57 | # Enhancing tumor
58 | y_match = xp.equal(y, 3).astype(xp.float32)
59 | t_match = xp.equal(t, 3).astype(xp.float32)
60 | elif i == 2:
61 | # Tumor core
62 | y_match = (xp.equal(y, 2)+xp.equal(y, 3)).astype(xp.float32)
63 | t_match = (xp.equal(t, 2)+xp.equal(t, 3)).astype(xp.float32)
64 | elif i == 3:
65 | # Whole Tumor
66 | y_match = (xp.equal(y, 1)+xp.equal(y, 2)+xp.equal(y, 3)).astype(xp.float32)
67 | t_match = (xp.equal(t, 1)+xp.equal(t, 2)+xp.equal(t, 3)).astype(xp.float32)
68 | else:
69 | y_match = xp.equal(y, i).astype(xp.float32)
70 | t_match = xp.equal(t, i).astype(xp.float32)
71 | else:
72 | y_match = xp.equal(y, i).astype(xp.float32)
73 | t_match = xp.equal(t, i).astype(xp.float32)
74 |
75 | intersect = xp.sum(y_match * t_match, axis=axis)
76 | union = xp.sum(y_match, axis=axis) + xp.sum(t_match, axis=axis)
77 | if union == 0.:
78 | intersect += 0.5*self.eps
79 | union += self.eps
80 | dice[i] = 0.0
81 | if self.ret_nan:
82 | dice[i] = np.nan
83 | else:
84 | dice[i] = 2.0 * (intersect / union)
85 |
86 | return xp.asarray(dice, dtype=xp.float32),
87 |
88 |
89 | def dice_coefficient(y, t, ret_nan=False, dataset='task8hepatic', eps=1e-7, is_brats=False):
90 | return DiceCoefficient(ret_nan=ret_nan, dataset=dataset, eps=eps, is_brats=is_brats)(y, t)
91 |
92 |
93 | def mean_dice_coefficient(dice_coefficients, ret_nan=True):
94 | if ret_nan:
95 | xp = cuda.get_array_module(dice_coefficients)
96 | selector = ~xp.isnan(dice_coefficients.data)
97 | dice_coefficients = F.get_item(dice_coefficients, selector)
98 | return F.mean(dice_coefficients, keepdims=True)
99 |
--------------------------------------------------------------------------------
/src/functions/loss/__init__.py:
--------------------------------------------------------------------------------
1 | from src.functions.loss.dice_loss import dice_loss # NOQA
2 | from src.functions.loss.dice_loss import softmax_dice_loss # NOQA
3 | from src.functions.loss.generalized_dice_loss import generalized_dice_loss # NOQA
4 | from src.functions.loss.generalized_dice_loss import softmax_generalized_dice_loss # NOQA
5 |
6 | from src.functions.loss.dice_loss import DiceLoss # NOQA
7 | from src.functions.loss.generalized_dice_loss import GeneralizedDiceLoss # NOQA
8 |
--------------------------------------------------------------------------------
/src/functions/loss/boundary_bce.py:
--------------------------------------------------------------------------------
1 | import chainer
2 | from chainer import function
3 | from chainer.backends import cuda
4 | from chainer.utils import type_check, force_array
5 | from chainer.functions import softmax_cross_entropy
6 |
7 |
8 | class BoundaryBCE(function.Function):
9 | def __init__(self, eps=1e-7):
10 | self.eps = eps
11 |
12 | def check_type_forward(self, in_types):
13 | type_check.expect(in_types.size() == 2)
14 | x_type, t_type = in_types
15 |
16 | type_check.expect(
17 | x_type.dtype.kind == 'f',
18 |
19 | x_type.shape[0] == t_type.shape[0],
20 | x_type.shape[2:] == t_type.shape[2:],
21 | )
22 |
23 | @staticmethod
24 | def _check_input_values(x, t):
25 | if not (((0 <= t) &
26 | (t < x.shape[1]))).all():
27 | msg = ('Each label `t` need to satisfy '
28 | '`0 <= t < x.shape[1]`')
29 | raise ValueError(msg)
30 |
31 | def forward(self, inputs):
32 | xp = cuda.get_array_module(*inputs)
33 | x, t = inputs
34 | if chainer.is_debug():
35 | self._check_input_values(x, t)
36 | num_class = t.shape[1]
37 | bound_loss = 0.
38 | for i in range(1, num_class):
39 | # convert label to a (0,1)label for each class
40 | tt = t[0, i, :, :, :].astype(xp.int32)
41 | tt = tt.reshape(-1)
42 | beta = xp.sum(tt)/tt.size
43 | class_weight = xp.stack([1-beta, beta])
44 |
45 | xx = x[:, [0, i], :, :, :]
46 | xx = xx.reshape(-1, 2)
47 | bound_loss += softmax_cross_entropy(xx, tt, class_weight=class_weight)
48 | return force_array(bound_loss.data, ),
49 |
50 |
51 | def boundary_bce(x, t, eps=1e-7):
52 | return BoundaryBCE(eps)(x, t)
53 |
--------------------------------------------------------------------------------
/src/functions/loss/cpc_loss.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import chainer
3 | from chainer import function
4 | import chainer.functions as F
5 | from chainer.backends import cuda
6 | from chainer.utils import type_check, force_array
7 | import random
8 |
9 |
10 | class CPCLoss(function.Function):
11 | """ Contrastive loss from Contrastive Predictive Coding,
12 | arXiv:1905.09272 """
13 | def __init__(self, upper=True, cpc_pattern='updown', num_neg=20, eps=1e-7):
14 | self.eps = eps
15 | # number of negative samples
16 | self.num_neg = num_neg
17 | self.upper = upper
18 | self.cpc_pattern = cpc_pattern
19 |
20 | def check_type_forward(self, in_types):
21 | type_check.expect(in_types.size() == 2)
22 | x_type, t_type = in_types
23 |
24 | type_check.expect(
25 | x_type.dtype.kind == 'f',
26 | t_type.dtype.kind == 'f',
27 | x_type.shape[0] == t_type.shape[0]
28 | )
29 |
30 | @staticmethod
31 | def _check_input_values(x, t):
32 | if not (0 <= t):
33 | msg = ('Each label `t` need to satisfy '
34 | '`0 <= t `')
35 | raise ValueError(msg)
36 |
37 | def forward(self, inputs):
38 | xp = cuda.get_array_module(*inputs)
39 | x, t = inputs
40 | b, c, gl0, gl1, gl2 = t.shape # gl: grid length
41 | cut_l = int(gl2/2)
42 |
43 | if chainer.is_debug():
44 | self._check_input_values(x, t)
45 |
46 | if self.cpc_pattern == 'ichimatsu':
47 | correct_z = t[:, :, 2:gl0+2:4, 2:gl1+2:4, 1:gl2:2]
48 | neg_z = t[:, :, :, :, 0:gl2:2]
49 | else:
50 | if self.upper:
51 | correct_z = t[:, :, :, :, -cut_l:]
52 | neg_z = t[:, :, :, :, :-cut_l]
53 | else:
54 | correct_z = t[:, :, :, :, :cut_l]
55 | neg_z = t[:, :, :, :, cut_l:]
56 |
57 | if self.cpc_pattern == 'ichimatsu':
58 | xx = F.reshape(x, (c, -1))
59 | correct_z = F.reshape(correct_z, (c, -1))
60 | neg_z = F.reshape(neg_z, (c, -1))
61 | # selecting num_neg negative samples to concat with correct samples of z
62 | selection = random.sample(set(np.arange(neg_z.shape[1])), self.num_neg)
63 | else:
64 | xx = F.reshape(x, (c, gl0*gl1*cut_l))
65 | correct_z = F.reshape(correct_z, (c, gl0*gl1*cut_l))
66 | neg_z = F.reshape(neg_z, (c, gl0*gl1*(gl2-cut_l)))
67 | # selecting num_neg negative samples to concat with correct samples of z
68 | selection = random.sample(set(np.arange(gl0*gl1*(gl2-cut_l))), self.num_neg)
69 |
70 | z_hat_T = F.transpose(xx)
71 | # selecting num_neg negative samples to concat with correct samples of z
72 | selection = random.sample(set(np.arange(gl0*gl1*(gl2-cut_l))), self.num_neg)
73 | neg_z = neg_z[:, selection]
74 | z = F.concat((correct_z, neg_z), axis=1)
75 |
76 | ip = xp.dot(z_hat_T.data, z.data)
77 | if self.cpc_pattern == 'ichimatsu':
78 | cpc = F.softmax_cross_entropy(ip, xp.arange(12), reduce='mean')
79 | else:
80 | cpc = F.softmax_cross_entropy(ip, xp.arange(gl0*gl1*cut_l), reduce='mean')
81 | return force_array(xp.mean(cpc.data), dtype=xp.float32),
82 |
83 |
84 | def cpc_loss(x, t, upper=True, cpc_pattern='updown', num_neg=20, eps=1e-7):
85 | return CPCLoss(upper, cpc_pattern, num_neg, eps)(x, t)
86 |
--------------------------------------------------------------------------------
/src/functions/loss/dice_loss.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import chainer
3 | from chainer import function
4 | from chainer.backends import cuda
5 | from chainer.utils import type_check, force_array
6 | from chainer.functions.activation import softmax
7 | from src.utils.encode_one_hot_vector import encode_one_hot_vector
8 |
9 |
10 | class DiceLoss(function.Function):
11 | def __init__(self, eps=1e-7, weight=False, encode=True):
12 | self.eps = eps
13 | self.intersect = None
14 | self.union = None
15 | self.weighted_dice = weight
16 | self.encode = encode
17 |
18 | def check_type_forward(self, in_types):
19 | type_check.expect(in_types.size() == 2)
20 | x_type, t_type = in_types
21 |
22 | type_check.expect(
23 | x_type.dtype.kind == 'f',
24 | t_type.dtype == np.int32,
25 |
26 | x_type.shape[0] == t_type.shape[0],
27 | x_type.shape[-3:] == t_type.shape[-3:],
28 | )
29 |
30 | @staticmethod
31 | def _check_input_values(x, t):
32 | if not (((0 <= t) &
33 | (t < x.shape[1]))).all():
34 | msg = ('Each label `t` need to satisfy '
35 | '`0 <= t < x.shape[1]`')
36 | raise ValueError(msg)
37 |
38 | def forward(self, inputs):
39 | xp = cuda.get_array_module(*inputs)
40 | x, t = inputs
41 | if chainer.is_debug():
42 | self._check_input_values(x, t)
43 | if self.encode:
44 | t = encode_one_hot_vector(t, x.shape[1])
45 | axis = (0,) + tuple(range(2, x.ndim))
46 | self.intersect = xp.sum((x * t), axis=axis)
47 | self.union = xp.sum((x * x), axis=axis) + xp.sum((t * t), axis=axis)
48 | dice = (2. * self.intersect + self.eps) / (self.union + self.eps)
49 | if self.weighted_dice:
50 | cw = xp.array([
51 | 1., 1., 1., 1., 0.5, 0.5, 0.8, 0.8, 0.5, 0.8, 0.8,
52 | 0.8, 0.5, 0.5, 0.7, 0.7, 0.7, 0.7, 0.6, 0.6, 1., 1., 1.], dtype='float32')
53 | dice = dice*cw*x.shape[1]/xp.sum(cw)
54 | return force_array(xp.mean(1. - dice), dtype=xp.float32),
55 |
56 | def backward(self, inputs, grad_outputs):
57 | xp = cuda.get_array_module(*inputs)
58 | x, t = inputs
59 | nb_class = x.shape[1]
60 | t = encode_one_hot_vector(t, nb_class)
61 |
62 | gx = xp.zeros_like(x)
63 | gloss = grad_outputs[0]
64 | cw = xp.array([
65 | 1., 1., 1., 1., 0.5, 0.5, 0.8, 0.8, 0.5, 0.8, 0.8,
66 | 0.8, 0.5, 0.5, 0.7, 0.7, 0.7, 0.7, 0.6, 0.6, 1., 1., 1.], dtype='float32')
67 | for i, w in zip(range(nb_class), cw):
68 | x_i = x[:, i]
69 | t_i = t[:, i]
70 | intersect = self.intersect[i]
71 | union = self.union[i]
72 |
73 | numerator = xp.multiply(union + self.eps, t_i) - \
74 | xp.multiply(2. * intersect + self.eps, x_i)
75 | denominator = xp.power(union + self.eps, 2)
76 | d_dice = 2 * xp.divide(numerator, denominator).astype(xp.float32)
77 | if self.weighted_dice:
78 | gx[:, i] = d_dice*w*nb_class/xp.sum(cw)
79 | else:
80 | gx[:, i] = d_dice
81 |
82 | gx *= gloss / nb_class
83 | return -gx.astype(xp.float32), None
84 |
85 |
86 | def dice_loss(x, t, eps=1e-7, weight=False, encode=True):
87 | return DiceLoss(eps, weight, encode)(x, t)
88 |
89 |
90 | def softmax_dice_loss(x, t, eps=1e-7, weight=False, encode=True):
91 | x1 = softmax.softmax(x, axis=1)
92 | return dice_loss(x1, t, eps, weight, encode)
93 |
--------------------------------------------------------------------------------
/src/functions/loss/focal_loss.py:
--------------------------------------------------------------------------------
1 | import chainer.functions as F
2 | from chainer.backends.cuda import get_array_module
3 |
4 |
5 | def focal_loss(x, t, gamma=0.5, eps=1e-6):
6 | xp = get_array_module(t)
7 |
8 | p = F.softmax(x)
9 | p = F.clip(p, x_min=eps, x_max=1-eps)
10 | log_p = F.log_softmax(x)
11 | t_onehot = xp.eye(x.shape[1])[t.ravel()]
12 |
13 | loss_sce = -1 * t_onehot * log_p
14 | loss_focal = F.sum(loss_sce * (1. - p) ** gamma, axis=1)
15 |
16 | return F.mean(loss_focal)
17 |
--------------------------------------------------------------------------------
/src/functions/loss/generalized_dice_loss.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import chainer
3 | from chainer import function
4 | from chainer.backends import cuda
5 | from chainer.utils import type_check, force_array
6 | from chainer.functions.activation import softmax
7 | from src.utils.encode_one_hot_vector import encode_one_hot_vector
8 |
9 |
10 | class GeneralizedDiceLoss(function.Function):
11 | def __init__(self, eps=1e-7):
12 | # avoid zero division error
13 | self.eps = eps
14 | self.w = None
15 | self.intersect = None
16 | self.union = None
17 |
18 | def check_type_forward(self, in_types):
19 | type_check.expect(in_types.size() == 2)
20 | x_type, t_type = in_types
21 |
22 | type_check.expect(
23 | x_type.dtype.kind == 'f',
24 | t_type.dtype == np.int32,
25 | t_type.ndim == x_type.ndim - 1,
26 |
27 | x_type.shape[0] == t_type.shape[0],
28 | x_type.shape[2:] == t_type.shape[1:],
29 | )
30 |
31 | @staticmethod
32 | def _check_input_values(x, t):
33 | if not (((0 <= t) &
34 | (t < x.shape[1]))).all():
35 | msg = ('Each label `t` need to satisfy '
36 | '`0 <= t < x.shape[1]`')
37 | raise ValueError(msg)
38 |
39 | def forward(self, inputs):
40 | xp = cuda.get_array_module(*inputs)
41 | x, t = inputs
42 | if chainer.is_debug():
43 | self._check_input_values(x, t)
44 | # one-hot encoding of ground truth
45 | t = encode_one_hot_vector(t, x.shape[1])
46 | # compute weight, intersection, and union
47 | axis = (0,) + tuple(range(2, x.ndim))
48 | sum_t = xp.sum(t, axis=axis)
49 | # avoid zero division error
50 | sum_t[sum_t == 0] = 1
51 | self.w = 1. / (sum_t*sum_t)
52 | self.intersect = xp.multiply(xp.sum((x * t), axis=axis), self.w)
53 | self.union = xp.multiply(
54 | xp.sum(x, axis=axis) + xp.sum(t, axis=axis), self.w)
55 | # compute dice loss
56 | dice = (2. * self.intersect + self.eps) / (self.union + self.eps)
57 | return force_array(1. - dice, dtype=xp.float32),
58 |
59 | def backward(self, inputs, grad_outputs):
60 | xp = cuda.get_array_module(*inputs)
61 | x, t = inputs
62 | nb_class = x.shape[1]
63 | t = encode_one_hot_vector(t, nb_class)
64 |
65 | gx = xp.zeros_like(x)
66 | gloss = grad_outputs[0]
67 | for i in range(nb_class):
68 | t_i = t[:, i]
69 | intersect = self.intersect
70 | union = self.union
71 | w = self.w[i]
72 | numerator = xp.multiply(
73 | union + self.eps, t_i) - intersect + self.eps
74 | denominator = xp.power(union + self.eps, 2)
75 | d_dice = 2 * xp.divide(
76 | w * numerator, denominator).astype(xp.float32)
77 | gx[:, i] = d_dice
78 |
79 | gx *= gloss
80 | return -gx.astype(xp.float32), None
81 |
82 |
83 | def generalized_dice_loss(x, t, eps=1e-7):
84 | return GeneralizedDiceLoss(eps)(x, t)
85 |
86 |
87 | def softmax_generalized_dice_loss(x, t, eps=1e-7):
88 | x1 = softmax.softmax(x, axis=1)
89 | return generalized_dice_loss(x1, t, eps)
90 |
--------------------------------------------------------------------------------
/src/functions/loss/mixed_dice_loss.py:
--------------------------------------------------------------------------------
1 | import chainer.functions as F
2 | from src.functions.loss import softmax_dice_loss
3 | from src.functions.loss import focal_loss
4 | import cupy
5 |
6 |
7 | def struct_dice_loss_plus_cross_entropy(x, t, w=0.5):
8 | cw = cupy.array([
9 | 1., 1., 1., 1., 0.5, 0.5, 0.8, 0.8, 0.5, 0.8, 0.8,
10 | 0.8, 0.5, 0.5, 0.7, 0.7, 0.7, 0.7, 0.6, 0.6, 1., 1., 1.], dtype='float32')
11 | return w * softmax_dice_loss(x, t) + (1 - w) * F.softmax_cross_entropy(x, t, class_weight=cw)
12 |
13 |
14 | def dice_loss_plus_cross_entropy(x, t, w=0.5, encode=True):
15 | return w * softmax_dice_loss(x, t, encode=encode) + (1 - w) * F.softmax_cross_entropy(x, t)
16 |
17 |
18 | def dice_loss_plus_focal_loss(x, t, w=0.5):
19 | return w * softmax_dice_loss(x, t) + (1 - w) * focal_loss(x, t)
20 |
--------------------------------------------------------------------------------
/src/links/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pfnet-research/label-efficient-brain-tumor-segmentation/aad80ed7acb510a3147bb11c3910d2e17fb355d1/src/links/__init__.py
--------------------------------------------------------------------------------
/src/links/model/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pfnet-research/label-efficient-brain-tumor-segmentation/aad80ed7acb510a3147bb11c3910d2e17fb355d1/src/links/model/__init__.py
--------------------------------------------------------------------------------
/src/links/model/resize_images_3d.py:
--------------------------------------------------------------------------------
1 | import numpy
2 | from chainer.backends import cuda
3 | from chainer import function_node
4 | from chainer.utils import type_check
5 |
6 |
7 | class ResizeImages3D(function_node.FunctionNode):
8 |
9 | def __init__(self, output_shape):
10 | self.out_H = output_shape[0]
11 | self.out_W = output_shape[1]
12 | self.out_D = output_shape[2]
13 |
14 | def check_type_forward(self, in_types):
15 | n_in = in_types.size()
16 | type_check.expect(n_in == 1)
17 |
18 | x_type = in_types[0]
19 | type_check.expect(
20 | x_type.dtype.char == 'f',
21 | x_type.ndim == 5
22 | )
23 |
24 | def forward(self, inputs):
25 | x, = inputs
26 | xp = cuda.get_array_module(x)
27 |
28 | B, C, H, W, D = x.shape
29 |
30 | u_1d = xp.linspace(0, W - 1, num=self.out_W)
31 | v_1d = xp.linspace(0, H - 1, num=self.out_H)
32 | t_1d = xp.linspace(0, D - 1, num=self.out_D)
33 | grid = xp.meshgrid(u_1d, v_1d, t_1d)
34 | u = grid[0].ravel()
35 | v = grid[1].ravel()
36 | t = grid[2].ravel()
37 |
38 | u0 = xp.floor(u).astype(numpy.int32)
39 | u0 = u0.clip(0, W - 2)
40 | u1 = u0 + 1
41 | v0 = xp.floor(v).astype(numpy.int32)
42 | v0 = v0.clip(0, H - 2)
43 | v1 = v0 + 1
44 | t0 = xp.floor(t).astype(numpy.int32)
45 | t0 = t0.clip(0, D - 2)
46 | t1 = t0 + 1
47 |
48 | # weights
49 | w1 = (u1 - u) * (v1 - v) * (t1 - t)
50 | w2 = (u - u0) * (v1 - v) * (t1 - t)
51 | w3 = (u1 - u) * (v - v0) * (t1 - t)
52 | w4 = (u - u0) * (v - v0) * (t1 - t)
53 | w5 = (u1 - u) * (v1 - v) * (t - t0)
54 | w6 = (u - u0) * (v1 - v) * (t - t0)
55 | w7 = (u1 - u) * (v - v0) * (t - t0)
56 | w8 = (u - u0) * (v - v0) * (t - t0)
57 | w1 = w1.astype(x.dtype)
58 | w2 = w2.astype(x.dtype)
59 | w3 = w3.astype(x.dtype)
60 | w4 = w4.astype(x.dtype)
61 | w5 = w5.astype(x.dtype)
62 | w6 = w6.astype(x.dtype)
63 | w7 = w7.astype(x.dtype)
64 | w8 = w8.astype(x.dtype)
65 |
66 | y = (w1[None, None, :] * x[:, :, v0, u0, t0] +
67 | w2[None, None, :] * x[:, :, v0, u1, t0] +
68 | w3[None, None, :] * x[:, :, v1, u0, t0] +
69 | w4[None, None, :] * x[:, :, v1, u1, t0] +
70 | w5[None, None, :] * x[:, :, v0, u0, t1] +
71 | w6[None, None, :] * x[:, :, v0, u1, t1] +
72 | w7[None, None, :] * x[:, :, v1, u0, t1] +
73 | w8[None, None, :] * x[:, :, v1, u1, t1])
74 | y = y.reshape(B, C, self.out_H, self.out_W, self.out_D)
75 | return y,
76 |
77 | def backward(self, indexes, grad_outputs):
78 | return ResizeImagesGrad3D(
79 | self.inputs[0].shape,
80 | (self.out_H, self.out_W, self.out_D)).apply(grad_outputs)
81 |
82 |
83 | class ResizeImagesGrad3D(function_node.FunctionNode):
84 |
85 | def __init__(self, input_shape, output_shape):
86 | self.out_H = output_shape[0]
87 | self.out_W = output_shape[1]
88 | self.out_D = output_shape[2]
89 | self.input_shape = input_shape
90 |
91 | def check_type_forward(self, in_types):
92 | n_in = in_types.size()
93 | type_check.expect(n_in == 1)
94 |
95 | x_type = in_types[0]
96 | type_check.expect(
97 | x_type.dtype.char == 'f',
98 | x_type.ndim == 5
99 | )
100 |
101 | def forward(self, inputs):
102 | xp = cuda.get_array_module(*inputs)
103 | gy, = inputs
104 |
105 | B, C, H, W, D = self.input_shape
106 |
107 | u_1d = xp.linspace(0, W - 1, num=self.out_W)
108 | v_1d = xp.linspace(0, H - 1, num=self.out_H)
109 | t_1d = xp.linspace(0, D - 1, num=self.out_D)
110 | grid = xp.meshgrid(u_1d, v_1d, t_1d)
111 | u = grid[0].ravel()
112 | v = grid[1].ravel()
113 | t = grid[2].ravel()
114 |
115 | u0 = xp.floor(u).astype(numpy.int32)
116 | u0 = u0.clip(0, W - 2)
117 | u1 = u0 + 1
118 | v0 = xp.floor(v).astype(numpy.int32)
119 | v0 = v0.clip(0, H - 2)
120 | v1 = v0 + 1
121 | t0 = xp.floor(t).astype(numpy.int32)
122 | t0 = t0.clip(0, D - 2)
123 | t1 = t0 + 1
124 |
125 | # weights
126 | wu0 = u - u0
127 | wu1 = u1 - u
128 | wv0 = v - v0
129 | wv1 = v1 - v
130 | wt0 = t - t0
131 | wt1 = t1 - t
132 | wu0 = wu0.astype(gy.dtype)
133 | wu1 = wu1.astype(gy.dtype)
134 | wv0 = wv0.astype(gy.dtype)
135 | wv1 = wv1.astype(gy.dtype)
136 | wt0 = wt0.astype(gy.dtype)
137 | wt1 = wt1.astype(gy.dtype)
138 |
139 | # --- gx
140 | if xp is numpy:
141 | scatter_add = numpy.add.at
142 | else:
143 | scatter_add = cuda.cupyx.scatter_add
144 |
145 | gx = xp.zeros(self.input_shape, dtype=gy.dtype)
146 | gy = gy.reshape(B, C, -1)
147 | scatter_add(gx, (slice(None), slice(None), v0, u0, t0),
148 | gy * wu1 * wv1 * wt1)
149 | scatter_add(gx, (slice(None), slice(None), v0, u1, t0),
150 | gy * wu0 * wv1 * wt1)
151 | scatter_add(gx, (slice(None), slice(None), v1, u0, t0),
152 | gy * wu1 * wv0 * wt1)
153 | scatter_add(gx, (slice(None), slice(None), v1, u1, t0),
154 | gy * wu0 * wv0 * wt1)
155 | scatter_add(gx, (slice(None), slice(None), v0, u0, t1),
156 | gy * wu1 * wv1 * wt0)
157 | scatter_add(gx, (slice(None), slice(None), v0, u1, t1),
158 | gy * wu0 * wv1 * wt0)
159 | scatter_add(gx, (slice(None), slice(None), v1, u0, t1),
160 | gy * wu1 * wv0 * wt0)
161 | scatter_add(gx, (slice(None), slice(None), v1, u1, t1),
162 | gy * wu0 * wv0 * wt0)
163 | return gx,
164 |
165 | def backward(self, indexes, grad_outputs):
166 | return ResizeImages3D(
167 | (self.out_H, self.out_W, self.out_D)).apply(grad_outputs)
168 |
169 |
170 | def resize_images_3d(x, output_shape):
171 | """Resize images to the given shape.
172 | This function resizes 3D data to :obj:`output_shape`.
173 | Currently, only bilinear interpolation is supported as the sampling method.
174 | Notatition: here is a notation for dimensionalities.
175 | - :math:`n` is the batch size.
176 | - :math:`c_I` is the number of the input channels.
177 | - :math:`h`, :math:`w` and :math:`d` are the height, width and depth of the
178 | input image, respectively.
179 | - :math:`h_O`, :math:`w_O` and :math:`d_0` are the height, width and depth
180 | of the output image.
181 | Args:
182 | x (~chainer.Variable):
183 | Input variable of shape :math:`(n, c_I, h, w, d)`.
184 | output_shape (tuple):
185 | This is a tuple of length 3 whose values are :obj:`(h_O, w_O, d_O)`.
186 | Returns:
187 | ~chainer.Variable: Resized image whose shape is \
188 | :math:`(n, c_I, h_O, w_O, d_O)`.
189 | """
190 | return ResizeImages3D(output_shape).apply((x,))[0]
191 |
--------------------------------------------------------------------------------
/src/links/model/vaeseg.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import chainer
3 | import chainer.links as L
4 | import chainer.functions as F
5 | from chainer.links import GroupNormalization
6 | from src.functions.array import resize_images_3d
7 | from chainer.utils.conv_nd import im2col_nd
8 |
9 |
10 | class CNR(chainer.Chain):
11 | """Convolution, Normalize, then ReLU"""
12 |
13 | def __init__(
14 | self,
15 | channels,
16 | norm=GroupNormalization,
17 | down_sampling=False,
18 | comm=None
19 | ):
20 | super(CNR, self).__init__()
21 | with self.init_scope():
22 | if down_sampling:
23 | self.c = L.Convolution3D(None, channels, 3, 2, 1)
24 | else:
25 | self.c = L.Convolution3D(None, channels, 3, 1, 1)
26 | if norm.__name__ == 'MultiNodeBatchNormalization':
27 | self.n = norm(channels, comm, eps=1e-5)
28 | elif norm.__name__ == 'BatchNormalization':
29 | self.n = norm(channels, eps=1e-5)
30 | elif norm.__name__ == 'GroupNormalization':
31 | self.n = norm(groups=8, size=channels)
32 | else:
33 | self.n = norm(channels)
34 |
35 | def forward(self, x):
36 | h = F.relu(self.n(self.c(x)))
37 | return h
38 |
39 |
40 | class NRC(chainer.Chain):
41 | """Normalize, ReLU, then Convolution"""
42 | def __init__(
43 | self,
44 | in_channels,
45 | out_channels,
46 | norm=GroupNormalization,
47 | down_sampling=False,
48 | comm=None
49 | ):
50 | super(NRC, self).__init__()
51 | with self.init_scope():
52 | if norm.__name__ == 'MultiNodeBatchNormalization':
53 | self.n = norm(in_channels, comm, eps=1e-5)
54 | elif norm.__name__ == 'BatchNormalization':
55 | self.n = norm(in_channels, eps=1e-5)
56 | elif norm.__name__ == 'GroupNormalization':
57 | self.n = norm(groups=8, size=in_channels)
58 | else:
59 | self.n = norm(in_channels)
60 | if down_sampling:
61 | self.c = L.Convolution3D(None, out_channels, 3, 2, 1)
62 | else:
63 | self.c = L.Convolution3D(None, out_channels, 3, 1, 1)
64 |
65 | def forward(self, x):
66 | h = self.c(F.relu(self.n(x)))
67 | return h
68 |
69 |
70 | class ResBlock(chainer.Chain):
71 |
72 | def __init__(
73 | self,
74 | channels,
75 | norm=GroupNormalization,
76 | bn_first=True,
77 | comm=None,
78 | concat_mode=False
79 | ):
80 | super(ResBlock, self).__init__()
81 | with self.init_scope():
82 | if bn_first:
83 | if concat_mode:
84 | channels *= 2
85 | self.block1 = NRC(channels, channels, norm, False, comm)
86 | self.block2 = NRC(channels, channels, norm, False, comm)
87 | else:
88 | self.block1 = CNR(channels, norm, False, comm)
89 | self.block2 = CNR(channels, norm, False, comm)
90 |
91 | def forward(self, x):
92 | h = self.block1(x)
93 | h = self.block2(h)
94 | return h + x
95 |
96 |
97 | class DownBlock(chainer.Chain):
98 | """down sample (conv stride2), then ResBlock * num of layers"""
99 |
100 | def __init__(
101 | self,
102 | channels,
103 | norm=GroupNormalization,
104 | down_sample=True,
105 | n_blocks=1,
106 | bn_first=True,
107 | comm=None
108 | ):
109 | self.down_sample = down_sample
110 | self.n_blocks = n_blocks
111 | super(DownBlock, self).__init__()
112 | with self.init_scope():
113 | if down_sample:
114 | self.d = L.Convolution3D(None, channels, 3, 2, 1)
115 | else:
116 | self.d = L.Convolution3D(None, channels, 3, 1, 1)
117 | for i in range(n_blocks):
118 | layer = ResBlock(channels, norm, bn_first, comm)
119 | setattr(self, 'block{}'.format(i), layer)
120 |
121 | def forward(self, x):
122 | h = self.d(x)
123 | for i in range(self.n_blocks):
124 | h = getattr(self, 'block{}'.format(i))(h)
125 | return h
126 |
127 |
128 | class VD(chainer.Chain):
129 | """VAE (down to 256dim.) see Table1"""
130 | def __init__(
131 | self,
132 | channels,
133 | norm=GroupNormalization,
134 | bn_first=True,
135 | ndim_latent=128,
136 | comm=None
137 | ):
138 | super(VD, self).__init__()
139 | with self.init_scope():
140 | if bn_first:
141 | self.b = NRC(channels, 16, norm, True, comm)
142 | else:
143 | self.b = CNR(16, norm, True, comm)
144 | self.d_mu = L.Linear(None, ndim_latent)
145 | self.d_ln_var = L.Linear(None, ndim_latent)
146 |
147 | def forward(self, x):
148 | h = self.b(x)
149 | mu = self.d_mu(h)
150 | ln_var = self.d_ln_var(h)
151 | return mu, ln_var
152 |
153 |
154 | class Encoder(chainer.Chain):
155 |
156 | def __init__(
157 | self,
158 | base_channels=32,
159 | norm=GroupNormalization,
160 | bn_first=True,
161 | ndim_latent=128,
162 | comm=None
163 | ):
164 | super(Encoder, self).__init__()
165 | with self.init_scope():
166 | self.enc_initconv = L.Convolution3D(None, base_channels, 3, 1, 1)
167 | self.enc_block0 = DownBlock(base_channels, norm, False, 1, bn_first, comm)
168 | self.enc_block1 = DownBlock(2*base_channels, norm, True, 2, bn_first, comm)
169 | self.enc_block2 = DownBlock(4*base_channels, norm, True, 2, bn_first, comm)
170 | self.enc_block3 = DownBlock(8*base_channels, norm, True, 4, bn_first, comm)
171 |
172 | def forward(self, x):
173 | h = F.dropout(self.enc_initconv(x), ratio=0.2)
174 | hs = []
175 | for i in range(4):
176 | h = getattr(self, 'enc_block{}'.format(i))(h)
177 | hs.append(h)
178 | return hs
179 |
180 |
181 | class UpBlock(chainer.Chain):
182 |
183 | def __init__(
184 | self,
185 | channels,
186 | norm=GroupNormalization,
187 | bn_first=True,
188 | mode='sum',
189 | comm=None
190 | ):
191 | self.mode = mode
192 | concat_mode = True if mode == 'concat' else False
193 | super(UpBlock, self).__init__()
194 | with self.init_scope():
195 | self.c1 = L.Convolution3D(None, channels, 1, 1, 0)
196 | self.rb = ResBlock(channels, norm, bn_first, comm, concat_mode)
197 |
198 | def forward(self, xd, xu=None):
199 | xd_shape = xd.shape[2:]
200 | out_shape = tuple(i * 2 for i in xd_shape)
201 | h = resize_images_3d(self.c1(xd), output_shape=out_shape)
202 | if xu is not None:
203 | if self.mode == 'sum':
204 | h += xu
205 | elif self.mode == 'concat':
206 | h = F.concat((h, xu), axis=1)
207 | h = self.rb(h)
208 | return h
209 |
210 |
211 | class Decoder(chainer.Chain):
212 |
213 | def __init__(
214 | self,
215 | base_channels,
216 | out_channels,
217 | norm=GroupNormalization,
218 | bn_first=True,
219 | mode='sum',
220 | comm=None
221 | ):
222 | super(Decoder, self).__init__()
223 | with self.init_scope():
224 | for i in range(3):
225 | layer = UpBlock(2**i * base_channels, norm, bn_first, mode, comm)
226 | setattr(self, 'dec_block{}'.format(i), layer)
227 | self.dec_end = L.Convolution3D(None, out_channels, 1, 1, 0)
228 |
229 | def forward(self, hs, bs=None):
230 | # bs: output from boundary stream
231 | y = hs[-1]
232 | for i in reversed(range(3)):
233 | h = hs[i]
234 | y = getattr(self, 'dec_block{}'.format(i))(y, h)
235 | if bs is not None:
236 | y = F.concat((y, bs), axis=1)
237 | y = self.dec_end(y)
238 | return y
239 |
240 |
241 | class VU(chainer.Chain):
242 |
243 | def __init__(
244 | self,
245 | input_shape
246 | ):
247 | self.bottom_shape = tuple(i // 16 for i in input_shape)
248 | bottom_size = 16 * np.prod(list(self.bottom_shape))
249 | self.output_shape = tuple(i // 8 for i in input_shape)
250 | super(VU, self).__init__()
251 | with self.init_scope():
252 | self.dense = L.Linear(None, bottom_size)
253 | self.conv1 = L.Convolution3D(None, 256, 1, 1, 0)
254 |
255 | def forward(self, x):
256 | x_shape = x.shape
257 | h = F.relu(self.dense(x))
258 | h = F.reshape(h, (x_shape[0], 16) + self.bottom_shape)
259 | h = resize_images_3d(self.conv1(h), output_shape=self.output_shape)
260 | return h
261 |
262 |
263 | class VAE(chainer.Chain):
264 |
265 | def __init__(
266 | self,
267 | in_channels,
268 | base_channels,
269 | norm=GroupNormalization,
270 | bn_first=True,
271 | input_shape=(160, 192, 128),
272 | comm=None
273 | ):
274 | super(VAE, self).__init__()
275 | with self.init_scope():
276 | self.vu = VU(input_shape)
277 | for i in range(3):
278 | layer = UpBlock(2**i * base_channels, norm, bn_first, None, comm)
279 | setattr(self, 'vae_block{}'.format(i), layer)
280 | self.vae_end = L.Convolution3D(None, in_channels, 1, 1, 0)
281 |
282 | def forward(self, x):
283 | h = self.vu(x)
284 | for i in reversed(range(3)):
285 | h = getattr(self, 'vae_block{}'.format(i))(h)
286 | h = self.vae_end(h)
287 | return h
288 |
289 |
290 | def divide_img(x, grid_size=32):
291 | b, c, x1, x2, x3 = x.shape
292 | kersize = grid_size # kernel size
293 | ssize = int(grid_size*0.5) # stride size
294 | gl0 = int(x1/ssize-1) # grid length
295 | gl1 = int(x2/ssize-1)
296 | gl2 = int(x3/ssize-1)
297 |
298 | if type(x) == chainer.variable.Variable:
299 | h = im2col_nd(
300 | x.data, ksize=(kersize, kersize, kersize),
301 | stride=(ssize, ssize, ssize), pad=(0, 0, 0))
302 | else:
303 | h = im2col_nd(
304 | x, ksize=(kersize, kersize, kersize),
305 | stride=(ssize, ssize, ssize), pad=(0, 0, 0))
306 | h = F.reshape(h, (1, 4, kersize, kersize, kersize, gl0*gl1*gl2))
307 | h = F.transpose(h, axes=(5, 1, 2, 3, 4, 0))
308 | h = F.reshape(h, (gl0*gl1*gl2, 4, kersize, kersize, kersize))
309 | return h
310 |
311 |
312 | class CPCPredictor(chainer.Chain):
313 |
314 | def __init__(
315 | self,
316 | base_channels=256,
317 | norm=GroupNormalization,
318 | bn_first=True,
319 | grid_size=32,
320 | input_shape=(160, 160, 128),
321 | upper=True, # whether to predict the upper half in the CPC task
322 | cpc_pattern='updown',
323 | comm=None
324 | ):
325 | x1, x2, x3 = input_shape
326 | ssize = int(grid_size*0.5)
327 | self.gl0 = int(x1/ssize-1)
328 | self.gl1 = int(x2/ssize-1)
329 | self.gl2 = int(x3/ssize-1)
330 | self.cut_l = int(self.gl2/2)
331 | self.base_channels = base_channels
332 | self.upper = upper
333 | self.cpc_pattern = cpc_pattern
334 |
335 | super(CPCPredictor, self).__init__()
336 | with self.init_scope():
337 | for i in range(8):
338 | layer = ResBlock(base_channels, norm, bn_first, comm)
339 | setattr(self, 'pred_block{}'.format(i), layer)
340 | self.pred1 = L.Convolution3D(
341 | None, base_channels,
342 | ksize=(1, 1, 1), stride=1, pad=0)
343 |
344 | def forward(self, x):
345 | h = F.transpose(x, axes=(1, 0))
346 | h = F.reshape(h, (1, self.base_channels, self.gl0, self.gl1, self.gl2))
347 | if self.cpc_pattern == 'ichimatsu':
348 | hs = h[:, :, 0:self.gl0:4, 0:self.gl1:4, 0:6:2]
349 | else:
350 | if self.upper:
351 | hs = h[:, :, :, :, :self.cut_l]
352 | else:
353 | hs = h[:, :, :, :, -self.cut_l:]
354 | for i in range(8):
355 | hs = getattr(self, 'pred_block{}'.format(i))(hs)
356 | hs = self.pred1(hs)
357 | return hs
358 |
359 |
360 | class Attention(chainer.Chain):
361 | """concatenate inputs from the attention boundary stream,
362 | and from the Encoder. Then apply a 1*1 conv and a sigmoid activation."""
363 | def __init__(
364 | self,
365 | comm=None
366 | ):
367 | super(Attention, self).__init__()
368 | with self.init_scope():
369 | self.c = L.Convolution3D(None, out_channels=1, ksize=1, stride=1, pad=0)
370 |
371 | def forward(self, s, m):
372 | alpha = F.concat((s, m), axis=1)
373 | alpha = F.sigmoid(self.c(alpha))
374 | o = s*alpha
375 | return o
376 |
377 |
378 | class BoundaryStream(chainer.Chain):
379 | """Combination of upblock and attention for the boundary stream"""
380 | def __init__(
381 | self,
382 | base_channels,
383 | out_channels,
384 | norm=GroupNormalization,
385 | bn_first=True,
386 | mode='sum',
387 | comm=None
388 | ):
389 | super(BoundaryStream, self).__init__()
390 | with self.init_scope():
391 | self.c = L.Convolution3D(None, 2**2*base_channels, 1, 1, 0)
392 | self.rb = ResBlock(2**2*base_channels, norm, bn_first, comm)
393 | self.att = Attention(comm)
394 | for i in range(3):
395 | layer = UpBlock(2**i * base_channels, norm, bn_first, mode, comm)
396 | setattr(self, 'dec_block{}'.format(i), layer)
397 | att = Attention(comm)
398 | setattr(self, 'att_block{}'.format(i), att)
399 | self.dec_end = L.Convolution3D(None, out_channels, ksize=1, stride=1, pad=0)
400 |
401 | def forward(self, hs):
402 | bs = hs[-1]
403 | bs = self.att(self.rb(self.c(bs)), hs[-1])
404 | for i in reversed(range(3)):
405 | h = hs[i]
406 | bs = getattr(self, 'dec_block{}'.format(i))(bs)
407 | bs = getattr(self, 'att_block{}'.format(i))(bs, h)
408 | y = self.dec_end(bs)
409 | return y, bs
410 |
--------------------------------------------------------------------------------
/src/training/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pfnet-research/label-efficient-brain-tumor-segmentation/aad80ed7acb510a3147bb11c3910d2e17fb355d1/src/training/__init__.py
--------------------------------------------------------------------------------
/src/training/extensions/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pfnet-research/label-efficient-brain-tumor-segmentation/aad80ed7acb510a3147bb11c3910d2e17fb355d1/src/training/extensions/__init__.py
--------------------------------------------------------------------------------
/src/training/extensions/boundseg_evaluator.py:
--------------------------------------------------------------------------------
1 | import copy
2 | import cupy
3 | import math
4 |
5 | import chainer
6 | import chainer.functions as F
7 | from chainer.backends import cuda
8 | from chainer.training.extensions import Evaluator
9 | from chainer.dataset import convert
10 | from chainer import Reporter
11 | from chainer import reporter as reporter_module
12 | from src.functions.evaluation import dice_coefficient, mean_dice_coefficient
13 | from src.functions.loss.mixed_dice_loss import dice_loss_plus_cross_entropy
14 |
15 |
16 | class BoundSegEvaluator(Evaluator):
17 |
18 | def __init__(
19 | self,
20 | config,
21 | iterator,
22 | target,
23 | converter=convert.concat_examples,
24 | device=None
25 | ):
26 | super().__init__(iterator, target,
27 | converter, device,
28 | None, None)
29 | self.nested_label = config['nested_label']
30 | self.seg_lossfun = eval(config['seg_lossfun'])
31 | self.dataset = config['dataset_name']
32 | self.nb_labels = config['nb_labels']
33 | self.crop_size = eval(config['crop_size'])
34 | self.is_brats = config['is_brats']
35 |
36 | def compute_loss(self, y, t):
37 | if self.nested_label:
38 | loss = 0.
39 | b, c, h, w, d = t.shape
40 | for i in range(c):
41 | loss += self.seg_lossfun(y[:, 2*i:2*(i+1), ...], t[:, i, ...])
42 | else:
43 | loss = self.seg_lossfun(y, t)
44 | return loss
45 |
46 | def compute_accuracy(self, y, t):
47 | if self.nested_label:
48 | b, c, h, w, d = t.shape
49 | y = F.reshape(y, (b, 2, h*c, w, d))
50 | t = F.reshape(t, (b, h*c, w, d))
51 | return F.accuracy(y, t)
52 |
53 | def compute_dice_coef(self, y, t):
54 | if self.nested_label:
55 | dice = mean_dice_coefficient(dice_coefficient(y[:, 0:2, ...], t[:, 0, ...]))
56 | for i in range(1, t.shape[1]):
57 | dices = dice_coefficient(y[:, 2*i:2*(i+1), ...], t[:, i, ...])
58 | dice = F.concat((dice, mean_dice_coefficient(dices)), axis=0)
59 | else:
60 | dice = dice_coefficient(y, t, is_brats=self.is_brats)
61 | return dice
62 |
63 | def evaluate(self):
64 | summary = reporter_module.DictSummary()
65 | iterator = self._iterators['main']
66 | enc = self._targets['enc']
67 | dec = self._targets['dec']
68 | bound = self._targets['bound']
69 | reporter = Reporter()
70 | observer = object()
71 | reporter.add_observer(self.default_name, observer)
72 |
73 | if hasattr(iterator, 'reset'):
74 | iterator.reset()
75 | it = iterator
76 | else:
77 | it = copy.copy(iterator)
78 |
79 | for batch in it:
80 | x, t, te = self.converter(batch, self.device)
81 | with chainer.no_backprop_mode(), chainer.using_config('train', False):
82 | if self.dataset == 'msd_bound':
83 | h, w, d = x.shape[2:]
84 | hc, wc, dc = self.crop_size
85 | if self.nested_label:
86 | y = cupy.zeros((1, 2*(self.nb_labels-1), h, w, d), dtype='float32')
87 | else:
88 | y = cupy.zeros((1, self.nb_labels, h, w, d), dtype='float32')
89 | s = 128 # stride
90 | ker = 256 # kernel size
91 | dker = dc # kernel size for depth
92 | ds = dker*0.5 # stride for depth
93 | dsteps = int(math.floor((d-dker)/ds) + 1)
94 | steps = round((h - ker)/s + 1)
95 | for i in range(steps):
96 | for j in range(steps):
97 | for k in range(dsteps):
98 | xx = x[:, :, s*i:ker+s*i, s*j:ker+s*j, ds*k:dker+ds*k]
99 | hhs = enc(xx)
100 | yye, bbs = bound(hhs)
101 | yy = dec(hhs, bbs)
102 | y[:, :, s*i:ker+s*i, s*j:ker+s*j, ds*k:dker+ds*k] += yy.data
103 | # for the bottom depth part of the image
104 | xx = x[:, :, s*i:ker+s*i, s*j:ker+s*j, -dker:]
105 | hhs = enc(xx)
106 | yye, bbs = bound(hhs)
107 | yy = dec(hhs, bbs)
108 | y[:, :, s*i:ker+s*i, s*j:ker+s*j, -dker:] += yy.data
109 | else:
110 | hs = enc(x)
111 | ye, bs = bound(hs)
112 | y = dec(hs, bs)
113 | seg_loss = self.compute_loss(y, t)
114 | accuracy = self.compute_accuracy(y, t)
115 | dice = self.compute_dice_coef(y, t)
116 | mean_dice = mean_dice_coefficient(dice)
117 |
118 | weighted_loss = seg_loss
119 |
120 | observation = {}
121 | with reporter.scope(observation):
122 | reporter.report({
123 | 'loss/seg': seg_loss,
124 | 'loss/total': weighted_loss,
125 | 'acc': accuracy,
126 | 'mean_dc': mean_dice
127 | }, observer)
128 | xp = cuda.get_array_module(y)
129 | for i in range(len(dice)):
130 | if not xp.isnan(dice.data[i]):
131 | reporter.report({'dc_{}'.format(i): dice[i]}, observer)
132 | summary.add(observation)
133 | return summary.compute_mean()
134 |
--------------------------------------------------------------------------------
/src/training/extensions/cpcseg_evaluator.py:
--------------------------------------------------------------------------------
1 | import copy
2 | import cupy
3 | import chainer
4 | import chainer.functions as F
5 | from chainer.backends import cuda
6 | from chainer.training.extensions import Evaluator
7 | from chainer.dataset import convert
8 | from chainer import Reporter
9 | from chainer import reporter as reporter_module
10 | from src.functions.evaluation import dice_coefficient, mean_dice_coefficient
11 | from src.functions.loss.mixed_dice_loss import dice_loss_plus_cross_entropy
12 |
13 |
14 | class CPCSegEvaluator(Evaluator):
15 |
16 | def __init__(
17 | self,
18 | config,
19 | iterator,
20 | target,
21 | converter=convert.concat_examples,
22 | device=None,
23 | ):
24 | super().__init__(iterator, target,
25 | converter, device,
26 | None, None)
27 | self.nested_label = config['nested_label']
28 | self.seg_lossfun = eval(config['seg_lossfun'])
29 | self.rec_loss_weight = config['vaeseg_rec_loss_weight']
30 | self.kl_loss_weight = config['vaeseg_kl_loss_weight']
31 | self.grid_size = config['grid_size']
32 | self.base_channels = config['base_channels']
33 | self.cpc_loss_weight = config['cpc_vaeseg_cpc_loss_weight']
34 | self.cpc_pattern = config['cpc_pattern']
35 | self.is_brats = config['is_brats']
36 | self.dataset = config['dataset_name']
37 | self.nb_labels = config['nb_labels']
38 | self.crop_size = eval(config['crop_size'])
39 |
40 | def compute_loss(self, y, t):
41 | if self.nested_label:
42 | loss = 0.
43 | b, c, h, w, d = t.shape
44 | for i in range(c):
45 | loss += self.seg_lossfun(y[:, 2*i:2*(i+1), ...], t[:, i, ...])
46 | else:
47 | loss = self.seg_lossfun(y, t)
48 | return loss
49 |
50 | def compute_accuracy(self, y, t):
51 | if self.nested_label:
52 | b, c, h, w, d = t.shape
53 | y = F.reshape(y, (b, 2, h*c, w, d))
54 | t = F.reshape(t, (b, h*c, w, d))
55 | return F.accuracy(y, t)
56 |
57 | def compute_dice_coef(self, y, t):
58 | if self.nested_label:
59 | dice = mean_dice_coefficient(dice_coefficient(y[:, 0:2, ...], t[:, 0, ...]))
60 | for i in range(1, t.shape[1]):
61 | dices = dice_coefficient(y[:, 2*i:2*(i+1), ...], t[:, i, ...])
62 | dice = F.concat((dice, mean_dice_coefficient(dices)), axis=0)
63 | else:
64 | dice = dice_coefficient(y, t, is_brats=self.is_brats)
65 | return dice
66 |
67 | def evaluate(self):
68 | summary = reporter_module.DictSummary()
69 | iterator = self._iterators['main']
70 | enc = self._targets['enc']
71 | dec = self._targets['dec']
72 | reporter = Reporter()
73 | observer = object()
74 | reporter.add_observer(self.default_name, observer)
75 |
76 | if hasattr(iterator, 'reset'):
77 | iterator.reset()
78 | it = iterator
79 | else:
80 | it = copy.copy(iterator)
81 |
82 | for batch in it:
83 | x, t = self.converter(batch, self.device)
84 | with chainer.no_backprop_mode(), chainer.using_config('train', False):
85 | if self.dataset == 'msd_bound':
86 | # evaluation method for BRATS dataset only
87 | h, w, d = x.shape[2:]
88 | hc, wc, dc = self.crop_size
89 | if self.nested_label:
90 | y = cupy.zeros((1, 2*(self.nb_labels-1), h, w, d), dtype='float32')
91 | else:
92 | y = cupy.zeros((1, self.nb_labels, h, w, d), dtype='float32')
93 | hker = hc # kernel size
94 | hs = int(0.5*hker) # stride
95 | wker = wc
96 | wc = int(0.5*wker)
97 | dker = dc # kernel size for depth
98 | for i in range(2):
99 | for j in range(2):
100 | for k in range(2):
101 | xx = x[:, :, -i*hker:min(hker*(i+1), h),
102 | -j*wker:min(wker*(j+1), w), -k*dker:min(dker*(k+1), d)]
103 | hs = enc(xx)
104 | yy = dec(hs)
105 | y[:, :, -i*hker:min(hker*(i+1), h),
106 | -j*wker:min(wker*(j+1), w),
107 | -k*dker:min(dker*(k+1), d)] += yy.data
108 |
109 | else:
110 | hs = enc(x)
111 | y = dec(hs)
112 | seg_loss = self.compute_loss(y, t)
113 | accuracy = self.compute_accuracy(y, t)
114 | dice = self.compute_dice_coef(y, t)
115 | mean_dice = mean_dice_coefficient(dice)
116 |
117 | observation = {}
118 | with reporter.scope(observation):
119 | reporter.report({
120 | 'loss/seg': seg_loss,
121 | 'acc': accuracy,
122 | 'mean_dc': mean_dice
123 | }, observer)
124 | xp = cuda.get_array_module(y)
125 | for i in range(len(dice)):
126 | if not xp.isnan(dice.data[i]):
127 | reporter.report({'dc_{}'.format(i): dice[i]}, observer)
128 | summary.add(observation)
129 | return summary.compute_mean()
130 |
--------------------------------------------------------------------------------
/src/training/extensions/encdec_seg_evaluator.py:
--------------------------------------------------------------------------------
1 | import copy
2 | import cupy
3 | import math
4 | import chainer
5 | import chainer.functions as F
6 | from chainer.backends import cuda
7 | from chainer.training.extensions import Evaluator
8 | from chainer.dataset import convert
9 | from chainer import Reporter
10 | from chainer import reporter as reporter_module
11 | from src.functions.evaluation import dice_coefficient, mean_dice_coefficient
12 | from src.functions.loss.mixed_dice_loss import dice_loss_plus_cross_entropy
13 |
14 |
15 | class EncDecSegEvaluator(Evaluator):
16 |
17 | def __init__(
18 | self,
19 | config,
20 | iterator,
21 | target,
22 | converter=convert.concat_examples,
23 | device=None,
24 | ):
25 | super().__init__(iterator, target,
26 | converter, device,
27 | None, None)
28 | self.nested_label = config['nested_label']
29 | self.seg_lossfun = eval(config['seg_lossfun'])
30 | self.dataset = config['dataset_name']
31 | self.nb_labels = config['nb_labels']
32 | self.crop_size = eval(config['crop_size'])
33 | self.is_brats = config['is_brats']
34 |
35 | def compute_loss(self, y, t):
36 | if self.nested_label:
37 | loss = 0.
38 | b, c, h, w, d = t.shape
39 | for i in range(c):
40 | loss += self.seg_lossfun(y[:, 2*i:2*(i+1), ...], t[:, i, ...])
41 | else:
42 | loss = self.seg_lossfun(y, t)
43 | return loss
44 |
45 | def compute_accuracy(self, y, t):
46 | if self.nested_label:
47 | b, c, h, w, d = t.shape
48 | y = F.reshape(y, (b, 2, h*c, w, d))
49 | t = F.reshape(t, (b, h*c, w, d))
50 | return F.accuracy(y, t)
51 |
52 | def compute_dice_coef(self, y, t):
53 | if self.nested_label:
54 | dice = mean_dice_coefficient(dice_coefficient(y[:, 0:2, ...], t[:, 0, ...]))
55 | for i in range(1, t.shape[1]):
56 | dices = dice_coefficient(y[:, 2*i:2*(i+1), ...], t[:, i, ...])
57 | dice = F.concat((dice, mean_dice_coefficient(dices)), axis=0)
58 | else:
59 | dice = dice_coefficient(y, t, is_brats=self.is_brats)
60 | return dice
61 |
62 | def evaluate(self):
63 | summary = reporter_module.DictSummary()
64 | iterator = self._iterators['main']
65 | enc = self._targets['enc']
66 | dec = self._targets['dec']
67 | reporter = Reporter()
68 | observer = object()
69 | reporter.add_observer(self.default_name, observer)
70 |
71 | if hasattr(iterator, 'reset'):
72 | iterator.reset()
73 | it = iterator
74 | else:
75 | it = copy.copy(iterator)
76 |
77 | for batch in it:
78 | x, t = self.converter(batch, self.device)
79 | with chainer.no_backprop_mode(), chainer.using_config('train', False):
80 | if self.dataset == 'msd_bound':
81 | h, w, d = x.shape[2:]
82 | hc, wc, dc = self.crop_size
83 | if self.nested_label:
84 | y = cupy.zeros((1, 2*(self.nb_labels-1), h, w, d), dtype='float32')
85 | else:
86 | y = cupy.zeros((1, self.nb_labels, h, w, d), dtype='float32')
87 | s = 128 # stride
88 | ker = 256 # kernel size
89 | dker = dc # kernel size for depth
90 | ds = dker*0.5 # stride for depth
91 | dsteps = int(math.floor((d-dker)/ds) + 1)
92 | steps = round((h - ker)/s + 1)
93 | for i in range(steps):
94 | for j in range(steps):
95 | for k in range(dsteps):
96 | xx = x[:, :, s*i:ker+s*i, s*j:ker+s*j, ds*k:dker+ds*k]
97 | hs = enc(xx)
98 | yy = dec(hs)
99 | y[:, :, s*i:ker+s*i, s*j:ker+s*j, ds*k:dker+ds*k] += yy.data
100 | # for the bottom depth part of the image
101 | xx = x[:, :, s*i:ker+s*i, s*j:ker+s*j, -dker:]
102 | hs = enc(xx)
103 | yy = dec(hs)
104 | y[:, :, s*i:ker+s*i, s*j:ker+s*j, -dker:] += yy.data
105 | else:
106 | hs = enc(x)
107 | y = dec(hs)
108 | seg_loss = self.compute_loss(y, t)
109 | accuracy = self.compute_accuracy(y, t)
110 | dice = self.compute_dice_coef(y, t)
111 | mean_dice = mean_dice_coefficient(dice)
112 | weighted_loss = seg_loss
113 |
114 | observation = {}
115 | with reporter.scope(observation):
116 | reporter.report({
117 | 'loss/seg': seg_loss,
118 | 'loss/total': weighted_loss,
119 | 'acc': accuracy,
120 | 'mean_dc': mean_dice
121 | }, observer)
122 | xp = cuda.get_array_module(y)
123 | for i in range(len(dice)):
124 | if not xp.isnan(dice.data[i]):
125 | reporter.report({'dc_{}'.format(i): dice[i]}, observer)
126 | summary.add(observation)
127 | return summary.compute_mean()
128 |
--------------------------------------------------------------------------------
/src/training/extensions/vaeseg_evaluator.py:
--------------------------------------------------------------------------------
1 | import copy
2 | import cupy
3 | import math
4 | import chainer
5 | import chainer.functions as F
6 | from chainer.backends import cuda
7 | from chainer.training.extensions import Evaluator
8 | from chainer.dataset import convert
9 | from chainer import Reporter
10 | from chainer import reporter as reporter_module
11 | from src.functions.evaluation import dice_coefficient, mean_dice_coefficient
12 | from src.functions.loss.mixed_dice_loss import dice_loss_plus_cross_entropy
13 |
14 |
15 | class VAESegEvaluator(Evaluator):
16 |
17 | def __init__(
18 | self,
19 | config,
20 | iterator,
21 | target,
22 | converter=convert.concat_examples,
23 | device=None,
24 | ):
25 | super().__init__(iterator, target,
26 | converter, device,
27 | None, None)
28 | self.nested_label = config['nested_label']
29 | self.seg_lossfun = eval(config['seg_lossfun'])
30 | self.rec_loss_weight = config['vaeseg_rec_loss_weight']
31 | self.kl_loss_weight = config['vaeseg_kl_loss_weight']
32 | self.dataset = config['dataset_name']
33 | self.nb_labels = config['nb_labels']
34 | self.crop_size = eval(config['crop_size'])
35 | self.is_brats = config['is_brats']
36 |
37 | def compute_loss(self, y, t):
38 | if self.nested_label:
39 | loss = 0.
40 | b, c, h, w, d = t.shape
41 | for i in range(c):
42 | loss += self.seg_lossfun(y[:, 2*i:2*(i+1), ...], t[:, i, ...])
43 | else:
44 | loss = self.seg_lossfun(y, t)
45 | return loss
46 |
47 | def compute_accuracy(self, y, t):
48 | if self.nested_label:
49 | b, c, h, w, d = t.shape
50 | y = F.reshape(y, (b, 2, h*c, w, d))
51 | t = F.reshape(t, (b, h*c, w, d))
52 | return F.accuracy(y, t)
53 |
54 | def compute_dice_coef(self, y, t):
55 | if self.nested_label:
56 | dice = mean_dice_coefficient(dice_coefficient(y[:, 0:2, ...], t[:, 0, ...]))
57 | for i in range(1, t.shape[1]):
58 | dices = dice_coefficient(y[:, 2*i:2*(i+1), ...], t[:, i, ...], dataset=self.dataset)
59 | dice = F.concat((dice, mean_dice_coefficient(dices)), axis=0)
60 | else:
61 | dice = dice_coefficient(y, t, is_brats=self.is_brats)
62 | return dice
63 |
64 | def evaluate(self):
65 | summary = reporter_module.DictSummary()
66 | iterator = self._iterators['main']
67 | enc = self._targets['enc']
68 | dec = self._targets['dec']
69 | reporter = Reporter()
70 | observer = object()
71 | reporter.add_observer(self.default_name, observer)
72 |
73 | if hasattr(iterator, 'reset'):
74 | iterator.reset()
75 | it = iterator
76 | else:
77 | it = copy.copy(iterator)
78 |
79 | for batch in it:
80 | x, t = self.converter(batch, self.device)
81 | with chainer.no_backprop_mode(), chainer.using_config('train', False):
82 | if self.dataset == 'msd_bound':
83 | h, w, d = x.shape[2:]
84 | hc, wc, dc = self.crop_size
85 | if self.nested_label:
86 | y = cupy.zeros((1, 2*(self.nb_labels-1), h, w, d), dtype='float32')
87 | else:
88 | y = cupy.zeros((1, self.nb_labels, h, w, d), dtype='float32')
89 | s = 128 # stride
90 | ker = 256 # kernel size
91 | dker = dc # kernel size for depth
92 | ds = dker*0.5 # stride for depth
93 | dsteps = int(math.floor((d-dker)/ds) + 1)
94 | steps = round((h - ker)/s + 1)
95 | for i in range(steps):
96 | for j in range(steps):
97 | for k in range(dsteps):
98 | xx = x[:, :, s*i:ker+s*i, s*j:ker+s*j, ds*k:dker+ds*k]
99 | hs = enc(xx)
100 | yy = dec(hs)
101 | y[:, :, s*i:ker+s*i, s*j:ker+s*j, ds*k:dker+ds*k] += yy.data
102 | # for the bottom depth part of the image
103 | xx = x[:, :, s*i:ker+s*i, s*j:ker+s*j, -dker:]
104 | hs = enc(xx)
105 | yy = dec(hs)
106 | y[:, :, s*i:ker+s*i, s*j:ker+s*j, -dker:] += yy.data
107 |
108 | else:
109 | hs = enc(x)
110 | y = dec(hs)
111 | seg_loss = self.compute_loss(y, t)
112 | accuracy = self.compute_accuracy(y, t)
113 | dice = self.compute_dice_coef(y, t)
114 | mean_dice = mean_dice_coefficient(dice)
115 | weighted_loss = seg_loss
116 |
117 | observation = {}
118 | with reporter.scope(observation):
119 | reporter.report({
120 | 'loss/seg': seg_loss,
121 | 'loss/total': weighted_loss,
122 | 'acc': accuracy,
123 | 'mean_dc': mean_dice
124 | }, observer)
125 | xp = cuda.get_array_module(y)
126 | for i in range(len(dice)):
127 | if not xp.isnan(dice.data[i]):
128 | reporter.report({'dc_{}'.format(i): dice[i]}, observer)
129 | summary.add(observation)
130 | return summary.compute_mean()
131 |
--------------------------------------------------------------------------------
/src/training/updaters/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pfnet-research/label-efficient-brain-tumor-segmentation/aad80ed7acb510a3147bb11c3910d2e17fb355d1/src/training/updaters/__init__.py
--------------------------------------------------------------------------------
/src/training/updaters/boundseg_updater.py:
--------------------------------------------------------------------------------
1 | import chainer
2 | import chainer.functions as F
3 | from chainer.backends import cuda
4 | from chainer import reporter
5 | from chainer import Variable
6 | from chainer.training import StandardUpdater
7 | from src.functions.evaluation import dice_coefficient, mean_dice_coefficient
8 | from src.functions.loss import softmax_dice_loss
9 | from src.functions.loss.mixed_dice_loss import dice_loss_plus_cross_entropy
10 | from src.functions.loss.boundary_bce import boundary_bce
11 |
12 |
13 | class BoundSegUpdater(StandardUpdater):
14 |
15 | def __init__(self, config, **kwargs):
16 | self.nested_label = config['nested_label']
17 | self.seg_lossfun = eval(config['seg_lossfun'])
18 | self.optimizer_name = config['optimizer']
19 | self.init_lr = config['init_lr']
20 | self.nb_epoch = config['epoch']
21 | self.init_weight = config['init_encoder']
22 | self.enc_freeze = config['enc_freeze']
23 | self.edge_label = config['edge_label']
24 | super(BoundSegUpdater, self).__init__(**kwargs)
25 |
26 | def get_optimizer_and_model(self, key):
27 | optimizer = self.get_optimizer(key)
28 | return optimizer, optimizer.target
29 |
30 | def get_batch(self):
31 | batch = self.get_iterator('main').next()
32 | batchsize = len(batch)
33 | in_arrays = self.converter(batch, self.device)
34 | return Variable(in_arrays[0]), in_arrays[1], in_arrays[2], batchsize
35 |
36 | def report_scores(self, y, t):
37 | with chainer.no_backprop_mode():
38 | if self.nested_label:
39 | dice = mean_dice_coefficient(dice_coefficient(y[:, 0:2, ...], t[:, 0, ...]))
40 | for i in range(1, t.shape[1]):
41 | dices = dice_coefficient(y[:, 2 * i:2 * (i + 1), ...], t[:, i, ...])
42 | dice = F.concat((dice, mean_dice_coefficient(dices)), axis=0)
43 | else:
44 | dice = dice_coefficient(y, t)
45 | mean_dice = mean_dice_coefficient(dice)
46 |
47 | if self.nested_label:
48 | b, c, h, w, d = t.shape
49 | y = F.reshape(y, (b, 2, h * c, w, d))
50 | t = F.reshape(t, (b, h * c, w, d))
51 | accuracy = F.accuracy(y, t)
52 |
53 | reporter.report({
54 | 'acc': accuracy,
55 | 'mean_dc': mean_dice
56 | })
57 | xp = cuda.get_array_module(y)
58 | for i in range(len(dice)):
59 | if not xp.isnan(dice.data[i]):
60 | reporter.report({'dc_{}'.format(i): dice[i]})
61 |
62 | def update_core(self):
63 | opt_e, enc = self.get_optimizer_and_model('enc')
64 | opt_d, dec = self.get_optimizer_and_model('dec')
65 | opt_b, bound = self.get_optimizer_and_model('bound')
66 |
67 | if self.is_new_epoch:
68 | decay_rate = (1. - float(self.epoch / self.nb_epoch)) ** 0.9
69 | if self.optimizer_name == 'Adam':
70 | if self.init_weight is not None:
71 | opt_e.alpha = self.init_lr*self.enc_freeze * decay_rate
72 | # small learning rate for encoder
73 | else:
74 | opt_e.alpha = self.init_lr * decay_rate
75 | opt_d.alpha = self.init_lr * decay_rate
76 | opt_b.alpha = self.init_lr * decay_rate
77 | else:
78 | if self.init_weight is not None:
79 | opt_e.lr = self.init_lr*self.enc_freeze * decay_rate
80 | # small learning rate for encoder
81 | else:
82 | opt_e.lr = self.init_lr * decay_rate
83 | opt_d.lr = self.init_lr * decay_rate
84 | opt_b.lr = self.init_lr * decay_rate
85 |
86 | x, t, te, batchsize = self.get_batch()
87 |
88 | hs = enc(x)
89 | ye, bs = bound(hs) # ye:output for the edge loss ,bs: output for the decoder
90 | y = dec(hs, bs)
91 |
92 | if self.nested_label:
93 | seg_loss = 0.
94 | for i in range(t.shape[1]):
95 | seg_loss += self.seg_lossfun(y[:, 2*i:2*(i+1), ...], t[:, i, ...])
96 | else:
97 | seg_loss = self.seg_lossfun(y, t)
98 |
99 | ye_ = ye.data[:, 1:, :, :, :] # exclude background information from prediction
100 | te_ = te[:, 1:, :, :, :]
101 | if self.nested_label:
102 | edge_loss = 0.
103 | bce_loss = 0.
104 | for i in range(te_.shape[1]):
105 | edge_loss += softmax_dice_loss(ye[:, 2*i:2*(i+1), ...], te_[:, i, ...])
106 | bce_loss += boundary_bce(ye[:, 2*i:2*(i+1), ...], te[:, [0, i+1], ...])
107 | else:
108 | edge_loss = softmax_dice_loss(ye_, te_, encode=False)
109 | bce_loss = boundary_bce(ye, te)
110 |
111 | opt_e.target.cleargrads()
112 | opt_d.target.cleargrads()
113 | opt_b.target.cleargrads()
114 |
115 | weighted_loss = seg_loss + edge_loss + bce_loss
116 | weighted_loss.backward()
117 |
118 | opt_e.update()
119 | opt_d.update()
120 | opt_b.update()
121 |
122 | self.report_scores(y, t)
123 | reporter.report({
124 | 'loss/seg': seg_loss,
125 | 'loss/total': weighted_loss,
126 | 'loss/bound': edge_loss,
127 | 'loss/bce': bce_loss
128 | })
129 |
--------------------------------------------------------------------------------
/src/training/updaters/cpcseg_updater.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import chainer
3 | import chainer.functions as F
4 | from chainer.backends import cuda
5 | from chainer import reporter
6 | from chainer import Variable
7 | from chainer.training import StandardUpdater
8 | from src.functions.loss.mixed_dice_loss import dice_loss_plus_cross_entropy
9 | from src.functions.evaluation import dice_coefficient, mean_dice_coefficient
10 | from src.functions.loss.cpc_loss import cpc_loss
11 | from src.links.model.vaeseg import divide_img
12 |
13 |
14 | class CPCSegUpdater(StandardUpdater):
15 |
16 | def __init__(self, config, **kwargs):
17 | self.nested_label = config['nested_label']
18 | self.seg_lossfun = eval(config['seg_lossfun'])
19 | self.k = config['vaeseg_nb_sampling']
20 | self.rec_loss_weight = config['vaeseg_rec_loss_weight']
21 | self.kl_loss_weight = config['vaeseg_kl_loss_weight']
22 | self.optimizer_name = config['optimizer']
23 | self.init_lr = config['init_lr']
24 | self.nb_epoch = config['epoch']
25 | self.pretrain = config['pretrain']
26 | self.init_weight = config['init_encoder']
27 | self.grid_size = config['grid_size']
28 | self.base_channels = config['base_channels']
29 | self.cpc_loss_weight = config['cpc_vaeseg_cpc_loss_weight']
30 | self.enc_freeze = config['enc_freeze']
31 | self.cpc_pattern = config['cpc_pattern']
32 | self.idle_weight = config['vae_idle_weight']
33 | super(CPCSegUpdater, self).__init__(**kwargs)
34 |
35 | def get_optimizer_and_model(self, key):
36 | optimizer = self.get_optimizer(key)
37 | return optimizer, optimizer.target
38 |
39 | def get_batch(self):
40 | batch = self.get_iterator('main').next()
41 | batchsize = len(batch)
42 | in_arrays = self.converter(batch, self.device)
43 | return Variable(in_arrays[0]), in_arrays[1], batchsize
44 |
45 | def report_scores(self, y, t):
46 | with chainer.no_backprop_mode():
47 | if self.nested_label:
48 | dice = mean_dice_coefficient(dice_coefficient(y[:, 0:2, ...], t[:, 0, ...]))
49 | for i in range(1, t.shape[1]):
50 | dices = dice_coefficient(y[:, 2 * i:2 * (i + 1), ...], t[:, i, ...])
51 | dice = F.concat((dice, mean_dice_coefficient(dices)), axis=0)
52 | else:
53 | dice = dice_coefficient(y, t)
54 | mean_dice = mean_dice_coefficient(dice)
55 | if self.nested_label:
56 | b, c, h, w, d = t.shape
57 | y = F.reshape(y, (b, 2, h * c, w, d))
58 | t = F.reshape(t, (b, h * c, w, d))
59 | accuracy = F.accuracy(y, t)
60 |
61 | reporter.report({
62 | 'acc': accuracy,
63 | 'mean_dc': mean_dice
64 | })
65 | xp = cuda.get_array_module(y)
66 | for i in range(len(dice)):
67 | if not xp.isnan(dice.data[i]):
68 | reporter.report({'dc_{}'.format(i): dice[i]})
69 |
70 | def update_core(self):
71 | opt_e, enc = self.get_optimizer_and_model('enc')
72 | opt_d, dec = self.get_optimizer_and_model('dec')
73 |
74 | opt_p1, cpcpred1 = self.get_optimizer_and_model('cpcpred1')
75 |
76 | if self.is_new_epoch:
77 | decay_rate = (1. - float(self.epoch / self.nb_epoch)) ** 0.9
78 | if self.optimizer_name == 'Adam':
79 | if self.init_weight is not None:
80 | opt_e.alpha = self.init_lr*self.enc_freeze * decay_rate
81 | else:
82 | opt_e.alpha = self.init_lr * decay_rate
83 | opt_d.alpha = self.init_lr * decay_rate
84 | opt_p1.alpha = self.init_lr * decay_rate
85 | else:
86 | if self.init_weight is not None:
87 | opt_e.lr = self.init_lr*self.enc_freeze * decay_rate
88 | else:
89 | opt_e.lr = self.init_lr * decay_rate
90 | opt_d.lr = self.init_lr * decay_rate
91 | opt_p1.lr = self.init_lr * decay_rate
92 |
93 | x, t, batchsize = self.get_batch()
94 |
95 | b, c, x1, x2, x3 = x.shape
96 | ssize = int(self.grid_size*0.5) # stride size
97 | gl0 = int(x1/ssize-1) # grid length
98 | gl1 = int(x2/ssize-1)
99 | gl2 = int(x3/ssize-1)
100 |
101 | hs = enc(x)
102 | h2 = divide_img(x)
103 | h2 = enc(h2)
104 | h2 = F.average(h2[-1], axis=(2, 3, 4))
105 | cpc_t = F.transpose(h2, axes=(1, 0))
106 | cpc_t = F.reshape(cpc_t, (1, self.base_channels*8, gl0, gl1, gl2))
107 |
108 | y1 = cpcpred1(h2)
109 | cpc_loss1 = cpc_loss(y1, cpc_t, upper=True, cpc_pattern=self.cpc_pattern)
110 | cpc_loss_tot = cpc_loss1
111 |
112 | y = dec(hs)
113 |
114 | opt_e.target.cleargrads()
115 | opt_d.target.cleargrads()
116 | opt_p1.target.cleargrads()
117 |
118 | xp = cuda.get_array_module(t)
119 | if xp.sum(t) < -(2**31):
120 | weighted_loss = self.idle_weight * self.cpc_loss_weight * cpc_loss_tot
121 | weighted_loss.backward()
122 | t = xp.zeros(t.shape, dtype=np.int32)
123 |
124 | else:
125 | if self.nested_label:
126 | seg_loss = 0.
127 | for i in range(t.shape[1]):
128 | seg_loss += self.seg_lossfun(y[:, 2*i:2*(i+1), ...], t[:, i, ...])
129 | else:
130 | seg_loss = self.seg_lossfun(y, t)
131 |
132 | weighted_loss = seg_loss + self.cpc_loss_weight*cpc_loss_tot
133 | weighted_loss.backward()
134 |
135 | opt_e.update()
136 | opt_d.update()
137 | opt_p1.update()
138 |
139 | self.report_scores(y, t)
140 | reporter.report({
141 | 'loss/cpc': cpc_loss_tot,
142 | 'loss/total': weighted_loss
143 | })
144 |
--------------------------------------------------------------------------------
/src/training/updaters/encdec_seg_updater.py:
--------------------------------------------------------------------------------
1 | import chainer
2 | import chainer.functions as F
3 | from chainer.backends import cuda
4 | from chainer import reporter
5 | from chainer import Variable
6 | from chainer.training import StandardUpdater
7 | from src.functions.evaluation import dice_coefficient, mean_dice_coefficient
8 | from src.functions.loss.mixed_dice_loss import dice_loss_plus_cross_entropy
9 |
10 |
11 | class EncDecSegUpdater(StandardUpdater):
12 |
13 | def __init__(self, config, **kwargs):
14 | self.nested_label = config['nested_label']
15 | self.seg_lossfun = eval(config['seg_lossfun'])
16 | self.optimizer_name = config['optimizer']
17 | self.init_lr = config['init_lr']
18 | self.nb_epoch = config['epoch']
19 | self.init_weight = config['init_encoder']
20 | self.enc_freeze = config['enc_freeze']
21 | super(EncDecSegUpdater, self).__init__(**kwargs)
22 |
23 | def get_optimizer_and_model(self, key):
24 | optimizer = self.get_optimizer(key)
25 | return optimizer, optimizer.target
26 |
27 | def get_batch(self):
28 | batch = self.get_iterator('main').next()
29 | batchsize = len(batch)
30 | in_arrays = self.converter(batch, self.device)
31 | return Variable(in_arrays[0]), in_arrays[1], batchsize
32 |
33 | def report_scores(self, y, t):
34 | with chainer.no_backprop_mode():
35 | if self.nested_label:
36 | dice = mean_dice_coefficient(dice_coefficient(y[:, 0:2, ...], t[:, 0, ...]))
37 | for i in range(1, t.shape[1]):
38 | dices = dice_coefficient(y[:, 2 * i:2 * (i + 1), ...], t[:, i, ...])
39 | dice = F.concat((dice, mean_dice_coefficient(dices)), axis=0)
40 | else:
41 | dice = dice_coefficient(y, t)
42 | mean_dice = mean_dice_coefficient(dice)
43 | if self.nested_label:
44 | b, c, h, w, d = t.shape
45 | y = F.reshape(y, (b, 2, h * c, w, d))
46 | t = F.reshape(t, (b, h * c, w, d))
47 | accuracy = F.accuracy(y, t)
48 |
49 | reporter.report({
50 | 'acc': accuracy,
51 | 'mean_dc': mean_dice
52 | })
53 | xp = cuda.get_array_module(y)
54 | for i in range(len(dice)):
55 | if not xp.isnan(dice.data[i]):
56 | reporter.report({'dc_{}'.format(i): dice[i]})
57 |
58 | def update_core(self):
59 | opt_e, enc = self.get_optimizer_and_model('enc')
60 | opt_d, dec = self.get_optimizer_and_model('dec')
61 |
62 | if self.is_new_epoch:
63 | decay_rate = (1. - float(self.epoch / self.nb_epoch)) ** 0.9
64 | if self.optimizer_name == 'Adam':
65 | if self.init_weight is not None:
66 | opt_e.alpha = self.init_lr*self.enc_freeze * decay_rate
67 | else:
68 | opt_e.alpha = self.init_lr * decay_rate
69 | opt_d.alpha = self.init_lr * decay_rate
70 | else:
71 | if self.init_weight is not None:
72 | opt_e.lr = self.init_lr*self.enc_freeze * decay_rate
73 | else:
74 | opt_e.lr = self.init_lr * decay_rate
75 | opt_d.lr = self.init_lr * decay_rate
76 |
77 | x, t, batchsize = self.get_batch()
78 |
79 | hs = enc(x)
80 |
81 | y = dec(hs)
82 |
83 | if self.nested_label:
84 | seg_loss = 0.
85 | for i in range(t.shape[1]):
86 | seg_loss += self.seg_lossfun(y[:, 2*i:2*(i+1), ...], t[:, i, ...])
87 | else:
88 | seg_loss = self.seg_lossfun(y, t)
89 |
90 | opt_e.target.cleargrads()
91 | opt_d.target.cleargrads()
92 |
93 | weighted_loss = seg_loss
94 | weighted_loss.backward()
95 |
96 | opt_e.update()
97 | opt_d.update()
98 |
99 | self.report_scores(y, t)
100 | reporter.report({
101 | 'loss/seg': seg_loss,
102 | 'loss/total': weighted_loss
103 | })
104 |
--------------------------------------------------------------------------------
/src/training/updaters/vaeseg_updater.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import chainer
3 | import chainer.functions as F
4 | from chainer.backends import cuda
5 | from chainer import reporter
6 | from chainer import Variable
7 | from chainer.training import StandardUpdater
8 | from src.functions.evaluation import dice_coefficient, mean_dice_coefficient
9 | from src.functions.loss.mixed_dice_loss import dice_loss_plus_cross_entropy
10 |
11 |
12 | class VAESegUpdater(StandardUpdater):
13 |
14 | def __init__(self, config, **kwargs):
15 | self.nested_label = config['nested_label']
16 | self.seg_lossfun = eval(config['seg_lossfun'])
17 | self.k = config['vaeseg_nb_sampling']
18 | self.rec_loss_weight = config['vaeseg_rec_loss_weight']
19 | self.kl_loss_weight = config['vaeseg_kl_loss_weight']
20 | self.optimizer_name = config['optimizer']
21 | self.init_lr = config['init_lr']
22 | self.nb_epoch = config['epoch']
23 | self.pretrain = config['pretrain']
24 | self.init_weight = config['init_encoder']
25 | self.enc_freeze = config['enc_freeze']
26 | self.idle_weight = config['vae_idle_weight']
27 | super(VAESegUpdater, self).__init__(**kwargs)
28 |
29 | def get_optimizer_and_model(self, key):
30 | optimizer = self.get_optimizer(key)
31 | return optimizer, optimizer.target
32 |
33 | def get_batch(self):
34 | batch = self.get_iterator('main').next()
35 | batchsize = len(batch)
36 | in_arrays = self.converter(batch, self.device)
37 | return Variable(in_arrays[0]), in_arrays[1], batchsize
38 |
39 | def report_scores(self, y, t):
40 | with chainer.no_backprop_mode():
41 | if self.nested_label:
42 | dice = mean_dice_coefficient(dice_coefficient(y[:, 0:2, ...], t[:, 0, ...]))
43 | for i in range(1, t.shape[1]):
44 | dices = dice_coefficient(y[:, 2 * i:2 * (i + 1), ...], t[:, i, ...])
45 | dice = F.concat((dice, mean_dice_coefficient(dices)), axis=0)
46 | else:
47 | dice = dice_coefficient(y, t)
48 | mean_dice = mean_dice_coefficient(dice)
49 | if self.nested_label:
50 | b, c, h, w, d = t.shape
51 | y = F.reshape(y, (b, 2, h * c, w, d))
52 | t = F.reshape(t, (b, h * c, w, d))
53 | accuracy = F.accuracy(y, t)
54 |
55 | reporter.report({
56 | 'acc': accuracy,
57 | 'mean_dc': mean_dice
58 | })
59 | xp = cuda.get_array_module(y)
60 | for i in range(len(dice)):
61 | if not xp.isnan(dice.data[i]):
62 | reporter.report({'dc_{}'.format(i): dice[i]})
63 |
64 | def update_core(self):
65 | opt_e, enc = self.get_optimizer_and_model('enc')
66 | opt_em, emb = self.get_optimizer_and_model('emb')
67 | opt_d, dec = self.get_optimizer_and_model('dec')
68 | opt_v, vae = self.get_optimizer_and_model('vae')
69 |
70 | if self.is_new_epoch:
71 | decay_rate = (1. - float(self.epoch / self.nb_epoch)) ** 0.9
72 | if self.optimizer_name == 'Adam':
73 | if self.init_weight is not None:
74 | opt_e.alpha = self.init_lr*self.enc_freeze * decay_rate
75 | else:
76 | opt_e.alpha = self.init_lr * decay_rate
77 | opt_em.alpha = self.init_lr * decay_rate
78 | opt_d.alpha = self.init_lr * decay_rate
79 | opt_v.alpha = self.init_lr * decay_rate
80 | else:
81 | if self.init_weight is not None:
82 | opt_e.lr = self.init_lr*self.enc_freeze * decay_rate
83 | else:
84 | opt_e.lr = self.init_lr * decay_rate
85 | opt_em.lr = self.init_lr * decay_rate
86 | opt_d.lr = self.init_lr * decay_rate
87 | opt_v.lr = self.init_lr * decay_rate
88 |
89 | x, t, batchsize = self.get_batch()
90 |
91 | hs = enc(x)
92 | mu, ln_var = emb(hs[-1])
93 | latent_size = np.prod(list(mu.shape))
94 |
95 | kl_loss = F.gaussian_kl_divergence(mu, ln_var, reduce='sum') / latent_size
96 |
97 | rec_loss = 0.
98 | for i in range(self.k):
99 | z = F.gaussian(mu, ln_var)
100 | rec_x = vae(z)
101 | rec_loss += F.mean_squared_error(x, rec_x)
102 | rec_loss /= self.k
103 |
104 | opt_e.target.cleargrads()
105 | opt_em.target.cleargrads()
106 | opt_v.target.cleargrads()
107 | opt_d.target.cleargrads()
108 | xp = cuda.get_array_module(t)
109 | if xp.sum(t) < -(2**31):
110 | # if label is NaN, optimize encoder, VAE only
111 | weighted_loss = self.idle_weight*(
112 | self.rec_loss_weight
113 | * rec_loss + self.kl_loss_weight * kl_loss)
114 | t = xp.zeros(t.shape, dtype=np.int32)
115 | y = dec(hs)
116 |
117 | else:
118 | y = dec(hs)
119 | if self.nested_label:
120 | seg_loss = 0.
121 | for i in range(t.shape[1]):
122 | seg_loss += self.seg_lossfun(y[:, 2*i:2*(i+1), ...], t[:, i, ...])
123 | else:
124 | seg_loss = self.seg_lossfun(y, t)
125 | if self.pretrain:
126 | weighted_loss = self.rec_loss_weight * rec_loss + self.kl_loss_weight * kl_loss
127 | else:
128 | weighted_loss = seg_loss + self.rec_loss_weight * rec_loss \
129 | + self.kl_loss_weight * kl_loss
130 |
131 | weighted_loss.backward()
132 | opt_e.update()
133 | opt_em.update()
134 | opt_v.update()
135 | opt_d.update()
136 |
137 | self.report_scores(y, t)
138 | reporter.report({
139 | 'loss/rec': rec_loss,
140 | 'loss/kl': kl_loss,
141 | 'loss/total': weighted_loss
142 | })
143 |
--------------------------------------------------------------------------------
/src/utils/config.py:
--------------------------------------------------------------------------------
1 | import copy
2 | import os
3 | import time
4 | import yaml
5 |
6 |
7 | default_config = {
8 | 'dataset_name': 'msd_bound',
9 | 'image_path': '/PATH_TO_NORMALIZED_IMAGES/imagesTr_normalized',
10 | 'label_path': '/PATH_TO_LABELS/labelsTr',
11 | 'image_file_format': 'npz',
12 | 'label_file_format': 'nii',
13 | 'train_list_path': '/PATH_TO_TRAINING_LIST/train_list_cv0.txt',
14 | 'validation_list_path': '/PATH_TO_VALIDATION_LIST/validation_list_cv0.txt',
15 | 'test_list_path': '/PATH_TO_TEST_LIST/test_list_cv0.txt',
16 | 'target_label': 0,
17 | 'nested_label': False,
18 | 'crop_size': '(160,192,128)',
19 | 'shift_intensity': 0.1,
20 | 'random_flip': True,
21 |
22 | 'segmentor_name': 'vaeseg',
23 | 'in_channels': 4,
24 | 'base_channels': 32,
25 | 'nb_labels': 4,
26 | 'vaeseg_norm': 'GroupNormalization',
27 | 'vaeseg_bn_first': True,
28 | 'vaeseg_ndim_latent': 128,
29 | 'vaeseg_skip_connect_mode': 'sum',
30 | 'unet_cropping': False,
31 | 'dv_num_trans_layers': 12,
32 |
33 | 'seg_lossfun': 'softmax_dice_loss',
34 | 'auxiliary_weights': None,
35 | 'vaeseg_nb_sampling': 1,
36 | 'vaeseg_rec_loss_weight': 0.1,
37 | 'vaeseg_kl_loss_weight': 0.1,
38 | 'structseg_nb_copies': 1,
39 |
40 | 'mn': True,
41 | 'gpu_start_id': 0,
42 | 'loaderjob': 2,
43 | 'batchsize': 1,
44 | 'val_batchsize': 2,
45 | 'epoch': 200,
46 | 'optimizer': 'Adam',
47 | 'init_lr': 1e-4,
48 | 'lr_reduction_ratio': 0.99,
49 | 'lr_reduction_interval': (1, 'epoch'),
50 | 'weight_decay': 1e-5,
51 | 'report_interval': 10,
52 | 'eval_interval': 2,
53 | 'snapshot_interval': 20,
54 | 'init_segmentor': None,
55 | 'init_encoder': None,
56 | 'init_decoder': None,
57 | 'init_vae': None,
58 | 'resume': None,
59 |
60 | 'pretrain': False,
61 | 'grid_size': 32,
62 | 'init_embedder': None,
63 | 'cpc_vaeseg_cpc_loss_weight': 0.001,
64 | 'init_cpcpred': None,
65 | 'enc_freeze': 0.01,
66 | 'cpc_pattern': 'updown',
67 | 'random_scale': False,
68 | 'edge_path': '/PATH_TO_EDGE_LABELS/edgesTr',
69 | 'edge_file_format': 'npy',
70 | 'edge_label': False,
71 | 'print_each_dc': True,
72 | 'is_brats': False,
73 | 'ignore_path': None,
74 | 'vae_idle_weight': 1
75 | }
76 |
77 |
78 | def overwrite_config(
79 | input_cfg,
80 | dump_yaml_dir=None
81 | ):
82 | output_cfg = copy.copy(default_config)
83 | for key, val in input_cfg.items():
84 | if key not in output_cfg:
85 | raise ValueError('Unknown configuration key: {}'.format(key))
86 | output_cfg[key] = val
87 | if dump_yaml_dir is not None:
88 | os.makedirs(dump_yaml_dir, exist_ok=True)
89 | cur_time = time.strftime("%Y-%m-%d--%H-%M-%S", time.gmtime())
90 | dump_yaml_path = os.path.join(
91 | dump_yaml_dir, '{}.yaml'.format(cur_time))
92 | with open(dump_yaml_path, 'w') as f:
93 | yaml.dump(output_cfg, f)
94 | return output_cfg
95 |
--------------------------------------------------------------------------------
/src/utils/encode_one_hot_vector.py:
--------------------------------------------------------------------------------
1 | from chainer.backends import cuda
2 |
3 |
4 | def _encode_one_hot_vector_core(x, nb_class):
5 | xp = cuda.get_array_module(x)
6 | batch, h, w, d = x.shape
7 |
8 | res = xp.zeros((batch, nb_class, h, w, d), dtype=xp.float32)
9 | x = x.reshape(batch, -1)
10 | for i in range(batch):
11 | y = xp.identity(nb_class, dtype=xp.float32)[x[i]]
12 | res[i] = xp.swapaxes(y, 0, 1).reshape((nb_class, h, w, d))
13 | return res
14 |
15 |
16 | def encode_one_hot_vector(x, nb_class):
17 | if isinstance(x, cuda.ndarray):
18 | with x.device:
19 | return _encode_one_hot_vector_core(x, nb_class)
20 | else:
21 | return _encode_one_hot_vector_core(x, nb_class)
22 |
--------------------------------------------------------------------------------
/src/utils/setup_helpers.py:
--------------------------------------------------------------------------------
1 | from chainer.links import BatchNormalization, GroupNormalization
2 | from chainermn.links import MultiNodeBatchNormalization
3 | from chainer.functions import softmax_cross_entropy
4 | from chainer.optimizers import Adam
5 | from chainer.iterators import MultiprocessIterator, SerialIterator
6 | from chainer.optimizer import WeightDecay
7 | from chainer import serializers
8 | from chainer import training
9 | from chainer.training import extensions
10 | from chainer.backends.cuda import get_device_from_id
11 | import chainermn
12 | from src.datasets.msd_bound import MSDBoundDataset
13 | from src.links.model.vaeseg import BoundaryStream, CPCPredictor, Decoder, Encoder, VAE, VD
14 | from src.training.updaters.vaeseg_updater import VAESegUpdater
15 | from src.training.extensions.vaeseg_evaluator import VAESegEvaluator
16 | from src.training.updaters.encdec_seg_updater import EncDecSegUpdater
17 | from src.training.extensions.encdec_seg_evaluator import EncDecSegEvaluator
18 | from src.training.updaters.boundseg_updater import BoundSegUpdater
19 | from src.training.extensions.boundseg_evaluator import BoundSegEvaluator
20 | from src.training.updaters.cpcseg_updater import CPCSegUpdater
21 | from src.training.extensions.cpcseg_evaluator import CPCSegEvaluator
22 |
23 |
24 | def _setup_communicator(config, gpu_start_id=0):
25 | if config['mn']:
26 | comm = chainermn.create_communicator('pure_nccl')
27 | is_master = (comm.rank == 0)
28 | device = comm.intra_rank + gpu_start_id
29 | else:
30 | comm = None
31 | is_master = True
32 | device = gpu_start_id
33 | return comm, is_master, device
34 |
35 |
36 | def _setup_datasets(config, comm, is_master):
37 | if is_master:
38 | if config['dataset_name'] == 'msd_bound':
39 | train_data = MSDBoundDataset(config, config['train_list_path'])
40 | validation_data = MSDBoundDataset(config, config['validation_list_path'])
41 | test_data = MSDBoundDataset(config, config['test_list_path'])
42 | validation_data.random_scale = False
43 | test_data.random_scale = False
44 | validation_data.shift_intensity = 0
45 | test_data.shift_intensity = 0
46 | validation_data.random_flip = False
47 | test_data.random_flip = False
48 | validation_data.nb_copies = 1
49 | test_data.nb_copies = 1
50 | validation_data.training = False
51 | test_data.training = False
52 | else:
53 | raise ValueError('Unknown dataset_name: {}'.format(config['dataset_name']))
54 | print('Training dataset size: {}'.format(len(train_data)))
55 | print('Validation dataset size: {}'.format(len(validation_data)))
56 | print('Test dataset size: {}'.format(len(test_data)))
57 | else:
58 | train_data = None
59 | validation_data = None
60 | test_data = None
61 |
62 | # scatter dataset
63 | if comm is not None:
64 | train_data = chainermn.scatter_dataset(train_data, comm, shuffle=True)
65 | validation_data = chainermn.scatter_dataset(validation_data, comm, shuffle=True)
66 | test_data = chainermn.scatter_dataset(test_data, comm, shuffle=True)
67 |
68 | return train_data, validation_data, test_data
69 |
70 |
71 | def _setup_vae_segmentor(config, comm=None):
72 | in_channels = config['in_channels']
73 | base_channels = config['base_channels']
74 | out_channels = config['nb_labels']
75 | nested_label = config['nested_label']
76 | norm = eval(config['vaeseg_norm'])
77 | bn_first = config['vaeseg_bn_first']
78 | ndim_latent = config['vaeseg_ndim_latent']
79 | mode = config['vaeseg_skip_connect_mode']
80 | input_shape = eval(config['crop_size'])
81 | if nested_label:
82 | out_channels = 2 * (out_channels - 1)
83 |
84 | encoder = Encoder(
85 | base_channels=base_channels,
86 | norm=norm,
87 | bn_first=bn_first,
88 | ndim_latent=ndim_latent,
89 | comm=comm
90 | )
91 |
92 | embedder = VD(
93 | channels=8*base_channels,
94 | norm=norm,
95 | bn_first=bn_first,
96 | ndim_latent=ndim_latent,
97 | comm=comm
98 | )
99 |
100 | decoder = Decoder(
101 | base_channels=base_channels,
102 | out_channels=out_channels,
103 | norm=norm,
104 | bn_first=bn_first,
105 | mode=mode,
106 | comm=comm
107 | )
108 |
109 | vae = VAE(
110 | in_channels=in_channels,
111 | base_channels=base_channels,
112 | norm=norm,
113 | bn_first=bn_first,
114 | input_shape=input_shape,
115 | comm=comm
116 | )
117 |
118 | return encoder, embedder, decoder, vae
119 |
120 |
121 | def _setup_vae_segmentor_only(config, comm=None):
122 | base_channels = config['base_channels']
123 | out_channels = config['nb_labels']
124 | nested_label = config['nested_label']
125 | norm = eval(config['vaeseg_norm'])
126 | bn_first = config['vaeseg_bn_first']
127 | ndim_latent = config['vaeseg_ndim_latent']
128 | mode = config['vaeseg_skip_connect_mode']
129 | if nested_label:
130 | out_channels = 2 * (out_channels - 1)
131 |
132 | encoder = Encoder(
133 | base_channels=base_channels,
134 | norm=norm,
135 | bn_first=bn_first,
136 | ndim_latent=ndim_latent,
137 | comm=comm
138 | )
139 |
140 | decoder = Decoder(
141 | base_channels=base_channels,
142 | out_channels=out_channels,
143 | norm=norm,
144 | bn_first=bn_first,
145 | mode=mode,
146 | comm=comm
147 | )
148 |
149 | return encoder, decoder
150 |
151 |
152 | def _setup_cpc_segmentor(config, comm=None):
153 | base_channels = config['base_channels']
154 | out_channels = config['nb_labels']
155 | nested_label = config['nested_label']
156 | norm = eval(config['vaeseg_norm'])
157 | bn_first = config['vaeseg_bn_first']
158 | ndim_latent = config['vaeseg_ndim_latent']
159 | mode = config['vaeseg_skip_connect_mode']
160 | input_shape = eval(config['crop_size'])
161 | grid_size = config['grid_size']
162 | cpc_pattern = config['cpc_pattern']
163 | if nested_label:
164 | out_channels = 2 * (out_channels - 1)
165 |
166 | encoder = Encoder(
167 | base_channels=base_channels,
168 | norm=norm,
169 | bn_first=bn_first,
170 | ndim_latent=ndim_latent,
171 | comm=comm
172 | )
173 |
174 | decoder = Decoder(
175 | base_channels=base_channels,
176 | out_channels=out_channels,
177 | norm=norm,
178 | bn_first=bn_first,
179 | mode=mode,
180 | comm=comm
181 | )
182 |
183 | cpcpred1 = CPCPredictor(
184 | base_channels=base_channels*8,
185 | norm=norm,
186 | bn_first=bn_first,
187 | grid_size=grid_size,
188 | input_shape=input_shape,
189 | upper=True,
190 | cpc_pattern=cpc_pattern,
191 | comm=comm
192 | )
193 | return encoder, decoder, cpcpred1
194 |
195 |
196 | def _setup_bound_segmentor(config, comm=None):
197 | base_channels = config['base_channels']
198 | out_channels = config['nb_labels']
199 | nested_label = config['nested_label']
200 | norm = eval(config['vaeseg_norm'])
201 | bn_first = config['vaeseg_bn_first']
202 | mode = config['vaeseg_skip_connect_mode']
203 | ndim_latent = config['vaeseg_ndim_latent']
204 | if nested_label:
205 | out_channels = 2 * (out_channels - 1)
206 |
207 | encoder = Encoder(
208 | base_channels=base_channels,
209 | norm=norm,
210 | bn_first=bn_first,
211 | ndim_latent=ndim_latent,
212 | comm=comm
213 | )
214 |
215 | decoder = Decoder(
216 | base_channels=base_channels,
217 | out_channels=out_channels,
218 | norm=norm,
219 | bn_first=bn_first,
220 | mode=mode,
221 | comm=comm
222 | )
223 |
224 | boundary = BoundaryStream(
225 | base_channels=base_channels,
226 | out_channels=out_channels,
227 | norm=norm,
228 | comm=comm
229 | )
230 |
231 | return encoder, decoder, boundary
232 |
233 |
234 | def _setup_iterators(config, batch_size, train_data, validation_data, test_data):
235 | if isinstance(config['loaderjob'], int) and config['loaderjob'] > 1:
236 | train_iterator = MultiprocessIterator(
237 | train_data, batch_size, n_processes=config['loaderjob'])
238 | validation_iterator = MultiprocessIterator(
239 | validation_data, batch_size, n_processes=config['loaderjob'],
240 | repeat=False, shuffle=False)
241 | test_iterator = MultiprocessIterator(
242 | test_data, batch_size, n_processes=config['loaderjob'],
243 | repeat=False, shuffle=False)
244 | else:
245 | train_iterator = SerialIterator(train_data, batch_size)
246 | validation_iterator = SerialIterator(
247 | validation_data, batch_size, repeat=False, shuffle=False)
248 | test_iterator = SerialIterator(
249 | test_data, batch_size, repeat=False, shuffle=False)
250 |
251 | return train_iterator, validation_iterator, test_iterator
252 |
253 |
254 | # Optimizer
255 | def _setup_optimizer(config, model, comm):
256 | optimizer_name = config['optimizer']
257 | lr = float(config['init_lr'])
258 | weight_decay = float(config['weight_decay'])
259 | if optimizer_name == 'Adam':
260 | optimizer = Adam(alpha=lr, weight_decay_rate=weight_decay)
261 | elif optimizer_name in \
262 | ('SGD', 'MomentumSGD', 'CorrectedMomentumSGD', 'RMSprop'):
263 | optimizer = eval(optimizer_name)(lr=lr)
264 | if weight_decay > 0.:
265 | optimizer.add_hook(WeightDecay(weight_decay))
266 | else:
267 | raise ValueError('Invalid optimizer: {}'.format(optimizer_name))
268 | if comm is not None:
269 | optimizer = chainermn.create_multi_node_optimizer(optimizer, comm)
270 | optimizer.setup(model)
271 |
272 | return optimizer
273 |
274 |
275 | # Updater
276 | def _setup_updater(config, device, train_iterator, optimizers):
277 | updater_kwargs = dict()
278 | updater_kwargs['iterator'] = train_iterator
279 | updater_kwargs['optimizer'] = optimizers
280 | updater_kwargs['device'] = device
281 |
282 | if config['segmentor_name'] == 'vaeseg':
283 | return VAESegUpdater(config, **updater_kwargs)
284 | elif config['segmentor_name'] == 'encdec_seg':
285 | return EncDecSegUpdater(config, **updater_kwargs)
286 | elif config['segmentor_name'] == 'boundseg':
287 | return BoundSegUpdater(config, **updater_kwargs)
288 | elif config['segmentor_name'] == 'cpcseg':
289 | return CPCSegUpdater(config, **updater_kwargs)
290 | else:
291 | return training.StandardUpdater(**updater_kwargs)
292 |
293 |
294 | def _setup_extensions(config, trainer, optimizers, logging_counts, logging_attributes):
295 | if config['segmentor_name'] == 'vaeseg':
296 | trainer.extend(extensions.dump_graph('loss/total', out_name="segmentor.dot"))
297 | elif config['segmentor_name'] == 'encdec_seg':
298 | trainer.extend(extensions.dump_graph('loss/seg', out_name="segmentor.dot"))
299 | elif config['segmentor_name'] == 'boundseg':
300 | trainer.extend(extensions.dump_graph('loss/seg', out_name="segmentor.dot"))
301 | elif config['segmentor_name'] == 'cpcseg':
302 | trainer.extend(extensions.dump_graph('loss/total', out_name="segmentor.dot"))
303 | else:
304 | trainer.extend(extensions.dump_graph('main/loss', out_name="segmentor.dot"))
305 |
306 | # Report
307 | repo_trigger = (config['report_interval'], 'iteration')
308 | trainer.extend(
309 | extensions.LogReport(
310 | trigger=repo_trigger
311 | )
312 | )
313 | trainer.extend(
314 | extensions.PrintReport(logging_counts + logging_attributes),
315 | trigger=repo_trigger
316 | )
317 | trainer.extend(
318 | extensions.ProgressBar()
319 | )
320 |
321 | snap_trigger = (config['snapshot_interval'], 'epoch')
322 | trainer.extend(
323 | extensions.snapshot(filename='snapshot_epoch_{.updater.epoch}'),
324 | trigger=snap_trigger
325 | )
326 | for k, v in optimizers.items():
327 | trainer.extend(
328 | extensions.snapshot_object(v.target, k+'_epoch_{.updater.epoch}'),
329 | trigger=snap_trigger
330 | )
331 |
332 | for attr in logging_attributes:
333 | trainer.extend(
334 | extensions.PlotReport([attr, 'validation/' + attr], 'epoch',
335 | file_name=attr.replace('/', '_') + '.png')
336 | )
337 |
338 |
339 | # Trainer
340 | def setup_trainer(config, out, batch_size, epoch, gpu_start_id):
341 |
342 | comm, is_master, device = _setup_communicator(config, gpu_start_id)
343 |
344 | train_data, validation_data, test_data = _setup_datasets(config, comm, is_master)
345 |
346 | if config['segmentor_name'] == 'vaeseg':
347 | encoder, embedder, decoder, vae = _setup_vae_segmentor(config, comm)
348 | # load weights
349 | if config['init_encoder'] is not None:
350 | serializers.load_npz(config['init_encoder'], encoder)
351 | if config['init_embedder'] is not None:
352 | serializers.load_npz(config['init_embedder'], embedder)
353 | if config['init_decoder'] is not None:
354 | serializers.load_npz(config['init_decoder'], decoder)
355 | if config['init_vae'] is not None:
356 | serializers.load_npz(config['init_vae'], vae)
357 | if device is not None:
358 | get_device_from_id(device).use()
359 | encoder.to_gpu()
360 | embedder.to_gpu()
361 | decoder.to_gpu()
362 | vae.to_gpu()
363 |
364 | opt_enc = _setup_optimizer(config, encoder, comm)
365 | opt_emb = _setup_optimizer(config, embedder, comm)
366 | opt_dec = _setup_optimizer(config, decoder, comm)
367 | opt_vae = _setup_optimizer(config, vae, comm)
368 | optimizers = {'enc': opt_enc, 'emb': opt_emb, 'dec': opt_dec, 'vae': opt_vae}
369 |
370 | elif config['segmentor_name'] == 'cpcseg':
371 |
372 | encoder, decoder, cpcpred1 = _setup_cpc_segmentor(config, comm)
373 | # load weights
374 | if config['init_encoder'] is not None:
375 | serializers.load_npz(config['init_encoder'], encoder)
376 | if config['init_decoder'] is not None:
377 | serializers.load_npz(config['init_decoder'], decoder)
378 | if config['init_cpcpred'] is not None:
379 | serializers.load_npz(config['init_cpcpred'], cpcpred1)
380 | if device is not None:
381 | get_device_from_id(device).use()
382 | encoder.to_gpu()
383 | decoder.to_gpu()
384 | cpcpred1.to_gpu()
385 |
386 | opt_enc = _setup_optimizer(config, encoder, comm)
387 | opt_dec = _setup_optimizer(config, decoder, comm)
388 | opt_p1 = _setup_optimizer(config, cpcpred1, comm)
389 | optimizers = {'enc': opt_enc, 'dec': opt_dec, 'cpcpred1': opt_p1}
390 |
391 | elif config['segmentor_name'] == 'encdec_seg':
392 |
393 | encoder, decoder = _setup_vae_segmentor_only(config, comm)
394 | # load weights
395 | if config['init_encoder'] is not None:
396 | serializers.load_npz(config['init_encoder'], encoder)
397 | if config['init_decoder'] is not None:
398 | serializers.load_npz(config['init_decoder'], decoder)
399 | if device is not None:
400 | get_device_from_id(device).use()
401 | encoder.to_gpu()
402 | decoder.to_gpu()
403 |
404 | opt_enc = _setup_optimizer(config, encoder, comm)
405 | opt_dec = _setup_optimizer(config, decoder, comm)
406 | optimizers = {'enc': opt_enc, 'dec': opt_dec}
407 |
408 | elif config['segmentor_name'] == 'boundseg':
409 |
410 | encoder, decoder, boundary = _setup_bound_segmentor(config, comm)
411 | # load weights
412 | if config['init_encoder'] is not None:
413 | serializers.load_npz(config['init_encoder'], encoder)
414 | if config['init_decoder'] is not None:
415 | serializers.load_npz(config['init_decoder'], decoder)
416 | if device is not None:
417 | get_device_from_id(device).use()
418 | encoder.to_gpu()
419 | decoder.to_gpu()
420 | boundary.to_gpu()
421 |
422 | opt_enc = _setup_optimizer(config, encoder, comm)
423 | opt_dec = _setup_optimizer(config, decoder, comm)
424 | opt_bound = _setup_optimizer(config, boundary, comm)
425 | optimizers = {'enc': opt_enc, 'dec': opt_dec, 'bound': opt_bound}
426 |
427 | train_iterator, validation_iterator, test_iterator = \
428 | _setup_iterators(config, batch_size, train_data, validation_data, test_data)
429 |
430 | logging_counts = ['epoch', 'iteration']
431 | if config['segmentor_name'] == 'vaeseg':
432 | logging_attributes = \
433 | ['loss/rec', 'loss/kl', 'loss/total', 'acc',
434 | 'mean_dc', 'val/mean_dc', 'test/mean_dc']
435 | if config['print_each_dc']:
436 | for i in range(0, config['nb_labels']):
437 | logging_attributes.append('dc_{}'.format(i))
438 | logging_attributes.append('val/dc_{}'.format(i))
439 | logging_attributes.append('test/dc_{}'.format(i))
440 |
441 | elif config['segmentor_name'] == 'cpcseg':
442 | logging_attributes = \
443 | ['loss/total', 'acc', 'loss/cpc']
444 | for i in range(0, config['nb_labels']):
445 | logging_attributes.append('dc_{}'.format(i))
446 | logging_attributes.append('val/dc_{}'.format(i))
447 | logging_attributes.append('test/dc_{}'.format(i))
448 |
449 | elif config['segmentor_name'] == 'encdec_seg':
450 | logging_attributes = \
451 | ['loss/seg', 'loss/total', 'acc']
452 | if config['print_each_dc']:
453 | for i in range(0, config['nb_labels']):
454 | logging_attributes.append('dc_{}'.format(i))
455 | logging_attributes.append('val/dc_{}'.format(i))
456 | logging_attributes.append('test/dc_{}'.format(i))
457 |
458 | elif config['segmentor_name'] == 'boundseg':
459 | logging_attributes = \
460 | ['loss/seg', 'loss/total', 'acc', 'loss/bound', 'loss/bce']
461 | if config['print_each_dc']:
462 | for i in range(0, config['nb_labels']):
463 | logging_attributes.append('dc_{}'.format(i))
464 | logging_attributes.append('val/dc_{}'.format(i))
465 | logging_attributes.append('test/dc_{}'.format(i))
466 |
467 | else:
468 | logging_attributes = ['main/loss', 'main/acc']
469 | for i in range(1, config['nb_labels']):
470 | logging_attributes.append('main/dc_{}'.format(i))
471 | logging_attributes.append('val/main/dc_{}'.format(i))
472 | logging_attributes.append('test/main/dc_{}'.format(i))
473 |
474 | updater = _setup_updater(config, device, train_iterator, optimizers)
475 |
476 | trainer = training.Trainer(updater, (epoch, 'epoch'), out=out)
477 |
478 | if is_master:
479 | _setup_extensions(config, trainer, optimizers, logging_counts, logging_attributes)
480 |
481 | if config['segmentor_name'] == 'vaeseg':
482 | targets = {'enc': encoder, 'emb': embedder, 'dec': decoder, 'vae': vae}
483 | val_evaluator = VAESegEvaluator(config, validation_iterator, targets, device=device)
484 | test_evaluator = VAESegEvaluator(config, test_iterator, targets, device=device)
485 |
486 | elif config['segmentor_name'] == 'cpcseg':
487 | targets = {'enc': encoder, 'dec': decoder, 'cpcpred1': cpcpred1}
488 | val_evaluator = CPCSegEvaluator(config, validation_iterator, targets, device=device)
489 | test_evaluator = CPCSegEvaluator(config, test_iterator, targets, device=device)
490 |
491 | elif config['segmentor_name'] == 'encdec_seg':
492 | targets = {'enc': encoder, 'dec': decoder}
493 | val_evaluator = EncDecSegEvaluator(config, validation_iterator, targets, device=device)
494 | test_evaluator = EncDecSegEvaluator(config, test_iterator, targets, device=device)
495 |
496 | elif config['segmentor_name'] == 'boundseg':
497 | targets = {'enc': encoder, 'dec': decoder, 'bound': boundary}
498 | val_evaluator = BoundSegEvaluator(config, validation_iterator, targets, device=device)
499 | test_evaluator = BoundSegEvaluator(config, test_iterator, targets, device=device)
500 |
501 | val_evaluator.default_name = 'val'
502 | test_evaluator.default_name = 'test'
503 |
504 | if comm is not None:
505 | val_evaluator = chainermn.create_multi_node_evaluator(val_evaluator, comm)
506 | test_evaluator = chainermn.create_multi_node_evaluator(test_evaluator, comm)
507 | trainer.extend(val_evaluator, trigger=(config['eval_interval'], 'epoch'))
508 | trainer.extend(test_evaluator, trigger=(config['eval_interval'], 'epoch'))
509 |
510 | # Resume
511 | if config['resume'] is not None:
512 | serializers.load_npz(config['resume'], trainer)
513 |
514 | return trainer
515 |
--------------------------------------------------------------------------------