├── LICENSE ├── README.md ├── brats_split_list ├── test_list_cv0.txt ├── train_list_cv0.txt ├── train_list_cv0_12samples.txt ├── train_list_cv0_24samples.txt ├── train_list_cv0_3samples1.txt ├── train_list_cv0_6samples1.txt ├── train_list_cv0_96samples.txt └── validation_list_cv0.txt ├── configs ├── brats_boundseg.yml ├── brats_cpcseg.yml ├── brats_encdec.yml ├── brats_semi_cpcseg.yml ├── brats_semi_vaeseg.yml └── brats_vaeseg.yml ├── examples ├── normalize.py ├── split_kfold.py └── train.py ├── img └── architecture.png ├── requirements.txt └── src ├── __init__.py ├── datasets ├── __init__.py ├── generate_msd_edge.py ├── msd_bound.py └── preprocess.py ├── functions ├── array │ ├── __init__.py │ ├── resize_images_3d.py │ └── resize_images_3d_nearest_neighbor.py ├── evaluation │ ├── __init__.py │ └── dice_coefficient.py └── loss │ ├── __init__.py │ ├── boundary_bce.py │ ├── cpc_loss.py │ ├── dice_loss.py │ ├── focal_loss.py │ ├── generalized_dice_loss.py │ └── mixed_dice_loss.py ├── links ├── __init__.py └── model │ ├── __init__.py │ ├── resize_images_3d.py │ └── vaeseg.py ├── training ├── __init__.py ├── extensions │ ├── __init__.py │ ├── boundseg_evaluator.py │ ├── cpcseg_evaluator.py │ ├── encdec_seg_evaluator.py │ └── vaeseg_evaluator.py └── updaters │ ├── __init__.py │ ├── boundseg_updater.py │ ├── cpcseg_updater.py │ ├── encdec_seg_updater.py │ └── vaeseg_updater.py └── utils ├── config.py ├── encode_one_hot_vector.py └── setup_helpers.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Preferred Networks, Inc. 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Label-Efficient Multi-Task Segmentation using Contrastive Learning 2 | 3 | This repository contains the Chainer implemented code used for the following paper: 4 | > [1] **Label-Efficient Multi-Task Segmentation using Contrastive Learning** 5 | >Junichiro Iwasawa, Yuichiro Hirano, Yohei Sugawara 6 | > Accepted to MICCAI BrainLes 2020 workshop, [preprint on arXiv](https://arxiv.org/abs/2009.11160). 7 | > 8 | > 9 | > **Abstract:** *Obtaining annotations for 3D medical images is expensive and time-consuming, despite its importance for automating segmentation tasks. Although multi-task learning is considered an effective method for training segmentation models using small amounts of annotated data, a systematic understanding of various subtasks is still lacking. In this study, we propose a multi-task segmentation model with a contrastive learning based subtask and compare its performance with other multi-task models, varying the number of labeled data for training. We further extend our model so that it can utilize unlabeled data through the regularization branch in a semi-supervised manner. We experimentally show that our proposed method outperforms other multi-task methods including the state-of-the-art fully supervised model when the amount of annotated data is limited.* 10 | 11 | Disclaimer: PFN provides no warranty or support for this implementation. Use it at your own risk. 12 | 13 | ## Dependencies 14 | 15 | We have tested this code using: 16 | 17 | - Ubuntu 16.04.6 18 | - Python 3.7 19 | - CUDA 9.0 20 | - cuDNN 7.0 21 | 22 | The full list of Python packages for the code is given in `requirements.txt`. These can be installed using: 23 | 24 | ```bash 25 | pip install -r requirements.txt 26 | ``` 27 | 28 | ## Usage 29 | 30 |

31 | architecture 32 |

33 | 34 | The schematic image for the model is given above. To compare the effect of different regularization branches, we implemented four different models: the encoder-decoder alone (EncDec), EncDec with a VAE branch (VAEseg), EncDec with a boundary attention branch (Boundseg), and EncDec with a CPC branch (CPCseg). We have also implemented the semi-supervised VAEseg (ssVAEseg) and semi-supervised CPCseg (ssCPCseg) to investigate the effect of utilizing unlabeled data. The brain tumor dataset from the 2016 and 2017 Brain Tumor Image Segmentation ([BraTS](http://dx.doi.org/10.1109/TMI.2014.2377694)) Challenges was used for evaluation. 35 | 36 | ### Preparation 37 | 38 | #### Normalization 39 | 40 | The models were trained using images normalized to have zero mean and unit standard deviation. To normalize the original image data in the BraTS dataset (`imagesTr`), run the command below which will create a directory with normalized files (`imagesTr_normalized`). 41 | 42 | ```bash 43 | python examples/normalize.py --root_path PATH_TO_BRATS_DATA/imagesTr 44 | ``` 45 | 46 | #### Generating boundary labels for boundseg 47 | 48 | Boundseg model takes as input the boundary labels of the dataset. In order to generate boundary labels, run 49 | 50 | ```bash 51 | python src/datasets/generate_msd_edge.py 52 | ``` 53 | 54 | after modifying `label_path` in the file. 55 | 56 | #### Fixing training configs 57 | 58 | - Before running the model, please specify the paths for `image_path`, `label_path`, (and `edge_path` for Boundseg) in the config files within `configs/`. 59 | 60 | - When you want to change the number of labeled samples for training, rewrite `train_list_path` in the config file. For example, to use only 6 labeled data, rewrite it as `train_list_path: ./brats_split_list/train_list_cv0_6samples`. 61 | 62 | ### Training 63 | 64 | #### EncDec 65 | 66 | ```bash 67 | mpiexec -n 8 python3 examples/train.py \ 68 | --config configs/brats_encdec.yml \ 69 | --out results 70 | ``` 71 | 72 | #### VAEseg 73 | 74 | ```bash 75 | mpiexec -n 8 python3 examples/train.py \ 76 | --config configs/brats_vaeseg.yml \ 77 | --out results 78 | ``` 79 | 80 | #### Boundseg 81 | 82 | ```bash 83 | mpiexec -n 8 python3 examples/train.py \ 84 | --config configs/brats_boundseg.yml \ 85 | --out results 86 | ``` 87 | 88 | #### CPCseg 89 | 90 | ```bash 91 | mpiexec -n 8 python3 examples/train.py \ 92 | --config configs/brats_cpcseg.yml \ 93 | --out results 94 | ``` 95 | 96 | #### ssVAEseg 97 | 98 | ```bash 99 | mpiexec -n 8 python3 examples/train.py \ 100 | --config configs/brats_semi_vaeseg.yml \ 101 | --out results 102 | ``` 103 | 104 | #### ssCPCseg 105 | 106 | ```bash 107 | mpiexec -n 8 python3 examples/train.py \ 108 | --config configs/brats_semi_cpcseg.yml \ 109 | --out results 110 | ``` 111 | 112 | ## Results 113 | 114 | The performances of the models, trained with 6 labeled samples (`train_list_cv0_6samples1.txt`), were evaluated by the mean dice score and 95th percentile Hausdorff distance using the test dataset (`brats_split_list/test_list_cv0.txt`) ([[1](#Label-Efficient-Multi-Task-Segmentation-using-Contrastive-Learning)], Table 1). For more detailed results, please see the paper [[1](#Label-Efficient-Multi-Task-Segmentation-using-Contrastive-Learning)]. 115 | 116 | | | Dice (ET) | Dice (TC) | Dice (WT) | Hausdorff (ET) | Hausdorff (TC) | Hausdorff (WT) | 117 | |:-|:---:|:--------:|:-----------------------:|:-----------------:|:-----------------------:|:-----------------:| 118 | |EncDec |0.8412 |0.8383 |0.9144 | 11.4697 |20.12 |24.3726 119 | |VAEseg |0.8234 |0.8036 |0.8998 |14.3467 |22.4926 |17.9775 120 | |Boundseg |0.8356 |0.8378 |0.9041 |17.0323 |27.2128 |25.8112 121 | |CPCseg |0.8374 |0.8386 |0.9057 |10.2839 |14.9661 |15.0633 122 | |ssVAEseg |0.8626 |0.8425 |0.9131 |9.1966 | **12.5302** |14.8056 123 | |ssCPCseg | **0.8873** | **0.8761** | **0.9151** | **8.7092** |16.0947 | **12.3962** 124 | 125 | ## License 126 | 127 | [MIT License](LICENSE) 128 | -------------------------------------------------------------------------------- /brats_split_list/test_list_cv0.txt: -------------------------------------------------------------------------------- 1 | BRATS_072 2 | BRATS_207 3 | BRATS_102 4 | BRATS_294 5 | BRATS_265 6 | BRATS_152 7 | BRATS_427 8 | BRATS_209 9 | BRATS_418 10 | BRATS_377 11 | BRATS_344 12 | BRATS_234 13 | BRATS_095 14 | BRATS_261 15 | BRATS_348 16 | BRATS_448 17 | BRATS_035 18 | BRATS_148 19 | BRATS_069 20 | BRATS_191 21 | BRATS_482 22 | BRATS_256 23 | BRATS_381 24 | BRATS_400 25 | BRATS_055 26 | BRATS_065 27 | BRATS_408 28 | BRATS_449 29 | BRATS_404 30 | BRATS_161 31 | BRATS_166 32 | BRATS_469 33 | BRATS_124 34 | BRATS_280 35 | BRATS_401 36 | BRATS_421 37 | BRATS_406 38 | BRATS_335 39 | BRATS_248 40 | BRATS_129 41 | BRATS_199 42 | BRATS_125 43 | BRATS_099 44 | BRATS_063 45 | BRATS_410 46 | BRATS_442 47 | BRATS_366 48 | BRATS_048 49 | BRATS_300 50 | -------------------------------------------------------------------------------- /brats_split_list/train_list_cv0.txt: -------------------------------------------------------------------------------- 1 | BRATS_446 2 | BRATS_415 3 | BRATS_165 4 | BRATS_293 5 | BRATS_121 6 | BRATS_435 7 | BRATS_038 8 | BRATS_140 9 | BRATS_305 10 | BRATS_197 11 | BRATS_452 12 | BRATS_375 13 | BRATS_177 14 | BRATS_093 15 | BRATS_439 16 | BRATS_432 17 | BRATS_281 18 | BRATS_451 19 | BRATS_126 20 | BRATS_247 21 | BRATS_385 22 | BRATS_361 23 | BRATS_275 24 | BRATS_412 25 | BRATS_198 26 | BRATS_282 27 | BRATS_262 28 | BRATS_307 29 | BRATS_468 30 | BRATS_082 31 | BRATS_016 32 | BRATS_374 33 | BRATS_172 34 | BRATS_341 35 | BRATS_220 36 | BRATS_271 37 | BRATS_117 38 | BRATS_419 39 | BRATS_425 40 | BRATS_338 41 | BRATS_008 42 | BRATS_122 43 | BRATS_360 44 | BRATS_382 45 | BRATS_229 46 | BRATS_359 47 | BRATS_470 48 | BRATS_163 49 | BRATS_445 50 | BRATS_150 51 | BRATS_295 52 | BRATS_264 53 | BRATS_353 54 | BRATS_340 55 | BRATS_175 56 | BRATS_428 57 | BRATS_235 58 | BRATS_014 59 | BRATS_213 60 | BRATS_403 61 | BRATS_376 62 | BRATS_326 63 | BRATS_079 64 | BRATS_484 65 | BRATS_164 66 | BRATS_086 67 | BRATS_096 68 | BRATS_393 69 | BRATS_136 70 | BRATS_073 71 | BRATS_001 72 | BRATS_013 73 | BRATS_106 74 | BRATS_456 75 | BRATS_066 76 | BRATS_169 77 | BRATS_402 78 | BRATS_196 79 | BRATS_257 80 | BRATS_329 81 | BRATS_236 82 | BRATS_395 83 | BRATS_173 84 | BRATS_396 85 | BRATS_123 86 | BRATS_186 87 | BRATS_044 88 | BRATS_313 89 | BRATS_194 90 | BRATS_171 91 | BRATS_223 92 | BRATS_023 93 | BRATS_279 94 | BRATS_347 95 | BRATS_330 96 | BRATS_333 97 | BRATS_182 98 | BRATS_002 99 | BRATS_215 100 | BRATS_268 101 | BRATS_352 102 | BRATS_159 103 | BRATS_355 104 | BRATS_143 105 | BRATS_369 106 | BRATS_046 107 | BRATS_481 108 | BRATS_398 109 | BRATS_091 110 | BRATS_454 111 | BRATS_417 112 | BRATS_443 113 | BRATS_349 114 | BRATS_059 115 | BRATS_414 116 | BRATS_115 117 | BRATS_224 118 | BRATS_342 119 | BRATS_137 120 | BRATS_253 121 | BRATS_110 122 | BRATS_283 123 | BRATS_378 124 | BRATS_321 125 | BRATS_409 126 | BRATS_087 127 | BRATS_273 128 | BRATS_478 129 | BRATS_025 130 | BRATS_088 131 | BRATS_158 132 | BRATS_424 133 | BRATS_222 134 | BRATS_206 135 | BRATS_118 136 | BRATS_433 137 | BRATS_092 138 | BRATS_050 139 | BRATS_230 140 | BRATS_397 141 | BRATS_193 142 | BRATS_277 143 | BRATS_350 144 | BRATS_101 145 | BRATS_430 146 | BRATS_331 147 | BRATS_301 148 | BRATS_440 149 | BRATS_068 150 | BRATS_081 151 | BRATS_237 152 | BRATS_058 153 | BRATS_285 154 | BRATS_363 155 | BRATS_018 156 | BRATS_303 157 | BRATS_183 158 | BRATS_076 159 | BRATS_336 160 | BRATS_390 161 | BRATS_108 162 | BRATS_184 163 | BRATS_019 164 | BRATS_394 165 | BRATS_231 166 | BRATS_320 167 | BRATS_179 168 | BRATS_030 169 | BRATS_138 170 | BRATS_324 171 | BRATS_216 172 | BRATS_011 173 | BRATS_447 174 | BRATS_258 175 | BRATS_100 176 | BRATS_109 177 | BRATS_370 178 | BRATS_372 179 | BRATS_131 180 | BRATS_168 181 | BRATS_105 182 | BRATS_422 183 | BRATS_187 184 | BRATS_200 185 | BRATS_318 186 | BRATS_241 187 | BRATS_202 188 | BRATS_132 189 | BRATS_367 190 | BRATS_189 191 | BRATS_471 192 | BRATS_479 193 | BRATS_483 194 | BRATS_103 195 | BRATS_356 196 | BRATS_120 197 | BRATS_328 198 | BRATS_036 199 | BRATS_085 200 | BRATS_388 201 | BRATS_284 202 | BRATS_411 203 | BRATS_005 204 | BRATS_392 205 | BRATS_021 206 | BRATS_015 207 | BRATS_228 208 | BRATS_245 209 | BRATS_480 210 | BRATS_365 211 | BRATS_060 212 | BRATS_210 213 | BRATS_208 214 | BRATS_437 215 | BRATS_343 216 | BRATS_116 217 | BRATS_135 218 | BRATS_274 219 | BRATS_154 220 | BRATS_371 221 | BRATS_444 222 | BRATS_431 223 | BRATS_296 224 | BRATS_064 225 | BRATS_139 226 | BRATS_345 227 | BRATS_251 228 | BRATS_017 229 | BRATS_047 230 | BRATS_218 231 | BRATS_051 232 | BRATS_204 233 | BRATS_386 234 | BRATS_057 235 | BRATS_225 236 | BRATS_097 237 | BRATS_312 238 | BRATS_195 239 | BRATS_034 240 | BRATS_111 241 | BRATS_233 242 | BRATS_354 243 | BRATS_062 244 | BRATS_465 245 | BRATS_467 246 | BRATS_029 247 | BRATS_053 248 | BRATS_286 249 | BRATS_083 250 | BRATS_311 251 | BRATS_089 252 | BRATS_461 253 | BRATS_314 254 | BRATS_040 255 | BRATS_387 256 | BRATS_380 257 | BRATS_337 258 | BRATS_212 259 | BRATS_056 260 | BRATS_327 261 | BRATS_289 262 | BRATS_473 263 | BRATS_246 264 | BRATS_466 265 | BRATS_219 266 | BRATS_332 267 | BRATS_334 268 | BRATS_155 269 | BRATS_104 270 | BRATS_438 271 | BRATS_266 272 | BRATS_453 273 | BRATS_181 274 | BRATS_045 275 | BRATS_291 276 | BRATS_211 277 | BRATS_477 278 | BRATS_144 279 | BRATS_141 280 | BRATS_290 281 | BRATS_190 282 | BRATS_269 283 | BRATS_098 284 | BRATS_156 285 | BRATS_170 286 | BRATS_071 287 | BRATS_310 288 | BRATS_319 289 | BRATS_297 290 | BRATS_420 291 | BRATS_263 292 | BRATS_379 293 | BRATS_039 294 | BRATS_239 295 | BRATS_007 296 | BRATS_145 297 | BRATS_240 298 | BRATS_407 299 | BRATS_436 300 | BRATS_476 301 | BRATS_304 302 | BRATS_373 303 | BRATS_003 304 | BRATS_112 305 | BRATS_317 306 | BRATS_176 307 | BRATS_153 308 | BRATS_351 309 | BRATS_078 310 | BRATS_260 311 | BRATS_037 312 | BRATS_383 313 | BRATS_288 314 | BRATS_475 315 | BRATS_180 316 | BRATS_042 317 | BRATS_205 318 | BRATS_302 319 | BRATS_157 320 | BRATS_107 321 | BRATS_149 322 | BRATS_458 323 | BRATS_043 324 | BRATS_322 325 | BRATS_278 326 | BRATS_315 327 | BRATS_214 328 | BRATS_457 329 | BRATS_339 330 | BRATS_255 331 | BRATS_028 332 | BRATS_167 333 | BRATS_250 334 | BRATS_004 335 | BRATS_272 336 | BRATS_054 337 | BRATS_134 338 | BRATS_306 339 | BRATS_033 340 | BRATS_128 341 | BRATS_270 342 | BRATS_460 343 | BRATS_192 344 | BRATS_244 345 | BRATS_242 346 | BRATS_299 347 | BRATS_188 348 | BRATS_462 349 | BRATS_292 350 | BRATS_308 351 | BRATS_022 352 | BRATS_459 353 | BRATS_472 354 | BRATS_323 355 | BRATS_031 356 | BRATS_287 357 | BRATS_113 358 | BRATS_464 359 | BRATS_090 360 | BRATS_203 361 | BRATS_061 362 | BRATS_405 363 | BRATS_389 364 | BRATS_384 365 | BRATS_074 366 | BRATS_114 367 | BRATS_298 368 | BRATS_067 369 | BRATS_346 370 | BRATS_049 371 | BRATS_226 372 | BRATS_127 373 | BRATS_426 374 | BRATS_249 375 | BRATS_130 376 | BRATS_010 377 | BRATS_357 378 | BRATS_077 379 | BRATS_358 380 | BRATS_399 381 | BRATS_276 382 | BRATS_119 383 | BRATS_146 384 | BRATS_474 385 | BRATS_254 386 | BRATS_080 387 | BRATS_232 -------------------------------------------------------------------------------- /brats_split_list/train_list_cv0_12samples.txt: -------------------------------------------------------------------------------- 1 | BRATS_446 2 | BRATS_415 3 | BRATS_165 4 | BRATS_293 5 | BRATS_121 6 | BRATS_435 7 | BRATS_038 8 | BRATS_140 9 | BRATS_305 10 | BRATS_197 11 | BRATS_452 12 | BRATS_375 -------------------------------------------------------------------------------- /brats_split_list/train_list_cv0_24samples.txt: -------------------------------------------------------------------------------- 1 | BRATS_446 2 | BRATS_415 3 | BRATS_165 4 | BRATS_293 5 | BRATS_121 6 | BRATS_435 7 | BRATS_038 8 | BRATS_140 9 | BRATS_305 10 | BRATS_197 11 | BRATS_452 12 | BRATS_375 13 | BRATS_177 14 | BRATS_093 15 | BRATS_439 16 | BRATS_432 17 | BRATS_281 18 | BRATS_451 19 | BRATS_126 20 | BRATS_247 21 | BRATS_385 22 | BRATS_361 23 | BRATS_275 24 | BRATS_412 -------------------------------------------------------------------------------- /brats_split_list/train_list_cv0_3samples1.txt: -------------------------------------------------------------------------------- 1 | BRATS_446 2 | BRATS_415 3 | BRATS_165 -------------------------------------------------------------------------------- /brats_split_list/train_list_cv0_6samples1.txt: -------------------------------------------------------------------------------- 1 | BRATS_446 2 | BRATS_415 3 | BRATS_165 4 | BRATS_293 5 | BRATS_121 6 | BRATS_435 -------------------------------------------------------------------------------- /brats_split_list/train_list_cv0_96samples.txt: -------------------------------------------------------------------------------- 1 | BRATS_446 2 | BRATS_415 3 | BRATS_165 4 | BRATS_293 5 | BRATS_121 6 | BRATS_435 7 | BRATS_038 8 | BRATS_140 9 | BRATS_305 10 | BRATS_197 11 | BRATS_452 12 | BRATS_375 13 | BRATS_177 14 | BRATS_093 15 | BRATS_439 16 | BRATS_432 17 | BRATS_281 18 | BRATS_451 19 | BRATS_126 20 | BRATS_247 21 | BRATS_385 22 | BRATS_361 23 | BRATS_275 24 | BRATS_412 25 | BRATS_198 26 | BRATS_282 27 | BRATS_262 28 | BRATS_307 29 | BRATS_468 30 | BRATS_082 31 | BRATS_016 32 | BRATS_374 33 | BRATS_172 34 | BRATS_341 35 | BRATS_220 36 | BRATS_271 37 | BRATS_117 38 | BRATS_419 39 | BRATS_425 40 | BRATS_338 41 | BRATS_008 42 | BRATS_122 43 | BRATS_360 44 | BRATS_382 45 | BRATS_229 46 | BRATS_359 47 | BRATS_470 48 | BRATS_163 49 | BRATS_445 50 | BRATS_150 51 | BRATS_295 52 | BRATS_264 53 | BRATS_353 54 | BRATS_340 55 | BRATS_175 56 | BRATS_428 57 | BRATS_235 58 | BRATS_014 59 | BRATS_213 60 | BRATS_403 61 | BRATS_376 62 | BRATS_326 63 | BRATS_079 64 | BRATS_484 65 | BRATS_164 66 | BRATS_086 67 | BRATS_096 68 | BRATS_393 69 | BRATS_136 70 | BRATS_073 71 | BRATS_001 72 | BRATS_013 73 | BRATS_106 74 | BRATS_456 75 | BRATS_066 76 | BRATS_169 77 | BRATS_402 78 | BRATS_196 79 | BRATS_257 80 | BRATS_329 81 | BRATS_236 82 | BRATS_395 83 | BRATS_173 84 | BRATS_396 85 | BRATS_123 86 | BRATS_186 87 | BRATS_044 88 | BRATS_313 89 | BRATS_194 90 | BRATS_171 91 | BRATS_223 92 | BRATS_023 93 | BRATS_279 94 | BRATS_347 95 | BRATS_330 96 | BRATS_333 -------------------------------------------------------------------------------- /brats_split_list/validation_list_cv0.txt: -------------------------------------------------------------------------------- 1 | BRATS_325 2 | BRATS_364 3 | BRATS_413 4 | BRATS_041 5 | BRATS_094 6 | BRATS_174 7 | BRATS_217 8 | BRATS_012 9 | BRATS_147 10 | BRATS_026 11 | BRATS_185 12 | BRATS_450 13 | BRATS_441 14 | BRATS_368 15 | BRATS_151 16 | BRATS_052 17 | BRATS_259 18 | BRATS_133 19 | BRATS_227 20 | BRATS_070 21 | BRATS_423 22 | BRATS_084 23 | BRATS_178 24 | BRATS_316 25 | BRATS_142 26 | BRATS_201 27 | BRATS_309 28 | BRATS_267 29 | BRATS_020 30 | BRATS_463 31 | BRATS_006 32 | BRATS_416 33 | BRATS_362 34 | BRATS_238 35 | BRATS_162 36 | BRATS_391 37 | BRATS_455 38 | BRATS_160 39 | BRATS_252 40 | BRATS_434 41 | BRATS_009 42 | BRATS_024 43 | BRATS_221 44 | BRATS_075 45 | BRATS_243 46 | BRATS_032 47 | BRATS_429 48 | BRATS_027 49 | -------------------------------------------------------------------------------- /configs/brats_boundseg.yml: -------------------------------------------------------------------------------- 1 | dataset_name: msd_bound 2 | image_path: /PATH_TO_NORMALIZED_IMAGES/imagesTr_normalized 3 | label_path: /PATH_TO_LABELS/labelsTr 4 | edge_path: /PATH_TO_EDGE_LABELS/brats_edges 5 | edge_label: True 6 | image_file_format: npz 7 | label_file_format: nii.gz 8 | train_list_path: ./brats_split_list/train_list_cv0.txt 9 | validation_list_path: ./brats_split_list/validation_list_cv0.txt 10 | test_list_path: ./brats_split_list/test_list_cv0.txt 11 | crop_size: (160,192,128) 12 | random_flip: True 13 | in_channels: 4 14 | nb_labels: 4 15 | print_each_dc: True 16 | val_batchsize: 1 17 | structseg_nb_copies: 1 18 | shift_intensity: 0.1 19 | segmentor_name: boundseg 20 | vaeseg_norm: 'GroupNormalization' 21 | seg_lossfun: dice_loss_plus_cross_entropy 22 | init_lr: 0.0001 23 | epoch: 300 24 | snapshot_interval: 100 25 | random_scale: True 26 | eval_interval: 1 27 | is_brats: True 28 | nested_label: True 29 | -------------------------------------------------------------------------------- /configs/brats_cpcseg.yml: -------------------------------------------------------------------------------- 1 | dataset_name: msd_bound 2 | image_path: /PATH_TO_NORMALIZED_IMAGES/imagesTr_normalized 3 | label_path: /PATH_TO_LABELS/labelsTr 4 | edge_path: /PATH_TO_EDGE_LABELS/brats_edges 5 | edge_label: False 6 | image_file_format: npz 7 | label_file_format: nii.gz 8 | train_list_path: ./brats_split_list/train_list_cv0.txt 9 | validation_list_path: ./brats_split_list/validation_list_cv0.txt 10 | test_list_path: ./brats_split_list/test_list_cv0.txt 11 | crop_size: (144,144,128) 12 | random_flip: True 13 | in_channels: 4 14 | nb_labels: 4 15 | print_each_dc: True 16 | val_batchsize: 1 17 | structseg_nb_copies: 1 18 | shift_intensity: 0.1 19 | segmentor_name: cpcseg 20 | vaeseg_norm: 'GroupNormalization' 21 | seg_lossfun: dice_loss_plus_cross_entropy 22 | cpc_vaeseg_cpc_loss_weight: 0.001 # 0.0005 23 | init_lr: 0.0001 24 | epoch: 300 25 | cpc_pattern: 'updown' 26 | snapshot_interval: 100 27 | random_scale: True 28 | eval_interval: 5 29 | is_brats: True 30 | nested_label: True 31 | -------------------------------------------------------------------------------- /configs/brats_encdec.yml: -------------------------------------------------------------------------------- 1 | dataset_name: msd_bound 2 | image_path: /PATH_TO_NORMALIZED_IMAGES/imagesTr_normalized 3 | label_path: /PATH_TO_LABELS/labelsTr 4 | edge_path: /PATH_TO_EDGE_LABELS/brats_edges 5 | edge_label: False 6 | image_file_format: npz 7 | label_file_format: nii.gz 8 | train_list_path: ./brats_split_list/train_list_cv0.txt 9 | validation_list_path: ./brats_split_list/validation_list_cv0.txt 10 | test_list_path: ./brats_split_list/test_list_cv0.txt 11 | crop_size: (160,192,128) 12 | random_flip: True 13 | in_channels: 4 14 | nb_labels: 4 15 | print_each_dc: True 16 | val_batchsize: 1 17 | structseg_nb_copies: 1 18 | shift_intensity: 0.1 19 | segmentor_name: encdec_seg 20 | vaeseg_norm: 'GroupNormalization' 21 | seg_lossfun: dice_loss_plus_cross_entropy 22 | init_lr: 0.0001 23 | epoch: 300 24 | snapshot_interval: 100 25 | random_scale: True 26 | eval_interval: 5 27 | is_brats: True 28 | nested_label: True 29 | -------------------------------------------------------------------------------- /configs/brats_semi_cpcseg.yml: -------------------------------------------------------------------------------- 1 | dataset_name: msd_bound 2 | image_path: /PATH_TO_NORMALIZED_IMAGES/imagesTr_normalized 3 | label_path: /PATH_TO_LABELS/labelsTr 4 | edge_path: /PATH_TO_EDGE_LABELS/brats_edges 5 | edge_label: False 6 | image_file_format: npz 7 | label_file_format: nii.gz 8 | train_list_path: ./brats_split_list/train_list_cv0.txt 9 | ignore_path: ./brats_split_list/train_list_cv0_3samples1.txt 10 | validation_list_path: ./brats_split_list/validation_list_cv0.txt 11 | test_list_path: ./brats_split_list/test_list_cv0.txt 12 | crop_size: (144,144,128) 13 | random_flip: True 14 | in_channels: 4 15 | nb_labels: 4 16 | print_each_dc: True 17 | val_batchsize: 1 18 | structseg_nb_copies: 1 19 | shift_intensity: 0.1 20 | segmentor_name: cpcseg 21 | vaeseg_norm: 'GroupNormalization' 22 | seg_lossfun: dice_loss_plus_cross_entropy 23 | cpc_vaeseg_cpc_loss_weight: 0.001 # 0.0005 24 | init_lr: 0.0001 25 | epoch: 300 26 | cpc_pattern: 'updown' 27 | snapshot_interval: 100 28 | random_scale: True 29 | eval_interval: 2 30 | is_brats: True 31 | report_interval: 10 32 | vae_idle_weight: 1 33 | nested_label: True 34 | -------------------------------------------------------------------------------- /configs/brats_semi_vaeseg.yml: -------------------------------------------------------------------------------- 1 | dataset_name: msd_bound 2 | image_path: /PATH_TO_NORMALIZED_IMAGES/imagesTr_normalized 3 | label_path: /PATH_TO_LABELS/labelsTr 4 | edge_path: /PATH_TO_EDGE_LABELS/brats_edges 5 | edge_label: False 6 | image_file_format: npz 7 | label_file_format: nii.gz 8 | train_list_path: ./brats_split_list/train_list_cv0.txt 9 | ignore_path: ./brats_split_list/train_list_cv0_3samples1.txt 10 | validation_list_path: ./brats_split_list/validation_list_cv0.txt 11 | test_list_path: ./brats_split_list/test_list_cv0.txt 12 | crop_size: (160,192,128) 13 | random_flip: True 14 | in_channels: 4 15 | nb_labels: 4 16 | print_each_dc: True 17 | val_batchsize: 1 18 | structseg_nb_copies: 1 19 | shift_intensity: 0.1 20 | segmentor_name: vaeseg 21 | vaeseg_norm: 'GroupNormalization' 22 | seg_lossfun: dice_loss_plus_cross_entropy 23 | vaeseg_rec_loss_weight: 0.1 24 | vaeseg_kl_loss_weight: 0.1 25 | init_lr: 0.0001 26 | epoch: 300 27 | snapshot_interval: 100 28 | random_scale: True 29 | eval_interval: 5 30 | is_brats: True 31 | report_interval: 10 32 | vae_idle_weight: 1 33 | nested_label: True 34 | -------------------------------------------------------------------------------- /configs/brats_vaeseg.yml: -------------------------------------------------------------------------------- 1 | dataset_name: msd_bound 2 | image_path: /PATH_TO_NORMALIZED_IMAGES/imagesTr_normalized 3 | label_path: /PATH_TO_LABELS/labelsTr 4 | edge_path: /PATH_TO_EDGE_LABELS/brats_edges 5 | edge_label: False 6 | image_file_format: npz 7 | label_file_format: nii.gz 8 | train_list_path: ./brats_split_list/train_list_cv0.txt 9 | validation_list_path: ./brats_split_list/validation_list_cv0.txt 10 | test_list_path: ./brats_split_list/test_list_cv0.txt 11 | crop_size: (160,192,128) 12 | random_flip: True 13 | in_channels: 4 14 | nb_labels: 4 15 | print_each_dc: True 16 | val_batchsize: 1 17 | structseg_nb_copies: 1 18 | shift_intensity: 0.1 19 | segmentor_name: vaeseg 20 | vaeseg_norm: 'GroupNormalization' 21 | seg_lossfun: dice_loss_plus_cross_entropy 22 | vaeseg_rec_loss_weight: 0.1 23 | vaeseg_kl_loss_weight: 0.1 24 | init_lr: 0.0001 25 | epoch: 300 26 | snapshot_interval: 100 27 | random_scale: False 28 | eval_interval: 1 29 | is_brats: True 30 | nested_label: True 31 | -------------------------------------------------------------------------------- /examples/normalize.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.append('.') # NOQA 3 | from src.datasets.preprocess import normalize 4 | 5 | 6 | def main(root_path=None, arr_type='nii.gz', modality='mri'): 7 | # save normalized npz arrays in root_path/normalized/ 8 | normalize(root_path, arr_type, modality) 9 | 10 | 11 | if __name__ == '__main__': 12 | from fire import Fire 13 | Fire(main) 14 | -------------------------------------------------------------------------------- /examples/split_kfold.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.append('.') # NOQA 3 | from src.datasets.preprocess import split_kfold 4 | 5 | 6 | def main(root_path, arr_type='nii.gz', n_splits=5, random_state=42): 7 | # create split file 8 | split_kfold(root_path, arr_type, n_splits, random_state) 9 | 10 | 11 | if __name__ == '__main__': 12 | from fire import Fire 13 | Fire(main) 14 | -------------------------------------------------------------------------------- /examples/train.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.append('.') # NOQA 3 | import argparse 4 | import yaml 5 | import warnings 6 | import numpy as np 7 | import random 8 | import chainer 9 | from chainer import global_config 10 | 11 | from src.utils.config import overwrite_config 12 | from src.utils.setup_helpers import setup_trainer 13 | 14 | warnings.simplefilter(action='ignore', category=FutureWarning) # NOQA 15 | 16 | 17 | def reset_seed(seed=42): 18 | random.seed(seed) 19 | np.random.seed(seed) 20 | if chainer.cuda.available: 21 | chainer.cuda.cupy.random.seed(seed) 22 | 23 | 24 | def main(): 25 | reset_seed(42) 26 | parser = argparse.ArgumentParser( 27 | description='Medical image segmentation' 28 | ) 29 | parser.add_argument('--config', '-c') 30 | parser.add_argument('--out', default='results', 31 | help='Output directory') 32 | parser.add_argument('--batch_size', '-b', type=int, default=1, 33 | help="Batch size") 34 | parser.add_argument('--epoch', '-e', type=int, default=500, 35 | help="Number of epochs") 36 | parser.add_argument('--gpu_start_id', '-g', type=int, default=0, 37 | help="Start ID of gpu. (negative value indicates cpu)") 38 | args = parser.parse_args() 39 | config = overwrite_config( 40 | yaml.load(open(args.config)), dump_yaml_dir=args.out 41 | ) 42 | 43 | if config['mn']: 44 | global_config.autotune = True 45 | import multiprocessing 46 | multiprocessing.set_start_method('forkserver') 47 | p = multiprocessing.Process(target=print, args=('Initialize forkserver',)) 48 | p.start() 49 | p.join() 50 | 51 | trainer = setup_trainer(config, args.out, args.batch_size, args.epoch, args.gpu_start_id) 52 | trainer.run() 53 | 54 | 55 | if __name__ == '__main__': 56 | main() 57 | -------------------------------------------------------------------------------- /img/architecture.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pfnet-research/label-efficient-brain-tumor-segmentation/aad80ed7acb510a3147bb11c3910d2e17fb355d1/img/architecture.png -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | nibabel==2.3.1 2 | six==1.12.0 3 | numpy==1.16.2 4 | chainer==6.2.0 5 | scikit_image==0.14.2 6 | imageio==2.4.1 7 | cupy_cuda90==6.2.0 8 | fire==0.2.1 9 | pyyaml==3.13 10 | scikit_learn==0.22.1 11 | mpi4py==3.0.0 12 | SimpleITK==1.2.0 13 | pillow==4.3.0 14 | opencv-python==4.1.0.25 -------------------------------------------------------------------------------- /src/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pfnet-research/label-efficient-brain-tumor-segmentation/aad80ed7acb510a3147bb11c3910d2e17fb355d1/src/__init__.py -------------------------------------------------------------------------------- /src/datasets/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pfnet-research/label-efficient-brain-tumor-segmentation/aad80ed7acb510a3147bb11c3910d2e17fb355d1/src/datasets/__init__.py -------------------------------------------------------------------------------- /src/datasets/generate_msd_edge.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.append('.') # NOQA 3 | import copy 4 | import numpy as np 5 | import nibabel as nib 6 | import os 7 | import chainer 8 | from chainer.backends.cuda import get_device_from_id 9 | import chainer.links as L 10 | import cupy 11 | 12 | from src.utils.encode_one_hot_vector import encode_one_hot_vector 13 | from src.utils.setup_helpers import _setup_communicator 14 | 15 | chainer.backends.cuda.cudnn_enabled = True 16 | chainer.config.use_cudnn = 'always' 17 | chainer.cudnn_deterministic = True 18 | 19 | # generates edge information for structseg2019 20 | # dataset using 3D convolution with a Laplacian kernel 21 | dataset = 'nested_brats' 22 | nested = True 23 | 24 | if dataset == 'nested_brats': 25 | print('brats with nested labels!') 26 | label_path = '/PATH_TO_LABELS/labelsTr/' 27 | edge_path = './brats_nested_edges/' 28 | num_range = 484 29 | nb_class = 4 30 | else: 31 | print('Could not identify dataset') 32 | 33 | if not os.path.isdir(edge_path): 34 | raise FileNotFoundError('No such directory: {}'.format(edge_path)) 35 | 36 | # kernel for 3d Laplacian 37 | k = np.array([ 38 | [[0, 0, 0], 39 | [0, 1, 0], 40 | [0, 0, 0]], 41 | [[0, 1, 0], 42 | [1, -6, 1], 43 | [0, 1, 0]], 44 | [[0, 0, 0], 45 | [0, 1, 0], 46 | [0, 0, 0]]], dtype=np.float32) 47 | 48 | k = k.reshape((1, 1, 3, 3, 3)) 49 | k = k.transpose((0, 1, 3, 4, 2)) 50 | 51 | default_config = { 52 | 'mn': True, 53 | 'gpu_start_id': 0, 54 | } 55 | 56 | config = copy.copy(default_config) 57 | comm, is_master, device = _setup_communicator(config, gpu_start_id=0) 58 | 59 | get_device_from_id(device).use() 60 | conv3d = L.Convolution3D(in_channels=1, out_channels=1, 61 | ksize=3, stride=1, pad=1, initialW=k) 62 | conv3d.to_gpu() 63 | 64 | with chainer.no_backprop_mode(), chainer.using_config('train', False): 65 | for num in range(1, num_range+1): 66 | if dataset == 'nested_brats': 67 | num = "{0:0=3d}".format(num) 68 | path = label_path+'BRATS_'+str(num)+'.nii.gz' 69 | try: 70 | nii_img = nib.load(path) 71 | affine = nii_img.affine 72 | img = nii_img.get_data() 73 | shape = img.shape 74 | 75 | edge = np.zeros((1, nb_class) + shape) 76 | img = img.reshape((1,)+shape) # (h,w,d) --> (batch,h,w,d) 77 | img = encode_one_hot_vector(img, nb_class=nb_class) # (batch,channel,h,w,d) 78 | img = cupy.asarray(img, dtype=np.float32) 79 | for c in range(nb_class): 80 | if nested: 81 | if c == 1: 82 | wt = (img[0, 1, :, :, :]).astype(bool)+(img[0, 2, :, :, :]).astype(bool)\ 83 | + (img[0, 3, :, :, :]).astype(bool) 84 | crop_img = wt.reshape((1, 1)+shape) 85 | elif c == 2: 86 | tc = (img[0, 2, :, :, :]).astype(bool)+(img[0, 3, :, :, :]).astype(bool) 87 | crop_img = tc.reshape((1, 1)+shape) 88 | elif c == 3: 89 | et = (img[0, 3, :, :, :]).astype(bool) 90 | crop_img = et.reshape((1, 1)+shape) 91 | else: 92 | crop_img = img[0, c, :, :, :].reshape((1, 1)+shape) 93 | else: 94 | crop_img = img[0, c, :, :, :].reshape((1, 1)+shape) 95 | temp_img = conv3d(crop_img.astype(np.float32)).data[0, 0, :, :, :] 96 | edge[0, c, :, :, :] = (-(cupy.asnumpy(temp_img)) > 0.).astype(int) 97 | 98 | edge_labels_nii = nib.Nifti1Image(edge, affine) 99 | nib.save(edge_labels_nii, edge_path+'edge_'+str(num)+'.nii.gz') 100 | print(num) 101 | except: 102 | print('No data for '+str(num)+'.nii.gz') 103 | -------------------------------------------------------------------------------- /src/datasets/msd_bound.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | from chainer.dataset import DatasetMixin 4 | from src.datasets.preprocess import read_image, random_crop_pair 5 | from src.functions.array.resize_images_3d import resize_images_3d 6 | from src.functions.array.resize_images_3d_nearest_neighbor import resize_images_3d_nearest_neighbor 7 | from numpy.random import rand 8 | import chainer 9 | import chainer.functions as F 10 | 11 | 12 | class MSDBoundDataset(DatasetMixin): 13 | """Edited MSD dataset for 4-channel labels (for boundary aware networks)""" 14 | def __init__(self, config, split_file): 15 | self.crop_size = eval(config['crop_size']) 16 | self.shift_intensity = config['shift_intensity'] 17 | self.random_flip = config['random_flip'] 18 | self.target_label = config['target_label'] 19 | self.nb_class = config['nb_labels'] 20 | self.unet_cropping = config['unet_cropping'] 21 | self.random_scale = config['random_scale'] 22 | self.training = True 23 | with open(split_file) as f: 24 | self.split_list = [line.rstrip() for line in f] 25 | image_list = [os.path.join(config['image_path'], name) for name in self.split_list] 26 | label_list = [os.path.join(config['label_path'], name) for name in self.split_list] 27 | edge_list = [os.path.join(config['edge_path'], 'edge_'+str(name)[-3:]) 28 | for name in self.split_list] 29 | self.edge_label = config['edge_label'] 30 | assert len(image_list) == len(label_list) 31 | self.pair_list = list(zip(image_list, label_list, edge_list)) 32 | self.nb_copies = config['structseg_nb_copies'] 33 | self.image_file_format = config['image_file_format'] 34 | self.ignore_path = config['ignore_path'] 35 | self.nested_label = config['nested_label'] 36 | 37 | def __len__(self): 38 | return len(self.pair_list) * self.nb_copies 39 | 40 | def _get_image(self, i): 41 | i = i % len(self.pair_list) 42 | image = read_image(self.pair_list[i][0], self.image_file_format) 43 | image = image.transpose((3, 0, 1, 2)) 44 | return image 45 | 46 | def _get_label(self, i): 47 | i = i % len(self.pair_list) 48 | label = read_image(self.pair_list[i][1], 'nii.gz') 49 | if self.training and (self.ignore_path is not None): 50 | # use samples without labels and use subtask branch only 51 | ignore_list = open(self.ignore_path).read().splitlines() 52 | if self.pair_list[i][1][-9:] not in ignore_list: 53 | # giving NaN label when the sample is not in the ignore_list 54 | # ignore_list: list of samples which we DO NOT ignore labels 55 | label = np.empty(label.shape) 56 | label[:] = np.nan 57 | if self.unet_cropping: 58 | label = label[20:-20, 20:-20, 20:-20] 59 | return label 60 | 61 | def _get_edge(self, i): 62 | i = i % len(self.pair_list) 63 | edge = read_image(self.pair_list[i][2], 'nii.gz') 64 | edge = edge[0, :, :, :, :] 65 | if self.unet_cropping: 66 | edge = edge[20:-20, 20:-20, 20:-20] 67 | return edge 68 | 69 | def _crop(self, x, y): 70 | return random_crop_pair(x, y, self.crop_size) 71 | 72 | def _change_intensity(self, x): 73 | for i in range(len(x)): 74 | diff = 2 * self.shift_intensity * rand() - self.shift_intensity 75 | x[i] += diff 76 | return x 77 | 78 | @staticmethod 79 | def _flip(x, y): 80 | if np.random.random() < 0.5: 81 | x = x[:, ::-1, :, :] 82 | y = y[:, ::-1, :, :] 83 | if np.random.random() < 0.5: 84 | x = x[:, :, ::-1, :] 85 | y = y[:, :, ::-1, :] 86 | if np.random.random() < 0.5: 87 | x = x[:, :, :, ::-1] 88 | y = y[:, :, :, ::-1] 89 | return x, y 90 | 91 | @staticmethod 92 | def _binarize_label(y, t_idx): 93 | return (y >= t_idx).astype(np.int32) 94 | 95 | @staticmethod 96 | def _binarize_brats(y, dc=1): 97 | if (np.sum(y) < 0) or (np.isnan(y).any()): 98 | # when using ignore list (semi-supervised VAEseg) 99 | return y 100 | else: 101 | if dc == 1: 102 | whole_tumor = (y >= 1).astype(np.int32) 103 | return whole_tumor 104 | elif dc == 2: 105 | tumor_core = ((y == 3) + (y == 2)).astype(np.int32) 106 | return tumor_core 107 | elif dc == 3: 108 | enhancing_tumor = (y == 3).astype(np.int32) 109 | return enhancing_tumor 110 | else: 111 | print('Error for binarization') 112 | return y 113 | 114 | @staticmethod 115 | def _rand_scale(x, y): 116 | # random scale the image [0.9,1.1) for augmentation 117 | xx = x.reshape((1,)+x.shape) 118 | y = y.astype(np.float32) 119 | 120 | rand_scale = 0.9+rand()*0.2 121 | orig_shape = np.array(xx.shape) 122 | scaled_shape = orig_shape[2:]*rand_scale 123 | scaled_shape = np.round(scaled_shape).astype(int) 124 | 125 | resized_x = resize_images_3d(xx, scaled_shape) 126 | resized_y = resize_images_3d_nearest_neighbor(y.reshape((1,)+y.shape), 127 | scaled_shape)[0, :, :, :, :] 128 | resized_y = chainer.Variable(resized_y) 129 | if rand_scale < 1: 130 | # zero-padding if the image is scaled down 131 | pad_w = orig_shape[2] - scaled_shape[0] 132 | pad_h = orig_shape[3] - scaled_shape[1] 133 | pad_d = orig_shape[4] - scaled_shape[2] 134 | 135 | w_half = int(round(pad_w/2)) 136 | h_half = int(round(pad_h/2)) 137 | d_half = int(round(pad_d/2)) 138 | 139 | pad_width_x = ((0, 0), (0, 0), (w_half, pad_w-w_half), 140 | (h_half, pad_h-h_half), (d_half, pad_d-d_half)) 141 | pad_width_y = ((0, 0), (w_half, pad_w-w_half), 142 | (h_half, pad_h-h_half), (d_half, pad_d-d_half)) 143 | resized_x = F.pad(resized_x, pad_width=pad_width_x, mode='constant', constant_values=0.) 144 | resized_y = F.pad(resized_y, pad_width=pad_width_y, mode='constant', constant_values=0.) 145 | assert resized_x.shape[2:] == resized_y.shape[1:] 146 | resized_x = resized_x[0, :, :, :, :] 147 | 148 | return resized_x.data, resized_y.data.astype(np.int32) 149 | 150 | def get_example(self, i): 151 | x = self._get_image(i) 152 | y = self._get_label(i) 153 | if self.edge_label: 154 | y = y.reshape((1,)+y.shape) 155 | ye = self._get_edge(i) 156 | y = np.append(y, ye, axis=0) 157 | else: 158 | y = y.reshape((1,)+y.shape) 159 | 160 | if self.random_scale: 161 | x, y = self._rand_scale(x, y) 162 | if self.training: 163 | x, y = self._crop(x, y) 164 | if self.shift_intensity > 0.: 165 | x = self._change_intensity(x) 166 | if self.random_flip: 167 | x, y = self._flip(x, y) 168 | if self.target_label: 169 | assert self.nb_class == 2 170 | y = self._binarize_label(y, self.target_label) 171 | elif self.nested_label: 172 | ys = [] 173 | for i in range(self.nb_class-1): 174 | ys.append(self._binarize_brats(y[0, :, :, :], i+1)) 175 | if self.edge_label: 176 | return x.astype(np.float32),\ 177 | np.array(ys).astype(np.int32), y[1:, :, :, :].astype(np.int32) 178 | else: 179 | return x.astype(np.float32), np.array(ys).astype(np.int32) 180 | 181 | if self.edge_label: 182 | return x.astype(np.float32), y[0, :, :, :].astype(np.int32),\ 183 | y[1:, :, :, :].astype(np.int32) 184 | else: 185 | return x.astype(np.float32), y[0, :, :, :].astype(np.int32) 186 | -------------------------------------------------------------------------------- /src/datasets/preprocess.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | import numpy as np 4 | import nibabel as nib 5 | import imageio 6 | from skimage.transform import resize 7 | from sklearn.model_selection import KFold, train_test_split 8 | 9 | 10 | def read_image(path, file_format='nii.gz'): 11 | """ read image array from path 12 | Args: 13 | path (str) : path to directory which images are stored. 14 | file_format (str) : type of reading file {'npy','npz','jpg','png','nii'(3d)} 15 | Returns: 16 | image (np.ndarray) : image array 17 | """ 18 | path = path + '.' + file_format 19 | if file_format == 'npy': 20 | image = np.load(path) 21 | elif file_format == 'npz': 22 | image = np.load(path)['arr_0'] 23 | elif file_format in ('png', 'jpg'): 24 | image = np.array(imageio.imread(path)) 25 | elif file_format == 'dcm': 26 | image = np.array(imageio.volread(path, 'DICOM')) 27 | elif file_format in ('nii', 'nii.gz'): 28 | image = nib.load(path).get_data() 29 | else: 30 | raise ValueError('invalid --input_type : {}'.format(file_format)) 31 | 32 | return image 33 | 34 | 35 | def random_crop_pair(image, label, crop_size=(160, 192, 128)): 36 | H, W, D = crop_size 37 | 38 | if image.shape[1] >= H: 39 | top = random.randint(0, image.shape[1] - H) 40 | else: 41 | raise ValueError('shape of image needs to be larger than output shape') 42 | h_slice = slice(top, top + H) 43 | 44 | if image.shape[2] >= W: 45 | left = random.randint(0, image.shape[2] - W) 46 | else: 47 | raise ValueError('shape of image needs to be larger than output shape') 48 | w_slice = slice(left, left + W) 49 | 50 | if image.shape[3] >= D: 51 | rear = random.randint(0, image.shape[3] - D) 52 | else: 53 | raise ValueError('shape of image needs to be larger than output shape') 54 | d_slice = slice(rear, rear + D) 55 | 56 | image = image[:, h_slice, w_slice, d_slice] 57 | if label.ndim == 4: 58 | label = label[:, h_slice, w_slice, d_slice] 59 | else: 60 | label = label[h_slice, w_slice, d_slice] 61 | return image, label 62 | 63 | 64 | def compute_stats(root_path, arr_type='nii.gz', modality='mri'): 65 | files = os.listdir(root_path) 66 | means = [] 67 | stds = [] 68 | for file_i in files: 69 | img = read_image(os.path.join(root_path, file_i), arr_type) 70 | if len(img.shape) == 3: 71 | img = np.expand_dims(img, axis=-1) 72 | img = img.transpose((3, 0, 1, 2)) 73 | if modality == 'ct': 74 | np.clip(img, -1000., 1000., out=img) 75 | img += 1000. 76 | mean = [] 77 | std = [] 78 | for i in range(len(img)): 79 | mean.append(np.mean(img[i][img[i].nonzero()])) 80 | if modality == 'ct': 81 | mean[i] -= 1000. 82 | std.append(np.std(img[i][img[i].nonzero()])) 83 | means.append(mean) 84 | stds.append(std) 85 | return np.mean(np.array(means), axis=0), np.mean(np.array(stds), axis=0) 86 | 87 | 88 | def normalize(root_path, arr_type='nii.gz', modality='mri'): 89 | files = os.listdir(root_path) 90 | if arr_type == 'nii.gz': 91 | files = ['.'.join(file_i.split('.')[:-2]) for file_i in files] 92 | else: 93 | files = [os.path.splitext(file_i)[0] for file_i in files] 94 | base_name = os.path.basename(root_path) 95 | normalized_path = root_path.replace(base_name, base_name+"_normalized") 96 | if not os.path.exists(normalized_path): 97 | os.makedirs(normalized_path, exist_ok=True) 98 | for file_i in files: 99 | img = read_image(os.path.join(root_path, file_i), arr_type) 100 | if len(img.shape) == 3: 101 | img = np.expand_dims(img, axis=-1) 102 | img = img.transpose((3, 0, 1, 2)).astype(np.float32) 103 | if modality == 'ct': 104 | np.clip(img, -1000., 1000., out=img) 105 | img += 1000. 106 | for i in range(len(img)): 107 | mean = np.mean(img[i][img[i].nonzero()]) 108 | std = np.std(img[i][img[i].nonzero()]) 109 | img[i] = (img[i] - mean) / std 110 | img = img.transpose((1, 2, 3, 0)) 111 | np.savez_compressed(os.path.join(normalized_path, file_i.replace(arr_type, 'npz')), img) 112 | 113 | 114 | def resample_isotropic_nifti(root_path, arr_type='nii.gz', is_label=False): 115 | files = os.listdir(root_path) 116 | if arr_type == 'nii.gz': 117 | files = ['.'.join(file_i.split('.')[:-2]) for file_i in files] 118 | else: 119 | files = [os.path.splitext(file_i)[0] for file_i in files] 120 | base_name = os.path.basename(root_path) 121 | resampled_path = root_path.replace(base_name, base_name+"_resampled") 122 | order = 0 if is_label else 1 123 | if not os.path.exists(resampled_path): 124 | os.makedirs(resampled_path, exist_ok=True) 125 | for file_i in files: 126 | img = read_image(os.path.join(root_path, file_i), 'nii.gz') 127 | spacing = nib.load(os.path.join(root_path, file_i)).header['pixdim'][:4] 128 | img = img.transpose((3, 0, 1, 2)).astype(np.float32) 129 | img_shape = img.shape 130 | target_shape = tuple([int(img_shape[i] / spacing[i]) for i in range(1, 4)]) 131 | resampled_img = [] 132 | for i in range(len(img)): 133 | resampled_img.append(resize(img[i], target_shape, order, mode='constant')) 134 | resampled_img = np.array(resampled_img).transpose((1, 2, 3, 0)) 135 | np.savez_compressed( 136 | os.path.join(resampled_path, file_i.replace('nii.gz', 'npz')), 137 | resampled_img) 138 | 139 | 140 | def split_kfold(root_path, arr_type='nii.gz', n_splits=5, random_state=42): 141 | data_list = os.listdir(root_path) 142 | if arr_type == 'nii.gz': 143 | data_list = ['.'.join(data.split('.')[:-2]) for data in data_list] 144 | else: 145 | data_list = [os.path.splitext(data)[0] for data in data_list] 146 | save_path = root_path.replace(os.path.basename(root_path), 'split_list') 147 | if not os.path.exists(save_path): 148 | os.makedirs(save_path, exist_ok=True) 149 | # k-fold 150 | kf = KFold( 151 | n_splits=n_splits, shuffle=True, random_state=random_state) 152 | for i, (train_index, val_test_index) in enumerate(kf.split(data_list)): 153 | val_index, test_index = \ 154 | train_test_split(val_test_index, test_size=0.5, shuffle=True, random_state=random_state) 155 | val_index = sorted(val_index) 156 | test_index = sorted(test_index) 157 | f = open(os.path.join(save_path, 'train_list_cv{}.txt'.format(i)), "w") 158 | for j in train_index: 159 | f.write(str(data_list[j]) + "\n") 160 | f.close() 161 | f = open(os.path.join(save_path, 'validation_list_cv{}.txt'.format(i)), "w") 162 | for j in val_index: 163 | f.write(str(data_list[j]) + "\n") 164 | f.close() 165 | f = open(os.path.join(save_path, 'test_list_cv{}.txt'.format(i)), "w") 166 | for j in test_index: 167 | f.write(str(data_list[j]) + "\n") 168 | f.close() 169 | -------------------------------------------------------------------------------- /src/functions/array/__init__.py: -------------------------------------------------------------------------------- 1 | from src.functions.array.resize_images_3d import resize_images_3d # NOQA 2 | -------------------------------------------------------------------------------- /src/functions/array/resize_images_3d.py: -------------------------------------------------------------------------------- 1 | import numpy 2 | 3 | from chainer.backends import cuda 4 | from chainer import function_node 5 | from chainer.utils import type_check 6 | 7 | 8 | class ResizeImages3D(function_node.FunctionNode): 9 | 10 | def __init__(self, output_shape): 11 | self.out_H = output_shape[0] 12 | self.out_W = output_shape[1] 13 | self.out_D = output_shape[2] 14 | 15 | def check_type_forward(self, in_types): 16 | n_in = in_types.size() 17 | type_check.expect(n_in == 1) 18 | 19 | x_type = in_types[0] 20 | type_check.expect( 21 | x_type.dtype.char == 'f', 22 | x_type.ndim == 5 23 | ) 24 | 25 | def forward(self, inputs): 26 | x, = inputs 27 | xp = cuda.get_array_module(x) 28 | 29 | B, C, H, W, D = x.shape 30 | 31 | v_1d = xp.linspace(0, H - 1, num=self.out_H) 32 | u_1d = xp.linspace(0, W - 1, num=self.out_W) 33 | t_1d = xp.linspace(0, D - 1, num=self.out_D) 34 | grid = xp.meshgrid(v_1d, u_1d, t_1d, indexing='ij') 35 | v = grid[0].ravel() 36 | u = grid[1].ravel() 37 | t = grid[2].ravel() 38 | 39 | v0 = xp.floor(v).astype(numpy.int32) 40 | v0 = v0.clip(0, H - 2) 41 | u0 = xp.floor(u).astype(numpy.int32) 42 | u0 = u0.clip(0, W - 2) 43 | t0 = xp.floor(t).astype(numpy.int32) 44 | t0 = t0.clip(0, D - 2) 45 | 46 | y = 0 47 | for i in range(2): 48 | for j in range(2): 49 | for k in range(2): 50 | wv = xp.abs(v0 + (1 - i) - v) 51 | wu = xp.abs(u0 + (1 - j) - u) 52 | wt = xp.abs(t0 + (1 - k) - t) 53 | w = (wv * wu * wt).astype(x.dtype, copy=False) 54 | y += w[None, None, :] * x[:, :, v0 + i, u0 + j, t0 + k] 55 | y = y.reshape(B, C, self.out_H, self.out_W, self.out_D) 56 | return y, 57 | 58 | def backward(self, indexes, grad_outputs): 59 | return ResizeImagesGrad3D( 60 | self.inputs[0].shape, 61 | (self.out_H, self.out_W, self.out_D)).apply(grad_outputs) 62 | 63 | 64 | class ResizeImagesGrad3D(function_node.FunctionNode): 65 | 66 | def __init__(self, input_shape, output_shape): 67 | self.out_H = output_shape[0] 68 | self.out_W = output_shape[1] 69 | self.out_D = output_shape[2] 70 | self.input_shape = input_shape 71 | 72 | def check_type_forward(self, in_types): 73 | n_in = in_types.size() 74 | type_check.expect(n_in == 1) 75 | 76 | x_type = in_types[0] 77 | type_check.expect( 78 | x_type.dtype.char == 'f', 79 | x_type.ndim == 5 80 | ) 81 | 82 | def forward(self, inputs): 83 | xp = cuda.get_array_module(*inputs) 84 | gy, = inputs 85 | 86 | B, C, H, W, D = self.input_shape 87 | 88 | v_1d = xp.linspace(0, H - 1, num=self.out_H) 89 | u_1d = xp.linspace(0, W - 1, num=self.out_W) 90 | t_1d = xp.linspace(0, D - 1, num=self.out_D) 91 | grid = xp.meshgrid(v_1d, u_1d, t_1d, indexing='ij') 92 | v = grid[0].ravel() 93 | u = grid[1].ravel() 94 | t = grid[2].ravel() 95 | 96 | v0 = xp.floor(v).astype(numpy.int32) 97 | v0 = v0.clip(0, H - 2) 98 | u0 = xp.floor(u).astype(numpy.int32) 99 | u0 = u0.clip(0, W - 2) 100 | t0 = xp.floor(t).astype(numpy.int32) 101 | t0 = t0.clip(0, D - 2) 102 | 103 | if xp is numpy: 104 | scatter_add = numpy.add.at 105 | else: 106 | scatter_add = cuda.cupyx.scatter_add 107 | 108 | gx = xp.zeros(self.input_shape, dtype=gy.dtype) 109 | gy = gy.reshape(B, C, -1) 110 | for i in range(2): 111 | for j in range(2): 112 | for k in range(2): 113 | wv = xp.abs(v0 + (1 - i) - v) 114 | wu = xp.abs(u0 + (1 - j) - u) 115 | wt = xp.abs(t0 + (1 - k) - t) 116 | w = (wv * wu * wt).astype(gy.dtype, copy=False) 117 | scatter_add( 118 | gx, 119 | (slice(None), slice(None), v0 + i, u0 + j, t0 + k), 120 | gy * w) 121 | return gx, 122 | 123 | def backward(self, indexes, grad_outputs): 124 | return ResizeImages3D( 125 | (self.out_H, self.out_W, self.out_D)).apply(grad_outputs) 126 | 127 | 128 | def resize_images_3d(x, output_shape): 129 | """Resize images to the given shape. 130 | This function resizes 3D data to :obj:`output_shape`. 131 | Currently, only bilinear interpolation is supported as the sampling method. 132 | Notation: here is a notation for dimensionalities. 133 | - :math:`n` is the batch size. 134 | - :math:`c_I` is the number of the input channels. 135 | - :math:`h`, :math:`w` and :math:`d` are the height, width and depth of the 136 | input image, respectively. 137 | - :math:`h_O`, :math:`w_O` and :math:`d_O` are the height, width and depth 138 | of the output image. 139 | Args: 140 | x (~chainer.Variable): 141 | Input variable of shape :math:`(n, c_I, h, w, d)`. 142 | output_shape (tuple): 143 | This is a tuple of length 3 whose values are :obj:`(h_O, w_O, d_O)` 144 | Returns: 145 | ~chainer.Variable: Resized image whose shape is \ 146 | :math:`(n, c_I, h_O, w_O, d_O)`. 147 | """ 148 | return ResizeImages3D(output_shape).apply((x,))[0] 149 | -------------------------------------------------------------------------------- /src/functions/array/resize_images_3d_nearest_neighbor.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from typing import Tuple 3 | 4 | from chainer.backends import cuda 5 | from chainer.types import NdArray 6 | 7 | 8 | def resize_images_3d_nearest_neighbor(x: NdArray, output_shape: Tuple[int, int, int]) -> NdArray: 9 | # Note: this function is not differentiable 10 | # Resize from (H, W, D) to (H', W', D') 11 | out_H, out_W, out_D = output_shape 12 | xp = cuda.get_array_module(x) 13 | 14 | B, C, H, W, D = x.shape 15 | v_1d = xp.linspace(0, H - 1, num=out_H) 16 | u_1d = xp.linspace(0, W - 1, num=out_W) 17 | t_1d = xp.linspace(0, D - 1, num=out_D) 18 | grid = xp.meshgrid(v_1d, u_1d, t_1d, indexing='ij') 19 | v = grid[0].ravel() # (H'W'D',) 20 | u = grid[1].ravel() 21 | t = grid[2].ravel() 22 | 23 | v0 = xp.floor(v).astype(np.int32) 24 | v0 = v0.clip(0, H - 2) 25 | u0 = xp.floor(u).astype(np.int32) 26 | u0 = u0.clip(0, W - 2) 27 | t0 = xp.floor(t).astype(np.int32) 28 | t0 = t0.clip(0, D - 2) 29 | 30 | ws, xs = [], [] 31 | for i in range(2): 32 | for j in range(2): 33 | for k in range(2): 34 | wv = xp.abs(v0 + (1 - i) - v) 35 | wu = xp.abs(u0 + (1 - j) - u) 36 | wt = xp.abs(t0 + (1 - k) - t) 37 | w = wv * wu * wt 38 | ws.append(w) 39 | xs.append(x[:, :, v0 + i, u0 + j, t0 + k]) 40 | ws = xp.stack(ws) # (8, H'W'D') 41 | xs = xp.stack(xs) # (8, B, C, H'W'D') 42 | xs = xs.transpose((0, 3, 1, 2)) # (8, H'W'D', B, C) 43 | 44 | target_indices = xp.argmax(ws, axis=0) # (H'W'D',) 45 | y = xs[target_indices, np.arange(out_H * out_W * out_D)] # (H'W'D', B, C) 46 | y = y.transpose((1, 2, 0)).reshape((B, C, out_H, out_W, out_D)) # (B, C, H', W', D') 47 | return y 48 | -------------------------------------------------------------------------------- /src/functions/evaluation/__init__.py: -------------------------------------------------------------------------------- 1 | from src.functions.evaluation.dice_coefficient import dice_coefficient # NOQA 2 | from src.functions.evaluation.dice_coefficient import mean_dice_coefficient # NOQA 3 | -------------------------------------------------------------------------------- /src/functions/evaluation/dice_coefficient.py: -------------------------------------------------------------------------------- 1 | from six import moves 2 | import numpy as np 3 | 4 | from chainer import function 5 | from chainer import functions as F 6 | from chainer.backends import cuda 7 | from chainer.utils import type_check 8 | 9 | 10 | class DiceCoefficient(function.Function): 11 | 12 | def __init__(self, ret_nan=True, dataset='task8hepatic', eps=1e-7, is_brats=False): 13 | self.ret_nan = ret_nan # return NaN if union==0 14 | self.dataset = dataset 15 | self.eps = eps 16 | self.is_brats = is_brats 17 | 18 | def check_type_forward(self, in_types): 19 | type_check.expect(in_types.size() == 2) 20 | x_type, t_type = in_types 21 | 22 | type_check.expect( 23 | x_type.dtype.kind == 'f', 24 | t_type.dtype == np.int32 25 | ) 26 | 27 | t_ndim = type_check.eval(t_type.ndim) 28 | type_check.expect( 29 | x_type.ndim >= t_type.ndim, 30 | x_type.shape[0] == t_type.shape[0], 31 | x_type.shape[2: t_ndim + 1] == t_type.shape[1:] 32 | ) 33 | for i in moves.range(t_ndim + 1, type_check.eval(x_type.ndim)): 34 | type_check.expect(x_type.shape[i] == 1) 35 | 36 | def forward(self, inputs): 37 | """ 38 | compute average Dice coefficient between two label images 39 | Math: 40 | DC = \frac{2|A\cap B|}{|A|+|B|} 41 | Args: 42 | inputs ((array_like, array_like)): 43 | Input pair (prediction, ground_truth) 44 | Returns: 45 | dice (float) 46 | """ 47 | xp = cuda.get_array_module(*inputs) 48 | y, t = inputs 49 | number_class = y.shape[1] 50 | y = y.argmax(axis=1).reshape(t.shape) 51 | axis = tuple(range(t.ndim)) 52 | 53 | dice = xp.zeros(number_class, dtype=xp.float32) 54 | for i in range(number_class): 55 | if self.is_brats: 56 | if i == 1: 57 | # Enhancing tumor 58 | y_match = xp.equal(y, 3).astype(xp.float32) 59 | t_match = xp.equal(t, 3).astype(xp.float32) 60 | elif i == 2: 61 | # Tumor core 62 | y_match = (xp.equal(y, 2)+xp.equal(y, 3)).astype(xp.float32) 63 | t_match = (xp.equal(t, 2)+xp.equal(t, 3)).astype(xp.float32) 64 | elif i == 3: 65 | # Whole Tumor 66 | y_match = (xp.equal(y, 1)+xp.equal(y, 2)+xp.equal(y, 3)).astype(xp.float32) 67 | t_match = (xp.equal(t, 1)+xp.equal(t, 2)+xp.equal(t, 3)).astype(xp.float32) 68 | else: 69 | y_match = xp.equal(y, i).astype(xp.float32) 70 | t_match = xp.equal(t, i).astype(xp.float32) 71 | else: 72 | y_match = xp.equal(y, i).astype(xp.float32) 73 | t_match = xp.equal(t, i).astype(xp.float32) 74 | 75 | intersect = xp.sum(y_match * t_match, axis=axis) 76 | union = xp.sum(y_match, axis=axis) + xp.sum(t_match, axis=axis) 77 | if union == 0.: 78 | intersect += 0.5*self.eps 79 | union += self.eps 80 | dice[i] = 0.0 81 | if self.ret_nan: 82 | dice[i] = np.nan 83 | else: 84 | dice[i] = 2.0 * (intersect / union) 85 | 86 | return xp.asarray(dice, dtype=xp.float32), 87 | 88 | 89 | def dice_coefficient(y, t, ret_nan=False, dataset='task8hepatic', eps=1e-7, is_brats=False): 90 | return DiceCoefficient(ret_nan=ret_nan, dataset=dataset, eps=eps, is_brats=is_brats)(y, t) 91 | 92 | 93 | def mean_dice_coefficient(dice_coefficients, ret_nan=True): 94 | if ret_nan: 95 | xp = cuda.get_array_module(dice_coefficients) 96 | selector = ~xp.isnan(dice_coefficients.data) 97 | dice_coefficients = F.get_item(dice_coefficients, selector) 98 | return F.mean(dice_coefficients, keepdims=True) 99 | -------------------------------------------------------------------------------- /src/functions/loss/__init__.py: -------------------------------------------------------------------------------- 1 | from src.functions.loss.dice_loss import dice_loss # NOQA 2 | from src.functions.loss.dice_loss import softmax_dice_loss # NOQA 3 | from src.functions.loss.generalized_dice_loss import generalized_dice_loss # NOQA 4 | from src.functions.loss.generalized_dice_loss import softmax_generalized_dice_loss # NOQA 5 | 6 | from src.functions.loss.dice_loss import DiceLoss # NOQA 7 | from src.functions.loss.generalized_dice_loss import GeneralizedDiceLoss # NOQA 8 | -------------------------------------------------------------------------------- /src/functions/loss/boundary_bce.py: -------------------------------------------------------------------------------- 1 | import chainer 2 | from chainer import function 3 | from chainer.backends import cuda 4 | from chainer.utils import type_check, force_array 5 | from chainer.functions import softmax_cross_entropy 6 | 7 | 8 | class BoundaryBCE(function.Function): 9 | def __init__(self, eps=1e-7): 10 | self.eps = eps 11 | 12 | def check_type_forward(self, in_types): 13 | type_check.expect(in_types.size() == 2) 14 | x_type, t_type = in_types 15 | 16 | type_check.expect( 17 | x_type.dtype.kind == 'f', 18 | 19 | x_type.shape[0] == t_type.shape[0], 20 | x_type.shape[2:] == t_type.shape[2:], 21 | ) 22 | 23 | @staticmethod 24 | def _check_input_values(x, t): 25 | if not (((0 <= t) & 26 | (t < x.shape[1]))).all(): 27 | msg = ('Each label `t` need to satisfy ' 28 | '`0 <= t < x.shape[1]`') 29 | raise ValueError(msg) 30 | 31 | def forward(self, inputs): 32 | xp = cuda.get_array_module(*inputs) 33 | x, t = inputs 34 | if chainer.is_debug(): 35 | self._check_input_values(x, t) 36 | num_class = t.shape[1] 37 | bound_loss = 0. 38 | for i in range(1, num_class): 39 | # convert label to a (0,1)label for each class 40 | tt = t[0, i, :, :, :].astype(xp.int32) 41 | tt = tt.reshape(-1) 42 | beta = xp.sum(tt)/tt.size 43 | class_weight = xp.stack([1-beta, beta]) 44 | 45 | xx = x[:, [0, i], :, :, :] 46 | xx = xx.reshape(-1, 2) 47 | bound_loss += softmax_cross_entropy(xx, tt, class_weight=class_weight) 48 | return force_array(bound_loss.data, ), 49 | 50 | 51 | def boundary_bce(x, t, eps=1e-7): 52 | return BoundaryBCE(eps)(x, t) 53 | -------------------------------------------------------------------------------- /src/functions/loss/cpc_loss.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import chainer 3 | from chainer import function 4 | import chainer.functions as F 5 | from chainer.backends import cuda 6 | from chainer.utils import type_check, force_array 7 | import random 8 | 9 | 10 | class CPCLoss(function.Function): 11 | """ Contrastive loss from Contrastive Predictive Coding, 12 | arXiv:1905.09272 """ 13 | def __init__(self, upper=True, cpc_pattern='updown', num_neg=20, eps=1e-7): 14 | self.eps = eps 15 | # number of negative samples 16 | self.num_neg = num_neg 17 | self.upper = upper 18 | self.cpc_pattern = cpc_pattern 19 | 20 | def check_type_forward(self, in_types): 21 | type_check.expect(in_types.size() == 2) 22 | x_type, t_type = in_types 23 | 24 | type_check.expect( 25 | x_type.dtype.kind == 'f', 26 | t_type.dtype.kind == 'f', 27 | x_type.shape[0] == t_type.shape[0] 28 | ) 29 | 30 | @staticmethod 31 | def _check_input_values(x, t): 32 | if not (0 <= t): 33 | msg = ('Each label `t` need to satisfy ' 34 | '`0 <= t `') 35 | raise ValueError(msg) 36 | 37 | def forward(self, inputs): 38 | xp = cuda.get_array_module(*inputs) 39 | x, t = inputs 40 | b, c, gl0, gl1, gl2 = t.shape # gl: grid length 41 | cut_l = int(gl2/2) 42 | 43 | if chainer.is_debug(): 44 | self._check_input_values(x, t) 45 | 46 | if self.cpc_pattern == 'ichimatsu': 47 | correct_z = t[:, :, 2:gl0+2:4, 2:gl1+2:4, 1:gl2:2] 48 | neg_z = t[:, :, :, :, 0:gl2:2] 49 | else: 50 | if self.upper: 51 | correct_z = t[:, :, :, :, -cut_l:] 52 | neg_z = t[:, :, :, :, :-cut_l] 53 | else: 54 | correct_z = t[:, :, :, :, :cut_l] 55 | neg_z = t[:, :, :, :, cut_l:] 56 | 57 | if self.cpc_pattern == 'ichimatsu': 58 | xx = F.reshape(x, (c, -1)) 59 | correct_z = F.reshape(correct_z, (c, -1)) 60 | neg_z = F.reshape(neg_z, (c, -1)) 61 | # selecting num_neg negative samples to concat with correct samples of z 62 | selection = random.sample(set(np.arange(neg_z.shape[1])), self.num_neg) 63 | else: 64 | xx = F.reshape(x, (c, gl0*gl1*cut_l)) 65 | correct_z = F.reshape(correct_z, (c, gl0*gl1*cut_l)) 66 | neg_z = F.reshape(neg_z, (c, gl0*gl1*(gl2-cut_l))) 67 | # selecting num_neg negative samples to concat with correct samples of z 68 | selection = random.sample(set(np.arange(gl0*gl1*(gl2-cut_l))), self.num_neg) 69 | 70 | z_hat_T = F.transpose(xx) 71 | # selecting num_neg negative samples to concat with correct samples of z 72 | selection = random.sample(set(np.arange(gl0*gl1*(gl2-cut_l))), self.num_neg) 73 | neg_z = neg_z[:, selection] 74 | z = F.concat((correct_z, neg_z), axis=1) 75 | 76 | ip = xp.dot(z_hat_T.data, z.data) 77 | if self.cpc_pattern == 'ichimatsu': 78 | cpc = F.softmax_cross_entropy(ip, xp.arange(12), reduce='mean') 79 | else: 80 | cpc = F.softmax_cross_entropy(ip, xp.arange(gl0*gl1*cut_l), reduce='mean') 81 | return force_array(xp.mean(cpc.data), dtype=xp.float32), 82 | 83 | 84 | def cpc_loss(x, t, upper=True, cpc_pattern='updown', num_neg=20, eps=1e-7): 85 | return CPCLoss(upper, cpc_pattern, num_neg, eps)(x, t) 86 | -------------------------------------------------------------------------------- /src/functions/loss/dice_loss.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import chainer 3 | from chainer import function 4 | from chainer.backends import cuda 5 | from chainer.utils import type_check, force_array 6 | from chainer.functions.activation import softmax 7 | from src.utils.encode_one_hot_vector import encode_one_hot_vector 8 | 9 | 10 | class DiceLoss(function.Function): 11 | def __init__(self, eps=1e-7, weight=False, encode=True): 12 | self.eps = eps 13 | self.intersect = None 14 | self.union = None 15 | self.weighted_dice = weight 16 | self.encode = encode 17 | 18 | def check_type_forward(self, in_types): 19 | type_check.expect(in_types.size() == 2) 20 | x_type, t_type = in_types 21 | 22 | type_check.expect( 23 | x_type.dtype.kind == 'f', 24 | t_type.dtype == np.int32, 25 | 26 | x_type.shape[0] == t_type.shape[0], 27 | x_type.shape[-3:] == t_type.shape[-3:], 28 | ) 29 | 30 | @staticmethod 31 | def _check_input_values(x, t): 32 | if not (((0 <= t) & 33 | (t < x.shape[1]))).all(): 34 | msg = ('Each label `t` need to satisfy ' 35 | '`0 <= t < x.shape[1]`') 36 | raise ValueError(msg) 37 | 38 | def forward(self, inputs): 39 | xp = cuda.get_array_module(*inputs) 40 | x, t = inputs 41 | if chainer.is_debug(): 42 | self._check_input_values(x, t) 43 | if self.encode: 44 | t = encode_one_hot_vector(t, x.shape[1]) 45 | axis = (0,) + tuple(range(2, x.ndim)) 46 | self.intersect = xp.sum((x * t), axis=axis) 47 | self.union = xp.sum((x * x), axis=axis) + xp.sum((t * t), axis=axis) 48 | dice = (2. * self.intersect + self.eps) / (self.union + self.eps) 49 | if self.weighted_dice: 50 | cw = xp.array([ 51 | 1., 1., 1., 1., 0.5, 0.5, 0.8, 0.8, 0.5, 0.8, 0.8, 52 | 0.8, 0.5, 0.5, 0.7, 0.7, 0.7, 0.7, 0.6, 0.6, 1., 1., 1.], dtype='float32') 53 | dice = dice*cw*x.shape[1]/xp.sum(cw) 54 | return force_array(xp.mean(1. - dice), dtype=xp.float32), 55 | 56 | def backward(self, inputs, grad_outputs): 57 | xp = cuda.get_array_module(*inputs) 58 | x, t = inputs 59 | nb_class = x.shape[1] 60 | t = encode_one_hot_vector(t, nb_class) 61 | 62 | gx = xp.zeros_like(x) 63 | gloss = grad_outputs[0] 64 | cw = xp.array([ 65 | 1., 1., 1., 1., 0.5, 0.5, 0.8, 0.8, 0.5, 0.8, 0.8, 66 | 0.8, 0.5, 0.5, 0.7, 0.7, 0.7, 0.7, 0.6, 0.6, 1., 1., 1.], dtype='float32') 67 | for i, w in zip(range(nb_class), cw): 68 | x_i = x[:, i] 69 | t_i = t[:, i] 70 | intersect = self.intersect[i] 71 | union = self.union[i] 72 | 73 | numerator = xp.multiply(union + self.eps, t_i) - \ 74 | xp.multiply(2. * intersect + self.eps, x_i) 75 | denominator = xp.power(union + self.eps, 2) 76 | d_dice = 2 * xp.divide(numerator, denominator).astype(xp.float32) 77 | if self.weighted_dice: 78 | gx[:, i] = d_dice*w*nb_class/xp.sum(cw) 79 | else: 80 | gx[:, i] = d_dice 81 | 82 | gx *= gloss / nb_class 83 | return -gx.astype(xp.float32), None 84 | 85 | 86 | def dice_loss(x, t, eps=1e-7, weight=False, encode=True): 87 | return DiceLoss(eps, weight, encode)(x, t) 88 | 89 | 90 | def softmax_dice_loss(x, t, eps=1e-7, weight=False, encode=True): 91 | x1 = softmax.softmax(x, axis=1) 92 | return dice_loss(x1, t, eps, weight, encode) 93 | -------------------------------------------------------------------------------- /src/functions/loss/focal_loss.py: -------------------------------------------------------------------------------- 1 | import chainer.functions as F 2 | from chainer.backends.cuda import get_array_module 3 | 4 | 5 | def focal_loss(x, t, gamma=0.5, eps=1e-6): 6 | xp = get_array_module(t) 7 | 8 | p = F.softmax(x) 9 | p = F.clip(p, x_min=eps, x_max=1-eps) 10 | log_p = F.log_softmax(x) 11 | t_onehot = xp.eye(x.shape[1])[t.ravel()] 12 | 13 | loss_sce = -1 * t_onehot * log_p 14 | loss_focal = F.sum(loss_sce * (1. - p) ** gamma, axis=1) 15 | 16 | return F.mean(loss_focal) 17 | -------------------------------------------------------------------------------- /src/functions/loss/generalized_dice_loss.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import chainer 3 | from chainer import function 4 | from chainer.backends import cuda 5 | from chainer.utils import type_check, force_array 6 | from chainer.functions.activation import softmax 7 | from src.utils.encode_one_hot_vector import encode_one_hot_vector 8 | 9 | 10 | class GeneralizedDiceLoss(function.Function): 11 | def __init__(self, eps=1e-7): 12 | # avoid zero division error 13 | self.eps = eps 14 | self.w = None 15 | self.intersect = None 16 | self.union = None 17 | 18 | def check_type_forward(self, in_types): 19 | type_check.expect(in_types.size() == 2) 20 | x_type, t_type = in_types 21 | 22 | type_check.expect( 23 | x_type.dtype.kind == 'f', 24 | t_type.dtype == np.int32, 25 | t_type.ndim == x_type.ndim - 1, 26 | 27 | x_type.shape[0] == t_type.shape[0], 28 | x_type.shape[2:] == t_type.shape[1:], 29 | ) 30 | 31 | @staticmethod 32 | def _check_input_values(x, t): 33 | if not (((0 <= t) & 34 | (t < x.shape[1]))).all(): 35 | msg = ('Each label `t` need to satisfy ' 36 | '`0 <= t < x.shape[1]`') 37 | raise ValueError(msg) 38 | 39 | def forward(self, inputs): 40 | xp = cuda.get_array_module(*inputs) 41 | x, t = inputs 42 | if chainer.is_debug(): 43 | self._check_input_values(x, t) 44 | # one-hot encoding of ground truth 45 | t = encode_one_hot_vector(t, x.shape[1]) 46 | # compute weight, intersection, and union 47 | axis = (0,) + tuple(range(2, x.ndim)) 48 | sum_t = xp.sum(t, axis=axis) 49 | # avoid zero division error 50 | sum_t[sum_t == 0] = 1 51 | self.w = 1. / (sum_t*sum_t) 52 | self.intersect = xp.multiply(xp.sum((x * t), axis=axis), self.w) 53 | self.union = xp.multiply( 54 | xp.sum(x, axis=axis) + xp.sum(t, axis=axis), self.w) 55 | # compute dice loss 56 | dice = (2. * self.intersect + self.eps) / (self.union + self.eps) 57 | return force_array(1. - dice, dtype=xp.float32), 58 | 59 | def backward(self, inputs, grad_outputs): 60 | xp = cuda.get_array_module(*inputs) 61 | x, t = inputs 62 | nb_class = x.shape[1] 63 | t = encode_one_hot_vector(t, nb_class) 64 | 65 | gx = xp.zeros_like(x) 66 | gloss = grad_outputs[0] 67 | for i in range(nb_class): 68 | t_i = t[:, i] 69 | intersect = self.intersect 70 | union = self.union 71 | w = self.w[i] 72 | numerator = xp.multiply( 73 | union + self.eps, t_i) - intersect + self.eps 74 | denominator = xp.power(union + self.eps, 2) 75 | d_dice = 2 * xp.divide( 76 | w * numerator, denominator).astype(xp.float32) 77 | gx[:, i] = d_dice 78 | 79 | gx *= gloss 80 | return -gx.astype(xp.float32), None 81 | 82 | 83 | def generalized_dice_loss(x, t, eps=1e-7): 84 | return GeneralizedDiceLoss(eps)(x, t) 85 | 86 | 87 | def softmax_generalized_dice_loss(x, t, eps=1e-7): 88 | x1 = softmax.softmax(x, axis=1) 89 | return generalized_dice_loss(x1, t, eps) 90 | -------------------------------------------------------------------------------- /src/functions/loss/mixed_dice_loss.py: -------------------------------------------------------------------------------- 1 | import chainer.functions as F 2 | from src.functions.loss import softmax_dice_loss 3 | from src.functions.loss import focal_loss 4 | import cupy 5 | 6 | 7 | def struct_dice_loss_plus_cross_entropy(x, t, w=0.5): 8 | cw = cupy.array([ 9 | 1., 1., 1., 1., 0.5, 0.5, 0.8, 0.8, 0.5, 0.8, 0.8, 10 | 0.8, 0.5, 0.5, 0.7, 0.7, 0.7, 0.7, 0.6, 0.6, 1., 1., 1.], dtype='float32') 11 | return w * softmax_dice_loss(x, t) + (1 - w) * F.softmax_cross_entropy(x, t, class_weight=cw) 12 | 13 | 14 | def dice_loss_plus_cross_entropy(x, t, w=0.5, encode=True): 15 | return w * softmax_dice_loss(x, t, encode=encode) + (1 - w) * F.softmax_cross_entropy(x, t) 16 | 17 | 18 | def dice_loss_plus_focal_loss(x, t, w=0.5): 19 | return w * softmax_dice_loss(x, t) + (1 - w) * focal_loss(x, t) 20 | -------------------------------------------------------------------------------- /src/links/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pfnet-research/label-efficient-brain-tumor-segmentation/aad80ed7acb510a3147bb11c3910d2e17fb355d1/src/links/__init__.py -------------------------------------------------------------------------------- /src/links/model/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pfnet-research/label-efficient-brain-tumor-segmentation/aad80ed7acb510a3147bb11c3910d2e17fb355d1/src/links/model/__init__.py -------------------------------------------------------------------------------- /src/links/model/resize_images_3d.py: -------------------------------------------------------------------------------- 1 | import numpy 2 | from chainer.backends import cuda 3 | from chainer import function_node 4 | from chainer.utils import type_check 5 | 6 | 7 | class ResizeImages3D(function_node.FunctionNode): 8 | 9 | def __init__(self, output_shape): 10 | self.out_H = output_shape[0] 11 | self.out_W = output_shape[1] 12 | self.out_D = output_shape[2] 13 | 14 | def check_type_forward(self, in_types): 15 | n_in = in_types.size() 16 | type_check.expect(n_in == 1) 17 | 18 | x_type = in_types[0] 19 | type_check.expect( 20 | x_type.dtype.char == 'f', 21 | x_type.ndim == 5 22 | ) 23 | 24 | def forward(self, inputs): 25 | x, = inputs 26 | xp = cuda.get_array_module(x) 27 | 28 | B, C, H, W, D = x.shape 29 | 30 | u_1d = xp.linspace(0, W - 1, num=self.out_W) 31 | v_1d = xp.linspace(0, H - 1, num=self.out_H) 32 | t_1d = xp.linspace(0, D - 1, num=self.out_D) 33 | grid = xp.meshgrid(u_1d, v_1d, t_1d) 34 | u = grid[0].ravel() 35 | v = grid[1].ravel() 36 | t = grid[2].ravel() 37 | 38 | u0 = xp.floor(u).astype(numpy.int32) 39 | u0 = u0.clip(0, W - 2) 40 | u1 = u0 + 1 41 | v0 = xp.floor(v).astype(numpy.int32) 42 | v0 = v0.clip(0, H - 2) 43 | v1 = v0 + 1 44 | t0 = xp.floor(t).astype(numpy.int32) 45 | t0 = t0.clip(0, D - 2) 46 | t1 = t0 + 1 47 | 48 | # weights 49 | w1 = (u1 - u) * (v1 - v) * (t1 - t) 50 | w2 = (u - u0) * (v1 - v) * (t1 - t) 51 | w3 = (u1 - u) * (v - v0) * (t1 - t) 52 | w4 = (u - u0) * (v - v0) * (t1 - t) 53 | w5 = (u1 - u) * (v1 - v) * (t - t0) 54 | w6 = (u - u0) * (v1 - v) * (t - t0) 55 | w7 = (u1 - u) * (v - v0) * (t - t0) 56 | w8 = (u - u0) * (v - v0) * (t - t0) 57 | w1 = w1.astype(x.dtype) 58 | w2 = w2.astype(x.dtype) 59 | w3 = w3.astype(x.dtype) 60 | w4 = w4.astype(x.dtype) 61 | w5 = w5.astype(x.dtype) 62 | w6 = w6.astype(x.dtype) 63 | w7 = w7.astype(x.dtype) 64 | w8 = w8.astype(x.dtype) 65 | 66 | y = (w1[None, None, :] * x[:, :, v0, u0, t0] + 67 | w2[None, None, :] * x[:, :, v0, u1, t0] + 68 | w3[None, None, :] * x[:, :, v1, u0, t0] + 69 | w4[None, None, :] * x[:, :, v1, u1, t0] + 70 | w5[None, None, :] * x[:, :, v0, u0, t1] + 71 | w6[None, None, :] * x[:, :, v0, u1, t1] + 72 | w7[None, None, :] * x[:, :, v1, u0, t1] + 73 | w8[None, None, :] * x[:, :, v1, u1, t1]) 74 | y = y.reshape(B, C, self.out_H, self.out_W, self.out_D) 75 | return y, 76 | 77 | def backward(self, indexes, grad_outputs): 78 | return ResizeImagesGrad3D( 79 | self.inputs[0].shape, 80 | (self.out_H, self.out_W, self.out_D)).apply(grad_outputs) 81 | 82 | 83 | class ResizeImagesGrad3D(function_node.FunctionNode): 84 | 85 | def __init__(self, input_shape, output_shape): 86 | self.out_H = output_shape[0] 87 | self.out_W = output_shape[1] 88 | self.out_D = output_shape[2] 89 | self.input_shape = input_shape 90 | 91 | def check_type_forward(self, in_types): 92 | n_in = in_types.size() 93 | type_check.expect(n_in == 1) 94 | 95 | x_type = in_types[0] 96 | type_check.expect( 97 | x_type.dtype.char == 'f', 98 | x_type.ndim == 5 99 | ) 100 | 101 | def forward(self, inputs): 102 | xp = cuda.get_array_module(*inputs) 103 | gy, = inputs 104 | 105 | B, C, H, W, D = self.input_shape 106 | 107 | u_1d = xp.linspace(0, W - 1, num=self.out_W) 108 | v_1d = xp.linspace(0, H - 1, num=self.out_H) 109 | t_1d = xp.linspace(0, D - 1, num=self.out_D) 110 | grid = xp.meshgrid(u_1d, v_1d, t_1d) 111 | u = grid[0].ravel() 112 | v = grid[1].ravel() 113 | t = grid[2].ravel() 114 | 115 | u0 = xp.floor(u).astype(numpy.int32) 116 | u0 = u0.clip(0, W - 2) 117 | u1 = u0 + 1 118 | v0 = xp.floor(v).astype(numpy.int32) 119 | v0 = v0.clip(0, H - 2) 120 | v1 = v0 + 1 121 | t0 = xp.floor(t).astype(numpy.int32) 122 | t0 = t0.clip(0, D - 2) 123 | t1 = t0 + 1 124 | 125 | # weights 126 | wu0 = u - u0 127 | wu1 = u1 - u 128 | wv0 = v - v0 129 | wv1 = v1 - v 130 | wt0 = t - t0 131 | wt1 = t1 - t 132 | wu0 = wu0.astype(gy.dtype) 133 | wu1 = wu1.astype(gy.dtype) 134 | wv0 = wv0.astype(gy.dtype) 135 | wv1 = wv1.astype(gy.dtype) 136 | wt0 = wt0.astype(gy.dtype) 137 | wt1 = wt1.astype(gy.dtype) 138 | 139 | # --- gx 140 | if xp is numpy: 141 | scatter_add = numpy.add.at 142 | else: 143 | scatter_add = cuda.cupyx.scatter_add 144 | 145 | gx = xp.zeros(self.input_shape, dtype=gy.dtype) 146 | gy = gy.reshape(B, C, -1) 147 | scatter_add(gx, (slice(None), slice(None), v0, u0, t0), 148 | gy * wu1 * wv1 * wt1) 149 | scatter_add(gx, (slice(None), slice(None), v0, u1, t0), 150 | gy * wu0 * wv1 * wt1) 151 | scatter_add(gx, (slice(None), slice(None), v1, u0, t0), 152 | gy * wu1 * wv0 * wt1) 153 | scatter_add(gx, (slice(None), slice(None), v1, u1, t0), 154 | gy * wu0 * wv0 * wt1) 155 | scatter_add(gx, (slice(None), slice(None), v0, u0, t1), 156 | gy * wu1 * wv1 * wt0) 157 | scatter_add(gx, (slice(None), slice(None), v0, u1, t1), 158 | gy * wu0 * wv1 * wt0) 159 | scatter_add(gx, (slice(None), slice(None), v1, u0, t1), 160 | gy * wu1 * wv0 * wt0) 161 | scatter_add(gx, (slice(None), slice(None), v1, u1, t1), 162 | gy * wu0 * wv0 * wt0) 163 | return gx, 164 | 165 | def backward(self, indexes, grad_outputs): 166 | return ResizeImages3D( 167 | (self.out_H, self.out_W, self.out_D)).apply(grad_outputs) 168 | 169 | 170 | def resize_images_3d(x, output_shape): 171 | """Resize images to the given shape. 172 | This function resizes 3D data to :obj:`output_shape`. 173 | Currently, only bilinear interpolation is supported as the sampling method. 174 | Notatition: here is a notation for dimensionalities. 175 | - :math:`n` is the batch size. 176 | - :math:`c_I` is the number of the input channels. 177 | - :math:`h`, :math:`w` and :math:`d` are the height, width and depth of the 178 | input image, respectively. 179 | - :math:`h_O`, :math:`w_O` and :math:`d_0` are the height, width and depth 180 | of the output image. 181 | Args: 182 | x (~chainer.Variable): 183 | Input variable of shape :math:`(n, c_I, h, w, d)`. 184 | output_shape (tuple): 185 | This is a tuple of length 3 whose values are :obj:`(h_O, w_O, d_O)`. 186 | Returns: 187 | ~chainer.Variable: Resized image whose shape is \ 188 | :math:`(n, c_I, h_O, w_O, d_O)`. 189 | """ 190 | return ResizeImages3D(output_shape).apply((x,))[0] 191 | -------------------------------------------------------------------------------- /src/links/model/vaeseg.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import chainer 3 | import chainer.links as L 4 | import chainer.functions as F 5 | from chainer.links import GroupNormalization 6 | from src.functions.array import resize_images_3d 7 | from chainer.utils.conv_nd import im2col_nd 8 | 9 | 10 | class CNR(chainer.Chain): 11 | """Convolution, Normalize, then ReLU""" 12 | 13 | def __init__( 14 | self, 15 | channels, 16 | norm=GroupNormalization, 17 | down_sampling=False, 18 | comm=None 19 | ): 20 | super(CNR, self).__init__() 21 | with self.init_scope(): 22 | if down_sampling: 23 | self.c = L.Convolution3D(None, channels, 3, 2, 1) 24 | else: 25 | self.c = L.Convolution3D(None, channels, 3, 1, 1) 26 | if norm.__name__ == 'MultiNodeBatchNormalization': 27 | self.n = norm(channels, comm, eps=1e-5) 28 | elif norm.__name__ == 'BatchNormalization': 29 | self.n = norm(channels, eps=1e-5) 30 | elif norm.__name__ == 'GroupNormalization': 31 | self.n = norm(groups=8, size=channels) 32 | else: 33 | self.n = norm(channels) 34 | 35 | def forward(self, x): 36 | h = F.relu(self.n(self.c(x))) 37 | return h 38 | 39 | 40 | class NRC(chainer.Chain): 41 | """Normalize, ReLU, then Convolution""" 42 | def __init__( 43 | self, 44 | in_channels, 45 | out_channels, 46 | norm=GroupNormalization, 47 | down_sampling=False, 48 | comm=None 49 | ): 50 | super(NRC, self).__init__() 51 | with self.init_scope(): 52 | if norm.__name__ == 'MultiNodeBatchNormalization': 53 | self.n = norm(in_channels, comm, eps=1e-5) 54 | elif norm.__name__ == 'BatchNormalization': 55 | self.n = norm(in_channels, eps=1e-5) 56 | elif norm.__name__ == 'GroupNormalization': 57 | self.n = norm(groups=8, size=in_channels) 58 | else: 59 | self.n = norm(in_channels) 60 | if down_sampling: 61 | self.c = L.Convolution3D(None, out_channels, 3, 2, 1) 62 | else: 63 | self.c = L.Convolution3D(None, out_channels, 3, 1, 1) 64 | 65 | def forward(self, x): 66 | h = self.c(F.relu(self.n(x))) 67 | return h 68 | 69 | 70 | class ResBlock(chainer.Chain): 71 | 72 | def __init__( 73 | self, 74 | channels, 75 | norm=GroupNormalization, 76 | bn_first=True, 77 | comm=None, 78 | concat_mode=False 79 | ): 80 | super(ResBlock, self).__init__() 81 | with self.init_scope(): 82 | if bn_first: 83 | if concat_mode: 84 | channels *= 2 85 | self.block1 = NRC(channels, channels, norm, False, comm) 86 | self.block2 = NRC(channels, channels, norm, False, comm) 87 | else: 88 | self.block1 = CNR(channels, norm, False, comm) 89 | self.block2 = CNR(channels, norm, False, comm) 90 | 91 | def forward(self, x): 92 | h = self.block1(x) 93 | h = self.block2(h) 94 | return h + x 95 | 96 | 97 | class DownBlock(chainer.Chain): 98 | """down sample (conv stride2), then ResBlock * num of layers""" 99 | 100 | def __init__( 101 | self, 102 | channels, 103 | norm=GroupNormalization, 104 | down_sample=True, 105 | n_blocks=1, 106 | bn_first=True, 107 | comm=None 108 | ): 109 | self.down_sample = down_sample 110 | self.n_blocks = n_blocks 111 | super(DownBlock, self).__init__() 112 | with self.init_scope(): 113 | if down_sample: 114 | self.d = L.Convolution3D(None, channels, 3, 2, 1) 115 | else: 116 | self.d = L.Convolution3D(None, channels, 3, 1, 1) 117 | for i in range(n_blocks): 118 | layer = ResBlock(channels, norm, bn_first, comm) 119 | setattr(self, 'block{}'.format(i), layer) 120 | 121 | def forward(self, x): 122 | h = self.d(x) 123 | for i in range(self.n_blocks): 124 | h = getattr(self, 'block{}'.format(i))(h) 125 | return h 126 | 127 | 128 | class VD(chainer.Chain): 129 | """VAE (down to 256dim.) see Table1""" 130 | def __init__( 131 | self, 132 | channels, 133 | norm=GroupNormalization, 134 | bn_first=True, 135 | ndim_latent=128, 136 | comm=None 137 | ): 138 | super(VD, self).__init__() 139 | with self.init_scope(): 140 | if bn_first: 141 | self.b = NRC(channels, 16, norm, True, comm) 142 | else: 143 | self.b = CNR(16, norm, True, comm) 144 | self.d_mu = L.Linear(None, ndim_latent) 145 | self.d_ln_var = L.Linear(None, ndim_latent) 146 | 147 | def forward(self, x): 148 | h = self.b(x) 149 | mu = self.d_mu(h) 150 | ln_var = self.d_ln_var(h) 151 | return mu, ln_var 152 | 153 | 154 | class Encoder(chainer.Chain): 155 | 156 | def __init__( 157 | self, 158 | base_channels=32, 159 | norm=GroupNormalization, 160 | bn_first=True, 161 | ndim_latent=128, 162 | comm=None 163 | ): 164 | super(Encoder, self).__init__() 165 | with self.init_scope(): 166 | self.enc_initconv = L.Convolution3D(None, base_channels, 3, 1, 1) 167 | self.enc_block0 = DownBlock(base_channels, norm, False, 1, bn_first, comm) 168 | self.enc_block1 = DownBlock(2*base_channels, norm, True, 2, bn_first, comm) 169 | self.enc_block2 = DownBlock(4*base_channels, norm, True, 2, bn_first, comm) 170 | self.enc_block3 = DownBlock(8*base_channels, norm, True, 4, bn_first, comm) 171 | 172 | def forward(self, x): 173 | h = F.dropout(self.enc_initconv(x), ratio=0.2) 174 | hs = [] 175 | for i in range(4): 176 | h = getattr(self, 'enc_block{}'.format(i))(h) 177 | hs.append(h) 178 | return hs 179 | 180 | 181 | class UpBlock(chainer.Chain): 182 | 183 | def __init__( 184 | self, 185 | channels, 186 | norm=GroupNormalization, 187 | bn_first=True, 188 | mode='sum', 189 | comm=None 190 | ): 191 | self.mode = mode 192 | concat_mode = True if mode == 'concat' else False 193 | super(UpBlock, self).__init__() 194 | with self.init_scope(): 195 | self.c1 = L.Convolution3D(None, channels, 1, 1, 0) 196 | self.rb = ResBlock(channels, norm, bn_first, comm, concat_mode) 197 | 198 | def forward(self, xd, xu=None): 199 | xd_shape = xd.shape[2:] 200 | out_shape = tuple(i * 2 for i in xd_shape) 201 | h = resize_images_3d(self.c1(xd), output_shape=out_shape) 202 | if xu is not None: 203 | if self.mode == 'sum': 204 | h += xu 205 | elif self.mode == 'concat': 206 | h = F.concat((h, xu), axis=1) 207 | h = self.rb(h) 208 | return h 209 | 210 | 211 | class Decoder(chainer.Chain): 212 | 213 | def __init__( 214 | self, 215 | base_channels, 216 | out_channels, 217 | norm=GroupNormalization, 218 | bn_first=True, 219 | mode='sum', 220 | comm=None 221 | ): 222 | super(Decoder, self).__init__() 223 | with self.init_scope(): 224 | for i in range(3): 225 | layer = UpBlock(2**i * base_channels, norm, bn_first, mode, comm) 226 | setattr(self, 'dec_block{}'.format(i), layer) 227 | self.dec_end = L.Convolution3D(None, out_channels, 1, 1, 0) 228 | 229 | def forward(self, hs, bs=None): 230 | # bs: output from boundary stream 231 | y = hs[-1] 232 | for i in reversed(range(3)): 233 | h = hs[i] 234 | y = getattr(self, 'dec_block{}'.format(i))(y, h) 235 | if bs is not None: 236 | y = F.concat((y, bs), axis=1) 237 | y = self.dec_end(y) 238 | return y 239 | 240 | 241 | class VU(chainer.Chain): 242 | 243 | def __init__( 244 | self, 245 | input_shape 246 | ): 247 | self.bottom_shape = tuple(i // 16 for i in input_shape) 248 | bottom_size = 16 * np.prod(list(self.bottom_shape)) 249 | self.output_shape = tuple(i // 8 for i in input_shape) 250 | super(VU, self).__init__() 251 | with self.init_scope(): 252 | self.dense = L.Linear(None, bottom_size) 253 | self.conv1 = L.Convolution3D(None, 256, 1, 1, 0) 254 | 255 | def forward(self, x): 256 | x_shape = x.shape 257 | h = F.relu(self.dense(x)) 258 | h = F.reshape(h, (x_shape[0], 16) + self.bottom_shape) 259 | h = resize_images_3d(self.conv1(h), output_shape=self.output_shape) 260 | return h 261 | 262 | 263 | class VAE(chainer.Chain): 264 | 265 | def __init__( 266 | self, 267 | in_channels, 268 | base_channels, 269 | norm=GroupNormalization, 270 | bn_first=True, 271 | input_shape=(160, 192, 128), 272 | comm=None 273 | ): 274 | super(VAE, self).__init__() 275 | with self.init_scope(): 276 | self.vu = VU(input_shape) 277 | for i in range(3): 278 | layer = UpBlock(2**i * base_channels, norm, bn_first, None, comm) 279 | setattr(self, 'vae_block{}'.format(i), layer) 280 | self.vae_end = L.Convolution3D(None, in_channels, 1, 1, 0) 281 | 282 | def forward(self, x): 283 | h = self.vu(x) 284 | for i in reversed(range(3)): 285 | h = getattr(self, 'vae_block{}'.format(i))(h) 286 | h = self.vae_end(h) 287 | return h 288 | 289 | 290 | def divide_img(x, grid_size=32): 291 | b, c, x1, x2, x3 = x.shape 292 | kersize = grid_size # kernel size 293 | ssize = int(grid_size*0.5) # stride size 294 | gl0 = int(x1/ssize-1) # grid length 295 | gl1 = int(x2/ssize-1) 296 | gl2 = int(x3/ssize-1) 297 | 298 | if type(x) == chainer.variable.Variable: 299 | h = im2col_nd( 300 | x.data, ksize=(kersize, kersize, kersize), 301 | stride=(ssize, ssize, ssize), pad=(0, 0, 0)) 302 | else: 303 | h = im2col_nd( 304 | x, ksize=(kersize, kersize, kersize), 305 | stride=(ssize, ssize, ssize), pad=(0, 0, 0)) 306 | h = F.reshape(h, (1, 4, kersize, kersize, kersize, gl0*gl1*gl2)) 307 | h = F.transpose(h, axes=(5, 1, 2, 3, 4, 0)) 308 | h = F.reshape(h, (gl0*gl1*gl2, 4, kersize, kersize, kersize)) 309 | return h 310 | 311 | 312 | class CPCPredictor(chainer.Chain): 313 | 314 | def __init__( 315 | self, 316 | base_channels=256, 317 | norm=GroupNormalization, 318 | bn_first=True, 319 | grid_size=32, 320 | input_shape=(160, 160, 128), 321 | upper=True, # whether to predict the upper half in the CPC task 322 | cpc_pattern='updown', 323 | comm=None 324 | ): 325 | x1, x2, x3 = input_shape 326 | ssize = int(grid_size*0.5) 327 | self.gl0 = int(x1/ssize-1) 328 | self.gl1 = int(x2/ssize-1) 329 | self.gl2 = int(x3/ssize-1) 330 | self.cut_l = int(self.gl2/2) 331 | self.base_channels = base_channels 332 | self.upper = upper 333 | self.cpc_pattern = cpc_pattern 334 | 335 | super(CPCPredictor, self).__init__() 336 | with self.init_scope(): 337 | for i in range(8): 338 | layer = ResBlock(base_channels, norm, bn_first, comm) 339 | setattr(self, 'pred_block{}'.format(i), layer) 340 | self.pred1 = L.Convolution3D( 341 | None, base_channels, 342 | ksize=(1, 1, 1), stride=1, pad=0) 343 | 344 | def forward(self, x): 345 | h = F.transpose(x, axes=(1, 0)) 346 | h = F.reshape(h, (1, self.base_channels, self.gl0, self.gl1, self.gl2)) 347 | if self.cpc_pattern == 'ichimatsu': 348 | hs = h[:, :, 0:self.gl0:4, 0:self.gl1:4, 0:6:2] 349 | else: 350 | if self.upper: 351 | hs = h[:, :, :, :, :self.cut_l] 352 | else: 353 | hs = h[:, :, :, :, -self.cut_l:] 354 | for i in range(8): 355 | hs = getattr(self, 'pred_block{}'.format(i))(hs) 356 | hs = self.pred1(hs) 357 | return hs 358 | 359 | 360 | class Attention(chainer.Chain): 361 | """concatenate inputs from the attention boundary stream, 362 | and from the Encoder. Then apply a 1*1 conv and a sigmoid activation.""" 363 | def __init__( 364 | self, 365 | comm=None 366 | ): 367 | super(Attention, self).__init__() 368 | with self.init_scope(): 369 | self.c = L.Convolution3D(None, out_channels=1, ksize=1, stride=1, pad=0) 370 | 371 | def forward(self, s, m): 372 | alpha = F.concat((s, m), axis=1) 373 | alpha = F.sigmoid(self.c(alpha)) 374 | o = s*alpha 375 | return o 376 | 377 | 378 | class BoundaryStream(chainer.Chain): 379 | """Combination of upblock and attention for the boundary stream""" 380 | def __init__( 381 | self, 382 | base_channels, 383 | out_channels, 384 | norm=GroupNormalization, 385 | bn_first=True, 386 | mode='sum', 387 | comm=None 388 | ): 389 | super(BoundaryStream, self).__init__() 390 | with self.init_scope(): 391 | self.c = L.Convolution3D(None, 2**2*base_channels, 1, 1, 0) 392 | self.rb = ResBlock(2**2*base_channels, norm, bn_first, comm) 393 | self.att = Attention(comm) 394 | for i in range(3): 395 | layer = UpBlock(2**i * base_channels, norm, bn_first, mode, comm) 396 | setattr(self, 'dec_block{}'.format(i), layer) 397 | att = Attention(comm) 398 | setattr(self, 'att_block{}'.format(i), att) 399 | self.dec_end = L.Convolution3D(None, out_channels, ksize=1, stride=1, pad=0) 400 | 401 | def forward(self, hs): 402 | bs = hs[-1] 403 | bs = self.att(self.rb(self.c(bs)), hs[-1]) 404 | for i in reversed(range(3)): 405 | h = hs[i] 406 | bs = getattr(self, 'dec_block{}'.format(i))(bs) 407 | bs = getattr(self, 'att_block{}'.format(i))(bs, h) 408 | y = self.dec_end(bs) 409 | return y, bs 410 | -------------------------------------------------------------------------------- /src/training/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pfnet-research/label-efficient-brain-tumor-segmentation/aad80ed7acb510a3147bb11c3910d2e17fb355d1/src/training/__init__.py -------------------------------------------------------------------------------- /src/training/extensions/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pfnet-research/label-efficient-brain-tumor-segmentation/aad80ed7acb510a3147bb11c3910d2e17fb355d1/src/training/extensions/__init__.py -------------------------------------------------------------------------------- /src/training/extensions/boundseg_evaluator.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import cupy 3 | import math 4 | 5 | import chainer 6 | import chainer.functions as F 7 | from chainer.backends import cuda 8 | from chainer.training.extensions import Evaluator 9 | from chainer.dataset import convert 10 | from chainer import Reporter 11 | from chainer import reporter as reporter_module 12 | from src.functions.evaluation import dice_coefficient, mean_dice_coefficient 13 | from src.functions.loss.mixed_dice_loss import dice_loss_plus_cross_entropy 14 | 15 | 16 | class BoundSegEvaluator(Evaluator): 17 | 18 | def __init__( 19 | self, 20 | config, 21 | iterator, 22 | target, 23 | converter=convert.concat_examples, 24 | device=None 25 | ): 26 | super().__init__(iterator, target, 27 | converter, device, 28 | None, None) 29 | self.nested_label = config['nested_label'] 30 | self.seg_lossfun = eval(config['seg_lossfun']) 31 | self.dataset = config['dataset_name'] 32 | self.nb_labels = config['nb_labels'] 33 | self.crop_size = eval(config['crop_size']) 34 | self.is_brats = config['is_brats'] 35 | 36 | def compute_loss(self, y, t): 37 | if self.nested_label: 38 | loss = 0. 39 | b, c, h, w, d = t.shape 40 | for i in range(c): 41 | loss += self.seg_lossfun(y[:, 2*i:2*(i+1), ...], t[:, i, ...]) 42 | else: 43 | loss = self.seg_lossfun(y, t) 44 | return loss 45 | 46 | def compute_accuracy(self, y, t): 47 | if self.nested_label: 48 | b, c, h, w, d = t.shape 49 | y = F.reshape(y, (b, 2, h*c, w, d)) 50 | t = F.reshape(t, (b, h*c, w, d)) 51 | return F.accuracy(y, t) 52 | 53 | def compute_dice_coef(self, y, t): 54 | if self.nested_label: 55 | dice = mean_dice_coefficient(dice_coefficient(y[:, 0:2, ...], t[:, 0, ...])) 56 | for i in range(1, t.shape[1]): 57 | dices = dice_coefficient(y[:, 2*i:2*(i+1), ...], t[:, i, ...]) 58 | dice = F.concat((dice, mean_dice_coefficient(dices)), axis=0) 59 | else: 60 | dice = dice_coefficient(y, t, is_brats=self.is_brats) 61 | return dice 62 | 63 | def evaluate(self): 64 | summary = reporter_module.DictSummary() 65 | iterator = self._iterators['main'] 66 | enc = self._targets['enc'] 67 | dec = self._targets['dec'] 68 | bound = self._targets['bound'] 69 | reporter = Reporter() 70 | observer = object() 71 | reporter.add_observer(self.default_name, observer) 72 | 73 | if hasattr(iterator, 'reset'): 74 | iterator.reset() 75 | it = iterator 76 | else: 77 | it = copy.copy(iterator) 78 | 79 | for batch in it: 80 | x, t, te = self.converter(batch, self.device) 81 | with chainer.no_backprop_mode(), chainer.using_config('train', False): 82 | if self.dataset == 'msd_bound': 83 | h, w, d = x.shape[2:] 84 | hc, wc, dc = self.crop_size 85 | if self.nested_label: 86 | y = cupy.zeros((1, 2*(self.nb_labels-1), h, w, d), dtype='float32') 87 | else: 88 | y = cupy.zeros((1, self.nb_labels, h, w, d), dtype='float32') 89 | s = 128 # stride 90 | ker = 256 # kernel size 91 | dker = dc # kernel size for depth 92 | ds = dker*0.5 # stride for depth 93 | dsteps = int(math.floor((d-dker)/ds) + 1) 94 | steps = round((h - ker)/s + 1) 95 | for i in range(steps): 96 | for j in range(steps): 97 | for k in range(dsteps): 98 | xx = x[:, :, s*i:ker+s*i, s*j:ker+s*j, ds*k:dker+ds*k] 99 | hhs = enc(xx) 100 | yye, bbs = bound(hhs) 101 | yy = dec(hhs, bbs) 102 | y[:, :, s*i:ker+s*i, s*j:ker+s*j, ds*k:dker+ds*k] += yy.data 103 | # for the bottom depth part of the image 104 | xx = x[:, :, s*i:ker+s*i, s*j:ker+s*j, -dker:] 105 | hhs = enc(xx) 106 | yye, bbs = bound(hhs) 107 | yy = dec(hhs, bbs) 108 | y[:, :, s*i:ker+s*i, s*j:ker+s*j, -dker:] += yy.data 109 | else: 110 | hs = enc(x) 111 | ye, bs = bound(hs) 112 | y = dec(hs, bs) 113 | seg_loss = self.compute_loss(y, t) 114 | accuracy = self.compute_accuracy(y, t) 115 | dice = self.compute_dice_coef(y, t) 116 | mean_dice = mean_dice_coefficient(dice) 117 | 118 | weighted_loss = seg_loss 119 | 120 | observation = {} 121 | with reporter.scope(observation): 122 | reporter.report({ 123 | 'loss/seg': seg_loss, 124 | 'loss/total': weighted_loss, 125 | 'acc': accuracy, 126 | 'mean_dc': mean_dice 127 | }, observer) 128 | xp = cuda.get_array_module(y) 129 | for i in range(len(dice)): 130 | if not xp.isnan(dice.data[i]): 131 | reporter.report({'dc_{}'.format(i): dice[i]}, observer) 132 | summary.add(observation) 133 | return summary.compute_mean() 134 | -------------------------------------------------------------------------------- /src/training/extensions/cpcseg_evaluator.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import cupy 3 | import chainer 4 | import chainer.functions as F 5 | from chainer.backends import cuda 6 | from chainer.training.extensions import Evaluator 7 | from chainer.dataset import convert 8 | from chainer import Reporter 9 | from chainer import reporter as reporter_module 10 | from src.functions.evaluation import dice_coefficient, mean_dice_coefficient 11 | from src.functions.loss.mixed_dice_loss import dice_loss_plus_cross_entropy 12 | 13 | 14 | class CPCSegEvaluator(Evaluator): 15 | 16 | def __init__( 17 | self, 18 | config, 19 | iterator, 20 | target, 21 | converter=convert.concat_examples, 22 | device=None, 23 | ): 24 | super().__init__(iterator, target, 25 | converter, device, 26 | None, None) 27 | self.nested_label = config['nested_label'] 28 | self.seg_lossfun = eval(config['seg_lossfun']) 29 | self.rec_loss_weight = config['vaeseg_rec_loss_weight'] 30 | self.kl_loss_weight = config['vaeseg_kl_loss_weight'] 31 | self.grid_size = config['grid_size'] 32 | self.base_channels = config['base_channels'] 33 | self.cpc_loss_weight = config['cpc_vaeseg_cpc_loss_weight'] 34 | self.cpc_pattern = config['cpc_pattern'] 35 | self.is_brats = config['is_brats'] 36 | self.dataset = config['dataset_name'] 37 | self.nb_labels = config['nb_labels'] 38 | self.crop_size = eval(config['crop_size']) 39 | 40 | def compute_loss(self, y, t): 41 | if self.nested_label: 42 | loss = 0. 43 | b, c, h, w, d = t.shape 44 | for i in range(c): 45 | loss += self.seg_lossfun(y[:, 2*i:2*(i+1), ...], t[:, i, ...]) 46 | else: 47 | loss = self.seg_lossfun(y, t) 48 | return loss 49 | 50 | def compute_accuracy(self, y, t): 51 | if self.nested_label: 52 | b, c, h, w, d = t.shape 53 | y = F.reshape(y, (b, 2, h*c, w, d)) 54 | t = F.reshape(t, (b, h*c, w, d)) 55 | return F.accuracy(y, t) 56 | 57 | def compute_dice_coef(self, y, t): 58 | if self.nested_label: 59 | dice = mean_dice_coefficient(dice_coefficient(y[:, 0:2, ...], t[:, 0, ...])) 60 | for i in range(1, t.shape[1]): 61 | dices = dice_coefficient(y[:, 2*i:2*(i+1), ...], t[:, i, ...]) 62 | dice = F.concat((dice, mean_dice_coefficient(dices)), axis=0) 63 | else: 64 | dice = dice_coefficient(y, t, is_brats=self.is_brats) 65 | return dice 66 | 67 | def evaluate(self): 68 | summary = reporter_module.DictSummary() 69 | iterator = self._iterators['main'] 70 | enc = self._targets['enc'] 71 | dec = self._targets['dec'] 72 | reporter = Reporter() 73 | observer = object() 74 | reporter.add_observer(self.default_name, observer) 75 | 76 | if hasattr(iterator, 'reset'): 77 | iterator.reset() 78 | it = iterator 79 | else: 80 | it = copy.copy(iterator) 81 | 82 | for batch in it: 83 | x, t = self.converter(batch, self.device) 84 | with chainer.no_backprop_mode(), chainer.using_config('train', False): 85 | if self.dataset == 'msd_bound': 86 | # evaluation method for BRATS dataset only 87 | h, w, d = x.shape[2:] 88 | hc, wc, dc = self.crop_size 89 | if self.nested_label: 90 | y = cupy.zeros((1, 2*(self.nb_labels-1), h, w, d), dtype='float32') 91 | else: 92 | y = cupy.zeros((1, self.nb_labels, h, w, d), dtype='float32') 93 | hker = hc # kernel size 94 | hs = int(0.5*hker) # stride 95 | wker = wc 96 | wc = int(0.5*wker) 97 | dker = dc # kernel size for depth 98 | for i in range(2): 99 | for j in range(2): 100 | for k in range(2): 101 | xx = x[:, :, -i*hker:min(hker*(i+1), h), 102 | -j*wker:min(wker*(j+1), w), -k*dker:min(dker*(k+1), d)] 103 | hs = enc(xx) 104 | yy = dec(hs) 105 | y[:, :, -i*hker:min(hker*(i+1), h), 106 | -j*wker:min(wker*(j+1), w), 107 | -k*dker:min(dker*(k+1), d)] += yy.data 108 | 109 | else: 110 | hs = enc(x) 111 | y = dec(hs) 112 | seg_loss = self.compute_loss(y, t) 113 | accuracy = self.compute_accuracy(y, t) 114 | dice = self.compute_dice_coef(y, t) 115 | mean_dice = mean_dice_coefficient(dice) 116 | 117 | observation = {} 118 | with reporter.scope(observation): 119 | reporter.report({ 120 | 'loss/seg': seg_loss, 121 | 'acc': accuracy, 122 | 'mean_dc': mean_dice 123 | }, observer) 124 | xp = cuda.get_array_module(y) 125 | for i in range(len(dice)): 126 | if not xp.isnan(dice.data[i]): 127 | reporter.report({'dc_{}'.format(i): dice[i]}, observer) 128 | summary.add(observation) 129 | return summary.compute_mean() 130 | -------------------------------------------------------------------------------- /src/training/extensions/encdec_seg_evaluator.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import cupy 3 | import math 4 | import chainer 5 | import chainer.functions as F 6 | from chainer.backends import cuda 7 | from chainer.training.extensions import Evaluator 8 | from chainer.dataset import convert 9 | from chainer import Reporter 10 | from chainer import reporter as reporter_module 11 | from src.functions.evaluation import dice_coefficient, mean_dice_coefficient 12 | from src.functions.loss.mixed_dice_loss import dice_loss_plus_cross_entropy 13 | 14 | 15 | class EncDecSegEvaluator(Evaluator): 16 | 17 | def __init__( 18 | self, 19 | config, 20 | iterator, 21 | target, 22 | converter=convert.concat_examples, 23 | device=None, 24 | ): 25 | super().__init__(iterator, target, 26 | converter, device, 27 | None, None) 28 | self.nested_label = config['nested_label'] 29 | self.seg_lossfun = eval(config['seg_lossfun']) 30 | self.dataset = config['dataset_name'] 31 | self.nb_labels = config['nb_labels'] 32 | self.crop_size = eval(config['crop_size']) 33 | self.is_brats = config['is_brats'] 34 | 35 | def compute_loss(self, y, t): 36 | if self.nested_label: 37 | loss = 0. 38 | b, c, h, w, d = t.shape 39 | for i in range(c): 40 | loss += self.seg_lossfun(y[:, 2*i:2*(i+1), ...], t[:, i, ...]) 41 | else: 42 | loss = self.seg_lossfun(y, t) 43 | return loss 44 | 45 | def compute_accuracy(self, y, t): 46 | if self.nested_label: 47 | b, c, h, w, d = t.shape 48 | y = F.reshape(y, (b, 2, h*c, w, d)) 49 | t = F.reshape(t, (b, h*c, w, d)) 50 | return F.accuracy(y, t) 51 | 52 | def compute_dice_coef(self, y, t): 53 | if self.nested_label: 54 | dice = mean_dice_coefficient(dice_coefficient(y[:, 0:2, ...], t[:, 0, ...])) 55 | for i in range(1, t.shape[1]): 56 | dices = dice_coefficient(y[:, 2*i:2*(i+1), ...], t[:, i, ...]) 57 | dice = F.concat((dice, mean_dice_coefficient(dices)), axis=0) 58 | else: 59 | dice = dice_coefficient(y, t, is_brats=self.is_brats) 60 | return dice 61 | 62 | def evaluate(self): 63 | summary = reporter_module.DictSummary() 64 | iterator = self._iterators['main'] 65 | enc = self._targets['enc'] 66 | dec = self._targets['dec'] 67 | reporter = Reporter() 68 | observer = object() 69 | reporter.add_observer(self.default_name, observer) 70 | 71 | if hasattr(iterator, 'reset'): 72 | iterator.reset() 73 | it = iterator 74 | else: 75 | it = copy.copy(iterator) 76 | 77 | for batch in it: 78 | x, t = self.converter(batch, self.device) 79 | with chainer.no_backprop_mode(), chainer.using_config('train', False): 80 | if self.dataset == 'msd_bound': 81 | h, w, d = x.shape[2:] 82 | hc, wc, dc = self.crop_size 83 | if self.nested_label: 84 | y = cupy.zeros((1, 2*(self.nb_labels-1), h, w, d), dtype='float32') 85 | else: 86 | y = cupy.zeros((1, self.nb_labels, h, w, d), dtype='float32') 87 | s = 128 # stride 88 | ker = 256 # kernel size 89 | dker = dc # kernel size for depth 90 | ds = dker*0.5 # stride for depth 91 | dsteps = int(math.floor((d-dker)/ds) + 1) 92 | steps = round((h - ker)/s + 1) 93 | for i in range(steps): 94 | for j in range(steps): 95 | for k in range(dsteps): 96 | xx = x[:, :, s*i:ker+s*i, s*j:ker+s*j, ds*k:dker+ds*k] 97 | hs = enc(xx) 98 | yy = dec(hs) 99 | y[:, :, s*i:ker+s*i, s*j:ker+s*j, ds*k:dker+ds*k] += yy.data 100 | # for the bottom depth part of the image 101 | xx = x[:, :, s*i:ker+s*i, s*j:ker+s*j, -dker:] 102 | hs = enc(xx) 103 | yy = dec(hs) 104 | y[:, :, s*i:ker+s*i, s*j:ker+s*j, -dker:] += yy.data 105 | else: 106 | hs = enc(x) 107 | y = dec(hs) 108 | seg_loss = self.compute_loss(y, t) 109 | accuracy = self.compute_accuracy(y, t) 110 | dice = self.compute_dice_coef(y, t) 111 | mean_dice = mean_dice_coefficient(dice) 112 | weighted_loss = seg_loss 113 | 114 | observation = {} 115 | with reporter.scope(observation): 116 | reporter.report({ 117 | 'loss/seg': seg_loss, 118 | 'loss/total': weighted_loss, 119 | 'acc': accuracy, 120 | 'mean_dc': mean_dice 121 | }, observer) 122 | xp = cuda.get_array_module(y) 123 | for i in range(len(dice)): 124 | if not xp.isnan(dice.data[i]): 125 | reporter.report({'dc_{}'.format(i): dice[i]}, observer) 126 | summary.add(observation) 127 | return summary.compute_mean() 128 | -------------------------------------------------------------------------------- /src/training/extensions/vaeseg_evaluator.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import cupy 3 | import math 4 | import chainer 5 | import chainer.functions as F 6 | from chainer.backends import cuda 7 | from chainer.training.extensions import Evaluator 8 | from chainer.dataset import convert 9 | from chainer import Reporter 10 | from chainer import reporter as reporter_module 11 | from src.functions.evaluation import dice_coefficient, mean_dice_coefficient 12 | from src.functions.loss.mixed_dice_loss import dice_loss_plus_cross_entropy 13 | 14 | 15 | class VAESegEvaluator(Evaluator): 16 | 17 | def __init__( 18 | self, 19 | config, 20 | iterator, 21 | target, 22 | converter=convert.concat_examples, 23 | device=None, 24 | ): 25 | super().__init__(iterator, target, 26 | converter, device, 27 | None, None) 28 | self.nested_label = config['nested_label'] 29 | self.seg_lossfun = eval(config['seg_lossfun']) 30 | self.rec_loss_weight = config['vaeseg_rec_loss_weight'] 31 | self.kl_loss_weight = config['vaeseg_kl_loss_weight'] 32 | self.dataset = config['dataset_name'] 33 | self.nb_labels = config['nb_labels'] 34 | self.crop_size = eval(config['crop_size']) 35 | self.is_brats = config['is_brats'] 36 | 37 | def compute_loss(self, y, t): 38 | if self.nested_label: 39 | loss = 0. 40 | b, c, h, w, d = t.shape 41 | for i in range(c): 42 | loss += self.seg_lossfun(y[:, 2*i:2*(i+1), ...], t[:, i, ...]) 43 | else: 44 | loss = self.seg_lossfun(y, t) 45 | return loss 46 | 47 | def compute_accuracy(self, y, t): 48 | if self.nested_label: 49 | b, c, h, w, d = t.shape 50 | y = F.reshape(y, (b, 2, h*c, w, d)) 51 | t = F.reshape(t, (b, h*c, w, d)) 52 | return F.accuracy(y, t) 53 | 54 | def compute_dice_coef(self, y, t): 55 | if self.nested_label: 56 | dice = mean_dice_coefficient(dice_coefficient(y[:, 0:2, ...], t[:, 0, ...])) 57 | for i in range(1, t.shape[1]): 58 | dices = dice_coefficient(y[:, 2*i:2*(i+1), ...], t[:, i, ...], dataset=self.dataset) 59 | dice = F.concat((dice, mean_dice_coefficient(dices)), axis=0) 60 | else: 61 | dice = dice_coefficient(y, t, is_brats=self.is_brats) 62 | return dice 63 | 64 | def evaluate(self): 65 | summary = reporter_module.DictSummary() 66 | iterator = self._iterators['main'] 67 | enc = self._targets['enc'] 68 | dec = self._targets['dec'] 69 | reporter = Reporter() 70 | observer = object() 71 | reporter.add_observer(self.default_name, observer) 72 | 73 | if hasattr(iterator, 'reset'): 74 | iterator.reset() 75 | it = iterator 76 | else: 77 | it = copy.copy(iterator) 78 | 79 | for batch in it: 80 | x, t = self.converter(batch, self.device) 81 | with chainer.no_backprop_mode(), chainer.using_config('train', False): 82 | if self.dataset == 'msd_bound': 83 | h, w, d = x.shape[2:] 84 | hc, wc, dc = self.crop_size 85 | if self.nested_label: 86 | y = cupy.zeros((1, 2*(self.nb_labels-1), h, w, d), dtype='float32') 87 | else: 88 | y = cupy.zeros((1, self.nb_labels, h, w, d), dtype='float32') 89 | s = 128 # stride 90 | ker = 256 # kernel size 91 | dker = dc # kernel size for depth 92 | ds = dker*0.5 # stride for depth 93 | dsteps = int(math.floor((d-dker)/ds) + 1) 94 | steps = round((h - ker)/s + 1) 95 | for i in range(steps): 96 | for j in range(steps): 97 | for k in range(dsteps): 98 | xx = x[:, :, s*i:ker+s*i, s*j:ker+s*j, ds*k:dker+ds*k] 99 | hs = enc(xx) 100 | yy = dec(hs) 101 | y[:, :, s*i:ker+s*i, s*j:ker+s*j, ds*k:dker+ds*k] += yy.data 102 | # for the bottom depth part of the image 103 | xx = x[:, :, s*i:ker+s*i, s*j:ker+s*j, -dker:] 104 | hs = enc(xx) 105 | yy = dec(hs) 106 | y[:, :, s*i:ker+s*i, s*j:ker+s*j, -dker:] += yy.data 107 | 108 | else: 109 | hs = enc(x) 110 | y = dec(hs) 111 | seg_loss = self.compute_loss(y, t) 112 | accuracy = self.compute_accuracy(y, t) 113 | dice = self.compute_dice_coef(y, t) 114 | mean_dice = mean_dice_coefficient(dice) 115 | weighted_loss = seg_loss 116 | 117 | observation = {} 118 | with reporter.scope(observation): 119 | reporter.report({ 120 | 'loss/seg': seg_loss, 121 | 'loss/total': weighted_loss, 122 | 'acc': accuracy, 123 | 'mean_dc': mean_dice 124 | }, observer) 125 | xp = cuda.get_array_module(y) 126 | for i in range(len(dice)): 127 | if not xp.isnan(dice.data[i]): 128 | reporter.report({'dc_{}'.format(i): dice[i]}, observer) 129 | summary.add(observation) 130 | return summary.compute_mean() 131 | -------------------------------------------------------------------------------- /src/training/updaters/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pfnet-research/label-efficient-brain-tumor-segmentation/aad80ed7acb510a3147bb11c3910d2e17fb355d1/src/training/updaters/__init__.py -------------------------------------------------------------------------------- /src/training/updaters/boundseg_updater.py: -------------------------------------------------------------------------------- 1 | import chainer 2 | import chainer.functions as F 3 | from chainer.backends import cuda 4 | from chainer import reporter 5 | from chainer import Variable 6 | from chainer.training import StandardUpdater 7 | from src.functions.evaluation import dice_coefficient, mean_dice_coefficient 8 | from src.functions.loss import softmax_dice_loss 9 | from src.functions.loss.mixed_dice_loss import dice_loss_plus_cross_entropy 10 | from src.functions.loss.boundary_bce import boundary_bce 11 | 12 | 13 | class BoundSegUpdater(StandardUpdater): 14 | 15 | def __init__(self, config, **kwargs): 16 | self.nested_label = config['nested_label'] 17 | self.seg_lossfun = eval(config['seg_lossfun']) 18 | self.optimizer_name = config['optimizer'] 19 | self.init_lr = config['init_lr'] 20 | self.nb_epoch = config['epoch'] 21 | self.init_weight = config['init_encoder'] 22 | self.enc_freeze = config['enc_freeze'] 23 | self.edge_label = config['edge_label'] 24 | super(BoundSegUpdater, self).__init__(**kwargs) 25 | 26 | def get_optimizer_and_model(self, key): 27 | optimizer = self.get_optimizer(key) 28 | return optimizer, optimizer.target 29 | 30 | def get_batch(self): 31 | batch = self.get_iterator('main').next() 32 | batchsize = len(batch) 33 | in_arrays = self.converter(batch, self.device) 34 | return Variable(in_arrays[0]), in_arrays[1], in_arrays[2], batchsize 35 | 36 | def report_scores(self, y, t): 37 | with chainer.no_backprop_mode(): 38 | if self.nested_label: 39 | dice = mean_dice_coefficient(dice_coefficient(y[:, 0:2, ...], t[:, 0, ...])) 40 | for i in range(1, t.shape[1]): 41 | dices = dice_coefficient(y[:, 2 * i:2 * (i + 1), ...], t[:, i, ...]) 42 | dice = F.concat((dice, mean_dice_coefficient(dices)), axis=0) 43 | else: 44 | dice = dice_coefficient(y, t) 45 | mean_dice = mean_dice_coefficient(dice) 46 | 47 | if self.nested_label: 48 | b, c, h, w, d = t.shape 49 | y = F.reshape(y, (b, 2, h * c, w, d)) 50 | t = F.reshape(t, (b, h * c, w, d)) 51 | accuracy = F.accuracy(y, t) 52 | 53 | reporter.report({ 54 | 'acc': accuracy, 55 | 'mean_dc': mean_dice 56 | }) 57 | xp = cuda.get_array_module(y) 58 | for i in range(len(dice)): 59 | if not xp.isnan(dice.data[i]): 60 | reporter.report({'dc_{}'.format(i): dice[i]}) 61 | 62 | def update_core(self): 63 | opt_e, enc = self.get_optimizer_and_model('enc') 64 | opt_d, dec = self.get_optimizer_and_model('dec') 65 | opt_b, bound = self.get_optimizer_and_model('bound') 66 | 67 | if self.is_new_epoch: 68 | decay_rate = (1. - float(self.epoch / self.nb_epoch)) ** 0.9 69 | if self.optimizer_name == 'Adam': 70 | if self.init_weight is not None: 71 | opt_e.alpha = self.init_lr*self.enc_freeze * decay_rate 72 | # small learning rate for encoder 73 | else: 74 | opt_e.alpha = self.init_lr * decay_rate 75 | opt_d.alpha = self.init_lr * decay_rate 76 | opt_b.alpha = self.init_lr * decay_rate 77 | else: 78 | if self.init_weight is not None: 79 | opt_e.lr = self.init_lr*self.enc_freeze * decay_rate 80 | # small learning rate for encoder 81 | else: 82 | opt_e.lr = self.init_lr * decay_rate 83 | opt_d.lr = self.init_lr * decay_rate 84 | opt_b.lr = self.init_lr * decay_rate 85 | 86 | x, t, te, batchsize = self.get_batch() 87 | 88 | hs = enc(x) 89 | ye, bs = bound(hs) # ye:output for the edge loss ,bs: output for the decoder 90 | y = dec(hs, bs) 91 | 92 | if self.nested_label: 93 | seg_loss = 0. 94 | for i in range(t.shape[1]): 95 | seg_loss += self.seg_lossfun(y[:, 2*i:2*(i+1), ...], t[:, i, ...]) 96 | else: 97 | seg_loss = self.seg_lossfun(y, t) 98 | 99 | ye_ = ye.data[:, 1:, :, :, :] # exclude background information from prediction 100 | te_ = te[:, 1:, :, :, :] 101 | if self.nested_label: 102 | edge_loss = 0. 103 | bce_loss = 0. 104 | for i in range(te_.shape[1]): 105 | edge_loss += softmax_dice_loss(ye[:, 2*i:2*(i+1), ...], te_[:, i, ...]) 106 | bce_loss += boundary_bce(ye[:, 2*i:2*(i+1), ...], te[:, [0, i+1], ...]) 107 | else: 108 | edge_loss = softmax_dice_loss(ye_, te_, encode=False) 109 | bce_loss = boundary_bce(ye, te) 110 | 111 | opt_e.target.cleargrads() 112 | opt_d.target.cleargrads() 113 | opt_b.target.cleargrads() 114 | 115 | weighted_loss = seg_loss + edge_loss + bce_loss 116 | weighted_loss.backward() 117 | 118 | opt_e.update() 119 | opt_d.update() 120 | opt_b.update() 121 | 122 | self.report_scores(y, t) 123 | reporter.report({ 124 | 'loss/seg': seg_loss, 125 | 'loss/total': weighted_loss, 126 | 'loss/bound': edge_loss, 127 | 'loss/bce': bce_loss 128 | }) 129 | -------------------------------------------------------------------------------- /src/training/updaters/cpcseg_updater.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import chainer 3 | import chainer.functions as F 4 | from chainer.backends import cuda 5 | from chainer import reporter 6 | from chainer import Variable 7 | from chainer.training import StandardUpdater 8 | from src.functions.loss.mixed_dice_loss import dice_loss_plus_cross_entropy 9 | from src.functions.evaluation import dice_coefficient, mean_dice_coefficient 10 | from src.functions.loss.cpc_loss import cpc_loss 11 | from src.links.model.vaeseg import divide_img 12 | 13 | 14 | class CPCSegUpdater(StandardUpdater): 15 | 16 | def __init__(self, config, **kwargs): 17 | self.nested_label = config['nested_label'] 18 | self.seg_lossfun = eval(config['seg_lossfun']) 19 | self.k = config['vaeseg_nb_sampling'] 20 | self.rec_loss_weight = config['vaeseg_rec_loss_weight'] 21 | self.kl_loss_weight = config['vaeseg_kl_loss_weight'] 22 | self.optimizer_name = config['optimizer'] 23 | self.init_lr = config['init_lr'] 24 | self.nb_epoch = config['epoch'] 25 | self.pretrain = config['pretrain'] 26 | self.init_weight = config['init_encoder'] 27 | self.grid_size = config['grid_size'] 28 | self.base_channels = config['base_channels'] 29 | self.cpc_loss_weight = config['cpc_vaeseg_cpc_loss_weight'] 30 | self.enc_freeze = config['enc_freeze'] 31 | self.cpc_pattern = config['cpc_pattern'] 32 | self.idle_weight = config['vae_idle_weight'] 33 | super(CPCSegUpdater, self).__init__(**kwargs) 34 | 35 | def get_optimizer_and_model(self, key): 36 | optimizer = self.get_optimizer(key) 37 | return optimizer, optimizer.target 38 | 39 | def get_batch(self): 40 | batch = self.get_iterator('main').next() 41 | batchsize = len(batch) 42 | in_arrays = self.converter(batch, self.device) 43 | return Variable(in_arrays[0]), in_arrays[1], batchsize 44 | 45 | def report_scores(self, y, t): 46 | with chainer.no_backprop_mode(): 47 | if self.nested_label: 48 | dice = mean_dice_coefficient(dice_coefficient(y[:, 0:2, ...], t[:, 0, ...])) 49 | for i in range(1, t.shape[1]): 50 | dices = dice_coefficient(y[:, 2 * i:2 * (i + 1), ...], t[:, i, ...]) 51 | dice = F.concat((dice, mean_dice_coefficient(dices)), axis=0) 52 | else: 53 | dice = dice_coefficient(y, t) 54 | mean_dice = mean_dice_coefficient(dice) 55 | if self.nested_label: 56 | b, c, h, w, d = t.shape 57 | y = F.reshape(y, (b, 2, h * c, w, d)) 58 | t = F.reshape(t, (b, h * c, w, d)) 59 | accuracy = F.accuracy(y, t) 60 | 61 | reporter.report({ 62 | 'acc': accuracy, 63 | 'mean_dc': mean_dice 64 | }) 65 | xp = cuda.get_array_module(y) 66 | for i in range(len(dice)): 67 | if not xp.isnan(dice.data[i]): 68 | reporter.report({'dc_{}'.format(i): dice[i]}) 69 | 70 | def update_core(self): 71 | opt_e, enc = self.get_optimizer_and_model('enc') 72 | opt_d, dec = self.get_optimizer_and_model('dec') 73 | 74 | opt_p1, cpcpred1 = self.get_optimizer_and_model('cpcpred1') 75 | 76 | if self.is_new_epoch: 77 | decay_rate = (1. - float(self.epoch / self.nb_epoch)) ** 0.9 78 | if self.optimizer_name == 'Adam': 79 | if self.init_weight is not None: 80 | opt_e.alpha = self.init_lr*self.enc_freeze * decay_rate 81 | else: 82 | opt_e.alpha = self.init_lr * decay_rate 83 | opt_d.alpha = self.init_lr * decay_rate 84 | opt_p1.alpha = self.init_lr * decay_rate 85 | else: 86 | if self.init_weight is not None: 87 | opt_e.lr = self.init_lr*self.enc_freeze * decay_rate 88 | else: 89 | opt_e.lr = self.init_lr * decay_rate 90 | opt_d.lr = self.init_lr * decay_rate 91 | opt_p1.lr = self.init_lr * decay_rate 92 | 93 | x, t, batchsize = self.get_batch() 94 | 95 | b, c, x1, x2, x3 = x.shape 96 | ssize = int(self.grid_size*0.5) # stride size 97 | gl0 = int(x1/ssize-1) # grid length 98 | gl1 = int(x2/ssize-1) 99 | gl2 = int(x3/ssize-1) 100 | 101 | hs = enc(x) 102 | h2 = divide_img(x) 103 | h2 = enc(h2) 104 | h2 = F.average(h2[-1], axis=(2, 3, 4)) 105 | cpc_t = F.transpose(h2, axes=(1, 0)) 106 | cpc_t = F.reshape(cpc_t, (1, self.base_channels*8, gl0, gl1, gl2)) 107 | 108 | y1 = cpcpred1(h2) 109 | cpc_loss1 = cpc_loss(y1, cpc_t, upper=True, cpc_pattern=self.cpc_pattern) 110 | cpc_loss_tot = cpc_loss1 111 | 112 | y = dec(hs) 113 | 114 | opt_e.target.cleargrads() 115 | opt_d.target.cleargrads() 116 | opt_p1.target.cleargrads() 117 | 118 | xp = cuda.get_array_module(t) 119 | if xp.sum(t) < -(2**31): 120 | weighted_loss = self.idle_weight * self.cpc_loss_weight * cpc_loss_tot 121 | weighted_loss.backward() 122 | t = xp.zeros(t.shape, dtype=np.int32) 123 | 124 | else: 125 | if self.nested_label: 126 | seg_loss = 0. 127 | for i in range(t.shape[1]): 128 | seg_loss += self.seg_lossfun(y[:, 2*i:2*(i+1), ...], t[:, i, ...]) 129 | else: 130 | seg_loss = self.seg_lossfun(y, t) 131 | 132 | weighted_loss = seg_loss + self.cpc_loss_weight*cpc_loss_tot 133 | weighted_loss.backward() 134 | 135 | opt_e.update() 136 | opt_d.update() 137 | opt_p1.update() 138 | 139 | self.report_scores(y, t) 140 | reporter.report({ 141 | 'loss/cpc': cpc_loss_tot, 142 | 'loss/total': weighted_loss 143 | }) 144 | -------------------------------------------------------------------------------- /src/training/updaters/encdec_seg_updater.py: -------------------------------------------------------------------------------- 1 | import chainer 2 | import chainer.functions as F 3 | from chainer.backends import cuda 4 | from chainer import reporter 5 | from chainer import Variable 6 | from chainer.training import StandardUpdater 7 | from src.functions.evaluation import dice_coefficient, mean_dice_coefficient 8 | from src.functions.loss.mixed_dice_loss import dice_loss_plus_cross_entropy 9 | 10 | 11 | class EncDecSegUpdater(StandardUpdater): 12 | 13 | def __init__(self, config, **kwargs): 14 | self.nested_label = config['nested_label'] 15 | self.seg_lossfun = eval(config['seg_lossfun']) 16 | self.optimizer_name = config['optimizer'] 17 | self.init_lr = config['init_lr'] 18 | self.nb_epoch = config['epoch'] 19 | self.init_weight = config['init_encoder'] 20 | self.enc_freeze = config['enc_freeze'] 21 | super(EncDecSegUpdater, self).__init__(**kwargs) 22 | 23 | def get_optimizer_and_model(self, key): 24 | optimizer = self.get_optimizer(key) 25 | return optimizer, optimizer.target 26 | 27 | def get_batch(self): 28 | batch = self.get_iterator('main').next() 29 | batchsize = len(batch) 30 | in_arrays = self.converter(batch, self.device) 31 | return Variable(in_arrays[0]), in_arrays[1], batchsize 32 | 33 | def report_scores(self, y, t): 34 | with chainer.no_backprop_mode(): 35 | if self.nested_label: 36 | dice = mean_dice_coefficient(dice_coefficient(y[:, 0:2, ...], t[:, 0, ...])) 37 | for i in range(1, t.shape[1]): 38 | dices = dice_coefficient(y[:, 2 * i:2 * (i + 1), ...], t[:, i, ...]) 39 | dice = F.concat((dice, mean_dice_coefficient(dices)), axis=0) 40 | else: 41 | dice = dice_coefficient(y, t) 42 | mean_dice = mean_dice_coefficient(dice) 43 | if self.nested_label: 44 | b, c, h, w, d = t.shape 45 | y = F.reshape(y, (b, 2, h * c, w, d)) 46 | t = F.reshape(t, (b, h * c, w, d)) 47 | accuracy = F.accuracy(y, t) 48 | 49 | reporter.report({ 50 | 'acc': accuracy, 51 | 'mean_dc': mean_dice 52 | }) 53 | xp = cuda.get_array_module(y) 54 | for i in range(len(dice)): 55 | if not xp.isnan(dice.data[i]): 56 | reporter.report({'dc_{}'.format(i): dice[i]}) 57 | 58 | def update_core(self): 59 | opt_e, enc = self.get_optimizer_and_model('enc') 60 | opt_d, dec = self.get_optimizer_and_model('dec') 61 | 62 | if self.is_new_epoch: 63 | decay_rate = (1. - float(self.epoch / self.nb_epoch)) ** 0.9 64 | if self.optimizer_name == 'Adam': 65 | if self.init_weight is not None: 66 | opt_e.alpha = self.init_lr*self.enc_freeze * decay_rate 67 | else: 68 | opt_e.alpha = self.init_lr * decay_rate 69 | opt_d.alpha = self.init_lr * decay_rate 70 | else: 71 | if self.init_weight is not None: 72 | opt_e.lr = self.init_lr*self.enc_freeze * decay_rate 73 | else: 74 | opt_e.lr = self.init_lr * decay_rate 75 | opt_d.lr = self.init_lr * decay_rate 76 | 77 | x, t, batchsize = self.get_batch() 78 | 79 | hs = enc(x) 80 | 81 | y = dec(hs) 82 | 83 | if self.nested_label: 84 | seg_loss = 0. 85 | for i in range(t.shape[1]): 86 | seg_loss += self.seg_lossfun(y[:, 2*i:2*(i+1), ...], t[:, i, ...]) 87 | else: 88 | seg_loss = self.seg_lossfun(y, t) 89 | 90 | opt_e.target.cleargrads() 91 | opt_d.target.cleargrads() 92 | 93 | weighted_loss = seg_loss 94 | weighted_loss.backward() 95 | 96 | opt_e.update() 97 | opt_d.update() 98 | 99 | self.report_scores(y, t) 100 | reporter.report({ 101 | 'loss/seg': seg_loss, 102 | 'loss/total': weighted_loss 103 | }) 104 | -------------------------------------------------------------------------------- /src/training/updaters/vaeseg_updater.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import chainer 3 | import chainer.functions as F 4 | from chainer.backends import cuda 5 | from chainer import reporter 6 | from chainer import Variable 7 | from chainer.training import StandardUpdater 8 | from src.functions.evaluation import dice_coefficient, mean_dice_coefficient 9 | from src.functions.loss.mixed_dice_loss import dice_loss_plus_cross_entropy 10 | 11 | 12 | class VAESegUpdater(StandardUpdater): 13 | 14 | def __init__(self, config, **kwargs): 15 | self.nested_label = config['nested_label'] 16 | self.seg_lossfun = eval(config['seg_lossfun']) 17 | self.k = config['vaeseg_nb_sampling'] 18 | self.rec_loss_weight = config['vaeseg_rec_loss_weight'] 19 | self.kl_loss_weight = config['vaeseg_kl_loss_weight'] 20 | self.optimizer_name = config['optimizer'] 21 | self.init_lr = config['init_lr'] 22 | self.nb_epoch = config['epoch'] 23 | self.pretrain = config['pretrain'] 24 | self.init_weight = config['init_encoder'] 25 | self.enc_freeze = config['enc_freeze'] 26 | self.idle_weight = config['vae_idle_weight'] 27 | super(VAESegUpdater, self).__init__(**kwargs) 28 | 29 | def get_optimizer_and_model(self, key): 30 | optimizer = self.get_optimizer(key) 31 | return optimizer, optimizer.target 32 | 33 | def get_batch(self): 34 | batch = self.get_iterator('main').next() 35 | batchsize = len(batch) 36 | in_arrays = self.converter(batch, self.device) 37 | return Variable(in_arrays[0]), in_arrays[1], batchsize 38 | 39 | def report_scores(self, y, t): 40 | with chainer.no_backprop_mode(): 41 | if self.nested_label: 42 | dice = mean_dice_coefficient(dice_coefficient(y[:, 0:2, ...], t[:, 0, ...])) 43 | for i in range(1, t.shape[1]): 44 | dices = dice_coefficient(y[:, 2 * i:2 * (i + 1), ...], t[:, i, ...]) 45 | dice = F.concat((dice, mean_dice_coefficient(dices)), axis=0) 46 | else: 47 | dice = dice_coefficient(y, t) 48 | mean_dice = mean_dice_coefficient(dice) 49 | if self.nested_label: 50 | b, c, h, w, d = t.shape 51 | y = F.reshape(y, (b, 2, h * c, w, d)) 52 | t = F.reshape(t, (b, h * c, w, d)) 53 | accuracy = F.accuracy(y, t) 54 | 55 | reporter.report({ 56 | 'acc': accuracy, 57 | 'mean_dc': mean_dice 58 | }) 59 | xp = cuda.get_array_module(y) 60 | for i in range(len(dice)): 61 | if not xp.isnan(dice.data[i]): 62 | reporter.report({'dc_{}'.format(i): dice[i]}) 63 | 64 | def update_core(self): 65 | opt_e, enc = self.get_optimizer_and_model('enc') 66 | opt_em, emb = self.get_optimizer_and_model('emb') 67 | opt_d, dec = self.get_optimizer_and_model('dec') 68 | opt_v, vae = self.get_optimizer_and_model('vae') 69 | 70 | if self.is_new_epoch: 71 | decay_rate = (1. - float(self.epoch / self.nb_epoch)) ** 0.9 72 | if self.optimizer_name == 'Adam': 73 | if self.init_weight is not None: 74 | opt_e.alpha = self.init_lr*self.enc_freeze * decay_rate 75 | else: 76 | opt_e.alpha = self.init_lr * decay_rate 77 | opt_em.alpha = self.init_lr * decay_rate 78 | opt_d.alpha = self.init_lr * decay_rate 79 | opt_v.alpha = self.init_lr * decay_rate 80 | else: 81 | if self.init_weight is not None: 82 | opt_e.lr = self.init_lr*self.enc_freeze * decay_rate 83 | else: 84 | opt_e.lr = self.init_lr * decay_rate 85 | opt_em.lr = self.init_lr * decay_rate 86 | opt_d.lr = self.init_lr * decay_rate 87 | opt_v.lr = self.init_lr * decay_rate 88 | 89 | x, t, batchsize = self.get_batch() 90 | 91 | hs = enc(x) 92 | mu, ln_var = emb(hs[-1]) 93 | latent_size = np.prod(list(mu.shape)) 94 | 95 | kl_loss = F.gaussian_kl_divergence(mu, ln_var, reduce='sum') / latent_size 96 | 97 | rec_loss = 0. 98 | for i in range(self.k): 99 | z = F.gaussian(mu, ln_var) 100 | rec_x = vae(z) 101 | rec_loss += F.mean_squared_error(x, rec_x) 102 | rec_loss /= self.k 103 | 104 | opt_e.target.cleargrads() 105 | opt_em.target.cleargrads() 106 | opt_v.target.cleargrads() 107 | opt_d.target.cleargrads() 108 | xp = cuda.get_array_module(t) 109 | if xp.sum(t) < -(2**31): 110 | # if label is NaN, optimize encoder, VAE only 111 | weighted_loss = self.idle_weight*( 112 | self.rec_loss_weight 113 | * rec_loss + self.kl_loss_weight * kl_loss) 114 | t = xp.zeros(t.shape, dtype=np.int32) 115 | y = dec(hs) 116 | 117 | else: 118 | y = dec(hs) 119 | if self.nested_label: 120 | seg_loss = 0. 121 | for i in range(t.shape[1]): 122 | seg_loss += self.seg_lossfun(y[:, 2*i:2*(i+1), ...], t[:, i, ...]) 123 | else: 124 | seg_loss = self.seg_lossfun(y, t) 125 | if self.pretrain: 126 | weighted_loss = self.rec_loss_weight * rec_loss + self.kl_loss_weight * kl_loss 127 | else: 128 | weighted_loss = seg_loss + self.rec_loss_weight * rec_loss \ 129 | + self.kl_loss_weight * kl_loss 130 | 131 | weighted_loss.backward() 132 | opt_e.update() 133 | opt_em.update() 134 | opt_v.update() 135 | opt_d.update() 136 | 137 | self.report_scores(y, t) 138 | reporter.report({ 139 | 'loss/rec': rec_loss, 140 | 'loss/kl': kl_loss, 141 | 'loss/total': weighted_loss 142 | }) 143 | -------------------------------------------------------------------------------- /src/utils/config.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import os 3 | import time 4 | import yaml 5 | 6 | 7 | default_config = { 8 | 'dataset_name': 'msd_bound', 9 | 'image_path': '/PATH_TO_NORMALIZED_IMAGES/imagesTr_normalized', 10 | 'label_path': '/PATH_TO_LABELS/labelsTr', 11 | 'image_file_format': 'npz', 12 | 'label_file_format': 'nii', 13 | 'train_list_path': '/PATH_TO_TRAINING_LIST/train_list_cv0.txt', 14 | 'validation_list_path': '/PATH_TO_VALIDATION_LIST/validation_list_cv0.txt', 15 | 'test_list_path': '/PATH_TO_TEST_LIST/test_list_cv0.txt', 16 | 'target_label': 0, 17 | 'nested_label': False, 18 | 'crop_size': '(160,192,128)', 19 | 'shift_intensity': 0.1, 20 | 'random_flip': True, 21 | 22 | 'segmentor_name': 'vaeseg', 23 | 'in_channels': 4, 24 | 'base_channels': 32, 25 | 'nb_labels': 4, 26 | 'vaeseg_norm': 'GroupNormalization', 27 | 'vaeseg_bn_first': True, 28 | 'vaeseg_ndim_latent': 128, 29 | 'vaeseg_skip_connect_mode': 'sum', 30 | 'unet_cropping': False, 31 | 'dv_num_trans_layers': 12, 32 | 33 | 'seg_lossfun': 'softmax_dice_loss', 34 | 'auxiliary_weights': None, 35 | 'vaeseg_nb_sampling': 1, 36 | 'vaeseg_rec_loss_weight': 0.1, 37 | 'vaeseg_kl_loss_weight': 0.1, 38 | 'structseg_nb_copies': 1, 39 | 40 | 'mn': True, 41 | 'gpu_start_id': 0, 42 | 'loaderjob': 2, 43 | 'batchsize': 1, 44 | 'val_batchsize': 2, 45 | 'epoch': 200, 46 | 'optimizer': 'Adam', 47 | 'init_lr': 1e-4, 48 | 'lr_reduction_ratio': 0.99, 49 | 'lr_reduction_interval': (1, 'epoch'), 50 | 'weight_decay': 1e-5, 51 | 'report_interval': 10, 52 | 'eval_interval': 2, 53 | 'snapshot_interval': 20, 54 | 'init_segmentor': None, 55 | 'init_encoder': None, 56 | 'init_decoder': None, 57 | 'init_vae': None, 58 | 'resume': None, 59 | 60 | 'pretrain': False, 61 | 'grid_size': 32, 62 | 'init_embedder': None, 63 | 'cpc_vaeseg_cpc_loss_weight': 0.001, 64 | 'init_cpcpred': None, 65 | 'enc_freeze': 0.01, 66 | 'cpc_pattern': 'updown', 67 | 'random_scale': False, 68 | 'edge_path': '/PATH_TO_EDGE_LABELS/edgesTr', 69 | 'edge_file_format': 'npy', 70 | 'edge_label': False, 71 | 'print_each_dc': True, 72 | 'is_brats': False, 73 | 'ignore_path': None, 74 | 'vae_idle_weight': 1 75 | } 76 | 77 | 78 | def overwrite_config( 79 | input_cfg, 80 | dump_yaml_dir=None 81 | ): 82 | output_cfg = copy.copy(default_config) 83 | for key, val in input_cfg.items(): 84 | if key not in output_cfg: 85 | raise ValueError('Unknown configuration key: {}'.format(key)) 86 | output_cfg[key] = val 87 | if dump_yaml_dir is not None: 88 | os.makedirs(dump_yaml_dir, exist_ok=True) 89 | cur_time = time.strftime("%Y-%m-%d--%H-%M-%S", time.gmtime()) 90 | dump_yaml_path = os.path.join( 91 | dump_yaml_dir, '{}.yaml'.format(cur_time)) 92 | with open(dump_yaml_path, 'w') as f: 93 | yaml.dump(output_cfg, f) 94 | return output_cfg 95 | -------------------------------------------------------------------------------- /src/utils/encode_one_hot_vector.py: -------------------------------------------------------------------------------- 1 | from chainer.backends import cuda 2 | 3 | 4 | def _encode_one_hot_vector_core(x, nb_class): 5 | xp = cuda.get_array_module(x) 6 | batch, h, w, d = x.shape 7 | 8 | res = xp.zeros((batch, nb_class, h, w, d), dtype=xp.float32) 9 | x = x.reshape(batch, -1) 10 | for i in range(batch): 11 | y = xp.identity(nb_class, dtype=xp.float32)[x[i]] 12 | res[i] = xp.swapaxes(y, 0, 1).reshape((nb_class, h, w, d)) 13 | return res 14 | 15 | 16 | def encode_one_hot_vector(x, nb_class): 17 | if isinstance(x, cuda.ndarray): 18 | with x.device: 19 | return _encode_one_hot_vector_core(x, nb_class) 20 | else: 21 | return _encode_one_hot_vector_core(x, nb_class) 22 | -------------------------------------------------------------------------------- /src/utils/setup_helpers.py: -------------------------------------------------------------------------------- 1 | from chainer.links import BatchNormalization, GroupNormalization 2 | from chainermn.links import MultiNodeBatchNormalization 3 | from chainer.functions import softmax_cross_entropy 4 | from chainer.optimizers import Adam 5 | from chainer.iterators import MultiprocessIterator, SerialIterator 6 | from chainer.optimizer import WeightDecay 7 | from chainer import serializers 8 | from chainer import training 9 | from chainer.training import extensions 10 | from chainer.backends.cuda import get_device_from_id 11 | import chainermn 12 | from src.datasets.msd_bound import MSDBoundDataset 13 | from src.links.model.vaeseg import BoundaryStream, CPCPredictor, Decoder, Encoder, VAE, VD 14 | from src.training.updaters.vaeseg_updater import VAESegUpdater 15 | from src.training.extensions.vaeseg_evaluator import VAESegEvaluator 16 | from src.training.updaters.encdec_seg_updater import EncDecSegUpdater 17 | from src.training.extensions.encdec_seg_evaluator import EncDecSegEvaluator 18 | from src.training.updaters.boundseg_updater import BoundSegUpdater 19 | from src.training.extensions.boundseg_evaluator import BoundSegEvaluator 20 | from src.training.updaters.cpcseg_updater import CPCSegUpdater 21 | from src.training.extensions.cpcseg_evaluator import CPCSegEvaluator 22 | 23 | 24 | def _setup_communicator(config, gpu_start_id=0): 25 | if config['mn']: 26 | comm = chainermn.create_communicator('pure_nccl') 27 | is_master = (comm.rank == 0) 28 | device = comm.intra_rank + gpu_start_id 29 | else: 30 | comm = None 31 | is_master = True 32 | device = gpu_start_id 33 | return comm, is_master, device 34 | 35 | 36 | def _setup_datasets(config, comm, is_master): 37 | if is_master: 38 | if config['dataset_name'] == 'msd_bound': 39 | train_data = MSDBoundDataset(config, config['train_list_path']) 40 | validation_data = MSDBoundDataset(config, config['validation_list_path']) 41 | test_data = MSDBoundDataset(config, config['test_list_path']) 42 | validation_data.random_scale = False 43 | test_data.random_scale = False 44 | validation_data.shift_intensity = 0 45 | test_data.shift_intensity = 0 46 | validation_data.random_flip = False 47 | test_data.random_flip = False 48 | validation_data.nb_copies = 1 49 | test_data.nb_copies = 1 50 | validation_data.training = False 51 | test_data.training = False 52 | else: 53 | raise ValueError('Unknown dataset_name: {}'.format(config['dataset_name'])) 54 | print('Training dataset size: {}'.format(len(train_data))) 55 | print('Validation dataset size: {}'.format(len(validation_data))) 56 | print('Test dataset size: {}'.format(len(test_data))) 57 | else: 58 | train_data = None 59 | validation_data = None 60 | test_data = None 61 | 62 | # scatter dataset 63 | if comm is not None: 64 | train_data = chainermn.scatter_dataset(train_data, comm, shuffle=True) 65 | validation_data = chainermn.scatter_dataset(validation_data, comm, shuffle=True) 66 | test_data = chainermn.scatter_dataset(test_data, comm, shuffle=True) 67 | 68 | return train_data, validation_data, test_data 69 | 70 | 71 | def _setup_vae_segmentor(config, comm=None): 72 | in_channels = config['in_channels'] 73 | base_channels = config['base_channels'] 74 | out_channels = config['nb_labels'] 75 | nested_label = config['nested_label'] 76 | norm = eval(config['vaeseg_norm']) 77 | bn_first = config['vaeseg_bn_first'] 78 | ndim_latent = config['vaeseg_ndim_latent'] 79 | mode = config['vaeseg_skip_connect_mode'] 80 | input_shape = eval(config['crop_size']) 81 | if nested_label: 82 | out_channels = 2 * (out_channels - 1) 83 | 84 | encoder = Encoder( 85 | base_channels=base_channels, 86 | norm=norm, 87 | bn_first=bn_first, 88 | ndim_latent=ndim_latent, 89 | comm=comm 90 | ) 91 | 92 | embedder = VD( 93 | channels=8*base_channels, 94 | norm=norm, 95 | bn_first=bn_first, 96 | ndim_latent=ndim_latent, 97 | comm=comm 98 | ) 99 | 100 | decoder = Decoder( 101 | base_channels=base_channels, 102 | out_channels=out_channels, 103 | norm=norm, 104 | bn_first=bn_first, 105 | mode=mode, 106 | comm=comm 107 | ) 108 | 109 | vae = VAE( 110 | in_channels=in_channels, 111 | base_channels=base_channels, 112 | norm=norm, 113 | bn_first=bn_first, 114 | input_shape=input_shape, 115 | comm=comm 116 | ) 117 | 118 | return encoder, embedder, decoder, vae 119 | 120 | 121 | def _setup_vae_segmentor_only(config, comm=None): 122 | base_channels = config['base_channels'] 123 | out_channels = config['nb_labels'] 124 | nested_label = config['nested_label'] 125 | norm = eval(config['vaeseg_norm']) 126 | bn_first = config['vaeseg_bn_first'] 127 | ndim_latent = config['vaeseg_ndim_latent'] 128 | mode = config['vaeseg_skip_connect_mode'] 129 | if nested_label: 130 | out_channels = 2 * (out_channels - 1) 131 | 132 | encoder = Encoder( 133 | base_channels=base_channels, 134 | norm=norm, 135 | bn_first=bn_first, 136 | ndim_latent=ndim_latent, 137 | comm=comm 138 | ) 139 | 140 | decoder = Decoder( 141 | base_channels=base_channels, 142 | out_channels=out_channels, 143 | norm=norm, 144 | bn_first=bn_first, 145 | mode=mode, 146 | comm=comm 147 | ) 148 | 149 | return encoder, decoder 150 | 151 | 152 | def _setup_cpc_segmentor(config, comm=None): 153 | base_channels = config['base_channels'] 154 | out_channels = config['nb_labels'] 155 | nested_label = config['nested_label'] 156 | norm = eval(config['vaeseg_norm']) 157 | bn_first = config['vaeseg_bn_first'] 158 | ndim_latent = config['vaeseg_ndim_latent'] 159 | mode = config['vaeseg_skip_connect_mode'] 160 | input_shape = eval(config['crop_size']) 161 | grid_size = config['grid_size'] 162 | cpc_pattern = config['cpc_pattern'] 163 | if nested_label: 164 | out_channels = 2 * (out_channels - 1) 165 | 166 | encoder = Encoder( 167 | base_channels=base_channels, 168 | norm=norm, 169 | bn_first=bn_first, 170 | ndim_latent=ndim_latent, 171 | comm=comm 172 | ) 173 | 174 | decoder = Decoder( 175 | base_channels=base_channels, 176 | out_channels=out_channels, 177 | norm=norm, 178 | bn_first=bn_first, 179 | mode=mode, 180 | comm=comm 181 | ) 182 | 183 | cpcpred1 = CPCPredictor( 184 | base_channels=base_channels*8, 185 | norm=norm, 186 | bn_first=bn_first, 187 | grid_size=grid_size, 188 | input_shape=input_shape, 189 | upper=True, 190 | cpc_pattern=cpc_pattern, 191 | comm=comm 192 | ) 193 | return encoder, decoder, cpcpred1 194 | 195 | 196 | def _setup_bound_segmentor(config, comm=None): 197 | base_channels = config['base_channels'] 198 | out_channels = config['nb_labels'] 199 | nested_label = config['nested_label'] 200 | norm = eval(config['vaeseg_norm']) 201 | bn_first = config['vaeseg_bn_first'] 202 | mode = config['vaeseg_skip_connect_mode'] 203 | ndim_latent = config['vaeseg_ndim_latent'] 204 | if nested_label: 205 | out_channels = 2 * (out_channels - 1) 206 | 207 | encoder = Encoder( 208 | base_channels=base_channels, 209 | norm=norm, 210 | bn_first=bn_first, 211 | ndim_latent=ndim_latent, 212 | comm=comm 213 | ) 214 | 215 | decoder = Decoder( 216 | base_channels=base_channels, 217 | out_channels=out_channels, 218 | norm=norm, 219 | bn_first=bn_first, 220 | mode=mode, 221 | comm=comm 222 | ) 223 | 224 | boundary = BoundaryStream( 225 | base_channels=base_channels, 226 | out_channels=out_channels, 227 | norm=norm, 228 | comm=comm 229 | ) 230 | 231 | return encoder, decoder, boundary 232 | 233 | 234 | def _setup_iterators(config, batch_size, train_data, validation_data, test_data): 235 | if isinstance(config['loaderjob'], int) and config['loaderjob'] > 1: 236 | train_iterator = MultiprocessIterator( 237 | train_data, batch_size, n_processes=config['loaderjob']) 238 | validation_iterator = MultiprocessIterator( 239 | validation_data, batch_size, n_processes=config['loaderjob'], 240 | repeat=False, shuffle=False) 241 | test_iterator = MultiprocessIterator( 242 | test_data, batch_size, n_processes=config['loaderjob'], 243 | repeat=False, shuffle=False) 244 | else: 245 | train_iterator = SerialIterator(train_data, batch_size) 246 | validation_iterator = SerialIterator( 247 | validation_data, batch_size, repeat=False, shuffle=False) 248 | test_iterator = SerialIterator( 249 | test_data, batch_size, repeat=False, shuffle=False) 250 | 251 | return train_iterator, validation_iterator, test_iterator 252 | 253 | 254 | # Optimizer 255 | def _setup_optimizer(config, model, comm): 256 | optimizer_name = config['optimizer'] 257 | lr = float(config['init_lr']) 258 | weight_decay = float(config['weight_decay']) 259 | if optimizer_name == 'Adam': 260 | optimizer = Adam(alpha=lr, weight_decay_rate=weight_decay) 261 | elif optimizer_name in \ 262 | ('SGD', 'MomentumSGD', 'CorrectedMomentumSGD', 'RMSprop'): 263 | optimizer = eval(optimizer_name)(lr=lr) 264 | if weight_decay > 0.: 265 | optimizer.add_hook(WeightDecay(weight_decay)) 266 | else: 267 | raise ValueError('Invalid optimizer: {}'.format(optimizer_name)) 268 | if comm is not None: 269 | optimizer = chainermn.create_multi_node_optimizer(optimizer, comm) 270 | optimizer.setup(model) 271 | 272 | return optimizer 273 | 274 | 275 | # Updater 276 | def _setup_updater(config, device, train_iterator, optimizers): 277 | updater_kwargs = dict() 278 | updater_kwargs['iterator'] = train_iterator 279 | updater_kwargs['optimizer'] = optimizers 280 | updater_kwargs['device'] = device 281 | 282 | if config['segmentor_name'] == 'vaeseg': 283 | return VAESegUpdater(config, **updater_kwargs) 284 | elif config['segmentor_name'] == 'encdec_seg': 285 | return EncDecSegUpdater(config, **updater_kwargs) 286 | elif config['segmentor_name'] == 'boundseg': 287 | return BoundSegUpdater(config, **updater_kwargs) 288 | elif config['segmentor_name'] == 'cpcseg': 289 | return CPCSegUpdater(config, **updater_kwargs) 290 | else: 291 | return training.StandardUpdater(**updater_kwargs) 292 | 293 | 294 | def _setup_extensions(config, trainer, optimizers, logging_counts, logging_attributes): 295 | if config['segmentor_name'] == 'vaeseg': 296 | trainer.extend(extensions.dump_graph('loss/total', out_name="segmentor.dot")) 297 | elif config['segmentor_name'] == 'encdec_seg': 298 | trainer.extend(extensions.dump_graph('loss/seg', out_name="segmentor.dot")) 299 | elif config['segmentor_name'] == 'boundseg': 300 | trainer.extend(extensions.dump_graph('loss/seg', out_name="segmentor.dot")) 301 | elif config['segmentor_name'] == 'cpcseg': 302 | trainer.extend(extensions.dump_graph('loss/total', out_name="segmentor.dot")) 303 | else: 304 | trainer.extend(extensions.dump_graph('main/loss', out_name="segmentor.dot")) 305 | 306 | # Report 307 | repo_trigger = (config['report_interval'], 'iteration') 308 | trainer.extend( 309 | extensions.LogReport( 310 | trigger=repo_trigger 311 | ) 312 | ) 313 | trainer.extend( 314 | extensions.PrintReport(logging_counts + logging_attributes), 315 | trigger=repo_trigger 316 | ) 317 | trainer.extend( 318 | extensions.ProgressBar() 319 | ) 320 | 321 | snap_trigger = (config['snapshot_interval'], 'epoch') 322 | trainer.extend( 323 | extensions.snapshot(filename='snapshot_epoch_{.updater.epoch}'), 324 | trigger=snap_trigger 325 | ) 326 | for k, v in optimizers.items(): 327 | trainer.extend( 328 | extensions.snapshot_object(v.target, k+'_epoch_{.updater.epoch}'), 329 | trigger=snap_trigger 330 | ) 331 | 332 | for attr in logging_attributes: 333 | trainer.extend( 334 | extensions.PlotReport([attr, 'validation/' + attr], 'epoch', 335 | file_name=attr.replace('/', '_') + '.png') 336 | ) 337 | 338 | 339 | # Trainer 340 | def setup_trainer(config, out, batch_size, epoch, gpu_start_id): 341 | 342 | comm, is_master, device = _setup_communicator(config, gpu_start_id) 343 | 344 | train_data, validation_data, test_data = _setup_datasets(config, comm, is_master) 345 | 346 | if config['segmentor_name'] == 'vaeseg': 347 | encoder, embedder, decoder, vae = _setup_vae_segmentor(config, comm) 348 | # load weights 349 | if config['init_encoder'] is not None: 350 | serializers.load_npz(config['init_encoder'], encoder) 351 | if config['init_embedder'] is not None: 352 | serializers.load_npz(config['init_embedder'], embedder) 353 | if config['init_decoder'] is not None: 354 | serializers.load_npz(config['init_decoder'], decoder) 355 | if config['init_vae'] is not None: 356 | serializers.load_npz(config['init_vae'], vae) 357 | if device is not None: 358 | get_device_from_id(device).use() 359 | encoder.to_gpu() 360 | embedder.to_gpu() 361 | decoder.to_gpu() 362 | vae.to_gpu() 363 | 364 | opt_enc = _setup_optimizer(config, encoder, comm) 365 | opt_emb = _setup_optimizer(config, embedder, comm) 366 | opt_dec = _setup_optimizer(config, decoder, comm) 367 | opt_vae = _setup_optimizer(config, vae, comm) 368 | optimizers = {'enc': opt_enc, 'emb': opt_emb, 'dec': opt_dec, 'vae': opt_vae} 369 | 370 | elif config['segmentor_name'] == 'cpcseg': 371 | 372 | encoder, decoder, cpcpred1 = _setup_cpc_segmentor(config, comm) 373 | # load weights 374 | if config['init_encoder'] is not None: 375 | serializers.load_npz(config['init_encoder'], encoder) 376 | if config['init_decoder'] is not None: 377 | serializers.load_npz(config['init_decoder'], decoder) 378 | if config['init_cpcpred'] is not None: 379 | serializers.load_npz(config['init_cpcpred'], cpcpred1) 380 | if device is not None: 381 | get_device_from_id(device).use() 382 | encoder.to_gpu() 383 | decoder.to_gpu() 384 | cpcpred1.to_gpu() 385 | 386 | opt_enc = _setup_optimizer(config, encoder, comm) 387 | opt_dec = _setup_optimizer(config, decoder, comm) 388 | opt_p1 = _setup_optimizer(config, cpcpred1, comm) 389 | optimizers = {'enc': opt_enc, 'dec': opt_dec, 'cpcpred1': opt_p1} 390 | 391 | elif config['segmentor_name'] == 'encdec_seg': 392 | 393 | encoder, decoder = _setup_vae_segmentor_only(config, comm) 394 | # load weights 395 | if config['init_encoder'] is not None: 396 | serializers.load_npz(config['init_encoder'], encoder) 397 | if config['init_decoder'] is not None: 398 | serializers.load_npz(config['init_decoder'], decoder) 399 | if device is not None: 400 | get_device_from_id(device).use() 401 | encoder.to_gpu() 402 | decoder.to_gpu() 403 | 404 | opt_enc = _setup_optimizer(config, encoder, comm) 405 | opt_dec = _setup_optimizer(config, decoder, comm) 406 | optimizers = {'enc': opt_enc, 'dec': opt_dec} 407 | 408 | elif config['segmentor_name'] == 'boundseg': 409 | 410 | encoder, decoder, boundary = _setup_bound_segmentor(config, comm) 411 | # load weights 412 | if config['init_encoder'] is not None: 413 | serializers.load_npz(config['init_encoder'], encoder) 414 | if config['init_decoder'] is not None: 415 | serializers.load_npz(config['init_decoder'], decoder) 416 | if device is not None: 417 | get_device_from_id(device).use() 418 | encoder.to_gpu() 419 | decoder.to_gpu() 420 | boundary.to_gpu() 421 | 422 | opt_enc = _setup_optimizer(config, encoder, comm) 423 | opt_dec = _setup_optimizer(config, decoder, comm) 424 | opt_bound = _setup_optimizer(config, boundary, comm) 425 | optimizers = {'enc': opt_enc, 'dec': opt_dec, 'bound': opt_bound} 426 | 427 | train_iterator, validation_iterator, test_iterator = \ 428 | _setup_iterators(config, batch_size, train_data, validation_data, test_data) 429 | 430 | logging_counts = ['epoch', 'iteration'] 431 | if config['segmentor_name'] == 'vaeseg': 432 | logging_attributes = \ 433 | ['loss/rec', 'loss/kl', 'loss/total', 'acc', 434 | 'mean_dc', 'val/mean_dc', 'test/mean_dc'] 435 | if config['print_each_dc']: 436 | for i in range(0, config['nb_labels']): 437 | logging_attributes.append('dc_{}'.format(i)) 438 | logging_attributes.append('val/dc_{}'.format(i)) 439 | logging_attributes.append('test/dc_{}'.format(i)) 440 | 441 | elif config['segmentor_name'] == 'cpcseg': 442 | logging_attributes = \ 443 | ['loss/total', 'acc', 'loss/cpc'] 444 | for i in range(0, config['nb_labels']): 445 | logging_attributes.append('dc_{}'.format(i)) 446 | logging_attributes.append('val/dc_{}'.format(i)) 447 | logging_attributes.append('test/dc_{}'.format(i)) 448 | 449 | elif config['segmentor_name'] == 'encdec_seg': 450 | logging_attributes = \ 451 | ['loss/seg', 'loss/total', 'acc'] 452 | if config['print_each_dc']: 453 | for i in range(0, config['nb_labels']): 454 | logging_attributes.append('dc_{}'.format(i)) 455 | logging_attributes.append('val/dc_{}'.format(i)) 456 | logging_attributes.append('test/dc_{}'.format(i)) 457 | 458 | elif config['segmentor_name'] == 'boundseg': 459 | logging_attributes = \ 460 | ['loss/seg', 'loss/total', 'acc', 'loss/bound', 'loss/bce'] 461 | if config['print_each_dc']: 462 | for i in range(0, config['nb_labels']): 463 | logging_attributes.append('dc_{}'.format(i)) 464 | logging_attributes.append('val/dc_{}'.format(i)) 465 | logging_attributes.append('test/dc_{}'.format(i)) 466 | 467 | else: 468 | logging_attributes = ['main/loss', 'main/acc'] 469 | for i in range(1, config['nb_labels']): 470 | logging_attributes.append('main/dc_{}'.format(i)) 471 | logging_attributes.append('val/main/dc_{}'.format(i)) 472 | logging_attributes.append('test/main/dc_{}'.format(i)) 473 | 474 | updater = _setup_updater(config, device, train_iterator, optimizers) 475 | 476 | trainer = training.Trainer(updater, (epoch, 'epoch'), out=out) 477 | 478 | if is_master: 479 | _setup_extensions(config, trainer, optimizers, logging_counts, logging_attributes) 480 | 481 | if config['segmentor_name'] == 'vaeseg': 482 | targets = {'enc': encoder, 'emb': embedder, 'dec': decoder, 'vae': vae} 483 | val_evaluator = VAESegEvaluator(config, validation_iterator, targets, device=device) 484 | test_evaluator = VAESegEvaluator(config, test_iterator, targets, device=device) 485 | 486 | elif config['segmentor_name'] == 'cpcseg': 487 | targets = {'enc': encoder, 'dec': decoder, 'cpcpred1': cpcpred1} 488 | val_evaluator = CPCSegEvaluator(config, validation_iterator, targets, device=device) 489 | test_evaluator = CPCSegEvaluator(config, test_iterator, targets, device=device) 490 | 491 | elif config['segmentor_name'] == 'encdec_seg': 492 | targets = {'enc': encoder, 'dec': decoder} 493 | val_evaluator = EncDecSegEvaluator(config, validation_iterator, targets, device=device) 494 | test_evaluator = EncDecSegEvaluator(config, test_iterator, targets, device=device) 495 | 496 | elif config['segmentor_name'] == 'boundseg': 497 | targets = {'enc': encoder, 'dec': decoder, 'bound': boundary} 498 | val_evaluator = BoundSegEvaluator(config, validation_iterator, targets, device=device) 499 | test_evaluator = BoundSegEvaluator(config, test_iterator, targets, device=device) 500 | 501 | val_evaluator.default_name = 'val' 502 | test_evaluator.default_name = 'test' 503 | 504 | if comm is not None: 505 | val_evaluator = chainermn.create_multi_node_evaluator(val_evaluator, comm) 506 | test_evaluator = chainermn.create_multi_node_evaluator(test_evaluator, comm) 507 | trainer.extend(val_evaluator, trigger=(config['eval_interval'], 'epoch')) 508 | trainer.extend(test_evaluator, trigger=(config['eval_interval'], 'epoch')) 509 | 510 | # Resume 511 | if config['resume'] is not None: 512 | serializers.load_npz(config['resume'], trainer) 513 | 514 | return trainer 515 | --------------------------------------------------------------------------------