├── .gitignore ├── LICENSE ├── README.md ├── model.py ├── preprocess.py ├── segment.py ├── segment_proba.py ├── train.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | env/ 12 | build/ 13 | develop-eggs/ 14 | dist/ 15 | downloads/ 16 | eggs/ 17 | .eggs/ 18 | lib/ 19 | lib64/ 20 | parts/ 21 | sdist/ 22 | var/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | 27 | # PyInstaller 28 | # Usually these files are written by a python script from a template 29 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 30 | *.manifest 31 | *.spec 32 | 33 | # Installer logs 34 | pip-log.txt 35 | pip-delete-this-directory.txt 36 | 37 | # Unit test / coverage reports 38 | htmlcov/ 39 | .tox/ 40 | .coverage 41 | .coverage.* 42 | .cache 43 | nosetests.xml 44 | coverage.xml 45 | *,cover 46 | .hypothesis/ 47 | 48 | # Translations 49 | *.mo 50 | *.pot 51 | 52 | # Django stuff: 53 | *.log 54 | local_settings.py 55 | 56 | # Flask stuff: 57 | instance/ 58 | .webassets-cache 59 | 60 | # Scrapy stuff: 61 | .scrapy 62 | 63 | # Sphinx documentation 64 | docs/_build/ 65 | 66 | # PyBuilder 67 | target/ 68 | 69 | # IPython Notebook 70 | .ipynb_checkpoints 71 | 72 | # pyenv 73 | .python-version 74 | 75 | # celery beat schedule file 76 | celerybeat-schedule 77 | 78 | # dotenv 79 | .env 80 | 81 | # virtualenv 82 | venv/ 83 | ENV/ 84 | 85 | # Spyder project settings 86 | .spyderproject 87 | 88 | # Rope project settings 89 | .ropeproject 90 | 91 | dataset_train.csv 92 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2017 Ryo Ito 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Implementation of VoxResNet 2 | 3 | This is a repository containing code of Deep Voxelwise Residual Networks for Volumetric Brain Segmentation (VoxResNet) [1]. 4 | 5 | Note that this is not an official implementation. 6 | 7 | 8 | # Requirements 9 | - python 3.6 10 | - chainer v2 11 | - cupy 12 | - dipy 13 | - nibabel 14 | - numpy 15 | - pandas 16 | - SimpleITK 17 | 18 | 19 | # Preparing dataset 20 | 1. Download [Internet Brain Segmentation Repository (IBSR)](https://www.nitrc.org/frs/download.php/5731/IBSR_V2.0_nifti_stripped.tgz) (or other dataset you want to try) 21 | 2. preprocess dataset 22 | - training dataset 23 | `$ python preprocess.py -i /path/to/IBSR/dataset -s IBSR_01 IBSR_02 IBSR_03 IBSR_04 IBSR_05 --input_image_suffix _ana_strip.nii.gz --output_image_suffix _preprocessed.nii.gz --label_suffix _segTRI_ana.nii.gz -f dataset_train.json --n_classes 4 --zooms 1. 1. 1.` 24 | - test dataset 25 | `$ python preprocess.py -i /path/to/IBSR/dataset -s IBSR_11 IBSR_12 IBSR_13 IBSR_14 IBSR_15 --input_image_suffix _ana_strip.nii.gz --output_image_suffix _preprocessed.nii.gz --label_suffix _segTRI_ana.nii.gz -f dataset_test.json --n_classes 4 --zooms 1. 1. 1.` 26 | 27 | 28 | # Train VoxResNet 29 | `$ python train.py -g 0 -f dataset_train.json` 30 | 31 | 32 | # Test VoxResNet 33 | `$ python segment.py -g 0 -i dataset_test.json -m vrn.npz -o _segTRI_predict.nii.gz` 34 | 35 | 36 | # Reference 37 | 38 | [1] Chen, Hao, et al. "VoxResNet: Deep Voxelwise Residual Networks for Volumetric Brain Segmentation." arXiv preprint arXiv:1608.05895 (2016). 39 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import chainer 2 | import chainer.functions as F 3 | import chainer.links as L 4 | 5 | 6 | class VoxResModule(chainer.Chain): 7 | """ 8 | Voxel Residual Module 9 | input 10 | BatchNormalization, ReLU 11 | Conv 64, 3x3x3 12 | BatchNormalization, ReLU 13 | Conv 64, 3x3x3 14 | output 15 | """ 16 | 17 | def __init__(self): 18 | initW = chainer.initializers.HeNormal(scale=0.01) 19 | super().__init__() 20 | 21 | with self.init_scope(): 22 | self.bnorm1 = L.BatchNormalization(size=64) 23 | self.conv1 = L.ConvolutionND(3, 64, 64, 3, pad=1, initialW=initW) 24 | self.bnorm2 = L.BatchNormalization(size=64) 25 | self.conv2 = L.ConvolutionND(3, 64, 64, 3, pad=1, initialW=initW) 26 | 27 | def __call__(self, x): 28 | h = F.relu(self.bnorm1(x)) 29 | h = self.conv1(h) 30 | h = F.relu(self.bnorm2(h)) 31 | h = self.conv2(h) 32 | return h + x 33 | 34 | 35 | class VoxResNet(chainer.Chain): 36 | """Voxel Residual Network""" 37 | 38 | def __init__(self, in_channels=1, n_classes=4): 39 | init = chainer.initializers.HeNormal(scale=0.01) 40 | super().__init__() 41 | 42 | with self.init_scope(): 43 | self.conv1a = L.ConvolutionND( 44 | 3, in_channels, 32, 3, pad=1, initialW=init) 45 | self.bnorm1a = L.BatchNormalization(32) 46 | self.conv1b = L.ConvolutionND( 47 | 3, 32, 32, 3, pad=1, initialW=init) 48 | self.bnorm1b = L.BatchNormalization(32) 49 | self.conv1c = L.ConvolutionND( 50 | 3, 32, 64, 3, stride=2, pad=1, initialW=init) 51 | self.voxres2 = VoxResModule() 52 | self.voxres3 = VoxResModule() 53 | self.bnorm3 = L.BatchNormalization(64) 54 | self.conv4 = L.ConvolutionND( 55 | 3, 64, 64, 3, stride=2, pad=1, initialW=init) 56 | self.voxres5 = VoxResModule() 57 | self.voxres6 = VoxResModule() 58 | self.bnorm6 = L.BatchNormalization(64) 59 | self.conv7 = L.ConvolutionND( 60 | 3, 64, 64, 3, stride=2, pad=1, initialW=init) 61 | self.voxres8 = VoxResModule() 62 | self.voxres9 = VoxResModule() 63 | self.c1deconv = L.DeconvolutionND( 64 | 3, 32, 32, 3, pad=1, initialW=init) 65 | self.c1conv = L.ConvolutionND( 66 | 3, 32, n_classes, 3, pad=1, initialW=init) 67 | self.c2deconv = L.DeconvolutionND( 68 | 3, 64, 64, 4, stride=2, pad=1, initialW=init) 69 | self.c2conv = L.ConvolutionND( 70 | 3, 64, n_classes, 3, pad=1, initialW=init) 71 | self.c3deconv = L.DeconvolutionND( 72 | 3, 64, 64, 6, stride=4, pad=1, initialW=init) 73 | self.c3conv = L.ConvolutionND( 74 | 3, 64, n_classes, 3, pad=1, initialW=init) 75 | self.c4deconv = L.DeconvolutionND( 76 | 3, 64, 64, 10, stride=8, pad=1, initialW=init) 77 | self.c4conv = L.ConvolutionND( 78 | 3, 64, n_classes, 3, pad=1, initialW=init) 79 | 80 | def __call__(self, x, train=False): 81 | """ 82 | calculate output of VoxResNet given input x 83 | 84 | Parameters 85 | ---------- 86 | x : (batch_size, in_channels, xlen, ylen, zlen) ndarray 87 | image to perform semantic segmentation 88 | 89 | Returns 90 | ------- 91 | proba: (batch_size, n_classes, xlen, ylen, zlen) ndarray 92 | probability of each voxel belonging each class 93 | elif train=True, returns list of logits 94 | """ 95 | with chainer.using_config("train", train): 96 | h = self.conv1a(x) 97 | h = F.relu(self.bnorm1a(h)) 98 | h = self.conv1b(h) 99 | c1 = F.clipped_relu(self.c1deconv(h)) 100 | c1 = self.c1conv(c1) 101 | 102 | h = F.relu(self.bnorm1b(h)) 103 | h = self.conv1c(h) 104 | h = self.voxres2(h) 105 | h = self.voxres3(h) 106 | c2 = F.clipped_relu(self.c2deconv(h)) 107 | c2 = self.c2conv(c2) 108 | 109 | h = F.relu(self.bnorm3(h)) 110 | h = self.conv4(h) 111 | h = self.voxres5(h) 112 | h = self.voxres6(h) 113 | c3 = F.clipped_relu(self.c3deconv(h)) 114 | c3 = self.c3conv(c3) 115 | 116 | h = F.relu(self.bnorm6(h)) 117 | h = self.conv7(h) 118 | h = self.voxres8(h) 119 | h = self.voxres9(h) 120 | c4 = F.clipped_relu(self.c4deconv(h)) 121 | c4 = self.c4conv(c4) 122 | 123 | c = c1 + c2 + c3 + c4 124 | 125 | if train: 126 | return [c1, c2, c3, c4, c] 127 | else: 128 | return F.softmax(c) 129 | -------------------------------------------------------------------------------- /preprocess.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import os 4 | 5 | from dipy.align.reslice import reslice 6 | import nibabel as nib 7 | import numpy as np 8 | import pandas as pd 9 | from scipy.ndimage.filters import gaussian_filter 10 | import SimpleITK as sitk 11 | 12 | 13 | def preprocess_img(inputfile, output_preprocessed, zooms): 14 | img = nib.load(inputfile) 15 | data = img.get_data() 16 | affine = img.affine 17 | zoom = img.header.get_zooms()[:3] 18 | data, affine = reslice(data, affine, zoom, zooms, 1) 19 | data = np.squeeze(data) 20 | data = np.pad(data, [(0, 256 - len_) for len_ in data.shape], "constant") 21 | 22 | data_sub = data - gaussian_filter(data, sigma=1) 23 | img = sitk.GetImageFromArray(np.copy(data_sub)) 24 | img = sitk.AdaptiveHistogramEqualization(img) 25 | data_clahe = sitk.GetArrayFromImage(img)[:, :, :, None] 26 | data = np.concatenate((data_clahe, data[:, :, :, None]), 3) 27 | data = (data - np.mean(data, (0, 1, 2))) / np.std(data, (0, 1, 2)) 28 | assert data.ndim == 4, data.ndim 29 | assert np.allclose(np.mean(data, (0, 1, 2)), 0.), np.mean(data, (0, 1, 2)) 30 | assert np.allclose(np.std(data, (0, 1, 2)), 1.), np.std(data, (0, 1, 2)) 31 | data = np.float32(data) 32 | 33 | img = nib.Nifti1Image(data, affine) 34 | nib.save(img, output_preprocessed) 35 | 36 | 37 | def preprocess_label(inputfile, 38 | output_label, 39 | n_classes, 40 | zooms, 41 | df=None, 42 | input_key=None, 43 | output_key=None): 44 | img = nib.load(inputfile) 45 | data = img.get_data() 46 | affine = img.affine 47 | zoom = img.header.get_zooms()[:3] 48 | data, affine = reslice(data, affine, zoom, zooms, 0) 49 | data = np.squeeze(data) 50 | data = np.pad(data, [(0, 256 - len_) for len_ in data.shape], "constant") 51 | 52 | if df is not None: 53 | tmp = np.zeros_like(data) 54 | for target, source in zip(df[output_key], df[input_key]): 55 | tmp[np.where(data == source)] = target 56 | data = tmp 57 | data = np.int32(data) 58 | assert np.max(data) < n_classes 59 | img = nib.Nifti1Image(data, affine) 60 | nib.save(img, output_label) 61 | 62 | 63 | def main(): 64 | parser = argparse.ArgumentParser(description="preprocess dataset") 65 | parser.add_argument( 66 | "--input_directory", "-i", type=str, 67 | help="directory of original dataset" 68 | ) 69 | parser.add_argument( 70 | "--subjects", "-s", type=str, nargs="*", action="store", 71 | help="subjects to be preprocessed" 72 | ) 73 | parser.add_argument( 74 | "--weights", "-w", type=int, nargs="*", action="store", 75 | help="sample weight for each subject" 76 | ) 77 | parser.add_argument( 78 | "--input_image_suffix", type=str, 79 | help="suffix of input images" 80 | ) 81 | parser.add_argument( 82 | "--output_image_suffix", type=str, 83 | help="suffix of output images" 84 | ) 85 | parser.add_argument( 86 | "--label_suffix", type=str, 87 | help="suffix of labels" 88 | ) 89 | parser.add_argument( 90 | "--output_file", "-f", type=str, default="dataset.json", 91 | help="json file of preprocessed dataset, default=dataset.json" 92 | ) 93 | parser.add_argument( 94 | "--label_file", "-l", type=str, default=None, 95 | help="csv file with label translation rule, default=None" 96 | ) 97 | parser.add_argument( 98 | "--input_key", type=str, default=None, 99 | help="specifies column for input of label translation, default=None" 100 | ) 101 | parser.add_argument( 102 | "--output_key", type=str, default=None, 103 | help="specifies column for output of label translation, default=None" 104 | ) 105 | parser.add_argument( 106 | "--n_classes", type=int, 107 | help="number of classes to classify" 108 | ) 109 | parser.add_argument( 110 | "--zooms", type=float, nargs="*", action="store", default=[1., 1., 1.], 111 | help="zooming resolution" 112 | ) 113 | args = parser.parse_args() 114 | if args.weights is None: 115 | args.weights = [1. for _ in args.subjects] 116 | assert len(args.subjects) == len(args.weights) 117 | print(args) 118 | 119 | if args.label_file is None: 120 | df = None 121 | else: 122 | df = pd.read_csv(args.label_file) 123 | 124 | dataset = {"in_channels": 2, "n_classes": args.n_classes} 125 | dataset_list = [] 126 | 127 | for subject, weight in zip(args.subjects, args.weights): 128 | if not os.path.exists(subject): 129 | os.makedirs(subject) 130 | filedict = {"subject": subject, "weight": weight} 131 | 132 | if args.input_image_suffix is not None: 133 | filedict["image"] = os.path.join( 134 | subject, 135 | subject + args.output_image_suffix 136 | ) 137 | preprocess_img( 138 | os.path.join( 139 | args.input_directory, 140 | subject, 141 | subject + args.input_image_suffix 142 | ), 143 | filedict["image"], 144 | args.zooms 145 | ) 146 | 147 | if args.label_suffix is not None: 148 | filedict["label"] = os.path.join( 149 | subject, 150 | subject + args.label_suffix 151 | ) 152 | preprocess_label( 153 | os.path.join( 154 | args.input_directory, 155 | subject, 156 | subject + args.label_suffix 157 | ), 158 | filedict["label"], 159 | args.n_classes, 160 | args.zooms, 161 | df=df, 162 | input_key=args.input_key, 163 | output_key=args.output_key 164 | ) 165 | 166 | dataset_list.append(filedict) 167 | dataset["data"] = dataset_list 168 | 169 | with open(args.output_file, "w") as f: 170 | json.dump(dataset, f, indent=4, sort_keys=True) 171 | 172 | 173 | if __name__ == '__main__': 174 | main() 175 | -------------------------------------------------------------------------------- /segment.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import itertools 3 | import json 4 | import os 5 | 6 | import chainer 7 | import numpy as np 8 | import nibabel as nib 9 | import pandas as pd 10 | 11 | from model import VoxResNet 12 | from utils import crop_patch, load_nifti, feedforward 13 | 14 | 15 | def main(): 16 | parser = argparse.ArgumentParser(description="segment with VoxResNet") 17 | parser.add_argument( 18 | "--input_file", "-i", type=str, 19 | help="input json file of test dataset") 20 | parser.add_argument( 21 | "--output_suffix", "-o", type=str, 22 | help="result of the segmentation") 23 | parser.add_argument( 24 | "--model", "-m", type=str, 25 | help="a file containing parameters of trained VoxResNet") 26 | parser.add_argument( 27 | "--input_shape", type=int, nargs="*", action="store", 28 | default=[80, 80, 80], 29 | help="input patch shape of VoxResNet, default=[80, 80, 80]") 30 | parser.add_argument( 31 | "--output_shape", type=int, nargs="*", action="store", 32 | default=[60, 60, 60], 33 | help="output patch shape of VoxResNet, default=[60, 60, 60]") 34 | parser.add_argument( 35 | "--gpu", "-g", default=-1, type=int, 36 | help="negative value indicates no gpu, default=-1") 37 | parser.add_argument( 38 | "--n_tiles", type=int, nargs="*", action="store", 39 | default=[5, 5, 5], 40 | help="number of tiles along each axis, default=[5, 5, 5]") 41 | args = parser.parse_args() 42 | print(args) 43 | 44 | with open(args.input_file) as f: 45 | dataset = json.load(f) 46 | test_df = pd.DataFrame(dataset["data"]) 47 | 48 | vrn = VoxResNet(dataset["in_channels"], dataset["n_classes"]) 49 | chainer.serializers.load_npz(args.model, vrn) 50 | 51 | if args.gpu >= 0: 52 | chainer.cuda.get_device_from_id(args.gpu).use() 53 | vrn.to_gpu() 54 | 55 | for image_path, subject in zip(test_df["image"], test_df["subject"]): 56 | image, affine = load_nifti(image_path, with_affine=True) 57 | output = feedforward( 58 | vrn, 59 | image, 60 | args.input_shape, 61 | args.output_shape, 62 | args.n_tiles, 63 | dataset["n_classes"] 64 | ) 65 | y = np.argmax(output, axis=0) 66 | nib.save( 67 | nib.Nifti1Image(np.int32(y), affine), 68 | os.path.join( 69 | os.path.dirname(image_path), 70 | subject + args.output_suffix)) 71 | 72 | 73 | if __name__ == '__main__': 74 | main() 75 | -------------------------------------------------------------------------------- /segment_proba.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import itertools 3 | import json 4 | import os 5 | 6 | import chainer 7 | import numpy as np 8 | import nibabel as nib 9 | import pandas as pd 10 | 11 | from model import VoxResNet 12 | from utils import crop_patch, load_nifti, feedforward 13 | 14 | 15 | def main(): 16 | parser = argparse.ArgumentParser( 17 | description="calculate class probabilities with VoxResNet") 18 | parser.add_argument( 19 | "--input_file", "-i", type=str, 20 | help="input json file of test dataset" 21 | ) 22 | parser.add_argument( 23 | "--output_suffix", "-o", type=str, 24 | help="result of the segmentation" 25 | ) 26 | parser.add_argument( 27 | "--model", "-m", type=str, 28 | help="a file containing parameters of trained VoxResNet" 29 | ) 30 | parser.add_argument( 31 | "--input_shape", type=int, nargs="*", action="store", 32 | default=[80, 80, 80], 33 | help="input patch shape of VoxResNet, default=[80, 80, 80]" 34 | ) 35 | parser.add_argument( 36 | "--output_shape", type=int, nargs="*", action="store", 37 | default=[60, 60, 60], 38 | help="output patch shape of VoxResNet, default=[60, 60, 60]" 39 | ) 40 | parser.add_argument( 41 | "--gpu", "-g", default=-1, type=int, 42 | help="negative value indicates no gpu, default=-1" 43 | ) 44 | parser.add_argument( 45 | "--n_tiles", type=int, nargs="*", action="store", 46 | default=[5, 5, 5], 47 | help="number of tiles along each axis" 48 | ) 49 | args = parser.parse_args() 50 | print(args) 51 | 52 | with open(args.input_file) as f: 53 | dataset = json.load(f) 54 | test_df = pd.DataFrame(dataset["data"]) 55 | 56 | vrn = VoxResNet(dataset["in_channels"], dataset["n_classes"]) 57 | chainer.serializers.load_npz(args.model, vrn) 58 | 59 | if args.gpu >= 0: 60 | chainer.cuda.get_device_from_id(args.gpu).use() 61 | vrn.to_gpu() 62 | 63 | for image_path, subject in zip(test_df["image"], test_df["subject"]): 64 | image, affine = load_nifti(image_path, with_affine=True) 65 | output = feedforward( 66 | vrn, 67 | image, 68 | args.input_shape, 69 | args.output_shape, 70 | args.n_tiles, 71 | dataset["n_classes"] 72 | ) 73 | 74 | output /= np.sum(output, axis=0, keepdims=True) 75 | 76 | nib.save( 77 | nib.Nifti1Image(np.float32(output).transpose(1, 2, 3, 0), affine), 78 | os.path.join( 79 | os.path.dirname(image_path), 80 | subject + args.output_suffix 81 | ) 82 | ) 83 | 84 | 85 | if __name__ == '__main__': 86 | main() -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import itertools 3 | import json 4 | 5 | import chainer 6 | import chainer.functions as F 7 | import numpy as np 8 | import pandas as pd 9 | 10 | from model import VoxResNet 11 | from utils import load_sample, load_nifti, crop_patch, dice_coefficients, feedforward 12 | 13 | 14 | def validate(model, df, input_shape, output_shape, n_tiles, n_classes): 15 | dice_coefs = [] 16 | for image_path, label_path in zip(df["image"], df["label"]): 17 | image = load_nifti(image_path) 18 | label = load_nifti(label_path) 19 | output = feedforward( 20 | model, 21 | image, 22 | input_shape, 23 | output_shape, 24 | n_tiles, 25 | n_classes) 26 | y = np.int32(np.argmax(output, axis=0)) 27 | dice_coefs.append( 28 | dice_coefficients(y, label, labels=range(n_classes)) 29 | ) 30 | dice_coefs = np.array(dice_coefs) 31 | return np.mean(dice_coefs, axis=0) 32 | 33 | 34 | def main(): 35 | parser = argparse.ArgumentParser(description="train VoxResNet") 36 | parser.add_argument( 37 | "--iteration", "-i", default=10000, type=int, 38 | help="number of iterations, default=10000") 39 | parser.add_argument( 40 | "--monitor_step", "-s", default=1000, type=int, 41 | help="number of steps to monitor, default=1000") 42 | parser.add_argument( 43 | "--gpu", "-g", default=-1, type=int, 44 | help="negative value indicates no gpu, default=-1") 45 | parser.add_argument( 46 | "--input_file", "-f", type=str, default="dataset.json", 47 | help="json file of traininig dataset, default=dataset.json") 48 | parser.add_argument( 49 | "--validation_file", "-v", type=str, 50 | help="json file for validation dataset") 51 | parser.add_argument( 52 | "--n_batch", type=int, default=1, 53 | help="batch size, default=1") 54 | parser.add_argument( 55 | "--input_shape", type=int, nargs='*', action="store", 56 | default=[80, 80, 80], 57 | help="shape of input for the network, default=[80, 80, 80]") 58 | parser.add_argument( 59 | "--output_shape", type=int, nargs="*", action="store", 60 | default=[60, 60, 60], 61 | help="shape of output of the network, default=[60, 60, 60]") 62 | parser.add_argument( 63 | "--n_tiles", type=int, nargs="*", action="store", 64 | default=[5, 5, 5], 65 | help="number of tiles along each axis, default=[5, 5, 5]") 66 | parser.add_argument( 67 | '--out', '-o', default='vrn.npz', type=str, 68 | help='parameters of trained model, default=vrn.npz') 69 | parser.add_argument( 70 | "--learning_rate", "-r", default=1e-3, type=float, 71 | help="update rate, default=1e-3") 72 | parser.add_argument( 73 | "--weight_decay", "-w", default=0.0005, type=float, 74 | help="coefficient of l2norm weight penalty, default=0.0005") 75 | args = parser.parse_args() 76 | print(args) 77 | 78 | with open(args.input_file) as f: 79 | dataset = json.load(f) 80 | train_df = pd.DataFrame(dataset["data"]) 81 | if args.validation_file is not None: 82 | with open(args.validation_file) as f: 83 | dataset_val = json.load(f) 84 | df_val = pd.DataFrame(dataset_val["data"]) 85 | val_score = 0 86 | 87 | vrn = VoxResNet(dataset["in_channels"], dataset["n_classes"]) 88 | if args.gpu >= 0: 89 | chainer.cuda.get_device_from_id(args.gpu).use() 90 | vrn.to_gpu() 91 | 92 | optimizer = chainer.optimizers.Adam(alpha=args.learning_rate) 93 | optimizer.use_cleargrads() 94 | optimizer.setup(vrn) 95 | optimizer.add_hook(chainer.optimizer.WeightDecay(args.weight_decay)) 96 | slices_in = [slice(None), slice(None)] + [ 97 | slice((len_in - len_out) / 2, len_in - (len_in - len_out) / 2) 98 | for len_out, len_in in zip(args.output_shape, args.input_shape) 99 | ] 100 | for i in range(args.iteration): 101 | vrn.cleargrads() 102 | image, label = load_sample( 103 | train_df, 104 | args.n_batch, 105 | args.input_shape, 106 | args.output_shape 107 | ) 108 | x_train = vrn.xp.asarray(image) 109 | y_train = vrn.xp.asarray(label) 110 | logits = vrn(x_train, train=True) 111 | logits = [logit[slices_in] for logit in logits] 112 | loss = F.softmax_cross_entropy(logits[-1], y_train) 113 | for logit in logits[:-1]: 114 | loss += F.softmax_cross_entropy(logit, y_train) 115 | loss.backward() 116 | optimizer.update() 117 | if i % args.monitor_step == 0: 118 | accuracy = float(F.accuracy(logits[-1], y_train).data) 119 | print( 120 | f"step {i:5d}, accuracy {accuracy:.02f}, loss {float(loss.data):g}" 121 | ) 122 | 123 | if args.validation_file is not None: 124 | dice_coefs = validate( 125 | vrn, 126 | df_val, 127 | args.input_shape, 128 | args.output_shape, 129 | args.n_tiles, 130 | dataset["n_classes"] 131 | ) 132 | mean_dice_coefs = np.mean(dice_coefs) 133 | if mean_dice_coefs > val_score: 134 | chainer.serializers.save_npz(args.out, vrn) 135 | print(f"step {i:5d}, saved model") 136 | val_score = mean_dice_coefs 137 | print( 138 | f"step {i:5d}", 139 | f"val/dice {mean_dice_coefs:.02f}", 140 | *[f"val/dice{j} {dice:.02f}" for j, dice in enumerate(dice_coefs)], 141 | sep=", " 142 | ) 143 | 144 | if args.validation_file is not None: 145 | dice_coefs = validate( 146 | vrn, 147 | df_val, 148 | args.input_shape, 149 | args.output_shape, 150 | args.n_tiles, 151 | dataset["n_classes"] 152 | ) 153 | mean_dice_coefs = np.mean(dice_coefs) 154 | if mean_dice_coefs > val_score: 155 | chainer.serializers.save_npz(args.out, vrn) 156 | print(f"step {args.iteration:5d}, saved model") 157 | print( 158 | f"step {args.iteration:5d}", 159 | f"val/dice {mean_dice_coefs:.02f}", 160 | *[f"val/dice{j} {dice:.02f}" for j, dice in enumerate(dice_coefs)], 161 | sep=", " 162 | ) 163 | else: 164 | chainer.serializers.save_npz(args.out, vrn) 165 | print(f"step {args.iteration:5d}, saved model") 166 | 167 | 168 | if __name__ == '__main__': 169 | main() 170 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import itertools 2 | import chainer 3 | import numpy as np 4 | import nibabel as nib 5 | 6 | 7 | def load_nifti(filename, with_affine=False): 8 | """ 9 | load image from NIFTI file 10 | Parameters 11 | ---------- 12 | filename : str 13 | filename of NIFTI file 14 | with_affine : bool 15 | if True, returns affine parameters 16 | 17 | Returns 18 | ------- 19 | data : np.ndarray 20 | image data 21 | """ 22 | img = nib.load(filename) 23 | data = img.get_data() 24 | data = np.copy(data, order="C") 25 | if with_affine: 26 | return data, img.affine 27 | return data 28 | 29 | 30 | def load_sample(df, n, input_shape, output_shape): 31 | """ 32 | randomly sample patch images from DataFrame 33 | 34 | Parameters 35 | ---------- 36 | df : pd.DataFrame 37 | DataFrame containing name of image files 38 | n : int 39 | number of patches to extract 40 | input_shape : list 41 | shape of input patches to extract 42 | output_shape : list 43 | shape of output patches to extract 44 | 45 | Returns 46 | ------- 47 | images : (n, n_channels, input_shape[0], input_shape[1], ...) ndarray 48 | input patches 49 | labels : (n, output_shape[0], output_shape[1], ...) ndarray 50 | label patches 51 | """ 52 | N = len(df) 53 | if "weight" in list(df): 54 | weights = np.asarray(df["weight"]) 55 | weights /= np.sum(weights) 56 | indices = np.random.choice(N, n, replace=True, p=weights) 57 | else: 58 | indices = np.random.choice(N, n, replace=True) 59 | image_files = df["image"][indices] 60 | label_files = df["label"][indices] 61 | images = [] 62 | labels = [] 63 | for image_file, label_file in zip(image_files, label_files): 64 | image = load_nifti(image_file) 65 | label = load_nifti(label_file).astype(np.int32) 66 | mask = np.int32(label > 0) 67 | slices = [slice(len_ // 2, -len_ // 2) for len_ in input_shape] 68 | mask[slices] *= 2 69 | indices = np.where(mask > 1.5) 70 | i = np.random.choice(len(indices[0])) 71 | input_slices = [ 72 | slice(index[i] - len_ // 2, index[i] + len_ // 2) 73 | for index, len_ in zip(indices, input_shape) 74 | ] 75 | output_slices = [ 76 | slice(index[i] - len_ // 2, index[i] + len_ // 2) 77 | for index, len_ in zip(indices, output_shape) 78 | ] 79 | image_patch = image[input_slices] 80 | label_patch = label[output_slices] 81 | image_patch = image_patch.transpose(3, 0, 1, 2) 82 | images.append(image_patch) 83 | labels.append(label_patch) 84 | images = np.array(images) 85 | labels = np.array(labels) 86 | return images, labels 87 | 88 | 89 | def crop_patch(image, center, shape): 90 | """ 91 | crop patch from an image 92 | 93 | Parameters 94 | ---------- 95 | image : (xlen, ylen, zlen, n_channels) np.ndarray 96 | input image to extract patch from 97 | center : [x, y, z] iterable 98 | center index of a patch 99 | shape : iterable 100 | shape of patch 101 | 102 | Returns 103 | ------- 104 | patch : (n_channels, xlen, ylen, zlen) np.ndarray 105 | extracted patch 106 | """ 107 | mini = [c - len_ // 2 for c, len_ in zip(center, shape)] 108 | maxi = [c + len_ // 2 for c, len_ in zip(center, shape)] 109 | if all(m >= 0 for m in mini) and all(m < img_len for m, img_len in zip(maxi, image.shape)): 110 | slices = [slice(mi, ma) for mi, ma in zip(mini, maxi)] 111 | else: 112 | slices = [ 113 | np.clip(range(mi, ma), 0, img_len - 1) 114 | for mi, ma, img_len in zip(mini, maxi, image.shape) 115 | ] 116 | slices = np.meshgrid(*slices, indexing="ij") 117 | patch = image[slices] 118 | patch = patch.transpose(3, 0, 1, 2) 119 | return patch 120 | 121 | 122 | def dice_coefficients(label1, label2, labels=None): 123 | if labels is None: 124 | labels = np.unique(np.hstack((np.unique(label1), np.unique(label2)))) 125 | dice_coefs = [] 126 | for label in labels: 127 | match1 = (label1 == label) 128 | match2 = (label2 == label) 129 | denominator = 0.5 * (np.sum(match1.astype(np.float)) + np.sum(match2.astype(np.float))) 130 | numerator = np.sum(np.logical_and(match1, match2).astype(np.float)) 131 | if denominator == 0: 132 | dice_coefs.append(0.) 133 | else: 134 | dice_coefs.append(numerator / denominator) 135 | return dice_coefs 136 | 137 | 138 | def feedforward(model, image, input_shape, output_shape, n_tiles, n_classes): 139 | centers = [[] for _ in range(3)] 140 | for img_len, len_out, center, n_tile in zip(image.shape, 141 | output_shape, 142 | centers, 143 | n_tiles): 144 | if img_len >= len_out * n_tile: 145 | raise ValueError( 146 | f"{img_len} must be smaller than {len_out} x {n_tile}" 147 | ) 148 | 149 | stride = (img_len - len_out) // (n_tile - 1) 150 | center.append(len_out // 2) 151 | for i in range(n_tile - 2): 152 | center.append(center[-1] + stride) 153 | center.append(img_len - len_out // 2) 154 | output = np.zeros((n_classes,) + image.shape[:-1]) 155 | for x, y, z in itertools.product(*centers): 156 | patch = crop_patch(image, [x, y, z], input_shape) 157 | patch = np.expand_dims(patch, 0) 158 | patch = model.xp.asarray(patch) 159 | slices_out = [slice(None)] + [ 160 | slice(center - len_out // 2, center + len_out // 2) 161 | for len_out, center in zip(output_shape, [x, y, z]) 162 | ] 163 | slices_in = [0, slice(None)] + [ 164 | slice((len_in - len_out) // 2, (len_out - len_in) // 2) 165 | for len_out, len_in, in zip(output_shape, input_shape) 166 | ] 167 | with chainer.no_backprop_mode(): 168 | output[slices_out] += chainer.cuda.to_cpu(model(patch).data)[slices_in] 169 | 170 | return output 171 | --------------------------------------------------------------------------------