├── README.md ├── data ├── HGG │ └── Brats17_2013_2_1 │ │ ├── Brats17_2013_2_1_seg.nii.gz │ │ ├── Brats17_2013_2_1_t1.nii.gz │ │ ├── Brats17_2013_2_1_t1ce.nii.gz │ │ ├── Brats17_2013_2_1_t2.nii.gz │ │ └── Brats17_2013_2_1_flair.nii.gz ├── LGG │ └── Brats17_2013_0_1 │ │ ├── Brats17_2013_0_1_seg.nii.gz │ │ ├── Brats17_2013_0_1_t1.nii.gz │ │ ├── Brats17_2013_0_1_t1ce.nii.gz │ │ ├── Brats17_2013_0_1_t2.nii.gz │ │ └── Brats17_2013_0_1_flair.nii.gz ├── HGGTrimmed │ └── Brats17_2013_2_1 │ │ └── Brats17_2013_2_1_t1ce.nii.gz ├── LGGTrimmed │ └── Brats17_2013_0_1 │ │ └── Brats17_2013_0_1_t1ce.nii.gz ├── HGGSegTrimmed │ └── Brats17_2013_2_1 │ │ └── Brats17_2013_2_1_t1ce.nii.gz └── LGGSegTrimmed │ └── Brats17_2013_0_1 │ └── Brats17_2013_0_1_t1ce.nii.gz ├── src ├── pre_paras.json ├── hyper_paras.json ├── auto_run.sh ├── DataSplit │ ├── validset.csv │ ├── testset.csv │ └── trainset.csv ├── btc.py ├── btc_models.py ├── btc_test.py ├── btc_train.py ├── btc_preprocess.py └── btc_dataset.py ├── .gitignore └── LICENSE.md /README.md: -------------------------------------------------------------------------------- 1 | # Brain Tumor Classification 2 | -------------------------------------------------------------------------------- /data/HGG/Brats17_2013_2_1/Brats17_2013_2_1_seg.nii.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/quqixun/BTClassification/HEAD/data/HGG/Brats17_2013_2_1/Brats17_2013_2_1_seg.nii.gz -------------------------------------------------------------------------------- /data/HGG/Brats17_2013_2_1/Brats17_2013_2_1_t1.nii.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/quqixun/BTClassification/HEAD/data/HGG/Brats17_2013_2_1/Brats17_2013_2_1_t1.nii.gz -------------------------------------------------------------------------------- /data/HGG/Brats17_2013_2_1/Brats17_2013_2_1_t1ce.nii.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/quqixun/BTClassification/HEAD/data/HGG/Brats17_2013_2_1/Brats17_2013_2_1_t1ce.nii.gz -------------------------------------------------------------------------------- /data/HGG/Brats17_2013_2_1/Brats17_2013_2_1_t2.nii.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/quqixun/BTClassification/HEAD/data/HGG/Brats17_2013_2_1/Brats17_2013_2_1_t2.nii.gz -------------------------------------------------------------------------------- /data/LGG/Brats17_2013_0_1/Brats17_2013_0_1_seg.nii.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/quqixun/BTClassification/HEAD/data/LGG/Brats17_2013_0_1/Brats17_2013_0_1_seg.nii.gz -------------------------------------------------------------------------------- /data/LGG/Brats17_2013_0_1/Brats17_2013_0_1_t1.nii.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/quqixun/BTClassification/HEAD/data/LGG/Brats17_2013_0_1/Brats17_2013_0_1_t1.nii.gz -------------------------------------------------------------------------------- /data/LGG/Brats17_2013_0_1/Brats17_2013_0_1_t1ce.nii.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/quqixun/BTClassification/HEAD/data/LGG/Brats17_2013_0_1/Brats17_2013_0_1_t1ce.nii.gz -------------------------------------------------------------------------------- /data/LGG/Brats17_2013_0_1/Brats17_2013_0_1_t2.nii.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/quqixun/BTClassification/HEAD/data/LGG/Brats17_2013_0_1/Brats17_2013_0_1_t2.nii.gz -------------------------------------------------------------------------------- /data/HGG/Brats17_2013_2_1/Brats17_2013_2_1_flair.nii.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/quqixun/BTClassification/HEAD/data/HGG/Brats17_2013_2_1/Brats17_2013_2_1_flair.nii.gz -------------------------------------------------------------------------------- /data/LGG/Brats17_2013_0_1/Brats17_2013_0_1_flair.nii.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/quqixun/BTClassification/HEAD/data/LGG/Brats17_2013_0_1/Brats17_2013_0_1_flair.nii.gz -------------------------------------------------------------------------------- /data/HGGTrimmed/Brats17_2013_2_1/Brats17_2013_2_1_t1ce.nii.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/quqixun/BTClassification/HEAD/data/HGGTrimmed/Brats17_2013_2_1/Brats17_2013_2_1_t1ce.nii.gz -------------------------------------------------------------------------------- /data/LGGTrimmed/Brats17_2013_0_1/Brats17_2013_0_1_t1ce.nii.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/quqixun/BTClassification/HEAD/data/LGGTrimmed/Brats17_2013_0_1/Brats17_2013_0_1_t1ce.nii.gz -------------------------------------------------------------------------------- /data/HGGSegTrimmed/Brats17_2013_2_1/Brats17_2013_2_1_t1ce.nii.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/quqixun/BTClassification/HEAD/data/HGGSegTrimmed/Brats17_2013_2_1/Brats17_2013_2_1_t1ce.nii.gz -------------------------------------------------------------------------------- /data/LGGSegTrimmed/Brats17_2013_0_1/Brats17_2013_0_1_t1ce.nii.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/quqixun/BTClassification/HEAD/data/LGGSegTrimmed/Brats17_2013_0_1/Brats17_2013_0_1_t1ce.nii.gz -------------------------------------------------------------------------------- /src/pre_paras.json: -------------------------------------------------------------------------------- 1 | { 2 | "data_dir": "data", 3 | "hgg_in": "HGG", 4 | "lgg_in": "LGG", 5 | "hgg_out": "HGGSegTrimmed", 6 | "lgg_out": "LGGSegTrimmed", 7 | "volume_type": "t1ce", 8 | "is_mask": true, 9 | "non_mask_coeff": 0.333, 10 | "processes_num": -1, 11 | "pre_split": true, 12 | "pre_trainset_path": "DataSplit/trainset.csv", 13 | "pre_validset_path": "DataSplit/validset.csv", 14 | "pre_testset_path": "DataSplit/testset.csv", 15 | "train_prop": 0.6, 16 | "valid_prop": 0.2, 17 | "random_state": 0, 18 | "save_split": false, 19 | "save_split_dir": "DataSplit", 20 | "data_format": ".nii.gz", 21 | "paras_json_path": "hyper_paras.json", 22 | "weights_save_dir": "weights", 23 | "save_best_weights": true, 24 | "logs_save_dir": "logs", 25 | "results_save_dir": "results", 26 | "test_weights": "last", 27 | "pred_trainset": true 28 | } -------------------------------------------------------------------------------- /src/hyper_paras.json: -------------------------------------------------------------------------------- 1 | { 2 | "paras-1": { 3 | "comment": "baseline", 4 | "model_name": "pyramid", 5 | "input_shape": [112, 96, 96, 1], 6 | "pooling": "max", 7 | "l2_coeff": 5e-5, 8 | "drop_rate": 0.5, 9 | "bn_momentum": 0.9, 10 | "initializer": "glorot_uniform", 11 | "optimizer": "adam", 12 | "lr_start": 1e-3, 13 | "epochs_num": 100, 14 | "batch_size": 16 15 | }, 16 | "paras-2": { 17 | "comment": "another set of hyperparameters", 18 | "model_name": "pyramid", 19 | "input_shape": [112, 96, 96, 1], 20 | "pooling": "max", 21 | "l2_coeff": 5e-5, 22 | "drop_rate": 0.5, 23 | "bn_momentum": 0.9, 24 | "initializer": "glorot_uniform", 25 | "optimizer": "adam", 26 | "lr_start": 1e-3, 27 | "epochs_num": 100, 28 | "batch_size": 16 29 | } 30 | } 31 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | env/ 12 | build/ 13 | develop-eggs/ 14 | dist/ 15 | downloads/ 16 | eggs/ 17 | .eggs/ 18 | lib/ 19 | lib64/ 20 | parts/ 21 | sdist/ 22 | var/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | 27 | # PyInstaller 28 | # Usually these files are written by a python script from a template 29 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 30 | *.manifest 31 | *.spec 32 | 33 | # Installer logs 34 | pip-log.txt 35 | pip-delete-this-directory.txt 36 | 37 | # Unit test / coverage reports 38 | htmlcov/ 39 | .tox/ 40 | .coverage 41 | .coverage.* 42 | .cache 43 | nosetests.xml 44 | coverage.xml 45 | *,cover 46 | 47 | # Translations 48 | *.mo 49 | *.pot 50 | 51 | # Django stuff: 52 | *.log 53 | 54 | # Sphinx documentation 55 | docs/_build/ 56 | 57 | # PyBuilder 58 | target/ 59 | -------------------------------------------------------------------------------- /LICENSE.md: -------------------------------------------------------------------------------- 1 | 2 | The MIT License (MIT) 3 | 4 | Copyright (c) 2017 5 | 6 | Permission is hereby granted, free of charge, to any person obtaining a copy 7 | of this software and associated documentation files (the "Software"), to deal 8 | in the Software without restriction, including without limitation the rights 9 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 10 | copies of the Software, and to permit persons to whom the Software is 11 | furnished to do so, subject to the following conditions: 12 | 13 | The above copyright notice and this permission notice shall be included in all 14 | copies or substantial portions of the Software. 15 | 16 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 17 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 18 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 19 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 20 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 21 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 22 | SOFTWARE. 23 | -------------------------------------------------------------------------------- /src/auto_run.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | 4 | # Brain Tumor Classification 5 | # Commands for training and testing models. 6 | # Author: Qixun QU 7 | # Copyleft: MIT Licience 8 | 9 | # ,,, ,,, 10 | # ;" '; ;' ", 11 | # ; @.ss$$$$$$s.@ ; 12 | # `s$$$$$$$$$$$$$$$' 13 | # $$$$$$$$$$$$$$$$$$ 14 | # $$$$P""Y$$$Y""W$$$$$ 15 | # $$$$ p"$$$"q $$$$$ 16 | # $$$$ .$$$$$. $$$$' 17 | # $$$DaU$$O$$DaU$$$' 18 | # '$$$$'.^.'$$$$' 19 | # '&$$$$$&' 20 | 21 | 22 | # 23 | # Section 1 24 | # 25 | # Train and test model 26 | # Command: 27 | # python add.py --paras=paras_name 28 | # Parameters: 29 | # - paras: hyperparameters set in hyper_paras.json 30 | 31 | # Using enhanced tumor regions to train model: 32 | # In pre_paras.json, set 33 | # "hgg_out": "HGGSegTrimmed" 34 | # "lgg_out": "LGGSegTrimmed" 35 | 36 | # Using non-enhanced tumor regions to train model: 37 | # In pre_paras.json, set 38 | # "hgg_out": "HGGTrimmed" 39 | # "lgg_out": "LGGTrimmed" 40 | 41 | python btc.py --paras=paras-1 42 | # python btc.py --paras=paras-2 43 | 44 | 45 | # 46 | # Section 2 47 | # 48 | # Train and test model respectively 49 | # Commands: 50 | # python btc_train.py --paras=paras_name 51 | # python btc_test.py --paras=paras_name 52 | # Same parameter as in Section 1 53 | 54 | # python btc_train.py --paras=paras-1 55 | # python btc_test.py --paras=paras-1 56 | -------------------------------------------------------------------------------- /src/DataSplit/validset.csv: -------------------------------------------------------------------------------- 1 | ID,label 2 | Brats17_TCIA_606_1_t1ce,1 3 | Brats17_CBICA_AQJ_1_t1ce,1 4 | Brats17_CBICA_AUR_1_t1ce,1 5 | Brats17_2013_2_1_t1ce,1 6 | Brats17_CBICA_APY_1_t1ce,1 7 | Brats17_2013_20_1_t1ce,1 8 | Brats17_CBICA_ABN_1_t1ce,1 9 | Brats17_2013_26_1_t1ce,1 10 | Brats17_CBICA_ASN_1_t1ce,1 11 | Brats17_CBICA_AZH_1_t1ce,1 12 | Brats17_TCIA_412_1_t1ce,1 13 | Brats17_TCIA_280_1_t1ce,1 14 | Brats17_CBICA_ANI_1_t1ce,1 15 | Brats17_2013_23_1_t1ce,1 16 | Brats17_TCIA_607_1_t1ce,1 17 | Brats17_CBICA_AAP_1_t1ce,1 18 | Brats17_TCIA_401_1_t1ce,1 19 | Brats17_CBICA_ARZ_1_t1ce,1 20 | Brats17_CBICA_AAG_1_t1ce,1 21 | Brats17_TCIA_429_1_t1ce,1 22 | Brats17_CBICA_ATV_1_t1ce,1 23 | Brats17_CBICA_APR_1_t1ce,1 24 | Brats17_CBICA_ALX_1_t1ce,1 25 | Brats17_CBICA_ABY_1_t1ce,1 26 | Brats17_CBICA_AUN_1_t1ce,1 27 | Brats17_TCIA_117_1_t1ce,1 28 | Brats17_CBICA_AOP_1_t1ce,1 29 | Brats17_CBICA_AXO_1_t1ce,1 30 | Brats17_CBICA_APZ_1_t1ce,1 31 | Brats17_TCIA_314_1_t1ce,1 32 | Brats17_TCIA_322_1_t1ce,1 33 | Brats17_TCIA_180_1_t1ce,1 34 | Brats17_CBICA_AXL_1_t1ce,1 35 | Brats17_2013_17_1_t1ce,1 36 | Brats17_TCIA_135_1_t1ce,1 37 | Brats17_TCIA_436_1_t1ce,1 38 | Brats17_TCIA_478_1_t1ce,1 39 | Brats17_TCIA_331_1_t1ce,1 40 | Brats17_TCIA_335_1_t1ce,1 41 | Brats17_TCIA_430_1_t1ce,1 42 | Brats17_TCIA_147_1_t1ce,1 43 | Brats17_TCIA_138_1_t1ce,1 44 | Brats17_TCIA_490_1_t1ce,0 45 | Brats17_TCIA_621_1_t1ce,0 46 | Brats17_TCIA_650_1_t1ce,0 47 | Brats17_2013_8_1_t1ce,0 48 | Brats17_TCIA_653_1_t1ce,0 49 | Brats17_2013_24_1_t1ce,0 50 | Brats17_2013_16_1_t1ce,0 51 | Brats17_TCIA_346_1_t1ce,0 52 | Brats17_TCIA_428_1_t1ce,0 53 | Brats17_TCIA_618_1_t1ce,0 54 | Brats17_TCIA_101_1_t1ce,0 55 | Brats17_2013_6_1_t1ce,0 56 | Brats17_TCIA_644_1_t1ce,0 57 | Brats17_TCIA_387_1_t1ce,0 58 | Brats17_TCIA_177_1_t1ce,0 59 | -------------------------------------------------------------------------------- /src/DataSplit/testset.csv: -------------------------------------------------------------------------------- 1 | ID,label 2 | Brats17_TCIA_199_1_t1ce,1 3 | Brats17_TCIA_437_1_t1ce,1 4 | Brats17_TCIA_277_1_t1ce,1 5 | Brats17_CBICA_ATD_1_t1ce,1 6 | Brats17_TCIA_184_1_t1ce,1 7 | Brats17_TCIA_390_1_t1ce,1 8 | Brats17_CBICA_ASU_1_t1ce,1 9 | Brats17_CBICA_ARF_1_t1ce,1 10 | Brats17_TCIA_278_1_t1ce,1 11 | Brats17_CBICA_ASW_1_t1ce,1 12 | Brats17_TCIA_131_1_t1ce,1 13 | Brats17_CBICA_ASA_1_t1ce,1 14 | Brats17_TCIA_332_1_t1ce,1 15 | Brats17_CBICA_ABE_1_t1ce,1 16 | Brats17_TCIA_328_1_t1ce,1 17 | Brats17_CBICA_AQU_1_t1ce,1 18 | Brats17_CBICA_AYI_1_t1ce,1 19 | Brats17_CBICA_BFP_1_t1ce,1 20 | Brats17_TCIA_309_1_t1ce,1 21 | Brats17_TCIA_211_1_t1ce,1 22 | Brats17_CBICA_ASE_1_t1ce,1 23 | Brats17_TCIA_283_1_t1ce,1 24 | Brats17_TCIA_242_1_t1ce,1 25 | Brats17_TCIA_372_1_t1ce,1 26 | Brats17_CBICA_AQT_1_t1ce,1 27 | Brats17_CBICA_ATX_1_t1ce,1 28 | Brats17_TCIA_394_1_t1ce,1 29 | Brats17_TCIA_603_1_t1ce,1 30 | Brats17_TCIA_265_1_t1ce,1 31 | Brats17_TCIA_208_1_t1ce,1 32 | Brats17_TCIA_168_1_t1ce,1 33 | Brats17_CBICA_AZD_1_t1ce,1 34 | Brats17_CBICA_AVG_1_t1ce,1 35 | Brats17_CBICA_ANG_1_t1ce,1 36 | Brats17_TCIA_218_1_t1ce,1 37 | Brats17_TCIA_343_1_t1ce,1 38 | Brats17_CBICA_AXW_1_t1ce,1 39 | Brats17_CBICA_AME_1_t1ce,1 40 | Brats17_CBICA_BFB_1_t1ce,1 41 | Brats17_CBICA_AQG_1_t1ce,1 42 | Brats17_TCIA_448_1_t1ce,1 43 | Brats17_CBICA_AVV_1_t1ce,1 44 | Brats17_TCIA_175_1_t1ce,0 45 | Brats17_TCIA_141_1_t1ce,0 46 | Brats17_TCIA_103_1_t1ce,0 47 | Brats17_TCIA_629_1_t1ce,0 48 | Brats17_TCIA_298_1_t1ce,0 49 | Brats17_TCIA_633_1_t1ce,0 50 | Brats17_TCIA_325_1_t1ce,0 51 | Brats17_TCIA_625_1_t1ce,0 52 | Brats17_TCIA_109_1_t1ce,0 53 | Brats17_TCIA_152_1_t1ce,0 54 | Brats17_2013_1_1_t1ce,0 55 | Brats17_TCIA_413_1_t1ce,0 56 | Brats17_TCIA_299_1_t1ce,0 57 | Brats17_TCIA_451_1_t1ce,0 58 | Brats17_TCIA_630_1_t1ce,0 59 | -------------------------------------------------------------------------------- /src/btc.py: -------------------------------------------------------------------------------- 1 | # Brain Tumor Classification 2 | # Main script contains whole process. 3 | # Author: Qixun QU 4 | # Copyleft: MIT Licience 5 | 6 | # ,,, ,,, 7 | # ;" '; ;' ", 8 | # ; @.ss$$$$$$s.@ ; 9 | # `s$$$$$$$$$$$$$$$' 10 | # $$$$$$$$$$$$$$$$$$ 11 | # $$$$P""Y$$$Y""W$$$$$ 12 | # $$$$ p"$$$"q $$$$$ 13 | # $$$$ .$$$$$. $$$$' 14 | # $$$DaU$$O$$DaU$$$' 15 | # '$$$$'.^.'$$$$' 16 | # '&$$$$$&' 17 | 18 | 19 | from __future__ import print_function 20 | 21 | 22 | import os 23 | import json 24 | import argparse 25 | from btc_test import BTCTest 26 | from btc_train import BTCTrain 27 | from btc_dataset import BTCDataset 28 | from btc_preprocess import BTCPreprocess 29 | 30 | 31 | def main(hyper_paras_name): 32 | '''MAIN 33 | 34 | Main process of Brain Tumor Classification. 35 | -1- Split dataset for training, validating and testing. 36 | -2- Train model. 37 | -3- Test model. 38 | 39 | Inputs: 40 | ------- 41 | 42 | - hyper_paras_name: string, the name of hyperparanters set, 43 | which can be found in hyper_paras.json. 44 | 45 | ''' 46 | 47 | # Basic settings in pre_paras.json, including 48 | # 1. directory paths for input and output 49 | # 2. necessary information for splitting dataset 50 | pre_paras_path = "pre_paras.json" 51 | pre_paras = json.load(open(pre_paras_path)) 52 | 53 | # Get root path of input data 54 | parent_dir = os.path.dirname(os.getcwd()) 55 | data_dir = os.path.join(parent_dir, pre_paras["data_dir"]) 56 | 57 | # Set directories of input images 58 | hgg_in_dir = os.path.join(data_dir, pre_paras["hgg_in"]) 59 | lgg_in_dir = os.path.join(data_dir, pre_paras["lgg_in"]) 60 | 61 | # Set output directory to save preprocesses images 62 | hgg_out_dir = os.path.join(data_dir, pre_paras["hgg_out"]) 63 | lgg_out_dir = os.path.join(data_dir, pre_paras["lgg_out"]) 64 | 65 | # Set directory to save weights 66 | weights_save_dir = os.path.join(parent_dir, pre_paras["weights_save_dir"]) 67 | # Set directory to save training and validation logs 68 | logs_save_dir = os.path.join(parent_dir, pre_paras["logs_save_dir"]) 69 | # Set directory to save metrics 70 | results_save_dir = os.path.join(parent_dir, pre_paras["results_save_dir"]) 71 | 72 | # Preprocessing to enhance tumor regions 73 | prep = BTCPreprocess([hgg_in_dir, lgg_in_dir], 74 | [hgg_out_dir, lgg_out_dir], 75 | pre_paras["volume_type"]) 76 | prep.run(is_mask=pre_paras["is_mask"], 77 | non_mask_coeff=pre_paras["non_mask_coeff"], 78 | processes=pre_paras["processes_num"]) 79 | 80 | # Split dataset 81 | data = BTCDataset(hgg_out_dir, lgg_out_dir, 82 | volume_type=pre_paras["volume_type"], 83 | train_prop=pre_paras["train_prop"], 84 | valid_prop=pre_paras["valid_prop"], 85 | random_state=pre_paras["random_state"], 86 | pre_trainset_path=pre_paras["pre_trainset_path"], 87 | pre_validset_path=pre_paras["pre_validset_path"], 88 | pre_testset_path=pre_paras["pre_testset_path"], 89 | data_format=pre_paras["data_format"]) 90 | data.run(pre_split=pre_paras["pre_split"], 91 | save_split=pre_paras["save_split"], 92 | save_split_dir=pre_paras["save_split_dir"]) 93 | 94 | # Training the model using enhanced tumor regions 95 | train = BTCTrain(paras_name=hyper_paras_name, 96 | paras_json_path=pre_paras["paras_json_path"], 97 | weights_save_dir=weights_save_dir, 98 | logs_save_dir=logs_save_dir, 99 | save_best_weights=pre_paras["save_best_weights"]) 100 | train.run(data) 101 | 102 | # Testing the model 103 | test = BTCTest(paras_name=hyper_paras_name, 104 | paras_json_path=pre_paras["paras_json_path"], 105 | weights_save_dir=weights_save_dir, 106 | results_save_dir=results_save_dir, 107 | test_weights=pre_paras["test_weights"], 108 | pred_trainset=pre_paras["pred_trainset"]) 109 | test.run(data) 110 | 111 | return 112 | 113 | 114 | if __name__ == "__main__": 115 | 116 | # Command line 117 | # python add.py --paras=paras-1 118 | 119 | parser = argparse.ArgumentParser() 120 | 121 | # Set json file path to extract hyperparameters 122 | help_str = "Select a set of hyper-parameters in hyper_paras.json" 123 | parser.add_argument("--paras", action="store", default="paras-1", 124 | dest="hyper_paras_name", help=help_str) 125 | 126 | args = parser.parse_args() 127 | main(args.hyper_paras_name) 128 | -------------------------------------------------------------------------------- /src/DataSplit/trainset.csv: -------------------------------------------------------------------------------- 1 | ID,label 2 | Brats17_CBICA_AOO_1_t1ce,1 3 | Brats17_CBICA_ASG_1_t1ce,1 4 | Brats17_CBICA_AXN_1_t1ce,1 5 | Brats17_CBICA_ANZ_1_t1ce,1 6 | Brats17_CBICA_ABM_1_t1ce,1 7 | Brats17_CBICA_ATF_1_t1ce,1 8 | Brats17_CBICA_AQA_1_t1ce,1 9 | Brats17_TCIA_234_1_t1ce,1 10 | Brats17_CBICA_AYA_1_t1ce,1 11 | Brats17_TCIA_190_1_t1ce,1 12 | Brats17_TCIA_235_1_t1ce,1 13 | Brats17_CBICA_BHM_1_t1ce,1 14 | Brats17_TCIA_479_1_t1ce,1 15 | Brats17_CBICA_AAL_1_t1ce,1 16 | Brats17_TCIA_406_1_t1ce,1 17 | Brats17_TCIA_396_1_t1ce,1 18 | Brats17_CBICA_BHB_1_t1ce,1 19 | Brats17_TCIA_300_1_t1ce,1 20 | Brats17_TCIA_222_1_t1ce,1 21 | Brats17_TCIA_186_1_t1ce,1 22 | Brats17_TCIA_425_1_t1ce,1 23 | Brats17_TCIA_411_1_t1ce,1 24 | Brats17_CBICA_AQZ_1_t1ce,1 25 | Brats17_TCIA_111_1_t1ce,1 26 | Brats17_TCIA_118_1_t1ce,1 27 | Brats17_CBICA_ANP_1_t1ce,1 28 | Brats17_TCIA_321_1_t1ce,1 29 | Brats17_TCIA_491_1_t1ce,1 30 | Brats17_TCIA_361_1_t1ce,1 31 | Brats17_TCIA_203_1_t1ce,1 32 | Brats17_TCIA_192_1_t1ce,1 33 | Brats17_2013_19_1_t1ce,1 34 | Brats17_CBICA_ATP_1_t1ce,1 35 | Brats17_CBICA_ASH_1_t1ce,1 36 | Brats17_TCIA_296_1_t1ce,1 37 | Brats17_CBICA_AQY_1_t1ce,1 38 | Brats17_TCIA_498_1_t1ce,1 39 | Brats17_CBICA_AAB_1_t1ce,1 40 | Brats17_TCIA_473_1_t1ce,1 41 | Brats17_TCIA_171_1_t1ce,1 42 | Brats17_TCIA_205_1_t1ce,1 43 | Brats17_TCIA_231_1_t1ce,1 44 | Brats17_CBICA_AXJ_1_t1ce,1 45 | Brats17_CBICA_AUQ_1_t1ce,1 46 | Brats17_2013_10_1_t1ce,1 47 | Brats17_CBICA_AWI_1_t1ce,1 48 | Brats17_CBICA_AOZ_1_t1ce,1 49 | Brats17_TCIA_151_1_t1ce,1 50 | Brats17_CBICA_AWH_1_t1ce,1 51 | Brats17_2013_25_1_t1ce,1 52 | Brats17_TCIA_226_1_t1ce,1 53 | Brats17_2013_27_1_t1ce,1 54 | Brats17_TCIA_167_1_t1ce,1 55 | Brats17_TCIA_338_1_t1ce,1 56 | Brats17_2013_18_1_t1ce,1 57 | Brats17_TCIA_105_1_t1ce,1 58 | Brats17_CBICA_AVJ_1_t1ce,1 59 | Brats17_2013_21_1_t1ce,1 60 | Brats17_TCIA_374_1_t1ce,1 61 | Brats17_2013_22_1_t1ce,1 62 | Brats17_TCIA_474_1_t1ce,1 63 | Brats17_TCIA_605_1_t1ce,1 64 | Brats17_CBICA_AYW_1_t1ce,1 65 | Brats17_TCIA_247_1_t1ce,1 66 | Brats17_TCIA_378_1_t1ce,1 67 | Brats17_TCIA_608_1_t1ce,1 68 | Brats17_CBICA_AQD_1_t1ce,1 69 | Brats17_TCIA_179_1_t1ce,1 70 | Brats17_TCIA_377_1_t1ce,1 71 | Brats17_TCIA_319_1_t1ce,1 72 | Brats17_2013_3_1_t1ce,1 73 | Brats17_TCIA_221_1_t1ce,1 74 | Brats17_TCIA_257_1_t1ce,1 75 | Brats17_TCIA_471_1_t1ce,1 76 | Brats17_2013_11_1_t1ce,1 77 | Brats17_CBICA_AQQ_1_t1ce,1 78 | Brats17_TCIA_121_1_t1ce,1 79 | Brats17_TCIA_469_1_t1ce,1 80 | Brats17_CBICA_ABB_1_t1ce,1 81 | Brats17_TCIA_133_1_t1ce,1 82 | Brats17_CBICA_BHK_1_t1ce,1 83 | Brats17_2013_13_1_t1ce,1 84 | Brats17_2013_4_1_t1ce,1 85 | Brats17_2013_12_1_t1ce,1 86 | Brats17_CBICA_AOD_1_t1ce,1 87 | Brats17_CBICA_ALU_1_t1ce,1 88 | Brats17_TCIA_370_1_t1ce,1 89 | Brats17_CBICA_ASK_1_t1ce,1 90 | Brats17_TCIA_375_1_t1ce,1 91 | Brats17_CBICA_AWG_1_t1ce,1 92 | Brats17_CBICA_AQN_1_t1ce,1 93 | Brats17_2013_14_1_t1ce,1 94 | Brats17_CBICA_ARW_1_t1ce,1 95 | Brats17_CBICA_ATB_1_t1ce,1 96 | Brats17_CBICA_AMH_1_t1ce,1 97 | Brats17_CBICA_AXM_1_t1ce,1 98 | Brats17_TCIA_165_1_t1ce,1 99 | Brats17_TCIA_162_1_t1ce,1 100 | Brats17_CBICA_AYU_1_t1ce,1 101 | Brats17_CBICA_ASY_1_t1ce,1 102 | Brats17_TCIA_149_1_t1ce,1 103 | Brats17_TCIA_455_1_t1ce,1 104 | Brats17_TCIA_460_1_t1ce,1 105 | Brats17_TCIA_201_1_t1ce,1 106 | Brats17_TCIA_274_1_t1ce,1 107 | Brats17_TCIA_419_1_t1ce,1 108 | Brats17_CBICA_ALN_1_t1ce,1 109 | Brats17_TCIA_499_1_t1ce,1 110 | Brats17_CBICA_AQR_1_t1ce,1 111 | Brats17_TCIA_409_1_t1ce,1 112 | Brats17_TCIA_368_1_t1ce,1 113 | Brats17_TCIA_198_1_t1ce,1 114 | Brats17_CBICA_ABO_1_t1ce,1 115 | Brats17_CBICA_ASO_1_t1ce,1 116 | Brats17_CBICA_AQP_1_t1ce,1 117 | Brats17_2013_7_1_t1ce,1 118 | Brats17_TCIA_113_1_t1ce,1 119 | Brats17_TCIA_150_1_t1ce,1 120 | Brats17_CBICA_AQV_1_t1ce,1 121 | Brats17_CBICA_AOH_1_t1ce,1 122 | Brats17_TCIA_444_1_t1ce,1 123 | Brats17_2013_5_1_t1ce,1 124 | Brats17_CBICA_AQO_1_t1ce,1 125 | Brats17_TCIA_290_1_t1ce,1 126 | Brats17_CBICA_AXQ_1_t1ce,1 127 | Brats17_CBICA_ASV_1_t1ce,1 128 | Brats17_2013_9_1_t1ce,0 129 | Brats17_TCIA_202_1_t1ce,0 130 | Brats17_TCIA_130_1_t1ce,0 131 | Brats17_TCIA_637_1_t1ce,0 132 | Brats17_TCIA_470_1_t1ce,0 133 | Brats17_TCIA_408_1_t1ce,0 134 | Brats17_TCIA_310_1_t1ce,0 135 | Brats17_2013_28_1_t1ce,0 136 | Brats17_TCIA_255_1_t1ce,0 137 | Brats17_TCIA_410_1_t1ce,0 138 | Brats17_TCIA_276_1_t1ce,0 139 | Brats17_TCIA_307_1_t1ce,0 140 | Brats17_TCIA_402_1_t1ce,0 141 | Brats17_TCIA_261_1_t1ce,0 142 | Brats17_TCIA_615_1_t1ce,0 143 | Brats17_TCIA_462_1_t1ce,0 144 | Brats17_TCIA_639_1_t1ce,0 145 | Brats17_TCIA_449_1_t1ce,0 146 | Brats17_TCIA_254_1_t1ce,0 147 | Brats17_TCIA_442_1_t1ce,0 148 | Brats17_2013_15_1_t1ce,0 149 | Brats17_TCIA_632_1_t1ce,0 150 | Brats17_TCIA_241_1_t1ce,0 151 | Brats17_TCIA_620_1_t1ce,0 152 | Brats17_2013_0_1_t1ce,0 153 | Brats17_TCIA_393_1_t1ce,0 154 | Brats17_TCIA_640_1_t1ce,0 155 | Brats17_TCIA_623_1_t1ce,0 156 | Brats17_TCIA_312_1_t1ce,0 157 | Brats17_TCIA_493_1_t1ce,0 158 | Brats17_TCIA_634_1_t1ce,0 159 | Brats17_TCIA_466_1_t1ce,0 160 | Brats17_TCIA_645_1_t1ce,0 161 | Brats17_TCIA_654_1_t1ce,0 162 | Brats17_TCIA_624_1_t1ce,0 163 | Brats17_TCIA_420_1_t1ce,0 164 | Brats17_TCIA_330_1_t1ce,0 165 | Brats17_TCIA_266_1_t1ce,0 166 | Brats17_TCIA_351_1_t1ce,0 167 | Brats17_2013_29_1_t1ce,0 168 | Brats17_TCIA_642_1_t1ce,0 169 | Brats17_TCIA_249_1_t1ce,0 170 | Brats17_TCIA_282_1_t1ce,0 171 | Brats17_TCIA_480_1_t1ce,0 172 | Brats17_TCIA_628_1_t1ce,0 173 | -------------------------------------------------------------------------------- /src/btc_models.py: -------------------------------------------------------------------------------- 1 | # Brain Tumor Classification 2 | # Construct 3D Multi-Scale CNN. 3 | # Author: Qixun QU 4 | # Copyleft: MIT Licience 5 | 6 | # ,,, ,,, 7 | # ;" '; ;' ", 8 | # ; @.ss$$$$$$s.@ ; 9 | # `s$$$$$$$$$$$$$$$' 10 | # $$$$$$$$$$$$$$$$$$ 11 | # $$$$P""Y$$$Y""W$$$$$ 12 | # $$$$ p"$$$"q $$$$$ 13 | # $$$$ .$$$$$. $$$$' 14 | # $$$DaU$$O$$DaU$$$' 15 | # '$$$$'.^.'$$$$' 16 | # '&$$$$$&' 17 | 18 | 19 | from __future__ import print_function 20 | 21 | 22 | from keras.layers import * 23 | from keras.models import Model 24 | from keras.regularizers import l2 25 | 26 | 27 | class BTCModels(object): 28 | 29 | def __init__(self, 30 | model_name="pyramid", 31 | input_shape=[112, 96, 96, 1], 32 | pooling="max", 33 | l2_coeff=5e-5, 34 | drop_rate=0.5, 35 | bn_momentum=0.9, 36 | initializer="glorot_uniform"): 37 | '''__INIT__ 38 | 39 | Intialization to generate model. 40 | 41 | Inputs: 42 | ------- 43 | 44 | - model_name: string, selecte model, in this project, 45 | only one choice is "pyramid". 46 | - input_shape: list, dimentions of input data, 47 | [112, 96, 96, 1] is required. 48 | - pooling: string, pooling mathods, "max" for max pooling, 49 | "avg" for average pooling. Default is "max". 50 | - l2_coeff: float, coefficient of L2 penalty. Default is 5e-5. 51 | - drop_rate: float, dropout rate, default is 0.5. 52 | - bn_momentum: float, momentum of batch normalization, 53 | default is 0.9. 54 | - initializer: string, method to initialize parameters, 55 | default is "glorot_uniform". 56 | 57 | ''' 58 | 59 | # Set parameters 60 | self.input_shape = input_shape 61 | self.pooling = pooling 62 | self.l2_coeff = l2_coeff 63 | self.drop_rate = drop_rate 64 | self.bn_momentum = bn_momentum 65 | self.initializer = initializer 66 | 67 | # Build pyramid model, which is referred as 68 | # 3D Multi-Scale CNN in this project 69 | if model_name == "pyramid": 70 | self.model = self._pyramid() 71 | 72 | return 73 | 74 | def _conv3d(self, inputs, filter_num, filter_size, 75 | strides=(1, 1, 1), name=None): 76 | '''_CONV3D 77 | 78 | Construct a convolutional layer. 79 | 80 | Inputs: 81 | ------- 82 | 83 | - inputs: input tensor, it should be original input, 84 | or the output from previous layer. 85 | - filter_num: int, the number of filters. 86 | - filter_size: int or int list, the dimensions of filters. 87 | - strides: int tuple with length 3, the stride step in 88 | each dimension. 89 | - name: string, layer's name. 90 | 91 | Output: 92 | ------- 93 | 94 | - output tensor from convolutional layer. 95 | 96 | ''' 97 | 98 | return Convolution3D(filter_num, filter_size, 99 | strides=strides, 100 | kernel_initializer=self.initializer, 101 | kernel_regularizer=l2(self.l2_coeff), 102 | activation="relu", 103 | padding="same", 104 | name=name)(inputs) 105 | 106 | def _dense(self, inputs, units, activation="relu", name=None): 107 | '''_DENSE 108 | 109 | Construct a densely layer. 110 | 111 | Inputs: 112 | ------- 113 | 114 | - inputs: input tensor, the output from previous layer. 115 | - units: int, number of neurons in this layer. 116 | - activation: string, activation function, default is "relu". 117 | - name: string, layer's name. 118 | 119 | Output: 120 | ------- 121 | 122 | - output tensor from densely layer. 123 | 124 | ''' 125 | 126 | return Dense(units, 127 | kernel_initializer=self.initializer, 128 | kernel_regularizer=l2(self.l2_coeff), 129 | activation=activation, 130 | name=name)(inputs) 131 | 132 | def _extract_features(self, inputs, name=None): 133 | '''_EXTRACT_FEATURES 134 | 135 | Extract features from input tensor by: 136 | - Pooling (max or avg) in size 7*6*6. 137 | - Flatten + Batch normalization + Dropout. 138 | - Dense + Batch normalization. 139 | 140 | Inputs: 141 | ------- 142 | 143 | - inputs: input tensor, the output from each scale. 144 | - name: string, prefix of layer's name. 145 | 146 | Output: 147 | ------- 148 | 149 | - fc1: tensor in size of 256, features extracted from input. 150 | 151 | ''' 152 | 153 | # Pooling (max or avg) in size of 7*6*6 154 | if self.pooling == "max": 155 | pool = MaxPooling3D 156 | elif self.pooling == "avg": 157 | pool = AveragePooling3D 158 | fts_pool = pool((7, 6, 6), name=name + "_pre_pool")(inputs) 159 | 160 | # Flatten + Batch normalization + Dropout 161 | fts_flt = Flatten(name=name + "_pre_flt")(fts_pool) 162 | fts_bn = BatchNormalization(momentum=self.bn_momentum, name=name + "_pre_bn")(fts_flt) 163 | fts_dp = Dropout(self.drop_rate, name=name + "_pre_dp")(fts_bn) 164 | 165 | # Dense + Batch normalization 166 | fc1 = self._dense(fts_dp, 256, "relu", name) 167 | fc1 = BatchNormalization(momentum=self.bn_momentum, name=name + "_bn")(fc1) 168 | 169 | return fc1 170 | 171 | def _pyramid(self): 172 | '''_PYRAMID 173 | 174 | Build and return 3D Multi-Scale CNN. 175 | 176 | Output: 177 | ------- 178 | 179 | - model: Keras Models instance, 3D Multi-Scale CNN. 180 | 181 | ''' 182 | 183 | # Input layer 184 | inputs = Input(shape=self.input_shape) 185 | # 112 * 96 * 96 * 1 186 | 187 | # Conv1 + BN 188 | conv1 = self._conv3d(inputs, 32, 5, strides=(2, 2, 2), name="conv1") 189 | conv1_bn = BatchNormalization(momentum=self.bn_momentum, name="conv1_bn")(conv1) 190 | # 56 * 48 * 48 * 32 191 | 192 | # Conv2 + Max Pooling + BN 193 | conv2 = self._conv3d(conv1_bn, 64, 3, name="conv2") 194 | conv2_mp = MaxPooling3D((2, 2, 2), strides=(2, 2, 2), name="conv2_mp")(conv2) 195 | conv2_bn = BatchNormalization(momentum=self.bn_momentum, name="conv2_bn")(conv2_mp) 196 | # 28 * 24 * 24 * 64 197 | 198 | # Conv3 + Max Pooling + BN 199 | conv3 = self._conv3d(conv2_bn, 128, 3, name="conv3") 200 | conv3_mp = MaxPooling3D((2, 2, 2), strides=(2, 2, 2), name="conv3_mp")(conv3) 201 | conv3_bn = BatchNormalization(momentum=self.bn_momentum, name="conv3_bn")(conv3_mp) 202 | # 14 * 12 * 12 * 128 203 | 204 | # Conv4 + Max Pooling + BN 205 | conv4 = self._conv3d(conv3_bn, 256, 3, name="conv4") 206 | conv4_mp = MaxPooling3D((2, 2, 2), strides=(2, 2, 2), name="conv4_mp")(conv4) 207 | conv4_bn = BatchNormalization(momentum=self.bn_momentum, name="conv4_bn")(conv4_mp) 208 | # 7 * 6 * 6 * 256 209 | 210 | # Conv5 (Scale1) 211 | conv5 = self._conv3d(conv4_bn, 256, 3, name="conv5") 212 | # 7 * 6 * 6 * 256 213 | # Upsampling1 214 | conv5_up = UpSampling3D((2, 2, 2), name="conv5_up")(conv5) 215 | # 14 * 12 * 12 * 256 216 | 217 | # Conv4 ADD Upsampling1 + BN 218 | sum1 = Add(name="sum1")([conv4, conv5_up]) 219 | sum1_bn = BatchNormalization(momentum=self.bn_momentum, name="sum1_bn")(sum1) 220 | 221 | # Conv6 (Scale2) 222 | conv6 = self._conv3d(sum1_bn, 128, 3, name="conv6") 223 | # 14 * 12 * 12 * 128 224 | # Upsampling2 225 | conv6_up = UpSampling3D((2, 2, 2), name="conv6_up")(conv6) 226 | # 28 * 24 * 24 * 128 227 | 228 | # Conv3 ADD Upsampling2 + BN 229 | sum2 = Add(name="sum2")([conv3, conv6_up]) 230 | sum2_bn = BatchNormalization(momentum=self.bn_momentum, name="sum2_bn")(sum2) 231 | 232 | # Conv7 (Scale3) 233 | conv7 = self._conv3d(sum2_bn, 64, 3, name="conv7") 234 | # 28 * 24 * 24 * 64 235 | # Upsampling3 236 | conv7_up = UpSampling3D((2, 2, 2), name="conv7_up")(conv7) 237 | # 56 * 48 * 48 * 64 238 | 239 | # Conv2 ADD Upsampling3 + BN 240 | sum3 = Add(name="sum3")([conv2, conv7_up]) 241 | sum3_bn = BatchNormalization(momentum=self.bn_momentum, name="sum3_bn")(sum3) 242 | 243 | # Conv8 (Scale4) 244 | conv8 = self._conv3d(sum3_bn, 32, 3, name="conv8") 245 | # 56 * 48 * 48 * 32 246 | 247 | # Extracte features from Scale1 248 | fts1 = self._extract_features(conv5, name="fc1_1") # 256 --> 256 249 | # Extracte features from Scale2 250 | fts2 = self._extract_features(conv6, name="fc1_2") # 1024 --> 256 251 | # Extracte features from Scale3 252 | fts3 = self._extract_features(conv7, name="fc1_3") # 4096 --> 256 253 | # Extracte features from Scale4 254 | fts4 = self._extract_features(conv8, name="fc1_4") # 16384 --> 256 255 | 256 | # Fuse features of 4 scales + Dropout + Dense (256) + BN 257 | fts = Concatenate(name="fts_all")([fts1, fts2, fts3, fts4]) # 1024 258 | fts_dp = Dropout(rate=self.drop_rate, name="fts_all_dp")(fts) 259 | fc2 = self._dense(fts_dp, 256, "relu", name="fc2") 260 | fc2_bn = BatchNormalization(momentum=self.bn_momentum, name="fc2_bn")(fc2) 261 | 262 | # Output layer 263 | fc3 = self._dense(fc2_bn, 2, "softmax", name="fc3") # 2 264 | model = Model(inputs=inputs, outputs=fc3) 265 | return model 266 | 267 | 268 | if __name__ == "__main__": 269 | 270 | # A test to print model's architecture. 271 | 272 | from keras.optimizers import Adam 273 | 274 | model = BTCModels(model_name="pyramid", 275 | input_shape=[112, 96, 96, 1], 276 | pooling="max", 277 | l2_coeff=5e-5, 278 | drop_rate=0.5, 279 | bn_momentum=0.9, 280 | initializer="glorot_uniform").model 281 | model.compile(loss="categorical_crossentropy", 282 | optimizer=Adam(lr=1e-3), 283 | metrics=["accuracy"]) 284 | model.summary() 285 | -------------------------------------------------------------------------------- /src/btc_test.py: -------------------------------------------------------------------------------- 1 | # Brain Tumor Classification 2 | # Test 3D Multi-Scale CNN. 3 | # Author: Qixun QU 4 | # Copyleft: MIT Licience 5 | 6 | # ,,, ,,, 7 | # ;" '; ;' ", 8 | # ; @.ss$$$$$$s.@ ; 9 | # `s$$$$$$$$$$$$$$$' 10 | # $$$$$$$$$$$$$$$$$$ 11 | # $$$$P""Y$$$Y""W$$$$$ 12 | # $$$$ p"$$$"q $$$$$ 13 | # $$$$ .$$$$$. $$$$' 14 | # $$$DaU$$O$$DaU$$$' 15 | # '$$$$'.^.'$$$$' 16 | # '&$$$$$&' 17 | 18 | 19 | from __future__ import print_function 20 | 21 | 22 | import os 23 | import json 24 | import shutil 25 | import argparse 26 | import numpy as np 27 | import pandas as pd 28 | 29 | from keras import backend as K 30 | from btc_models import BTCModels 31 | from sklearn.metrics import (log_loss, 32 | roc_curve, 33 | recall_score, 34 | roc_auc_score, 35 | precision_score, 36 | confusion_matrix) 37 | 38 | 39 | class BTCTest(object): 40 | 41 | def __init__(self, 42 | paras_name, 43 | paras_json_path, 44 | weights_save_dir, 45 | results_save_dir, 46 | test_weights="last", 47 | pred_trainset=False): 48 | '''_INIT__ 49 | 50 | Set configurations before testing model. 51 | 52 | Inputs: 53 | ------- 54 | 55 | - paras_name: string, name of hyperparameters set, 56 | can be found in hyper_paras.json. 57 | - paras_json_path: string, path of file which provides 58 | hyperparamters, "hyper_paras.json" 59 | in this project. 60 | - weights_save_dir: string, directory path where saves 61 | trained model. 62 | - results_save_dir: string, dorectory to save results. 63 | - test_wrights: string, which weights used to do test, 64 | weights from "last" epoch or weights 65 | from "best" epoch. 66 | - pred_trainset: boolean, whether evaluate model on 67 | training set, default is False. 68 | 69 | ''' 70 | 71 | if not os.path.isdir(weights_save_dir): 72 | raise IOError("Model directory is not exist.") 73 | 74 | self.paras_name = paras_name 75 | self.results_save_dir = results_save_dir 76 | self.weights = test_weights 77 | self.pred_trainset = pred_trainset 78 | 79 | # Load hyperparameters 80 | self.paras = self.load_paras(paras_json_path, paras_name) 81 | self._load_paras() 82 | 83 | self.weights_path = os.path.join(weights_save_dir, 84 | paras_name, test_weights + ".h5") 85 | self.results_dir = os.path.join(results_save_dir, paras_name) 86 | self.create_dir(self.results_dir, rm=False) 87 | 88 | return 89 | 90 | def _load_paras(self): 91 | '''_LOAD_PARAS 92 | 93 | Load hyperparameters from hyper_paras.json. 94 | 95 | ''' 96 | 97 | self.model_name = self.paras["model_name"] 98 | self.batch_size = self.paras["batch_size"] 99 | return 100 | 101 | def _load_model(self): 102 | '''_LOAD_MODEL 103 | 104 | Create 3D Multi-Scale CNN. 105 | 106 | ''' 107 | 108 | self.model = BTCModels(model_name=self.model_name).model 109 | return 110 | 111 | def _pred_evaluate(self, x, y, dataset): 112 | '''_PRED_EVALUATE 113 | 114 | Predict input data and evaluate performance, including: 115 | - Accuracy. ----------| 116 | - Log loss. ----------| 117 | - Precision. ---------|--> *_*_res.csv 118 | - Recall. ------------| 119 | - ROC AUC. -----------| 120 | - Confusion matrix. --| 121 | - ROC curve. ------------> *_*_roc_curve.npy 122 | 123 | Inputs: 124 | ------- 125 | 126 | - x: numpy ndarray, input images. 127 | - y: numpy ndarray, ground truth labels 128 | - dataset: string, indicates which set to use, 129 | "train", "valid" or "test". 130 | 131 | Outputs: 132 | -------- 133 | 134 | - [dataset]_[self.weights]_res.csv 135 | - [dataset]_[self.weights]_roc_curve.npy 136 | 137 | ''' 138 | 139 | # Helper function to compute metrics 140 | # true_y: ground truth labels 141 | # pred_y: predicted labels 142 | def acc(true_y, pred_y): 143 | return (true_y == pred_y).all(axis=1).mean() 144 | 145 | def loss(true_y, pred_y): 146 | return log_loss(true_y, pred_y, normalize=True) 147 | 148 | def precision(true_y, pred_y, label): 149 | return precision_score(true_y, pred_y, pos_label=label) 150 | 151 | def recall(true_y, pred_y, label): 152 | return recall_score(true_y, pred_y, pos_label=label) 153 | 154 | print("Dataset to be predicted: " + dataset) 155 | 156 | # Obtain predictions of input data 157 | pred = self.model.predict(x, self.batch_size, 0) 158 | 159 | # Ground truth labels 160 | arg_y = np.argmax(y, axis=1) 161 | arg_y = np.reshape(arg_y, (-1, 1)) 162 | 163 | # Indices for HGG and LGG subjects 164 | hgg = np.where(arg_y == 1)[0] 165 | lgg = np.where(arg_y == 0)[0] 166 | 167 | # Predicted labels 168 | arg_pred = np.argmax(pred, axis=1) 169 | arg_pred = np.reshape(arg_pred, (-1, 1)) 170 | 171 | # Generate ROC curve 172 | roc_line = roc_curve(arg_y, pred[:, 1], pos_label=1) 173 | # Compute confusion matrix 174 | tn, fp, fn, tp = confusion_matrix(arg_y, arg_pred).ravel() 175 | 176 | # A dictionary conains all result to be written 177 | # in [dataset]_[self.weights]_res.csv 178 | results = {"name": self.paras_name, 179 | "acc": acc(arg_y, arg_pred), 180 | "hgg_acc": acc(arg_y[hgg], arg_pred[hgg]), 181 | "lgg_acc": acc(arg_y[lgg], arg_pred[lgg]), 182 | "loss": loss(y, pred), 183 | "hgg_loss": loss(y[hgg], pred[hgg]), 184 | "lgg_loss": loss(y[lgg], pred[lgg]), 185 | "hgg_precision": precision(arg_y, arg_pred, 1), 186 | "lgg_precision": precision(arg_y, arg_pred, 0), 187 | "hgg_recall": recall(arg_y, arg_pred, 1), 188 | "lgg_recall": recall(arg_y, arg_pred, 0), 189 | "roc_auc": roc_auc_score(arg_y, pred[:, 1]), 190 | "tn": tn, "fp": fp, "fn": fn, "tp": tp} 191 | 192 | # Create pandas DataFrame, and reorder columns 193 | res_df = pd.DataFrame(data=results, index=[0]) 194 | res_df = res_df[["name", "acc", "hgg_acc", "lgg_acc", 195 | "loss", "hgg_loss", "lgg_loss", 196 | "hgg_precision", "hgg_recall", 197 | "lgg_precision", "lgg_recall", 198 | "roc_auc", "tn", "fp", "fn", "tp"]] 199 | 200 | # Save results to [dataset]_[self.weights]_res.csv 201 | root_name = [dataset, self.weights] 202 | res_csv_name = "_".join(root_name + ["res.csv"]) 203 | res_csv_path = os.path.join(self.results_dir, res_csv_name) 204 | res_df.to_csv(res_csv_path, index=False) 205 | 206 | # Save ROC curve to [dataset]_[self.weights]_roc_curve.npy 207 | roc_line_name = "_".join(root_name + ["roc_curve.npy"]) 208 | roc_line_path = os.path.join(self.results_dir, roc_line_name) 209 | np.save(roc_line_path, roc_line) 210 | 211 | return 212 | 213 | def run(self, data): 214 | '''RUN 215 | 216 | Test model using given data. 217 | 218 | Input: 219 | ------ 220 | 221 | - data: an BTCDataset instance, including features and 222 | labels of training, validation and testing set. 223 | 224 | ''' 225 | 226 | print("\nTesting the model.\n") 227 | 228 | # Load model and weights 229 | self._load_model() 230 | self.model.load_weights(self.weights_path) 231 | 232 | if self.pred_trainset: 233 | # Predict and evluate on training set 234 | self._pred_evaluate(data.train_x, data.train_y, "train") 235 | 236 | # Predict and evluate on validation set 237 | self._pred_evaluate(data.valid_x, data.valid_y, "valid") 238 | # Predict and evluate on testing set 239 | self._pred_evaluate(data.test_x, data.test_y, "test") 240 | 241 | # Destroy the current TF graph 242 | K.clear_session() 243 | 244 | return 245 | 246 | @staticmethod 247 | def load_paras(paras_json_path, paras_name): 248 | '''LOAD_PARAS 249 | 250 | Load heperparameters from json file. 251 | See hyper_paras.json. 252 | 253 | Inputs: 254 | ------- 255 | 256 | - paras_name: string, name of hyperparameters set, 257 | can be found in hyper_paras.json. 258 | - paras_json_path: string, path of file which provides 259 | hyperparamters, "hyper_paras.json" 260 | in this project. 261 | 262 | Output: 263 | ------- 264 | 265 | - A dictionay pf hyperparameters. 266 | 267 | ''' 268 | 269 | paras = json.load(open(paras_json_path)) 270 | return paras[paras_name] 271 | 272 | @staticmethod 273 | def create_dir(dir_path, rm=True): 274 | '''CREATE_DIR 275 | 276 | Create directory. 277 | 278 | Inputs: 279 | ------- 280 | 281 | - dir_path: string, path of new directory. 282 | - rm: boolean, remove existing directory or not. 283 | 284 | ''' 285 | 286 | if os.path.isdir(dir_path): 287 | if rm: 288 | shutil.rmtree(dir_path) 289 | os.makedirs(dir_path) 290 | else: 291 | os.makedirs(dir_path) 292 | return 293 | 294 | 295 | def main(hyper_paras_name): 296 | '''MAIN 297 | 298 | Main process to train model. 299 | 300 | Inputs: 301 | ------- 302 | 303 | - hyper_paras_name: string, the name of hyperparameters set, 304 | which can be found in hyper_paras.json. 305 | 306 | ''' 307 | 308 | from btc_dataset import BTCDataset 309 | 310 | # Basic settings in pre_paras.json, including 311 | # 1. directory paths for input and output 312 | # 2. necessary information for splitting dataset 313 | pre_paras_path = "pre_paras.json" 314 | pre_paras = json.load(open(pre_paras_path)) 315 | 316 | # Get root path of input data 317 | parent_dir = os.path.dirname(os.getcwd()) 318 | data_dir = os.path.join(parent_dir, pre_paras["data_dir"]) 319 | 320 | # Set directories of preprocessed images 321 | hgg_dir = os.path.join(data_dir, pre_paras["hgg_out"]) 322 | lgg_dir = os.path.join(data_dir, pre_paras["lgg_out"]) 323 | 324 | # Set directory to save weights 325 | weights_save_dir = os.path.join(parent_dir, pre_paras["weights_save_dir"]) 326 | # Set directory to save results 327 | results_save_dir = os.path.join(parent_dir, pre_paras["results_save_dir"]) 328 | 329 | # Partition dataset 330 | data = BTCDataset(hgg_dir, lgg_dir, 331 | volume_type=pre_paras["volume_type"], 332 | pre_trainset_path=pre_paras["pre_trainset_path"], 333 | pre_validset_path=pre_paras["pre_validset_path"], 334 | pre_testset_path=pre_paras["pre_testset_path"], 335 | data_format=pre_paras["data_format"]) 336 | data.run(pre_split=pre_paras["pre_split"], 337 | save_split=pre_paras["save_split"], 338 | save_split_dir=pre_paras["save_split_dir"]) 339 | 340 | # Test the model 341 | train = BTCTest(paras_name=hyper_paras_name, 342 | paras_json_path=pre_paras["paras_json_path"], 343 | weights_save_dir=weights_save_dir, 344 | results_save_dir=results_save_dir, 345 | test_weights=pre_paras["test_weights"], 346 | pred_trainset=pre_paras["pred_trainset"]) 347 | train.run(data) 348 | 349 | return 350 | 351 | 352 | if __name__ == "__main__": 353 | 354 | # Command line 355 | # python btc_test.py --paras=paras-1 356 | 357 | parser = argparse.ArgumentParser() 358 | 359 | # Set json file path to extract hyperparameters 360 | help_str = "Select a set of hyper-parameters in hyper_paras.json." 361 | parser.add_argument("--paras", action="store", default="paras-1", 362 | dest="hyper_paras_name", help=help_str) 363 | 364 | args = parser.parse_args() 365 | main(args.hyper_paras_name) 366 | -------------------------------------------------------------------------------- /src/btc_train.py: -------------------------------------------------------------------------------- 1 | # Brain Tumor Classification 2 | # Train 3D Multi-Scale CNN. 3 | # Author: Qixun QU 4 | # Copyleft: MIT Licience 5 | 6 | # ,,, ,,, 7 | # ;" '; ;' ", 8 | # ; @.ss$$$$$$s.@ ; 9 | # `s$$$$$$$$$$$$$$$' 10 | # $$$$$$$$$$$$$$$$$$ 11 | # $$$$P""Y$$$Y""W$$$$$ 12 | # $$$$ p"$$$"q $$$$$ 13 | # $$$$ .$$$$$. $$$$' 14 | # $$$DaU$$O$$DaU$$$' 15 | # '$$$$'.^.'$$$$' 16 | # '&$$$$$&' 17 | 18 | 19 | from __future__ import print_function 20 | 21 | 22 | import os 23 | import json 24 | import shutil 25 | import argparse 26 | from btc_models import BTCModels 27 | 28 | from keras import backend as K 29 | from keras.optimizers import Adam 30 | from keras.callbacks import (CSVLogger, 31 | TensorBoard, 32 | ModelCheckpoint, 33 | LearningRateScheduler) 34 | 35 | 36 | class BTCTrain(object): 37 | 38 | def __init__(self, 39 | paras_name, 40 | paras_json_path, 41 | weights_save_dir, 42 | logs_save_dir, 43 | save_best_weights=True): 44 | '''__INIT__ 45 | 46 | Initalization before training model. 47 | 48 | Inputs: 49 | ------- 50 | 51 | - paras_name: string, name of hyperparameters set, 52 | can be found in hyper_paras.json. 53 | - paras_json_path: string, path of file which provides 54 | hyperparamters, "hyper_paras.json" 55 | in this project. 56 | - weights_save_dir: string, directory path where saves 57 | trained model. 58 | - logs_save_dir: string, directory path where saves 59 | logs of training process. 60 | - save_best_weights: boolean, if save the model with best 61 | validation accuracy. Default is True. 62 | 63 | ''' 64 | 65 | # Dataset: training, validation and test 66 | self.data = None 67 | 68 | # If save the model which provides best validation accuracy 69 | self.save_best_weights = save_best_weights 70 | 71 | # Load hyperparameters 72 | self.paras = self.load_paras(paras_json_path, paras_name) 73 | self._load_paras() 74 | 75 | # Create folder for saving weights 76 | self.weights_dir = os.path.join(weights_save_dir, paras_name) 77 | self.create_dir(self.weights_dir) 78 | 79 | # Create folder for saving training logs 80 | self.logs_dir = os.path.join(logs_save_dir, paras_name) 81 | self.create_dir(self.logs_dir) 82 | 83 | # Initialize files' names for weights at last or best epoch 84 | self.last_weights_path = os.path.join(self.weights_dir, "last.h5") 85 | self.best_weights_path = os.path.join(self.weights_dir, "best.h5") 86 | 87 | # CSV file path for writing learning curves 88 | self.curves_path = os.path.join(self.logs_dir, "curves.csv") 89 | 90 | return 91 | 92 | def _load_paras(self): 93 | '''_LOAD_PARAS 94 | 95 | Load hyperparameters from hyper_paras.json. 96 | 97 | ''' 98 | 99 | # Parameters to construct model 100 | self.model_name = self.paras["model_name"] 101 | self.input_shape = self.paras["input_shape"] 102 | self.pooling = self.paras["pooling"] 103 | self.l2_coeff = self.paras["l2_coeff"] 104 | self.drop_rate = self.paras["drop_rate"] 105 | self.bn_momentum = self.paras["bn_momentum"] 106 | self.initializer = self.paras["initializer"] 107 | 108 | # Parameters to train model 109 | self.optimizer = self.paras["optimizer"] 110 | self.lr_start = self.paras["lr_start"] 111 | self.epochs_num = self.paras["epochs_num"] 112 | self.batch_size = self.paras["batch_size"] 113 | return 114 | 115 | def _load_model(self): 116 | '''_LOAD_MODEL 117 | 118 | Create 3D Multi-Scale CNN. 119 | 120 | ''' 121 | 122 | self.model = BTCModels(model_name=self.model_name, 123 | input_shape=self.input_shape, 124 | pooling=self.pooling, 125 | l2_coeff=self.l2_coeff, 126 | drop_rate=self.drop_rate, 127 | bn_momentum=self.bn_momentum, 128 | initializer=self.initializer).model 129 | return 130 | 131 | def _set_optimizer(self): 132 | '''_SET_OPTIMIZER 133 | 134 | Set optimizer according to the given parameter. 135 | Use "Adam" in this project. 136 | 137 | ''' 138 | 139 | if self.optimizer == "adam": 140 | self.opt_fcn = Adam(lr=self.lr_start) 141 | return 142 | 143 | def _set_lr_scheduler(self, epoch): 144 | '''_SET_LR_SCHEDULER 145 | 146 | Learning rate scheduler for training process. 147 | LR: [init] * 40 + [init * 0.1] * 30 + [init * 0.01] * 30 148 | 149 | Input: 150 | ------ 151 | 152 | - epoch: int, nth training epoch. 153 | 154 | Output: 155 | ------- 156 | 157 | - Learning rate, a float, for nth training epoch. 158 | 159 | ''' 160 | 161 | lrs = [self.lr_start] * 40 + \ 162 | [self.lr_start * 0.1] * 30 + \ 163 | [self.lr_start * 0.01] * 30 164 | print("Learning rate:", lrs[epoch]) 165 | 166 | return lrs[epoch] 167 | 168 | def _set_callbacks(self): 169 | '''_SET_CALLBACKS 170 | 171 | Set callback functions while training model. 172 | -1- Save learning curves while training. 173 | -2- Set learning rate scheduler. 174 | -3- Add support for TensorBoard. 175 | -4- Save best model while training. (optional) 176 | 177 | ''' 178 | 179 | # Save learning curves in csv file while training 180 | csv_logger = CSVLogger(self.curves_path, 181 | append=True, separator=",") 182 | 183 | # Set learning rate scheduler 184 | lr_scheduler = LearningRateScheduler(self._set_lr_scheduler) 185 | 186 | # Add support for TensorBoard 187 | tb = TensorBoard(log_dir=self.logs_dir, 188 | batch_size=self.batch_size) 189 | self.callbacks = [csv_logger, lr_scheduler, tb] 190 | 191 | if self.save_best_weights: 192 | # Save best model while training 193 | checkpoint = ModelCheckpoint(filepath=self.best_weights_path, 194 | monitor="val_loss", 195 | verbose=0, 196 | save_best_only=True) 197 | self.callbacks += [checkpoint] 198 | 199 | return 200 | 201 | def _print_score(self): 202 | '''_PRINT_SCORE 203 | 204 | Print out metrics (loss and accuracy) of 205 | training, validation and testing set. 206 | 207 | ''' 208 | 209 | # Helper function to compute and print metrics 210 | def evaluate(x, y, data_str): 211 | score = self.model.evaluate(x, y, self.batch_size, 0) 212 | print(data_str + " Set: Loss: {0:.4f}, Accuracy: {1:.4f}".format( 213 | score[0], score[1])) 214 | return 215 | 216 | evaluate(self.data.train_x, self.data.train_y, "Training") 217 | evaluate(self.data.valid_x, self.data.valid_y, "Validation") 218 | evaluate(self.data.test_x, self.data.test_y, "Testing") 219 | 220 | return 221 | 222 | def run(self, data): 223 | '''RUN 224 | 225 | Train model using given data. 226 | 227 | Input: 228 | ------ 229 | 230 | - data: an BTCDataset instance, including features and 231 | labels of training, validation and testing set. 232 | 233 | ''' 234 | 235 | print("\nTraining the model.\n") 236 | 237 | self.data = data 238 | 239 | # Configurations of model and optimizer 240 | self._load_model() 241 | self._set_optimizer() 242 | 243 | # Compile model and print its structure 244 | self.model.compile(loss="categorical_crossentropy", 245 | optimizer=self.opt_fcn, 246 | metrics=["accuracy"]) 247 | self.model.summary() 248 | 249 | self._set_callbacks() 250 | # Train model 251 | self.model.fit(self.data.train_x, self.data.train_y, 252 | batch_size=self.batch_size, 253 | epochs=self.epochs_num, 254 | validation_data=(self.data.valid_x, 255 | self.data.valid_y), 256 | shuffle=True, 257 | callbacks=self.callbacks) 258 | 259 | # Save model in last epoch 260 | self.model.save(self.last_weights_path) 261 | # Print metrics 262 | self._print_score() 263 | 264 | # Destroy the current TF graph 265 | K.clear_session() 266 | 267 | return 268 | 269 | @staticmethod 270 | def load_paras(paras_json_path, paras_name): 271 | '''LOAD_PARAS 272 | 273 | Load heperparameters from json file. 274 | See hyper_paras.json. 275 | 276 | Inputs: 277 | ------- 278 | 279 | - paras_name: string, name of hyperparameters set, 280 | can be found in hyper_paras.json. 281 | - paras_json_path: string, path of file which provides 282 | hyperparamters, "hyper_paras.json" 283 | in this project. 284 | 285 | Output: 286 | ------- 287 | 288 | - A dictionay pf hyperparameters. 289 | 290 | ''' 291 | 292 | paras = json.load(open(paras_json_path)) 293 | return paras[paras_name] 294 | 295 | @staticmethod 296 | def create_dir(dir_path, rm=True): 297 | '''CREATE_DIR 298 | 299 | Create directory. 300 | 301 | Inputs: 302 | ------- 303 | 304 | - dir_path: string, path of new directory. 305 | - rm: boolean, remove existing directory or not. 306 | 307 | ''' 308 | 309 | if os.path.isdir(dir_path): 310 | if rm: 311 | shutil.rmtree(dir_path) 312 | os.makedirs(dir_path) 313 | else: 314 | os.makedirs(dir_path) 315 | return 316 | 317 | 318 | def main(hyper_paras_name): 319 | '''MAIN 320 | 321 | Main process to train model. 322 | 323 | Inputs: 324 | ------- 325 | 326 | - hyper_paras_name: string, the name of hyperparameters set, 327 | which can be found in hyper_paras.json. 328 | 329 | ''' 330 | 331 | from btc_dataset import BTCDataset 332 | 333 | # Basic settings in pre_paras.json, including 334 | # 1. directory paths for input and output 335 | # 2. necessary information for splitting dataset 336 | pre_paras_path = "pre_paras.json" 337 | pre_paras = json.load(open(pre_paras_path)) 338 | 339 | # Get root path of input data 340 | parent_dir = os.path.dirname(os.getcwd()) 341 | data_dir = os.path.join(parent_dir, pre_paras["data_dir"]) 342 | 343 | # Set directories of preprocessed images 344 | hgg_dir = os.path.join(data_dir, pre_paras["hgg_out"]) 345 | lgg_dir = os.path.join(data_dir, pre_paras["lgg_out"]) 346 | 347 | # Set directory to save weights 348 | weights_save_dir = os.path.join(parent_dir, pre_paras["weights_save_dir"]) 349 | # Set directory to save training and validation logs 350 | logs_save_dir = os.path.join(parent_dir, pre_paras["logs_save_dir"]) 351 | 352 | # Partition dataset 353 | data = BTCDataset(hgg_dir, lgg_dir, 354 | volume_type=pre_paras["volume_type"], 355 | pre_trainset_path=pre_paras["pre_trainset_path"], 356 | pre_validset_path=pre_paras["pre_validset_path"], 357 | pre_testset_path=pre_paras["pre_testset_path"], 358 | data_format=pre_paras["data_format"]) 359 | data.run(pre_split=pre_paras["pre_split"], 360 | save_split=pre_paras["save_split"], 361 | save_split_dir=pre_paras["save_split_dir"]) 362 | 363 | # Train the model 364 | train = BTCTrain(paras_name=hyper_paras_name, 365 | paras_json_path=pre_paras["paras_json_path"], 366 | weights_save_dir=weights_save_dir, 367 | logs_save_dir=logs_save_dir, 368 | save_best_weights=pre_paras["save_best_weights"]) 369 | train.run(data) 370 | 371 | 372 | if __name__ == "__main__": 373 | 374 | # Command line 375 | # python btc_train.py --paras=paras-1 376 | 377 | parser = argparse.ArgumentParser() 378 | 379 | # Set json file path to extract hyperparameters 380 | help_str = "Select a set of hyper-parameters in hyper_paras.json." 381 | parser.add_argument("--paras", action="store", default="paras-1", 382 | dest="hyper_paras_name", help=help_str) 383 | 384 | args = parser.parse_args() 385 | main(args.hyper_paras_name) 386 | -------------------------------------------------------------------------------- /src/btc_preprocess.py: -------------------------------------------------------------------------------- 1 | # Brain Tumor Classification 2 | # Enhance tumor region in each image. 3 | # Author: Qixun QU 4 | # Copyleft: MIT Licience 5 | 6 | # ,,, ,,, 7 | # ;" '; ;' ", 8 | # ; @.ss$$$$$$s.@ ; 9 | # `s$$$$$$$$$$$$$$$' 10 | # $$$$$$$$$$$$$$$$$$ 11 | # $$$$P""Y$$$Y""W$$$$$ 12 | # $$$$ p"$$$"q $$$$$ 13 | # $$$$ .$$$$$. $$$$' 14 | # $$$DaU$$O$$DaU$$$' 15 | # '$$$$'.^.'$$$$' 16 | # '&$$$$$&' 17 | 18 | 19 | from __future__ import print_function 20 | 21 | 22 | import os 23 | import warnings 24 | import numpy as np 25 | import nibabel as nib 26 | 27 | from multiprocessing import Pool, cpu_count 28 | from scipy.ndimage.interpolation import zoom 29 | 30 | 31 | # Ignore the warning caused by SciPy 32 | warnings.simplefilter("ignore", UserWarning) 33 | 34 | 35 | # Helper function to run in multiple processes 36 | def unwrap_preprocess(arg, **kwarg): 37 | return BTCPreprocess._preprocess(*arg, **kwarg) 38 | 39 | 40 | class BTCPreprocess(object): 41 | 42 | def __init__(self, input_dirs, output_dirs, volume_type="t1ce"): 43 | '''__INIT__ 44 | 45 | Generates paths for preprocessing. 46 | Variables: 47 | - self.in_paths: a list contains path of each input image. 48 | - self.out_paths: a list provides path for each output image. 49 | - self.mask_paths: a list contains path of mask for each input image. 50 | 51 | Inputs: 52 | ------- 53 | 54 | - input_dirs: a list with two lists, [hgg_input_dir, lgg_input_dir], 55 | path of the directory which saves input images of\ 56 | HGG and LGG subjects. 57 | - output_dirs: a list with teo lists, [hgg_output_dir, lgg_output_dir], 58 | path of output directory for every subject in HGG and LGG. 59 | - volume_type: string, type of brain volume, one of "t1ce", "t1", "t2" 60 | or "flair". Default is "t1ce". 61 | 62 | ''' 63 | 64 | self.in_paths, self.out_paths, self.mask_paths = \ 65 | self.generate_paths(input_dirs, output_dirs, volume_type) 66 | 67 | return 68 | 69 | def run(self, is_mask=True, non_mask_coeff=0.333, processes=-1): 70 | '''RUN 71 | 72 | Function to map task to multiple processes. 73 | 74 | Inputs: 75 | ------- 76 | 77 | - is_mask: boolearn, if True, enhance tumor region. 78 | Default is True. 79 | - non_mask_coeff: float from 0 to 1, the coefficient of 80 | voxels in non-tumor region. Default is 0.333. 81 | - processes: int, the number of processes used. Default is -1, 82 | which means use all processes. 83 | 84 | ''' 85 | 86 | print("\nPreprocessing on the sample in BraTS dataset.\n") 87 | num = len(self.in_paths) 88 | 89 | # Generate parameters 90 | paras = zip([self] * num, self.in_paths, self.out_paths, self.mask_paths, 91 | [is_mask] * num, [non_mask_coeff] * num) 92 | 93 | # Set the number of processes 94 | if processes == -1 or processes > cpu_count(): 95 | processes = cpu_count() 96 | 97 | # Map task 98 | pool = Pool(processes=processes) 99 | pool.map(unwrap_preprocess, paras) 100 | 101 | return 102 | 103 | def _preprocess(self, in_path, to_path, mask_path, 104 | is_mask=True, non_mask_coeff=0.333): 105 | '''_PREPROCESS 106 | 107 | For each input image, four steps are done: 108 | -1- If is_mask, enhance tumor region. 109 | -2- Remove background. 110 | -3- Resize image. 111 | -4- Save image. 112 | 113 | Inputs: 114 | ------- 115 | 116 | - in_path: string, path of input image. 117 | - to_path: string, path of output image. 118 | - mask_path: string, path of the mask of input image. 119 | - is_mask: boolearn, if True, enhance tumor region. 120 | Default is True. 121 | - non_mask_coeff: float from 0 to 1, the coefficient of 122 | voxels in non-tumor region. Default is 0.333. 123 | 124 | ''' 125 | 126 | try: 127 | print("Preprocessing on: " + in_path) 128 | # Load image 129 | volume = self.load_nii(in_path) 130 | if is_mask: 131 | # Enhance tumor region 132 | mask = self.load_nii(mask_path) 133 | volume = self.segment(volume, mask, non_mask_coeff) 134 | # Removce background 135 | volume = self.trim(volume) 136 | # Resize image 137 | volume = self.resize(volume, [112, 112, 96]) 138 | # Save image 139 | self.save2nii(to_path, volume) 140 | except RuntimeError: 141 | print("\tFailed to rescal:" + in_path) 142 | return 143 | 144 | return 145 | 146 | @staticmethod 147 | def generate_paths(in_dirs, out_dirs, volume_type=None): 148 | '''GENERATE_PATHS 149 | 150 | Generates three lists with files' paths for prerprocessing. 151 | 152 | Inputs: 153 | ------- 154 | 155 | - input_dirs: a list with two lists, [hgg_input_dir, lgg_input_dir], 156 | path of the directory which saves input images of\ 157 | HGG and LGG subjects. 158 | - output_dirs: a list with teo lists, [hgg_output_dir, lgg_output_dir], 159 | path of output directory for every subject in HGG and LGG. 160 | - volume_type: string, type of brain volume, one of "t1ce", "t1", "t2" 161 | or "flair". Default is "t1ce". 162 | 163 | Outputs: 164 | -------- 165 | 166 | - in_paths: a list contains path of each input image. 167 | - out_paths: a list provides path for each output image. 168 | - mask_paths: a list contains path of mask for each input image. 169 | 170 | ''' 171 | 172 | # Function to create new directory 173 | # according to given path 174 | def create_dir(path): 175 | if not os.path.isdir(path): 176 | os.makedirs(path) 177 | return 178 | 179 | in_paths, out_paths, mask_paths = [], [], [] 180 | for in_dir, out_dir in zip(in_dirs, out_dirs): 181 | # For HGG or LFF subjects 182 | if not os.path.isdir(in_dir): 183 | print("Input folder {} is not exist.".format(in_dir)) 184 | continue 185 | 186 | # Create output folder for HGG or LGG subjects 187 | create_dir(out_dir) 188 | 189 | for subject in os.listdir(in_dir): 190 | # For each subject in HGG or LGG 191 | subject_dir = os.path.join(in_dir, subject) 192 | subject2dir = os.path.join(out_dir, subject) 193 | # Create folder for output 194 | create_dir(subject2dir) 195 | 196 | scan_names = os.listdir(subject_dir) 197 | # Get path of mask file 198 | for scan_name in scan_names: 199 | if "seg" in scan_name: 200 | scan_mask_path = os.path.join(subject_dir, scan_name) 201 | 202 | for scan_name in scan_names: 203 | if "seg" in scan_name: 204 | continue 205 | 206 | if volume_type is not None: 207 | if volume_type not in scan_name: 208 | continue 209 | 210 | # When find the target volume, save its path 211 | # and save paths for its output and mask 212 | in_paths.append(os.path.join(subject_dir, scan_name)) 213 | out_paths.append(os.path.join(subject2dir, scan_name)) 214 | mask_paths.append(scan_mask_path) 215 | 216 | return in_paths, out_paths, mask_paths 217 | 218 | @staticmethod 219 | def load_nii(path): 220 | '''LOAD_NII 221 | 222 | Load image to numpy ndarray from NIfTi file. 223 | 224 | Input: 225 | ------ 226 | 227 | - path: string , path of input image. 228 | 229 | Ouput: 230 | ------ 231 | 232 | - A numpy array of input imgae. 233 | 234 | ''' 235 | 236 | return np.rot90(nib.load(path).get_data(), 3) 237 | 238 | @staticmethod 239 | def segment(volume, mask, non_mask_coeff=0.333): 240 | '''SEGMENT 241 | 242 | Enhance tumor region by suppressing non-tumor region 243 | with a coefficient. 244 | 245 | Inuuts: 246 | ------- 247 | 248 | - volume: numpy ndarray, input image. 249 | - mask: numpy ndarray, mask with segmentation labels. 250 | - non_mask_coeff: float from 0 to 1, the coefficient of 251 | voxels in non-tumor region. Default is 0.333. 252 | 253 | Output: 254 | ------- 255 | 256 | - segged: numpy ndarray, tumor enhanced image. 257 | 258 | ''' 259 | 260 | # Set background to 0 261 | if np.min(volume) != 0: 262 | volume -= np.min(volume) 263 | 264 | # Suppress non-tumor region 265 | non_mask_idx = np.where(mask == 0) 266 | segged = np.copy(volume) 267 | segged[non_mask_idx] = segged[non_mask_idx] * non_mask_coeff 268 | 269 | return segged 270 | 271 | @staticmethod 272 | def trim(volume): 273 | '''TRIM 274 | 275 | Remove unnecessary background around brain. 276 | 277 | Input: 278 | ------ 279 | 280 | - volume: numpy ndarray, input image. 281 | 282 | Output: 283 | ------- 284 | 285 | - trimmed: numpy ndarray, image without unwanted background. 286 | 287 | ''' 288 | 289 | # Get indices of slices that have brain's voxels 290 | non_zero_slices = [i for i in range(volume.shape[-1]) 291 | if np.sum(volume[..., i]) > 0] 292 | # Remove slices that only have background 293 | volume = volume[..., non_zero_slices] 294 | 295 | # In each slice, find the minimum area of brain 296 | # Coordinates of area are saved 297 | row_begins, row_ends = [], [] 298 | col_begins, col_ends = [], [] 299 | for i in range(volume.shape[-1]): 300 | non_zero_pixels = np.where(volume > 0) 301 | row_begins.append(np.min(non_zero_pixels[0])) 302 | row_ends.append(np.max(non_zero_pixels[0])) 303 | col_begins.append(np.min(non_zero_pixels[1])) 304 | col_ends.append(np.max(non_zero_pixels[1])) 305 | 306 | # Find the maximum area from all minimum areas 307 | row_begin, row_end = min(row_begins), max(row_ends) 308 | col_begin, col_end = min(col_begins), max(col_ends) 309 | 310 | # Generate a minimum square area taht includs the maximum area 311 | rows_num = row_end - row_begin 312 | cols_num = col_end - col_begin 313 | more_col_len = rows_num - cols_num 314 | more_col_len_left = more_col_len // 2 315 | more_col_len_right = more_col_len - more_col_len_left 316 | col_begin -= more_col_len_left 317 | col_end += more_col_len_right 318 | len_of_side = rows_num + 1 319 | 320 | # Remove unwanted background 321 | trimmed = np.zeros([len_of_side, len_of_side, volume.shape[-1]]) 322 | for i in range(volume.shape[-1]): 323 | trimmed[..., i] = volume[row_begin:row_end + 1, 324 | col_begin:col_end + 1, i] 325 | return trimmed 326 | 327 | @staticmethod 328 | def resize(volume, target_shape=[112, 112, 96]): 329 | '''RESIZE 330 | 331 | Resize input image to target shape. 332 | -1- Resize to [112, 112, 96]. 333 | -2- Crop image to [112, 96, 96]. 334 | 335 | ''' 336 | 337 | # Shape of input image 338 | old_shape = list(volume.shape) 339 | 340 | # Resize image 341 | factor = [n / float(o) for n, o in zip(target_shape, old_shape)] 342 | resized = zoom(volume, zoom=factor, order=1, prefilter=False) 343 | 344 | # Crop image 345 | resized = resized[:, 8:104, :] 346 | 347 | return resized 348 | 349 | @staticmethod 350 | def save2nii(to_path, volume): 351 | '''SAVE2NII 352 | 353 | Save numpy ndarray to NIfTi image. 354 | 355 | Input: 356 | ------ 357 | 358 | - to_path: string, path of output image. 359 | - volume: numpy ndarray, preprocessed image. 360 | 361 | ''' 362 | # Rotate image to standard space 363 | volume = volume.astype(np.int16) 364 | volume = np.rot90(volume, 3) 365 | 366 | # Convert to NIfTi 367 | volume_nii = nib.Nifti1Image(volume, np.eye(4)) 368 | # Save image 369 | nib.save(volume_nii, to_path) 370 | 371 | return 372 | 373 | 374 | if __name__ == "__main__": 375 | 376 | # Set path for input directory 377 | parent_dir = os.path.dirname(os.getcwd()) 378 | data_dir = os.path.join(parent_dir, "data") 379 | hgg_input_dir = os.path.join(data_dir, "HGG") 380 | lgg_input_dir = os.path.join(data_dir, "LGG") 381 | input_dirs = [hgg_input_dir, lgg_input_dir] 382 | 383 | # Generate Enhanced Tumor 384 | is_mask = True 385 | non_mask_coeff = 0.333 386 | # Set path for output directory 387 | hgg_output_dir = os.path.join(data_dir, "HGGSegTrimmed") 388 | lgg_output_dir = os.path.join(data_dir, "LGGSegTrimmed") 389 | output_dirs = [hgg_output_dir, lgg_output_dir] 390 | 391 | prep = BTCPreprocess(input_dirs, output_dirs, "t1ce") 392 | prep.run(non_mask_coeff=non_mask_coeff, 393 | is_mask=is_mask, processes=-1) 394 | 395 | # Generate Non-Enhanced Tumor 396 | is_mask = False 397 | # Set path for output directory 398 | hgg_output_dir = os.path.join(data_dir, "HGGTrimmed") 399 | lgg_output_dir = os.path.join(data_dir, "LGGTrimmed") 400 | output_dirs = [hgg_output_dir, lgg_output_dir] 401 | 402 | prep = BTCPreprocess(input_dirs, output_dirs, "t1ce") 403 | prep.run(is_mask=is_mask, processes=-1) 404 | -------------------------------------------------------------------------------- /src/btc_dataset.py: -------------------------------------------------------------------------------- 1 | # Brain Tumor Classification 2 | # Load and Split dataset into training set, 3 | # validation set and testing set. 4 | # Author: Qixun QU 5 | # Copyleft: MIT Licience 6 | 7 | # ,,, ,,, 8 | # ;" '; ;' ", 9 | # ; @.ss$$$$$$s.@ ; 10 | # `s$$$$$$$$$$$$$$$' 11 | # $$$$$$$$$$$$$$$$$$ 12 | # $$$$P""Y$$$Y""W$$$$$ 13 | # $$$$ p"$$$"q $$$$$ 14 | # $$$$ .$$$$$. $$$$' 15 | # $$$DaU$$O$$DaU$$$' 16 | # '$$$$'.^.'$$$$' 17 | # '&$$$$$&' 18 | 19 | 20 | from __future__ import print_function 21 | 22 | 23 | import os 24 | import numpy as np 25 | import pandas as pd 26 | import nibabel as nib 27 | from random import seed, shuffle 28 | from keras.utils import to_categorical 29 | 30 | 31 | class BTCDataset(object): 32 | 33 | def __init__(self, 34 | hgg_dir, lgg_dir, 35 | volume_type="t1ce", 36 | train_prop=0.6, 37 | valid_prop=0.2, 38 | random_state=0, 39 | is_augment=True, 40 | pre_trainset_path=None, 41 | pre_validset_path=None, 42 | pre_testset_path=None, 43 | data_format=".nii.gz"): 44 | '''__INIT__ 45 | 46 | Intialize configurations for loading 47 | and partitioning dataset. 48 | 49 | Important variables: 50 | - train_x, train_y 51 | - valid_x, valid_y 52 | - test_x, test_y 53 | (x: brain images, y: labels) 54 | 55 | Inputs: 56 | ------- 57 | 58 | - hgg_dir: string, path of directory contains HGG subjects. 59 | - lgg_dir: string, path of directory contains LGG subjects. 60 | - subj_separated: boolean, True: partition scans according to 61 | subjects or False: randomly partition all scans. 62 | Default is True. 63 | - volume_type: string, type of brain tissue, "t1ce", "flair", 64 | "t1" or "t2". Default is "t1ce". 65 | - train_prop: float between 0 and 1, proportion of training 66 | data to whole dataset. Default is 0.6. 67 | - valid_prop: float between 0 and 1, proportion of validation 68 | data to whole dataset. Default is 0.2. 69 | - random_state: int, seed for reproducibly partition dataset. 70 | - is_augment: boolean, if True, do augmentation by flipping 71 | image from left to right. Defalut is False. 72 | - pre_trainset_path, pre_validset_path, ore_testset_path: 73 | string, path of csv file, gives information of subjects (IDs 74 | and labels) in training set, validation set and testing set. 75 | - data_format: string, format of brain images, defalut is ".nii.gz". 76 | 77 | ''' 78 | 79 | self.hgg_dir = hgg_dir 80 | self.lgg_dir = lgg_dir 81 | self.volume_type = volume_type 82 | 83 | self.train_prop = train_prop 84 | self.valid_prop = valid_prop 85 | self.random_state = int(random_state) 86 | self.is_augment = is_augment 87 | 88 | self.pre_trainset = pre_trainset_path 89 | self.pre_validset = pre_validset_path 90 | self.pre_testset = pre_testset_path 91 | self.data_format = data_format 92 | 93 | self.train_x, self.train_y = None, None 94 | self.valid_x, self.valid_y = None, None 95 | self.test_x, self.test_y = None, None 96 | 97 | return 98 | 99 | def run(self, pre_split=True, 100 | save_split=False, 101 | save_split_dir=None): 102 | '''RUN 103 | 104 | Partition dataset. 105 | 106 | Inputs: 107 | ------- 108 | 109 | - pre_split: boolean, if True, read csv files to get information 110 | of partitions that have been split. Default is True. 111 | - save_split: boolean, if True, save partition to csv files. 112 | Default is False. 113 | - save_split_dir: string, path of directory to save partition 114 | information. It is useful if save_split is True. 115 | Default is None. 116 | 117 | ''' 118 | 119 | print("\nSplitting dataset to train, valide and test.\n") 120 | 121 | # Load partition's information from csv file 122 | # or generate new partitions 123 | trainset, validset, testset = \ 124 | self._get_pre_datasplit() if pre_split else \ 125 | self._get_new_datasplit() 126 | 127 | # Load images acording to partition information 128 | self._load_dataset(trainset, validset, testset) 129 | 130 | if save_split and (not pre_split): 131 | # Save new partitions into csv files 132 | self.save_split_dir = save_split_dir 133 | self._save_dataset(trainset, validset, testset) 134 | 135 | return 136 | 137 | def _get_pre_datasplit(self): 138 | '''_GET_PRE_DATASPLIT 139 | 140 | Load partition inforamtion from csv files for 141 | training set, validation set and testing set. 142 | In each csv file, information includes: 143 | - ID: subject's ID. 144 | - label: subject's label, 1 for HGG and 0 for LGG. 145 | 146 | Outputs: 147 | -------- 148 | 149 | - trainset, validset, testset: list of information, 150 | each element is [subject_path, label]. 151 | 152 | ''' 153 | 154 | # Parameters for function to load csv 155 | paras = {"hgg_dir": self.hgg_dir, 156 | "lgg_dir": self.lgg_dir, 157 | "data_format": self.data_format, 158 | "csv_path": None} 159 | 160 | # Load partition of training set 161 | paras["csv_path"] = self.pre_trainset 162 | trainset = self.load_datasplit(**paras) 163 | 164 | # Load partition of validation set 165 | paras["csv_path"] = self.pre_validset 166 | validset = self.load_datasplit(**paras) 167 | 168 | # Load partition of testing set 169 | paras["csv_path"] = self.pre_testset 170 | testset = self.load_datasplit(**paras) 171 | 172 | return trainset, validset, testset 173 | 174 | def _get_new_datasplit(self): 175 | '''_GET_NEW_DATASPLIT 176 | 177 | Obtain new partition of dataset. 178 | -1- Generate paths of all subjects. 179 | -2- Randomly reoarrange the path list. 180 | -3- Partition dataset according to proportions. 181 | -4- Merge HGG and LGG subjects. 182 | 183 | Outputs: 184 | -------- 185 | 186 | - trainset, validset, testset: list of information, 187 | each element is [subject_path, label]. 188 | 189 | ''' 190 | 191 | # Parameters for function to load subject's paths 192 | paras = {"label": None, 193 | "dir_path": None, 194 | "volume_type": self.volume_type, 195 | "random_state": self.random_state} 196 | 197 | # Load HGG subjects' paths 198 | paras["label"], paras["dir_path"] = 1, self.hgg_dir 199 | hgg_subjects = self.get_subjects_path(**paras) 200 | 201 | # Load LGG subjects' paths 202 | paras["label"], paras["dir_path"] = 0, self.lgg_dir 203 | lgg_subjects = self.get_subjects_path(**paras) 204 | 205 | # Parameters for function to partition dataset 206 | paras = {"subjects": None, 207 | "train_prop": self.train_prop, 208 | "valid_prop": self.valid_prop} 209 | 210 | # Partition HGG subjects into three sets 211 | paras["subjects"] = hgg_subjects 212 | hgg_train, hgg_valid, hgg_test = self.split_dataset(**paras) 213 | 214 | # Partition LGG subjects into three sets 215 | paras["subjects"] = lgg_subjects 216 | lgg_train, lgg_valid, lgg_test = self.split_dataset(**paras) 217 | 218 | # Merge HGG and LGG subjects 219 | trainset = hgg_train + lgg_train 220 | validset = hgg_valid + lgg_valid 221 | testset = hgg_test + lgg_test 222 | 223 | return trainset, validset, testset 224 | 225 | def _load_dataset(self, trainset, validset, testset): 226 | '''_LOAD_DATASET 227 | 228 | Load images and labels for three partitions: 229 | training set, validation set and testing set. 230 | 231 | ''' 232 | 233 | # Load images and labels of subjects in testing set 234 | self.test_x, test_y = self.load_data(testset, "test set") 235 | self.test_y = to_categorical(test_y, num_classes=2) 236 | 237 | # Load images and labels of subjects in validation set 238 | self.valid_x, valid_y = self.load_data(validset, "valid set") 239 | self.valid_y = to_categorical(valid_y, num_classes=2) 240 | 241 | # Load images and labels of subjects in training set 242 | train_x, train_y = self.load_data(trainset, "train set") 243 | 244 | if self.is_augment: 245 | # Augmentation on LGG subjects 246 | train_x, train_y = self.augment(train_x, train_y) 247 | 248 | self.train_x = train_x 249 | self.train_y = to_categorical(train_y, num_classes=2) 250 | 251 | return 252 | 253 | def _save_dataset(self, trainset, validset, testset): 254 | '''_SAVE_DATASET 255 | 256 | Save partition informatio into csv files. 257 | 258 | Outputs: 259 | -------- 260 | 261 | - trainset_[random_state].csv 262 | - validset_[random_state].csv 263 | - testset_[random_state].csv 264 | 265 | ''' 266 | 267 | # Generate paths for output csv files 268 | ap = str(self.random_state) + ".csv" 269 | trainset_path = os.path.join(self.save_split_dir, "trainset_" + ap) 270 | validset_path = os.path.join(self.save_split_dir, "validset_" + ap) 271 | testset_path = os.path.join(self.save_split_dir, "testset_" + ap) 272 | 273 | # Save information into csv files 274 | self.save_datasplit(trainset, trainset_path) 275 | self.save_datasplit(validset, validset_path) 276 | self.save_datasplit(testset, testset_path) 277 | 278 | return 279 | 280 | @staticmethod 281 | def load_datasplit(hgg_dir, lgg_dir, csv_path, 282 | data_format=".nii.gz"): 283 | '''LOAD_DATASPLIT 284 | 285 | Load partition information from given csv file. 286 | 287 | Inputs: 288 | ------- 289 | 290 | - hgg_dir: string, directory path of HGG subjects. 291 | - lgg_dir: string, directory path of LGG subjects. 292 | - csv_path: string, path of csv file which contains 293 | partition information. 294 | - data_format: string, dormat of input images, 295 | default is ".nii.gz". 296 | 297 | Output: 298 | ------- 299 | 300 | - info: list of partition information, each element is 301 | [subject_path, label]. 302 | 303 | ''' 304 | 305 | # Load IDs and labels form csv file 306 | df = pd.read_csv(csv_path) 307 | IDs = df["ID"].values.tolist() 308 | labels = df["label"].values.tolist() 309 | 310 | info = [] 311 | for ID, label in zip(IDs, labels): 312 | # Generate directopy path of each subject 313 | target_dir = hgg_dir if label else lgg_dir 314 | path = os.path.join(target_dir, ID[:-5], 315 | ID + data_format) 316 | info.append([path, label]) 317 | return info 318 | 319 | @staticmethod 320 | def save_datasplit(dataset, to_path): 321 | '''SAVE_DATASPLIT 322 | 323 | Save partition information into csv file. 324 | 325 | Inputs: 326 | ------- 327 | 328 | - dataset: list, information of partition, each element 329 | is [subject_path, label]. 330 | - to_path: string, the path of csv file to be saved. 331 | 332 | Output: 333 | ------- 334 | 335 | - A csv table with two columns, "ID" and "label". 336 | 337 | ''' 338 | 339 | IDs, labels = [], [] 340 | for i in dataset: 341 | # Extract ID from subject's path 342 | IDs.append(i[0].split("/")[-1].split(".")[0]) 343 | # Extract label 344 | labels.append(i[1]) 345 | 346 | # Create pandas DataFrame and save it into csv file 347 | df = pd.DataFrame(data={"ID": IDs, "label": labels}) 348 | df.to_csv(to_path, index=False) 349 | 350 | return 351 | 352 | @staticmethod 353 | def get_subjects_path(dir_path, volume_type, label, 354 | random_state=0): 355 | '''GET_SUBJECTS_PATH 356 | 357 | Obtain subjects' paths of HGG or LGG. 358 | 359 | Inputs: 360 | ------- 361 | 362 | - dir_path: string, directory path of HGG or LGG subjects. 363 | - volume_type: string, type of brain tissue, "t1ce", "flair", 364 | "t1" or "t2". 365 | - label: int, 1 for HGG and o for LGG. 366 | - random_state: int, seed for shuffle paths list. 367 | 368 | Output: 369 | ------- 370 | 371 | - subjects_paths: list with two columns, each element is 372 | [subject_path, label]. 373 | 374 | ''' 375 | 376 | # Obtain all subjects' names 377 | subjects = os.listdir(dir_path) 378 | 379 | # Set seed and shuffle list 380 | # Different seed leads to different shuffled list 381 | # to change subjects in partitions 382 | seed(random_state) 383 | shuffle(subjects) 384 | 385 | subjects_paths = [] 386 | for subject in subjects: 387 | subject_dir = os.path.join(dir_path, subject) 388 | for scan_name in os.listdir(subject_dir): 389 | if volume_type not in scan_name: 390 | # Not target volume 391 | continue 392 | 393 | # Element [subject_dir, label] 394 | scan_path = os.path.join(subject_dir, scan_name) 395 | subjects_paths.append([scan_path, label]) 396 | 397 | return subjects_paths 398 | 399 | @staticmethod 400 | def split_dataset(subjects, train_prop=0.6, valid_prop=0.2): 401 | '''SPLIT_DATASET 402 | 403 | Partition dataset into three parts according 404 | to proportions. 405 | 406 | Inputs: 407 | ------- 408 | 409 | - subjects: list with two columns, information of all 410 | subjects, each element is [subject_path, label]. 411 | - train_prop: float between 0 and 1, proportion of training 412 | data to whole dataset. Default is 0.6. 413 | - valid_prop: float between 0 and 1, proportion of validation 414 | data to whole dataset. Default is 0.2. 415 | 416 | Outputs: 417 | 418 | - trainset, validset, testset: partition information,including 419 | subjects' paths and labels. 420 | 421 | ''' 422 | 423 | subj_num = len(subjects) 424 | 425 | # Extract subjects for testing set 426 | train_valid_num = subj_num * (train_prop + valid_prop) 427 | train_valid_idx = int(round(train_valid_num)) 428 | testset = subjects[train_valid_idx:] 429 | 430 | # Extract subjects validation set 431 | valid_idx = int(round(subj_num * valid_prop)) 432 | validset = subjects[:valid_idx] 433 | # Extract subjects for training set 434 | trainset = subjects[valid_idx:train_valid_idx] 435 | 436 | return trainset, validset, testset 437 | 438 | @staticmethod 439 | def load_data(dataset, mode): 440 | '''LOAD_DATA 441 | 442 | Load images from partition information. 443 | 444 | Inputs: 445 | ------- 446 | 447 | - dataset: list with two columns, [subject_path, label]. 448 | - mode: string, indicates which partition, "train set", 449 | "valid set" or "test set". 450 | 451 | Outputs: 452 | -------- 453 | 454 | - x: numpy ndarray in shape [n, 112, 96, 96, 1], n is the 455 | number of scans in one partition. Input images. 456 | - y: numpy ndarray in shape [n, 1]. Labels of subjects. 457 | 458 | ''' 459 | 460 | x, y = [], [] 461 | print("Loading {} data ...".format(mode)) 462 | for subject in dataset: 463 | volume_path, label = subject[0], subject[1] 464 | # Load image and rotate it to standard space 465 | volume = nib.load(volume_path).get_data() 466 | volume = np.transpose(volume, axes=[1, 0, 2]) 467 | volume = np.flipud(volume) 468 | 469 | # Extract mean and std from brain object 470 | volume_obj = volume[volume > 0] 471 | obj_mean = np.mean(volume_obj) 472 | obj_std = np.std(volume_obj) 473 | # Normalize whole image 474 | volume = (volume - obj_mean) / obj_std 475 | 476 | volume = np.expand_dims(volume, axis=3) 477 | x.append(volume.astype(np.float32)) 478 | y.append(label) 479 | 480 | x = np.array(x) 481 | y = np.array(y).reshape((-1, 1)) 482 | 483 | return x, y 484 | 485 | @staticmethod 486 | def augment(train_x, train_y): 487 | '''AUGMENT 488 | 489 | Do augmentation of LGG subjects in training set 490 | by flipping each image from left to right. 491 | 492 | Inputs: 493 | ------- 494 | 495 | - train_x: numpy ndarray, images array of training set. 496 | - train_y: numpy ndarray, labels of training set. 497 | 498 | Outputs: 499 | -------- 500 | 501 | - train_x: augmented training images, which are double as original. 502 | - train_y: augmented labels of training set. 503 | 504 | ''' 505 | 506 | print("Do Augmentation on LGG Samples ...") 507 | train_x_aug, train_y_aug = [], [] 508 | for i in range(len(train_y)): 509 | train_x_aug.append(train_x[i]) 510 | train_y_aug.append(train_y[i]) 511 | if train_y[i] == 0: 512 | # Flip image if it is LGG 513 | train_x_aug.append(np.fliplr(train_x[i])) 514 | train_y_aug.append(np.array([0])) 515 | train_x = np.array(train_x_aug) 516 | train_y = np.array(train_y_aug).reshape((-1, 1)) 517 | 518 | return train_x, train_y 519 | 520 | 521 | if __name__ == "__main__": 522 | 523 | import gc 524 | 525 | parent_dir = os.path.dirname(os.getcwd()) 526 | 527 | # Set dirctory for input images (separated subjects) 528 | data_dir = os.path.join(parent_dir, "data") 529 | hgg_dir = os.path.join(data_dir, "HGGSegTrimmed") 530 | lgg_dir = os.path.join(data_dir, "LGGSegTrimmed") 531 | 532 | # Test 1 533 | # Load and split dataset 534 | data = BTCDataset(hgg_dir, lgg_dir, 535 | volume_type="t1ce", 536 | train_prop=0.6, 537 | valid_prop=0.2, 538 | random_state=0) 539 | data.run(pre_split=False, 540 | save_split=True, 541 | save_split_dir="DataSplit") 542 | print(data.train_x.shape, data.train_y.shape) 543 | del data 544 | gc.collect() 545 | 546 | # Test 2 547 | # Load dataset which has been split 548 | data = BTCDataset(hgg_dir, lgg_dir, 549 | volume_type="t1ce", 550 | pre_trainset_path="DataSplit/trainset.csv", 551 | pre_validset_path="DataSplit/validset.csv", 552 | pre_testset_path="DataSplit/testset.csv") 553 | data.run(pre_split=True) 554 | print(data.train_x.shape, data.train_y.shape) 555 | del data 556 | gc.collect() 557 | --------------------------------------------------------------------------------