├── monaifbs ├── src │ ├── train │ │ ├── __init__.py │ │ └── monai_dynunet_training.py │ ├── utils │ │ ├── __init__.py │ │ ├── custom_inferer.py │ │ ├── custom_transform.py │ │ └── custom_losses.py │ ├── inference │ │ ├── __init__.py │ │ └── monai_dynunet_inference.py │ └── __init__.py ├── __about__.py ├── __init__.py ├── config │ ├── mock_train_file_list_for_dynUnet_training.txt │ ├── mock_valid_file_list_for_dynUnet_training.txt │ ├── monai_dynUnet_inference_config.yml │ └── monai_dynUnet_training_config.yml └── fetal_brain_seg.py ├── .gitattributes ├── .gitignore ├── requirements.txt ├── setup.py ├── LICENSE └── README.md /monaifbs/src/train/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /monaifbs/src/utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /monaifbs/src/inference/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /.gitattributes: -------------------------------------------------------------------------------- 1 | *.pt filter=lfs diff=lfs merge=lfs -text 2 | -------------------------------------------------------------------------------- /monaifbs/src/__init__.py: -------------------------------------------------------------------------------- 1 | from .utils import * 2 | from .inference import * 3 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | *.nii 2 | *.nii.gz 3 | *.tar.gz 4 | *.deb 5 | *.zip 6 | *.pyc 7 | *iterate.dat 8 | *.egg-info 9 | *.idea 10 | *.wiki 11 | /.project 12 | .DS_Store 13 | **/.DS_Store 14 | __pycache__ 15 | **/__pycache__ 16 | models/ 17 | -------------------------------------------------------------------------------- /monaifbs/__about__.py: -------------------------------------------------------------------------------- 1 | __author__ = "Marta B.M. Ranzini" 2 | __email__ = "martabm.ranzini@gmail.com" 3 | __license__ = "Apache2.0" 4 | __version__ = "0.1.0" 5 | __summary__ = "MonaiFBS is a research-focused toolkit for " \ 6 | "fetal brain segmentation in 2D ultra-fast MRI." 7 | -------------------------------------------------------------------------------- /monaifbs/__init__.py: -------------------------------------------------------------------------------- 1 | from monaifbs.__about__ import ( 2 | __author__, 3 | __email__, 4 | __license__, 5 | __summary__, 6 | __version__, 7 | ) 8 | 9 | __all__ = [ 10 | "__author__", 11 | "__email__", 12 | "__license__", 13 | "__summary__", 14 | "__version__", 15 | ] 16 | -------------------------------------------------------------------------------- /monaifbs/config/mock_train_file_list_for_dynUnet_training.txt: -------------------------------------------------------------------------------- 1 | /path/to/file/for/subj1_img.nii.gz,/path/to/file/for/subj1_seg.nii.gz 2 | /path/to/file/for/subj2_img.nii.gz,/path/to/file/for/subj2_seg.nii.gz 3 | /path/to/file/for/subj3_img.nii.gz,/path/to/file/for/subj3_seg.nii.gz 4 | /path/to/file/for/subj4_img.nii.gz,/path/to/file/for/subj4_seg.nii.gz 5 | /path/to/file/for/subj5_img.nii.gz,/path/to/file/for/subj5_seg.nii.gz -------------------------------------------------------------------------------- /monaifbs/config/mock_valid_file_list_for_dynUnet_training.txt: -------------------------------------------------------------------------------- 1 | /path/to/file/for/subj1_img.nii.gz,/path/to/file/for/subj1_seg.nii.gz 2 | /path/to/file/for/subj2_img.nii.gz,/path/to/file/for/subj2_seg.nii.gz 3 | /path/to/file/for/subj3_img.nii.gz,/path/to/file/for/subj3_seg.nii.gz 4 | /path/to/file/for/subj4_img.nii.gz,/path/to/file/for/subj4_seg.nii.gz 5 | /path/to/file/for/subj5_img.nii.gz,/path/to/file/for/subj5_seg.nii.gz -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | matplotlib>=2.2.2 2 | natsort>=5.3.0 3 | nibabel>=2.4.1 4 | nipype>=1.0.3 5 | nose>=1.3.7 6 | nsol>=0.1.14 7 | numpy>=1.14.2,!=1.16.0 8 | pandas>=0.22.0 9 | pydicom>=1.2.0 10 | pysitk>=0.2.19 11 | scikit_image>=0.14.1 12 | scipy>=1.0.1 13 | seaborn>=0.8.1 14 | SimpleITK>=1.2.0 15 | six>=1.11.0 16 | torch>=1.4.0 17 | torch-summary>=1.3.2 18 | monai==0.3.0 19 | pyyaml>=5.3.1 20 | pytorch-ignite>=0.3.0 21 | tensorboard>=2.2.1: -------------------------------------------------------------------------------- /monaifbs/config/monai_dynUnet_inference_config.yml: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Marta Bianca Maria Ranzini and contributors 2 | # Licensed under the Apache License, Version 2.0 (the "License"); 3 | # you may not use this file except in compliance with the License. 4 | # You may obtain a copy of the License at 5 | # http://www.apache.org/licenses/LICENSE-2.0 6 | # Unless required by applicable law or agreed to in writing, software 7 | # distributed under the License is distributed on an "AS IS" BASIS, 8 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 9 | # See the License for the specific language governing permissions and 10 | # limitations under the License. 11 | 12 | device: 13 | num_workers: 1 # number of workers to use in pytorch for multi-processing 14 | 15 | inference: 16 | nr_out_channels: 2 # number of channels in the network output 17 | inplane_size: [448, 512] # 2D patch size, slices are either randomly cropped or padded to this dimension based on their size 18 | spacing: [0.8, 0.8, -1.0] # images are resampled to this spacing in mm (use -1.0 to preserve the original spacing in given direction) 19 | batch_size_inference: 1 # batch size at inferece, 1 is recommended 20 | probability_threshold: 0.5 # probability threshold to convert network output predictions to hard label 21 | model_to_load: "default" # path to pretrained network to be used for inference. If default, model in monaifbs/models/checkpoint_dynUnet_DiceXent.pt is used 22 | 23 | 24 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | ## 2 | # \file setup.py 3 | # 4 | # Instructions: 5 | # 1) `pip install -e .` 6 | # All python packages and command line tools are then installed during 7 | # 8 | # \author Marta B.M. Ranzini 9 | # \date November 2020 10 | # 11 | 12 | 13 | import re 14 | import os 15 | import sys 16 | from setuptools import setup, find_packages 17 | 18 | 19 | about = {} 20 | with open(os.path.join("monaifbs", "__about__.py")) as fp: 21 | exec(fp.read(), about) 22 | 23 | 24 | with open("README.md", "r") as fh: 25 | long_description = fh.read() 26 | 27 | 28 | def install_requires(fname="requirements.txt"): 29 | with open(fname) as f: 30 | content = f.readlines() 31 | content = [x.strip() for x in content] 32 | return content 33 | 34 | 35 | setup(name='MONAIfbs', 36 | version=about["__version__"], 37 | description=about["__summary__"], 38 | long_description=long_description, 39 | long_description_content_type="text/markdown", 40 | url='https://github.com/gift-surg/MONAIfbs', 41 | author=about["__author__"], 42 | author_email=about["__email__"], 43 | license=about["__license__"], 44 | packages=find_packages(), 45 | install_requires=install_requires(), 46 | zip_safe=False, 47 | keywords='Fetal brain segmentation with dynUnet', 48 | classifiers=[ 49 | 'Intended Audience :: Developers', 50 | 'Intended Audience :: Healthcare Industry', 51 | 'Intended Audience :: Science/Research', 52 | 53 | 'License :: OSI Approved :: Apache 2.0', 54 | 55 | 'Topic :: Software Development :: Build Tools', 56 | 'Topic :: Scientific/Engineering :: Medical Science Apps.', 57 | 58 | 'Programming Language :: Python', 59 | 'Programming Language :: Python :: 3', 60 | ], 61 | ) 62 | -------------------------------------------------------------------------------- /monaifbs/config/monai_dynUnet_training_config.yml: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Marta Bianca Maria Ranzini and contributors 2 | # Licensed under the Apache License, Version 2.0 (the "License"); 3 | # you may not use this file except in compliance with the License. 4 | # You may obtain a copy of the License at 5 | # http://www.apache.org/licenses/LICENSE-2.0 6 | # Unless required by applicable law or agreed to in writing, software 7 | # distributed under the License is distributed on an "AS IS" BASIS, 8 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 9 | # See the License for the specific language governing permissions and 10 | # limitations under the License. 11 | 12 | device: 13 | num_workers: 4 # number of workers to use in pytorch for multi-processing 14 | 15 | training: 16 | seg_labels: [0, 1] # labels to consider in the input ground truth segmentations 17 | inplane_size: [448, 512] # 2D patch size, slices are either randomly cropped or padded to this dimension based on their size 18 | spacing: [0.8, 0.8, -1.0] # images are resampled to this spacing in mm (use -1.0 to preserve the original spacing in given direction) 19 | loss_type: "dynDiceCELoss" # loss function to use. See monaifbs.src.training.monai_dynunet_training.choose_loss_function for options 20 | pow_dice: 1 # if loss has Dice loss term, defines the power to raise the Dice to. I.e. loss = Dice^pow 21 | batch_size_train: 14 # batch size for training 22 | batch_size_valid: 1 # batch size for validation, 1 is recommended 23 | nr_train_epochs: 20000 # number of training epochs [note: epoch is defined as extracting 1 patch per subject, so in 2D training only 1 slice is viewed per subject at each "epoch"] 24 | validation_every_n_epochs: 20 # number of training epochs to run in between validation loops 25 | lr: 1e-2 # initial learning rate [note: LR scheduler is (1 - epoch / nr_train_epochs) ** 0.9] 26 | manual_seed: 0 # set manual seed for determinism 27 | model_to_load: null # path to (pre-trained) network to load for continuing training. If null, training is restarted from scratch 28 | 29 | output: 30 | max_nr_models_saved: 1 # Maximum number of models to save in the output folders (older models are deleted) 31 | val_image_to_tensorboad: False # Output the segmentation results for validation on tensorboard [note: memory-consuming] 32 | 33 | log: 34 | message: "Training a 2D dynUNet with MONAI using Dice + Xent loss." # Logging message to briefly describe the experiment 35 | -------------------------------------------------------------------------------- /monaifbs/fetal_brain_seg.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Marta Bianca Maria Ranzini and contributors 2 | # Licensed under the Apache License, Version 2.0 (the "License"); 3 | # you may not use this file except in compliance with the License. 4 | # You may obtain a copy of the License at 5 | # http://www.apache.org/licenses/LICENSE-2.0 6 | # Unless required by applicable law or agreed to in writing, software 7 | # distributed under the License is distributed on an "AS IS" BASIS, 8 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 9 | # See the License for the specific language governing permissions and 10 | # limitations under the License. 11 | 12 | ## 13 | # \file fetal_brain_seg.py 14 | # \brief Script to apply automated fetal brain segmentation using a pre-trained dynUNet model in MONAI 15 | # Integrated within the NiftyMIC package, it performs the generation of brain masks for the 16 | # Super-Resolution Reconstruction. 17 | # This script is the default call by the executable niftymic_segment_fetal_brains if no other 18 | # fetal brain segmentation tool is specified. 19 | # 20 | # \author Marta B M Ranzini (marta.ranzini@kcl.ac.uk) 21 | # \date November 2020 22 | # 23 | 24 | import os 25 | import argparse 26 | import yaml 27 | import monaifbs 28 | 29 | from monaifbs.src.inference.monai_dynunet_inference import run_inference 30 | 31 | if __name__ == '__main__': 32 | 33 | parser = argparse.ArgumentParser(description='Run fetal brain segmentation using MONAI DynUNet.') 34 | parser.add_argument('--input_names', 35 | dest='input_names', 36 | metavar='input_names', 37 | type=str, 38 | nargs='+', 39 | help='input filenames to be automatically segmented', 40 | required=True) 41 | parser.add_argument('--segment_output_names', 42 | dest='segment_output_names', 43 | metavar='segment_output_names', 44 | type=str, 45 | nargs='+', 46 | help='output filenames where to store the segmentation masks', 47 | required=True) 48 | parser.add_argument('--config_file', 49 | dest='config_file', 50 | metavar='config_file', 51 | type=str, 52 | help='config file containing network information for inference', 53 | default=None) 54 | args = parser.parse_args() 55 | 56 | # check existence of config file and read it 57 | config_file = args.config_file 58 | if config_file is None: 59 | config_file = os.path.join(*[os.path.dirname(monaifbs.__file__), 60 | "config", "monai_dynUnet_inference_config.yml"]) 61 | if not os.path.isfile(config_file): 62 | raise FileNotFoundError('Expected config file: {} not found'.format(config_file)) 63 | with open(config_file) as f: 64 | print("*** Config file") 65 | print(config_file) 66 | config = yaml.load(f, Loader=yaml.FullLoader) 67 | 68 | if config['inference']['model_to_load'] == "default": 69 | config['inference']['model_to_load'] = os.path.join(*[os.path.dirname(monaifbs.__file__), 70 | "models", "checkpoint_dynUnet_DiceXent.pt"]) 71 | 72 | assert len(args.input_names) == len(args.segment_output_names), "The numbers of input output filenames do not match" 73 | 74 | # loop over all input files and run inference for each of them 75 | for img, seg in zip(args.input_names, args.segment_output_names): 76 | 77 | # set the output folder and add to the config file 78 | out_folder = os.path.dirname(seg) 79 | if not out_folder: 80 | out_folder = os.getcwd() 81 | if not os.path.exists(out_folder): 82 | os.makedirs(out_folder) 83 | config['output'] = {'out_postfix': 'seg', 'out_dir': out_folder} 84 | 85 | # run inference 86 | run_inference(input_data=img, config_info=config) 87 | 88 | # recover the filename generated by the inference code (as defined by MONAI output) 89 | img_filename = os.path.basename(img) 90 | flag_zip = 0 91 | if 'gz' in img_filename: 92 | img_filename = img_filename[:-7] 93 | flag_zip = 1 94 | else: 95 | img_filename = img_filename[:-4] 96 | out_filename = img_filename + '_' + config['output']['out_postfix'] + '.nii.gz' if flag_zip \ 97 | else img_filename + '_' + config['output']['out_postfix'] + '.nii' 98 | out_filename = os.path.join(*[out_folder, img_filename, out_filename]) 99 | 100 | # check existence of segmentation file 101 | if not os.path.exists(out_filename): 102 | raise FileNotFoundError("Network output file {} not found, " 103 | "check if the segmentation pipeline has failed".format(out_filename)) 104 | 105 | # rename file with the indicated output name 106 | os.rename(out_filename, seg) 107 | if os.path.exists(seg): 108 | os.rmdir(os.path.join(out_folder, img_filename)) 109 | 110 | 111 | 112 | 113 | 114 | 115 | 116 | 117 | -------------------------------------------------------------------------------- /monaifbs/src/utils/custom_inferer.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Marta Bianca Maria Ranzini and contributors 2 | # Licensed under the Apache License, Version 2.0 (the "License"); 3 | # you may not use this file except in compliance with the License. 4 | # You may obtain a copy of the License at 5 | # http://www.apache.org/licenses/LICENSE-2.0 6 | # Unless required by applicable law or agreed to in writing, software 7 | # distributed under the License is distributed on an "AS IS" BASIS, 8 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 9 | # See the License for the specific language governing permissions and 10 | # limitations under the License. 11 | 12 | ## 13 | # \file custom_inferer.py 14 | # \brief contains a series of classes to adapt the MONAI SlidingWindowInferer to the case of feeding slices 15 | # from a 3D volume into a 2D network. 16 | # Adapted from the MONAI class SlidingWindowInferer 17 | # https://github.com/Project-MONAI/MONAI/blob/releases/0.3.0/monai/inferers/inferer.py 18 | # 19 | # \author Marta B M Ranzini (marta.ranzini@kcl.ac.uk) 20 | # \date November 2020 21 | 22 | import copy 23 | import torch 24 | from typing import Union 25 | 26 | from monai.inferers.utils import sliding_window_inference 27 | from monai.inferers import Inferer 28 | from monai.utils import BlendMode 29 | 30 | 31 | class Predict2DFrom3D: 32 | """ 33 | Crop 2D slices from 3D inputs and perform 2D predictions. 34 | Args: 35 | predictor (Network): trained network to perform the prediction 36 | """ 37 | def __init__(self, 38 | predictor): 39 | self.predictor = predictor 40 | 41 | def __call__(self, data): 42 | """ 43 | Callable function to perform the prediction on input data given the defined predictor (network) after 44 | squeezing dimensions = 1. The removed dimension is added back after the prediction. 45 | Args: 46 | data: torch.tensor, model input data for inference. 47 | :return: 48 | """ 49 | # squeeze dimensions equal to 1 50 | orig_size = list(data.shape) 51 | data_size = list(data.shape[2:]) 52 | for idx_dim in range(2, 2+len(data_size)): 53 | if data_size[idx_dim-2] == 1: 54 | data = torch.squeeze(data, dim=idx_dim) 55 | predictions = self.predictor(data) # batched patch segmentation 56 | new_size = copy.deepcopy(orig_size) 57 | new_size[1] = predictions.shape[1] # keep original data shape, but take channel dimension from the prediction 58 | predictions = torch.reshape(predictions, new_size) 59 | return predictions 60 | 61 | 62 | class SlidingWindowInferer2D(Inferer): 63 | """ 64 | Sliding window method for model inference, 65 | with `sw_batch_size` windows for every model.forward(). 66 | Modified from monai.inferers.SlidingWindowInferer to squeeze the extra dimension derived from cropping slices from a 67 | 3D volume. In other words, reduces the input from [B, C, H, W, 1] to [B, C, H, W] for the forward pass through the 68 | network and then reshapes it back to [B, C, H, W, 1], before stitching all the patches back together. 69 | 70 | Args: 71 | roi_size (list, tuple): the window size to execute SlidingWindow evaluation. 72 | If it has non-positive components, the corresponding `inputs` size will be used. 73 | if the components of the `roi_size` are non-positive values, the transform will use the 74 | corresponding components of img size. For example, `roi_size=(32, -1)` will be adapted 75 | to `(32, 64)` if the second spatial dimension size of img is `64`. 76 | sw_batch_size: the batch size to run window slices. 77 | overlap: Amount of overlap between scans. 78 | mode: {``"constant"``, ``"gaussian"``} 79 | How to blend output of overlapping windows. Defaults to ``"constant"``. 80 | 81 | - ``"constant``": gives equal weight to all predictions. 82 | - ``"gaussian``": gives less weight to predictions on edges of windows. 83 | 84 | Note: 85 | the "sw_batch_size" here is to run a batch of window slices of 1 input image, 86 | not batch size of input images. 87 | 88 | """ 89 | 90 | def __init__( 91 | self, roi_size, sw_batch_size: int = 1, overlap: float = 0.25, mode: Union[BlendMode, str] = BlendMode.CONSTANT 92 | ): 93 | Inferer.__init__(self) 94 | self.roi_size = roi_size 95 | self.sw_batch_size = sw_batch_size 96 | self.overlap = overlap 97 | self.mode: BlendMode = BlendMode(mode) 98 | 99 | def __call__(self, inputs: torch.Tensor, network): 100 | """ 101 | Unified callable function API of Inferers. 102 | 103 | Args: 104 | inputs (torch.tensor): model input data for inference. 105 | network (Network): target model to execute inference. 106 | 107 | """ 108 | # convert the network to a callable that squeezes 3D slices to 2D before performing the network prediction 109 | predictor_2d = Predict2DFrom3D(network) 110 | return sliding_window_inference(inputs, self.roi_size, self.sw_batch_size, 111 | predictor_2d, self.overlap, self.mode) 112 | 113 | 114 | class SlidingWindowInferer2DWithResize(Inferer): 115 | """ 116 | Sliding window method for model inference, 117 | with `sw_batch_size` windows for every model.forward(). 118 | At inference, it applies a "resize" operation for the first two dimensions to match the network input size. 119 | After the forward pass, the network output is resized back to the original size. 120 | 121 | Args: 122 | roi_size (list, tuple): the window size to execute SlidingWindow evaluation. 123 | If it has non-positive components, the corresponding `inputs` size will be used. 124 | if the components of the `roi_size` are non-positive values, the transform will use the 125 | corresponding components of img size. For example, `roi_size=(32, -1)` will be adapted 126 | to `(32, 64)` if the second spatial dimension size of img is `64`. 127 | sw_batch_size: the batch size to run window slices. 128 | overlap: Amount of overlap between scans. 129 | mode: {``"constant"``, ``"gaussian"``} 130 | How to blend output of overlapping windows. Defaults to ``"constant"``. 131 | 132 | - ``"constant``": gives equal weight to all predictions. 133 | - ``"gaussian``": gives less weight to predictions on edges of windows. 134 | 135 | Note: 136 | the "sw_batch_size" here is to run a batch of window slices of 1 input image, 137 | not batch size of input images. 138 | 139 | """ 140 | 141 | def __init__( 142 | self, roi_size, sw_batch_size: int = 1, overlap: float = 0.25, mode: Union[BlendMode, str] = BlendMode.CONSTANT 143 | ): 144 | Inferer.__init__(self) 145 | self.roi_size = roi_size 146 | self.sw_batch_size = sw_batch_size 147 | self.overlap = overlap 148 | self.mode: BlendMode = BlendMode(mode) 149 | 150 | def __call__(self, inputs: torch.Tensor, network): 151 | """ 152 | Unified callable function API of Inferers. 153 | 154 | Args: 155 | inputs (torch.tensor): model input data for inference. 156 | network (Network): target model to execute inference. 157 | 158 | """ 159 | # resize the input to the appropriate network input 160 | orig_size = list(inputs.shape) 161 | resized_size = copy.deepcopy(orig_size) 162 | resized_size[2] = self.roi_size[0] 163 | resized_size[3] = self.roi_size[1] 164 | inputs_resize = torch.nn.functional.interpolate(inputs, size=resized_size[2:], mode='trilinear') 165 | 166 | # convert the network to a callable that squeezes 3D slices to 2D before performing the network prediction 167 | predictor_2d = Predict2DFrom3D(network) 168 | outputs = sliding_window_inference(inputs_resize, self.roi_size, self.sw_batch_size, 169 | predictor_2d, self.overlap, self.mode) 170 | 171 | # resize back to original size 172 | outputs = torch.nn.functional.interpolate(outputs, size=orig_size[2:], mode='nearest') 173 | return outputs -------------------------------------------------------------------------------- /monaifbs/src/utils/custom_transform.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Marta Bianca Maria Ranzini and contributors 2 | # Licensed under the Apache License, Version 2.0 (the "License"); 3 | # you may not use this file except in compliance with the License. 4 | # You may obtain a copy of the License at 5 | # http://www.apache.org/licenses/LICENSE-2.0 6 | # Unless required by applicable law or agreed to in writing, software 7 | # distributed under the License is distributed on an "AS IS" BASIS, 8 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 9 | # See the License for the specific language governing permissions and 10 | # limitations under the License. 11 | 12 | ## 13 | # \file custom_transform.py 14 | # \brief contains a series of custom dict transforms to be used in MONAI data preparation for the dynUnet model 15 | # 16 | # \author Marta B M Ranzini (marta.ranzini@kcl.ac.uk) 17 | # \date November 2020 18 | 19 | import numpy as np 20 | import copy 21 | from typing import Dict, Hashable, Mapping, Optional, Sequence, Union 22 | 23 | from monai.config import KeysCollection 24 | from monai.transforms import ( 25 | DivisiblePad, MapTransform, Spacing, Spacingd 26 | ) 27 | from monai.utils import ( 28 | NumpyPadMode, 29 | GridSampleMode, 30 | GridSamplePadMode, 31 | InterpolateMode, 32 | ensure_tuple, 33 | ensure_tuple_rep, 34 | fall_back_tuple, 35 | ) 36 | NumpyPadModeSequence = Union[Sequence[Union[NumpyPadMode, str]], NumpyPadMode, str] 37 | GridSampleModeSequence = Union[Sequence[Union[GridSampleMode, str]], GridSampleMode, str] 38 | GridSamplePadModeSequence = Union[Sequence[Union[GridSamplePadMode, str]], GridSamplePadMode, str] 39 | InterpolateModeSequence = Union[Sequence[Union[InterpolateMode, str]], InterpolateMode, str] 40 | 41 | 42 | class ConverToOneHotd(MapTransform): 43 | """ 44 | Convert multi-class label to One Hot Encoding 45 | """ 46 | 47 | def __init__(self, keys, labels): 48 | """ 49 | Args: 50 | keys: keys of the corresponding items to be transformed. 51 | See also: :py:class:`monai.transforms.compose.MapTransform` 52 | labels: list of labels to be converted to one-hot 53 | 54 | """ 55 | super().__init__(keys) 56 | self.labels = labels 57 | 58 | def __call__(self, data): 59 | d = dict(data) 60 | for key in self.keys: 61 | result = list() 62 | for n in self.labels: 63 | result.append(d[key] == n) 64 | d[key] = np.stack(result, axis=0).astype(np.float32) 65 | return d 66 | 67 | 68 | class MinimumPadd(MapTransform): 69 | """ 70 | Pad the input data, so that the spatial sizes are at least of size `k`. 71 | Dictionary-based wrapper of :py:class:`monai.transforms.DivisiblePad`. 72 | """ 73 | 74 | def __init__( 75 | self, keys: KeysCollection, k: Union[Sequence[int], int], mode: NumpyPadModeSequence = NumpyPadMode.CONSTANT 76 | ) -> None: 77 | """ 78 | Args: 79 | keys: keys of the corresponding items to be transformed. 80 | See also: :py:class:`monai.transforms.compose.MapTransform` 81 | k: the target k for each spatial dimension. 82 | if `k` is negative or 0, the original size is preserved. 83 | if `k` is an int, the same `k` be applied to all the input spatial dimensions. 84 | mode: {``"constant"``, ``"edge"``, ``"linear_ramp"``, ``"maximum"``, ``"mean"``, 85 | ``"median"``, ``"minimum"``, ``"reflect"``, ``"symmetric"``, ``"wrap"``, ``"empty"``} 86 | One of the listed string values or a user supplied function. Defaults to ``"constant"``. 87 | See also: https://numpy.org/doc/1.18/reference/generated/numpy.pad.html 88 | It also can be a sequence of string, each element corresponds to a key in ``keys``. 89 | See also :py:class:`monai.transforms.SpatialPad` 90 | """ 91 | super().__init__(keys) 92 | self.mode = ensure_tuple_rep(mode, len(self.keys)) 93 | self.k = k 94 | self.padder = DivisiblePad(k=k) 95 | 96 | def __call__(self, data): 97 | d = dict(data) 98 | for key, m in zip(self.keys, self.mode): 99 | spatial_shape = np.array(d[key].shape[1:]) 100 | k = np.array(fall_back_tuple(self.k, (1,) * len(spatial_shape))) 101 | if np.any(spatial_shape < k): 102 | d[key] = self.padder(d[key], mode=m) 103 | return d 104 | 105 | 106 | class InPlaneSpacingd(Spacingd): 107 | """ 108 | Performs the same operation as the MONAI Spacingd transform, but allows to preserve the spacing along some axes, 109 | which should be indicated as -1.0 in the input pixdim. 110 | E.g. pixdim=(0.8, 0.8, -1.0) would change the x-y plane spacing to (0.8, 0.8) while preserving the original 111 | spacing along z. 112 | See also :py:class: `monai.transforms.Spacingd` 113 | """ 114 | def __init__(self, 115 | keys: KeysCollection, 116 | pixdim: Sequence[float], 117 | diagonal: bool = False, 118 | mode: GridSampleModeSequence = GridSampleMode.BILINEAR, 119 | padding_mode: GridSamplePadModeSequence = GridSamplePadMode.BORDER, 120 | align_corners: Union[Sequence[bool], bool] = False, 121 | dtype: Optional[Union[Sequence[np.dtype], np.dtype]] = np.float64, 122 | meta_key_postfix: str = "meta_dict", 123 | ) -> None: 124 | """ 125 | Args 126 | keys: keys of the corresponding items to be transformed. 127 | See also: :py:class:`monai.transforms.compose.MapTransform` 128 | pixdim: output voxel spacing. 129 | diagonal: whether to resample the input to have a diagonal affine matrix. 130 | mode: {``"bilinear"``, ``"nearest"``}. Interpolation mode to calculate output values. 131 | Defaults to ``"bilinear"``. 132 | It also can be a sequence of string, each element corresponds to a key in ``keys``. 133 | padding_mode: {``"zeros"``, ``"border"``, ``"reflection"``} 134 | Padding mode for outside grid values. Defaults to ``"border"``. 135 | It also can be a sequence of string, each element corresponds to a key in ``keys``. 136 | align_corners: Geometrically, we consider the pixels of the input as squares rather than points. 137 | See also: https://pytorch.org/docs/stable/nn.functional.html#grid-sample 138 | It also can be a sequence of bool, each element corresponds to a key in ``keys``. 139 | dtype: data type for resampling computation. Defaults to ``np.float64`` for best precision. 140 | It also can be a sequence of bool, each element corresponds to a key in ``keys``. 141 | meta_key_postfix: use `key_{postfix}` to to fetch the meta data according to the key data, 142 | default is `meta_dict`, the meta data is a dictionary object. 143 | See also :py:class: `monai.transforms.Spacingd` for more information on the inputs 144 | """ 145 | super().__init__(keys, 146 | pixdim, 147 | diagonal, 148 | mode, 149 | padding_mode, 150 | align_corners, 151 | dtype, 152 | meta_key_postfix) 153 | self.pixdim = np.array(ensure_tuple(pixdim), dtype=np.float64) 154 | self.diagonal = diagonal 155 | self.dim_to_keep = np.argwhere(self.pixdim == -1.0) 156 | 157 | def __call__(self, 158 | data: Mapping[Union[Hashable, str], Dict[str, np.ndarray]] 159 | ) -> Dict[Union[Hashable, str], Union[np.ndarray, Dict[str, np.ndarray]]]: 160 | d = dict(data) 161 | for idx, key in enumerate(self.keys): 162 | meta_data = d[f"{key}_{self.meta_key_postfix}"] 163 | # set pixdim to original pixdim value where required 164 | current_pixdim = copy.deepcopy(self.pixdim) 165 | original_pixdim = meta_data["pixdim"] 166 | old_pixdim = original_pixdim[1:4] 167 | current_pixdim[self.dim_to_keep] = old_pixdim[self.dim_to_keep] 168 | 169 | # apply the transform 170 | spacing_transform = Spacing(current_pixdim, diagonal=self.diagonal) 171 | 172 | # resample array of each corresponding key 173 | # using affine fetched from d[affine_key] 174 | d[key], _, new_affine = spacing_transform( 175 | data_array=d[key], 176 | affine=meta_data["affine"], 177 | mode=self.mode[idx], 178 | padding_mode=self.padding_mode[idx], 179 | align_corners=self.align_corners[idx], 180 | dtype=self.dtype[idx], 181 | ) 182 | 183 | # store the modified affine 184 | meta_data["affine"] = new_affine 185 | return d 186 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /monaifbs/src/inference/monai_dynunet_inference.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Marta Bianca Maria Ranzini and contributors 2 | # Licensed under the Apache License, Version 2.0 (the "License"); 3 | # you may not use this file except in compliance with the License. 4 | # You may obtain a copy of the License at 5 | # http://www.apache.org/licenses/LICENSE-2.0 6 | # Unless required by applicable law or agreed to in writing, software 7 | # distributed under the License is distributed on an "AS IS" BASIS, 8 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 9 | # See the License for the specific language governing permissions and 10 | # limitations under the License. 11 | 12 | ## 13 | # \file monai_dynunet_inference.py 14 | # \brief Script to perform automated fetal brain segmentation using a pre-trained dynUNet model in MONAI 15 | # Example config file required by the main function is shown in 16 | # monaifbs/config/monai_dynUnet_inference_config.yml 17 | # Example of model loaded by this evaluation function is stored in 18 | # monaifbs/models/checkpoint_dynUnet_DiceXent.pt 19 | # 20 | # \author Marta B M Ranzini (marta.ranzini@kcl.ac.uk) 21 | # \date November 2020 22 | # 23 | 24 | import os 25 | import sys 26 | import yaml 27 | import argparse 28 | import logging 29 | import torch 30 | 31 | from torch.utils.data import DataLoader 32 | 33 | from monai.config import print_config 34 | from monai.data import DataLoader, Dataset 35 | from monai.networks.nets import DynUNet 36 | from monai.engines import SupervisedEvaluator 37 | from monai.handlers import CheckpointLoader, SegmentationSaver, StatsHandler 38 | from monai.transforms import ( 39 | Compose, 40 | LoadNiftid, 41 | AddChanneld, 42 | NormalizeIntensityd, 43 | ToTensord, 44 | Activationsd, 45 | AsDiscreted, 46 | KeepLargestConnectedComponentd 47 | ) 48 | 49 | import monaifbs 50 | from monaifbs.src.utils.custom_inferer import SlidingWindowInferer2D 51 | from monaifbs.src.utils.custom_transform import InPlaneSpacingd 52 | 53 | 54 | def create_data_list_of_dictionaries(input_files): 55 | """ 56 | Convert the list of input files to be processed in the dictionary format needed for MONAI 57 | Args: 58 | input_files: str or list of strings, filenames of images to be processed 59 | Returns: 60 | full_list: list of dicts, storing the filenames input to the inference pipeline 61 | """ 62 | 63 | print("*** Input data: ") 64 | full_list = [] 65 | # convert to list if single file 66 | if type(input_files) is str: 67 | input_files = [input_files] 68 | for current_f in input_files: 69 | if os.path.isfile(current_f): 70 | print(current_f) 71 | full_list.append({"image": current_f}) 72 | else: 73 | raise FileNotFoundError('Expected image file: {} not found'.format(current_f)) 74 | return full_list 75 | 76 | 77 | def run_inference(input_data, config_info): 78 | """ 79 | Pipeline to run inference with MONAI dynUNet model. The pipeline reads the input filenames, applies the required 80 | preprocessing and creates the pytorch dataloader; it then performs evaluation on each input file using a trained 81 | dynUNet model (random flipping augmentation is applied at inference). 82 | It uses the dynUNet model implemented in the MONAI framework 83 | (https://github.com/Project-MONAI/MONAI/blob/master/monai/networks/nets/dynunet.py) 84 | which is inspired by the nnU-Net framework (https://arxiv.org/abs/1809.10486) 85 | Inference is performed in 2D slice-by-slice, all slices are then recombined together into the 3D volume. 86 | 87 | Args: 88 | input_data: str or list of strings, filenames of images to be processed 89 | config_info: dict, contains the configuration parameters to reload the trained model 90 | 91 | """ 92 | 93 | """ 94 | Read input and configuration parameters 95 | """ 96 | 97 | val_files = create_data_list_of_dictionaries(input_data) 98 | 99 | # print MONAI config information 100 | logging.basicConfig(stream=sys.stdout, level=logging.INFO) 101 | print("*** MONAI config: ") 102 | print_config() 103 | 104 | # print to log the parameter setups 105 | print("*** Network inference config: ") 106 | print(yaml.dump(config_info)) 107 | 108 | # inference params 109 | nr_out_channels = config_info['inference']['nr_out_channels'] 110 | spacing = config_info["inference"]["spacing"] 111 | prob_thr = config_info['inference']['probability_threshold'] 112 | model_to_load = config_info['inference']['model_to_load'] 113 | if not os.path.exists(model_to_load): 114 | raise FileNotFoundError('Trained model not found') 115 | patch_size = config_info["inference"]["inplane_size"] + [1] 116 | print("Considering patch size = {}".format(patch_size)) 117 | 118 | # set up either GPU or CPU usage 119 | if torch.cuda.is_available(): 120 | print("\n#### GPU INFORMATION ###") 121 | print("Using device number: {}, name: {}".format(torch.cuda.current_device(), torch.cuda.get_device_name())) 122 | current_device = torch.device("cuda:0") 123 | else: 124 | current_device = torch.device("cpu") 125 | print("Using device: {}".format(current_device)) 126 | 127 | """ 128 | Data Preparation 129 | """ 130 | print("*** Preparing data ... ") 131 | # data preprocessing for inference: 132 | # - convert data to right format [batch, channel, dim, dim, dim] 133 | # - resample to the training resolution in-plane (not along z) 134 | # - apply whitening 135 | # - convert to tensor 136 | val_transforms = Compose( 137 | [ 138 | LoadNiftid(keys=["image"]), 139 | AddChanneld(keys=["image"]), 140 | InPlaneSpacingd( 141 | keys=["image"], 142 | pixdim=spacing, 143 | mode="bilinear", 144 | ), 145 | NormalizeIntensityd(keys=["image"], nonzero=False, channel_wise=True), 146 | ToTensord(keys=["image"]), 147 | ] 148 | ) 149 | # create a validation data loader 150 | val_ds = Dataset(data=val_files, transform=val_transforms) 151 | val_loader = DataLoader(val_ds, 152 | batch_size=1, 153 | num_workers=config_info['device']['num_workers']) 154 | 155 | def prepare_batch(batchdata): 156 | assert isinstance(batchdata, dict), "prepare_batch expects dictionary input data." 157 | return ( 158 | (batchdata["image"], batchdata["label"]) 159 | if "label" in batchdata 160 | else (batchdata["image"], None) 161 | ) 162 | 163 | """ 164 | Network preparation 165 | """ 166 | print("*** Preparing network ... ") 167 | # automatically extracts the strides and kernels based on nnU-Net empirical rules 168 | spacings = spacing[:2] 169 | sizes = patch_size[:2] 170 | strides, kernels = [], [] 171 | while True: 172 | spacing_ratio = [sp / min(spacings) for sp in spacings] 173 | stride = [2 if ratio <= 2 and size >= 8 else 1 for (ratio, size) in zip(spacing_ratio, sizes)] 174 | kernel = [3 if ratio <= 2 else 1 for ratio in spacing_ratio] 175 | if all(s == 1 for s in stride): 176 | break 177 | sizes = [i / j for i, j in zip(sizes, stride)] 178 | spacings = [i * j for i, j in zip(spacings, stride)] 179 | kernels.append(kernel) 180 | strides.append(stride) 181 | strides.insert(0, len(spacings) * [1]) 182 | kernels.append(len(spacings) * [3]) 183 | 184 | net = DynUNet( 185 | spatial_dims=2, 186 | in_channels=1, 187 | out_channels=nr_out_channels, 188 | kernel_size=kernels, 189 | strides=strides, 190 | upsample_kernel_size=strides[1:], 191 | norm_name="instance", 192 | deep_supervision=True, 193 | deep_supr_num=2, 194 | res_block=False 195 | ).to(current_device) 196 | 197 | """ 198 | Set ignite evaluator to perform inference 199 | """ 200 | print("*** Preparing evaluator ... ") 201 | if nr_out_channels == 1: 202 | do_sigmoid = True 203 | do_softmax = False 204 | elif nr_out_channels > 1: 205 | do_sigmoid = False 206 | do_softmax = True 207 | else: 208 | raise Exception("incompatible number of output channels") 209 | print("Using sigmoid={} and softmax={} as final activation".format(do_sigmoid, do_softmax)) 210 | val_post_transforms = Compose( 211 | [ 212 | Activationsd(keys="pred", sigmoid=do_sigmoid, softmax=do_softmax), 213 | AsDiscreted(keys="pred", argmax=True, threshold_values=True, logit_thresh=prob_thr), 214 | KeepLargestConnectedComponentd(keys="pred", applied_labels=1) 215 | ] 216 | ) 217 | val_handlers = [ 218 | StatsHandler(output_transform=lambda x: None), 219 | CheckpointLoader(load_path=model_to_load, load_dict={"net": net}, map_location=torch.device('cpu')), 220 | SegmentationSaver( 221 | output_dir=config_info['output']['out_dir'], 222 | output_ext='.nii.gz', 223 | output_postfix=config_info['output']['out_postfix'], 224 | batch_transform=lambda batch: batch["image_meta_dict"], 225 | output_transform=lambda output: output["pred"], 226 | ), 227 | ] 228 | 229 | # Define customized evaluator 230 | class DynUNetEvaluator(SupervisedEvaluator): 231 | def _iteration(self, engine, batchdata): 232 | inputs, targets = self.prepare_batch(batchdata) 233 | inputs = inputs.to(engine.state.device) 234 | if targets is not None: 235 | targets = targets.to(engine.state.device) 236 | flip_inputs_1 = torch.flip(inputs, dims=(2,)) 237 | flip_inputs_2 = torch.flip(inputs, dims=(3,)) 238 | flip_inputs_3 = torch.flip(inputs, dims=(2, 3)) 239 | 240 | def _compute_pred(): 241 | pred = self.inferer(inputs, self.network) 242 | # use random flipping as data augmentation at inference 243 | flip_pred_1 = torch.flip(self.inferer(flip_inputs_1, self.network), dims=(2,)) 244 | flip_pred_2 = torch.flip(self.inferer(flip_inputs_2, self.network), dims=(3,)) 245 | flip_pred_3 = torch.flip(self.inferer(flip_inputs_3, self.network), dims=(2, 3)) 246 | return (pred + flip_pred_1 + flip_pred_2 + flip_pred_3) / 4 247 | 248 | # execute forward computation 249 | self.network.eval() 250 | with torch.no_grad(): 251 | if self.amp: 252 | with torch.cuda.amp.autocast(): 253 | predictions = _compute_pred() 254 | else: 255 | predictions = _compute_pred() 256 | return {"image": inputs, "label": targets, "pred": predictions} 257 | 258 | evaluator = DynUNetEvaluator( 259 | device=current_device, 260 | val_data_loader=val_loader, 261 | network=net, 262 | prepare_batch=prepare_batch, 263 | inferer=SlidingWindowInferer2D(roi_size=patch_size, sw_batch_size=4, overlap=0.0), 264 | post_transform=val_post_transforms, 265 | val_handlers=val_handlers, 266 | amp=False, 267 | ) 268 | 269 | """ 270 | Run inference 271 | """ 272 | print("*** Running evaluator ... ") 273 | evaluator.run() 274 | print("Done!") 275 | 276 | return 277 | 278 | 279 | if __name__ == '__main__': 280 | 281 | parser = argparse.ArgumentParser(description='Run inference with dynUnet with MONAI.') 282 | parser.add_argument('--in_files', 283 | dest='in_files', 284 | metavar='in_files', 285 | type=str, 286 | nargs='+', 287 | help='all files to be processed', 288 | required=True) 289 | parser.add_argument('--out_folder', 290 | dest='out_folder', 291 | metavar='out_folder', 292 | type=str, 293 | help='directory where to store the outputs', 294 | required=True) 295 | parser.add_argument('--out_postfix', 296 | dest='out_postfix', 297 | metavar='out_postfix', 298 | type=str, 299 | help='postfix to add to the input names for the output filename', 300 | default='seg') 301 | parser.add_argument('--config_file', 302 | dest='config_file', 303 | metavar='config_file', 304 | type=str, 305 | help='config file containing network information for inference', 306 | default=None) 307 | args = parser.parse_args() 308 | 309 | # check existence of config file and read it 310 | config_file = args.config_file 311 | if config_file is None: 312 | config_file = os.path.join(*[os.path.dirname(monaifbs.__file__), 313 | "config", "monai_dynUnet_inference_config.yml"]) 314 | if not os.path.isfile(config_file): 315 | raise FileNotFoundError('Expected config file: {} not found'.format(config_file)) 316 | with open(config_file) as f: 317 | print("*** Config file") 318 | print(config_file) 319 | config = yaml.load(f, Loader=yaml.FullLoader) 320 | 321 | # read the input files 322 | in_files = args.in_files 323 | 324 | # add the output directory to the config dictionary 325 | config['output'] = {'out_postfix': args.out_postfix, 'out_dir': args.out_folder} 326 | if not os.path.exists(config['output']['out_dir']): 327 | os.makedirs(config['output']['out_dir']) 328 | 329 | if config['inference']['model_to_load'] == "default": 330 | config['inference']['model_to_load'] = os.path.join(*[os.path.dirname(monaifbs.__file__), 331 | "models", "checkpoint_dynUnet_DiceXent.pt"]) 332 | 333 | # run inference with MONAI dynUnet 334 | run_inference(in_files, config) 335 | -------------------------------------------------------------------------------- /monaifbs/src/utils/custom_losses.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Marta Bianca Maria Ranzini and contributors 2 | # Licensed under the Apache License, Version 2.0 (the "License"); 3 | # you may not use this file except in compliance with the License. 4 | # You may obtain a copy of the License at 5 | # http://www.apache.org/licenses/LICENSE-2.0 6 | # Unless required by applicable law or agreed to in writing, software 7 | # distributed under the License is distributed on an "AS IS" BASIS, 8 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 9 | # See the License for the specific language governing permissions and 10 | # limitations under the License. 11 | 12 | ## 13 | # \file custom_losses.py 14 | # \brief contains a series of loss functions that can be used to train the dynUNet model 15 | # The code is inspired and includes some modifications to the 16 | # DiceLoss implementation in MONAI 17 | # https://github.com/Project-MONAI/MONAI/blob/releases/0.3.0/monai/losses/dice.py 18 | # and the losses included in the dynUNet tutorial in MONAI 19 | # https://github.com/Project-MONAI/tutorials/blob/master/modules/dynunet_tutorial.ipynb 20 | # 21 | # \author Marta B M Ranzini (marta.ranzini@kcl.ac.uk) 22 | # \date November 2020 23 | 24 | import warnings 25 | from typing import Callable, Optional, Union 26 | 27 | import torch 28 | import torch.nn as nn 29 | from torch.nn.modules.loss import _Loss 30 | 31 | from monai.networks.utils import one_hot 32 | from monai.utils import LossReduction 33 | 34 | 35 | class DiceLossExtended(_Loss): 36 | """ 37 | Compute average Dice loss between two tensors. It can support both multi-classes and multi-labels tasks. 38 | Input logits `input` (BNHW[D] where N is number of classes) is compared with ground truth `target` (BNHW[D]). 39 | Axis N of `input` is expected to have logit predictions for each class rather than being image channels, 40 | while the same axis of `target` can be 1 or N (one-hot format). The `smooth` parameter is a value added to the 41 | intersection and union components of the inter-over-union calculation to smooth results and prevent divide by 0, 42 | this value should be small. The `include_background` class attribute can be set to False for an instance of 43 | DiceLoss to exclude the first category (channel index 0) which is by convention assumed to be background. 44 | If the non-background segmentations are small compared to the total image size they can get overwhelmed by 45 | the signal from the background so excluding it in such cases helps convergence. 46 | Milletari, F. et. al. (2016) V-Net: Fully Convolutional Neural Networks forVolumetric Medical Image Segmentation, 47 | 3DV, 2016. 48 | 49 | With respect to monai.losses.DiceLoss, this implementation allows for: 50 | - the use of a "Batch Dice" (batch version) as in the nnUNet implementation. The Dice is computed for the whole 51 | batch (1 value per class channel), as opposed to being computed for each element in the batch and then averaged 52 | across the batch. 53 | - the selection of different smooth terms at numerator and denominator. 54 | - the possibility to define a power term (pow) for the Dice, such as the returned loss is Dice^pow 55 | """ 56 | 57 | def __init__( 58 | self, 59 | include_background: bool = True, 60 | to_onehot_y: bool = False, 61 | sigmoid: bool = False, 62 | softmax: bool = False, 63 | other_act: Optional[Callable] = None, 64 | squared_pred: bool = False, 65 | pow: float = 1., 66 | jaccard: bool = False, 67 | reduction: Union[LossReduction, str] = LossReduction.MEAN, 68 | batch_version: bool = False, 69 | smooth_num: float = 1e-5, 70 | smooth_den: float = 1e-5 71 | ) -> None: 72 | """ 73 | Args: 74 | include_background: if False channel index 0 (background category) is excluded from the calculation. 75 | to_onehot_y: whether to convert `y` into the one-hot format. Defaults to False. 76 | sigmoid: if True, apply a sigmoid function to the prediction. 77 | softmax: if True, apply a softmax function to the prediction. 78 | other_act: if don't want to use `sigmoid` or `softmax`, use other callable function to execute 79 | other activation layers, Defaults to ``None``. for example: 80 | `other_act = torch.tanh`. 81 | squared_pred: use squared versions of targets and predictions in the denominator or not. 82 | pow: raise the Dice to the required power (default 1) 83 | jaccard: compute Jaccard Index (soft IoU) instead of dice or not. 84 | reduction: {``"none"``, ``"mean"``, ``"sum"``} 85 | Specifies the reduction to apply to the output. Defaults to ``"mean"``. 86 | - ``"none"``: no reduction will be applied. 87 | - ``"mean"``: the sum of the output will be divided by the number of elements in the output. 88 | - ``"sum"``: the output will be summed. 89 | batch_version: if True, a single Dice value is computed for the whole batch per class. If False, the Dice 90 | is computed per element in the batch and then reduced (sum/average/None) across the batch. 91 | smooth_num: a small constant to be added to the numerator of Dice to avoid nan. 92 | smooth_den: a small constant to be added to the denominator of Dice to avoid nan. 93 | Raises: 94 | TypeError: When ``other_act`` is not an ``Optional[Callable]``. 95 | ValueError: When more than 1 of [``sigmoid=True``, ``softmax=True``, ``other_act is not None``]. 96 | Incompatible values. 97 | """ 98 | super().__init__(reduction=LossReduction(reduction).value) 99 | if other_act is not None and not callable(other_act): 100 | raise TypeError(f"other_act must be None or callable but is {type(other_act).__name__}.") 101 | if int(sigmoid) + int(softmax) + int(other_act is not None) > 1: 102 | raise ValueError("Incompatible values: more than 1 of [sigmoid=True, softmax=True, other_act is not None].") 103 | self.include_background = include_background 104 | self.to_onehot_y = to_onehot_y 105 | self.sigmoid = sigmoid 106 | self.softmax = softmax 107 | self.other_act = other_act 108 | self.squared_pred = squared_pred 109 | self.pow = pow 110 | self.jaccard = jaccard 111 | self.batch_version = batch_version 112 | self.smooth_num = smooth_num 113 | self.smooth_den = smooth_den 114 | 115 | def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: 116 | """ 117 | Args: 118 | input: the shape should be BNH[WD]. 119 | target: the shape should be BNH[WD] 120 | Raises: 121 | ValueError: When ``self.reduction`` is not one of ["mean", "sum", "none"]. 122 | """ 123 | if self.sigmoid: 124 | input = torch.sigmoid(input) 125 | 126 | n_pred_ch = input.shape[1] 127 | if self.softmax: 128 | if n_pred_ch == 1: 129 | warnings.warn("single channel prediction, `softmax=True` ignored.") 130 | else: 131 | input = torch.softmax(input, 1) 132 | 133 | if self.other_act is not None: 134 | input = self.other_act(input) 135 | 136 | if self.to_onehot_y: 137 | if n_pred_ch == 1: 138 | warnings.warn("single channel prediction, `to_onehot_y=True` ignored.") 139 | else: 140 | target = one_hot(target, num_classes=n_pred_ch) 141 | 142 | if not self.include_background: 143 | if n_pred_ch == 1: 144 | warnings.warn("single channel prediction, `include_background=False` ignored.") 145 | else: 146 | # if skipping background, removing first channel 147 | target = target[:, 1:] 148 | input = input[:, 1:] 149 | 150 | assert ( 151 | target.shape == input.shape 152 | ), f"ground truth has differing shape ({target.shape}) from input ({input.shape})" 153 | 154 | if self.batch_version: 155 | # reducing only spatial dimensions and batch (not channels) 156 | reduce_axis = [0] + list(range(2, len(input.shape))) 157 | else: 158 | # reducing only spatial dimensions (not batch nor channels) 159 | reduce_axis = list(range(2, len(input.shape))) 160 | intersection = torch.sum(target * input, dim=reduce_axis) 161 | 162 | if self.squared_pred: 163 | target = torch.pow(target, 2) 164 | input = torch.pow(input, 2) 165 | 166 | ground_o = torch.sum(target, dim=reduce_axis) 167 | pred_o = torch.sum(input, dim=reduce_axis) 168 | 169 | denominator = ground_o + pred_o 170 | 171 | if self.jaccard: 172 | denominator = 2.0 * (denominator - intersection) 173 | 174 | f: torch.Tensor = (1.0 - (2.0 * intersection + self.smooth_num) / (denominator + self.smooth_den)) ** self.pow 175 | 176 | if self.reduction == LossReduction.MEAN.value: 177 | f = torch.mean(f) # the batch and channel average 178 | elif self.reduction == LossReduction.SUM.value: 179 | f = torch.sum(f) # sum over the batch and channel dims 180 | elif self.reduction == LossReduction.NONE.value: 181 | pass # returns [N, n_classes] losses or [n_classes] if batch version 182 | else: 183 | raise ValueError(f'Unsupported reduction: {self.reduction}, available options are ["mean", "sum", "none"].') 184 | 185 | return f 186 | 187 | 188 | # CUSTOM LOSSES FROM dynUNet tutorial for dynUNet training 189 | # code from https://github.com/Project-MONAI/tutorials/blob/master/modules/dynunet_tutorial.ipynb 190 | class CrossEntropyLoss(nn.Module): 191 | """ 192 | Compute the multi-channel cross entropy between predictions and ground truth. 193 | """ 194 | def __init__(self): 195 | super().__init__() 196 | self.loss = nn.CrossEntropyLoss() 197 | 198 | def forward(self, y_pred, y_true): 199 | """ 200 | Args: 201 | y_pred: the shape should be BNH[WD]. 202 | y_true: the shape should be BNH[WD] 203 | """ 204 | # CrossEntropyLoss target needs to have shape (B, D, H, W) 205 | # Target from pipeline has shape (B, 1, D, H, W) 206 | y_true = torch.squeeze(y_true, dim=1).long() 207 | return self.loss(y_pred, y_true) 208 | 209 | 210 | class DiceCELoss(nn.Module): 211 | """ 212 | Compute the loss function = Dice + Cross Entropy. 213 | The monaifbs.src.utils.custom_losses.DiceLossExtended class is used to compute the Dice score, which gives 214 | flexibility on the type of Dice to compute (e.g. use the Dice per image averaged across the batch 215 | or the Batch Dice). 216 | The monaifbs.src.utils.custom_losses.CrossEntropyLoss class is used to compute the cross entropy. 217 | """ 218 | def __init__(self, 219 | include_background: bool = True, 220 | to_onehot_y: bool = True, 221 | sigmoid: bool = False, 222 | softmax: bool = True, 223 | other_act: Optional[Callable] = None, 224 | squared_pred: bool = False, 225 | pow: float = 1., 226 | jaccard: bool = False, 227 | reduction: Union[LossReduction, str] = LossReduction.MEAN, 228 | batch_version: bool = False, 229 | smooth_num: float = 1e-5, 230 | smooth_den: float = 1e-5 231 | ) -> None: 232 | """ 233 | Args: 234 | include_background: if False channel index 0 (background category) is excluded from the calculation. 235 | to_onehot_y: whether to convert `y` into the one-hot format. Defaults to False. 236 | sigmoid: if True, apply a sigmoid function to the prediction. 237 | softmax: if True, apply a softmax function to the prediction. 238 | other_act: if don't want to use `sigmoid` or `softmax`, use other callable function to execute 239 | other activation layers, Defaults to ``None``. for example: 240 | `other_act = torch.tanh`. 241 | squared_pred: use squared versions of targets and predictions in the denominator or not. 242 | pow: raise the Dice to the required power (default 1) 243 | jaccard: compute Jaccard Index (soft IoU) instead of dice or not. 244 | reduction: {``"none"``, ``"mean"``, ``"sum"``} 245 | Specifies the reduction to apply to the output. Defaults to ``"mean"``. 246 | - ``"none"``: no reduction will be applied. 247 | - ``"mean"``: the sum of the output will be divided by the number of elements in the output. 248 | - ``"sum"``: the output will be summed. 249 | batch_version: if True, a single Dice value is computed for the whole batch per class. If False, the Dice 250 | is computed per element in the batch and then reduced (sum/average/None) across the batch. 251 | smooth_num: a small constant to be added to the numerator of Dice to avoid nan. 252 | smooth_den: a small constant to be added to the denominator of Dice to avoid nan. 253 | """ 254 | super().__init__() 255 | self.dice = DiceLossExtended(include_background=include_background, 256 | to_onehot_y=to_onehot_y, 257 | sigmoid=sigmoid, 258 | softmax=softmax, 259 | other_act=other_act, 260 | squared_pred=squared_pred, 261 | pow=pow, 262 | jaccard=jaccard, 263 | reduction=reduction, 264 | batch_version=batch_version, 265 | smooth_num=smooth_num, 266 | smooth_den=smooth_den) 267 | self.cross_entropy = CrossEntropyLoss() 268 | 269 | def forward(self, y_pred, y_true): 270 | """ 271 | Args 272 | y_pred: the shape should be BNH[WD]. 273 | y_true: the shape should be BNH[WD]. 274 | """ 275 | dice = self.dice(y_pred, y_true) 276 | cross_entropy = self.cross_entropy(y_pred, y_true) 277 | return dice + cross_entropy 278 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Fetal brain segmentation with MONAI DynUNet 2 | 3 | MONAIfbs (MONAI Fetal Brain Segmentation) is a Pytorch-based toolkit to train and test deep learning models for automated 4 | fetal brain segmentation of HASTE-like MR images. 5 | The toolkit was developed within the [GIFT-Surg][giftsurg] research project, and takes advantage of [MONAI][monai], 6 | a freely available, community-supported, PyTorch-based framework for deep learning in healthcare imaging. 7 | 8 | A pre-trained dynUNet model is [provided][dynUnetmodel] and can be directly used for inference on new data using 9 | the script `src/inference/monai_dynunet_inference.py`. Alternatively, the script `fetal_brain_seg.py` provides 10 | the same inference functionality within an appropriate interface to be used within the [NiftyMIC][NiftyMIC] package and by the 11 | executable command `niftymic_segment_fetal_brains`. See the sections [Inference][inference_section] and 12 | [Use within NiftyMIC][use_section] below. 13 | 14 | More information about MONAI dynUNet can be found [here][dynUnettutorial]. This deep learning pipeline is based on the 15 | [nnU-Net][nnunet] self-adapting framework for U-Net-based medical image segmentation. 16 | 17 | ### Contact information 18 | This package was developed by [Marta B.M. Ranzini][mranzini] at the [Department of Surgical and Interventional Sciences][sie], 19 | [King's College London (KCL)][kcl] (2020). 20 | If you have any questions or comments, please open an issue on GitHub or contact Prof Tom Vercauteren at 21 | `tom.vercauteren@kcl.ac.uk`. 22 | 23 | ## Important note 24 | Please make sure you download the [pre-trained model][dynUnetmodel] for inference and add it in the MONAIfbs folder as follows 25 | `/monaifbs/models/checkpoint_dynUnet_DiceXent.pt`. 26 | You can either download it manually from the webpage or you can use the `zenodo_get` tool from command line: 27 | ``` 28 | pip install zenodo-get 29 | zenodo_get 10.5281/zenodo.4282679 30 | tar xvd models.tar.gz 31 | mv models /monaifbs/ 32 | ``` 33 | 34 | Alternatively, you can store the model at a different location, but update the [inference config file][inference_config] 35 | with the right path to the model to load. 36 | 37 | ## Installation 38 | After installing git lfs, clone the repository locally using 39 | `git clone https://github.com/gift-surg/MONAIfbs.git` 40 | 41 | Change to the downloaded directory and install all Python and Pytorch dependencies by running the following commands sequentially: 42 | `pip install -r requirements.txt` 43 | `pip install -e .` 44 | 45 | *Note*: MONAI and monaifbs require Python versions >= 3.6. 46 | 47 | *Note*: A CUDA compatible GPU is recommended for training. Inference can be run both on GPU or on CPU. 48 | 49 | 50 | ## Training 51 | A python script was developed to train a [dynUNet][dynUnettutorial] model using [MONAI][monai]. 52 | The dynUNet is based on the [nnU-Net][nnunet] approach, which employs a series of heuristic rules to determine 53 | the optimal kernel sizes, strides and network depth from the training set. 54 | 55 | The available script trains a 2D dynUNet by randomly sampling 2D slices from the training set. By default, Dice + Cross 56 | Entropy is used as loss function (other options are available). 57 | Validation during training is also performed: a 58 | whole-volume validation strategy is applied (using a 2D sliding-window approach throughout each 3D image) and 59 | Mean 3D Dice Score over the validation set is used as metric for best model selection. 60 | 61 | #### Setting the training to run 62 | 63 | To run the training with your own data, the following command can be used: 64 | ``` 65 | python /monaifbs/src/train/monai_dynunet_training.py \ 66 | --train_files_list \ 67 | --validation_files_list \ 68 | --out_folder 69 | ``` 70 | The files `` and `` should be either .txt or 71 | .csv files storing pairs of image-segmentation filenames in each line, separated by a comma, as follows: 72 | ``` 73 | /path/to/file/for/subj1_img.nii.gz,/path/to/file/for/subj1_seg.nii.gz 74 | /path/to/file/for/subj2_img.nii.gz,/path/to/file/for/subj2_seg.nii.gz 75 | /path/to/file/for/subj3_img.nii.gz,/path/to/file/for/subj3_seg.nii.gz 76 | ... 77 | ``` 78 | Examples of the expected file formats are in `config/mock_train_file_list_for_dynUnet_training.txt` and 79 | `config/mock_valid_file_list_for_dynUnet_training.txt`. 80 | 81 | See `python /monaifbs/src/train/monai_dynunet_training.py -h` for help on additional input arguments. 82 | 83 | #### Changing the network configurations 84 | By default, the network will be trained with the configurations defined in `config/monai_dynUnet_training_config.yml`. 85 | See [the file][training_config] for a description of the user-defined parameters. 86 | To change the parameter values, create your own yaml config file following the structure [here][training_config]. The 87 | new config file can be input as an argument when running the training as follows: 88 | ``` 89 | python /monaifbs/src/train/monai_dynunet_training.py \ 90 | --train_files_list \ 91 | --validation_files_list \ 92 | --out_folder 93 | --config_file 94 | ``` 95 | When running inference, make sure the config file for inference is also updated accordingly, otherwise the model might 96 | not be correctly reloaded (See the section Inference below). 97 | 98 | #### Using the GPU 99 | The code is optimised to be used with 1 GPU (multi-GPU computation is not supported at present). 100 | To set the GPU to use, run the command `export CUDA_VISIBLE_DEVICES=` before running the python commands 101 | described above. 102 | 103 | #### Understanding the output 104 | The script will generate two subfolders in the indicated output directory: 105 | * folder with name formatted as `Year-month-day_hours-minutes-seconds_out_postfix`, which stores the results 106 | of the training. `out_postix` is `monai_dynUnet_2D` by default, but can be changed as input argument when running the training. 107 | This folder contains: 108 | * `best_valid_checkpoint_key_metric=####.pt`: saved pytorch model best performing on the validation set 109 | (based on Mean 3D Dice Score) 110 | * `checkpoint_epoch=####.pth`: latest saved pytorch model 111 | * directories `train` and `valid` storing the tensorboard outputs for the training and the validation respectively 112 | 113 | 114 | 115 | * folder named `persistent_cache`. To speed up the computation, the script uses MONAI 116 | [PersistentDataset][persistent_dataset], which pre-computes and stores to disk all the non-random pre-processing steps 117 | (pre-processing transforms outputs). This folder stores the results of these pre-computations. 118 | 119 | 120 | *Notes on the persistent cache*: 121 | 1. The persistent cache dataset favours reusability of pre-processed data when multiple runs need to be executed 122 | (e.g. for hyperparameters tuning). To change the location of this persistent cache or to re-use a pre-existing cache, 123 | the option `--cache_dir ` can be used in the command line for setting the training to run. 124 | 2. The persistent cache can take up quite some large amount of storage space, depending on the size of the training set 125 | and on the selected patch size for training (Example: with about 400 3D volumes and default patch size (418, 512), 126 | it took about 30G). 127 | 3. Alternate solutions to the PersistentCache exist which do not use this much storage space, but are not currently 128 | implemented in the training script. See this [MONAI tutorial][monai_datasets] for more information. 129 | To integrate other MONAI Datasets into the script, change the `train_ds` and `val_ds` definitions 130 | in `src/train/monai_dynunet_training.py`. 131 | 132 | ## Inference 133 | **Note** If using the provided pre-trained model, please make sure you have downloaded it and placed it as expected in 134 | `/monaifbs/models`. More details in the [important Note][importantnote] above. 135 | 136 | Inference can be run with the provided inference script with the following command: 137 | ``` 138 | python /monaifbs/src/inference/monai_dynunet_inference.py \ 139 | --in_files ... \ 140 | --out_folder 141 | ``` 142 | 143 | By default, this will use the provided [pre-trained model][dynUnetmodel] and the network configuration parameters 144 | reported in this [config file][inference_config]. If you want to specify a different (MONAI dynUNet) trained model, 145 | you can create your own config file indicating the model to load and its network configuration parameters 146 | following the provided [template][inference_config]. Then, you can simply run inference as: 147 | ``` 148 | python /monaifbs/src/inference/monai_dynunet_inference.py \ 149 | --in_files ... \ 150 | --out_folder \ 151 | --config_file 152 | ``` 153 | 154 | ## Use within NiftyMIC (for inference) 155 | The automated segmentation tool was developed in support to the Super-Resolution Reconstruction package [NiftyMIC][NiftyMIC]. 156 | By default, NiftyMIC uses monaifbs utilities to automatically generate fetal brain segmentation masks that can be used 157 | for the reconstruction pipeline. 158 | Provided the dependencies for NiftyMIC and monaifbs are installed, create the automatic fetal brain masks of HASTE-like 159 | images with the command: 160 | ``` 161 | niftymic_segment_fetal_brains \ 162 | --filenames \ 163 | nifti/name-of-stack-1.nii.gz \ 164 | nifti/name-of-stack-2.nii.gz \ 165 | nifti/name-of-stack-N.nii.gz \ 166 | --filenames-masks \ 167 | seg/name-of-stack-1.nii.gz \ 168 | seg/name-of-stack-2.nii.gz \ 169 | seg/name-of-stack-N.nii.gz 170 | ``` 171 | 172 | The interface between the niftymic package and monaifbs inference utilities is defined in `fetal_brain_seg.py`. 173 | Its working scheme is essentially the same as the `monai_dynunet_inference.py`, it simply provides a wrapper to feed 174 | the input data to the `run_inference()` function in `monai_dynunet_inference.py`. 175 | It can also be used as a standalone script for inference as follows: 176 | ``` 177 | python /monaifbs/fetal_brain_seg.py \ 178 | --input_names ... \ 179 | --segment_output_names ... 180 | ``` 181 | 182 | ## Troubleshooting 183 | #### Issue with ParallelNative on MacOS 184 | 185 | A warning message from ParallelNative is shown and the computation gets stuck. This issue appears to happen only on 186 | MacOS and is known to be linked to PyTorch DataLoader (as reported in https://github.com/pytorch/pytorch/issues/46409) 187 | 188 | Warning message: `[W ParallelNative.cpp:206] Warning: Cannot set number of intraop threads after parallel work has started 189 | or after set_num_threads call when using native parallel backend (function set_num_threads` 190 | 191 | When observed: MacOS, Python 3.6 and Python 3.7, running on CPU. 192 | 193 | Solution: add `OMP_NUM_THREADS=1` before the call of monaifbs scripts. 194 | Example 1: 195 | ``` 196 | OMP_NUM_THREADS=1 python /monaifbs/src/inference/monai_dynunet_inference.py \ 197 | --in_files ... \ 198 | --out_folder 199 | ``` 200 | Example 2: 201 | ``` 202 | OMP_NUM_THREADS=1 niftymic_segment_fetal_brains \ 203 | --filenames \ 204 | nifti/name-of-stack-1.nii.gz \ 205 | nifti/name-of-stack-2.nii.gz \ 206 | nifti/name-of-stack-N.nii.gz \ 207 | --filenames-masks \ 208 | seg/name-of-stack-1.nii.gz \ 209 | seg/name-of-stack-2.nii.gz \ 210 | seg/name-of-stack-N.nii.gz 211 | ``` 212 | 213 | ## Disclaimer 214 | 215 | Not intended for clinical use. 216 | 217 | ## Licensing and Copyright 218 | Copyright (c) 2020 Marta Bianca Maria Ranzini and contributors. 219 | Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at 220 | 221 | `http://www.apache.org/licenses/LICENSE-2.0` 222 | 223 | Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. 224 | 225 | Other licenses may apply for dependencies. 226 | 227 | ## Acknowledgements 228 | This work is part of the [GIFT-Surg project][giftsurg] and is funded by the Innovative Engineering for Health award ([Wellcome Trust][wellcometrust] [WT101957] and [EPSRC][epsrc] [NS/A000027/1]), 229 | the Wellcome/EPSRC [Centre for Medical Engineering][cme] [WT 203148/Z/16/Z, NS/A000049/1]] and supported by researchers at the NIHR Biomedical Research 230 | Centre based at GSTT NHS Trust and King's College London. 231 | 232 | [giftsurg]: http://www.gift-surg.ac.uk 233 | [tomvercauteren]: tom.vercauteren@kcl.ac.uk 234 | [kcl]: https://www.kcl.ac.uk 235 | [sie]: https://www.kcl.ac.uk/bmeis/our-departments/surgical-interventional-engineering 236 | [monai]: https://monai.io/ 237 | [installation]: https://github.com/gift-surg/NiftyMIC/wiki/niftymic-installation 238 | [gitlfs]: https://github.com/git-lfs/git-lfs/wiki/Installation 239 | [dynUnetmodel]: https://zenodo.org/record/4282679#.X7fyttvgqL5 240 | [inference_section]: https://github.com/gift-surg/MONAIfbs#inference 241 | [use_section]: https://github.com/gift-surg/MONAIfbs#use-within-niftymic-for-inference 242 | [dynUnettutorial]: https://github.com/Project-MONAI/tutorials/blob/master/modules/dynunet_tutorial.ipynb 243 | [nnunet]: https://arxiv.org/abs/1809.10486 244 | [inference_config]: https://github.com/gift-surg/MONAIfbs/blob/main/monaifbs/config/monai_dynUnet_inference_config.yml 245 | [training_config]: https://github.com/gift-surg/MONAIfbs/blob/main/monaifbs/config/monai_dynUnet_training_config.yml 246 | [mranzini]: https://www.linkedin.com/in/marta-bianca-maria-ranzini 247 | [bsd]: https://opensource.org/licenses/BSD-3-Clause 248 | [persistent_dataset]: https://github.com/Project-MONAI/MONAI/blob/9f51893d162e5650f007dff8e0bcc09f0d9a6680/monai/data/dataset.py#L71 249 | [monai_datasets]: https://github.com/Project-MONAI/tutorials/blob/master/acceleration/dataset_type_performance.ipynb 250 | [wellcometrust]: http://www.wellcome.ac.uk 251 | [epsrc]: http://www.epsrc.ac.uk 252 | [cme]: https://medicalengineering.org.uk/ 253 | [NiftyMIC]: https://github.com/gift-surg/NiftyMIC 254 | [importantnote]: https://github.com/gift-surg/MONAIfbs#important-note -------------------------------------------------------------------------------- /monaifbs/src/train/monai_dynunet_training.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Marta Bianca Maria Ranzini and contributors 2 | # Licensed under the Apache License, Version 2.0 (the "License"); 3 | # you may not use this file except in compliance with the License. 4 | # You may obtain a copy of the License at 5 | # http://www.apache.org/licenses/LICENSE-2.0 6 | # Unless required by applicable law or agreed to in writing, software 7 | # distributed under the License is distributed on an "AS IS" BASIS, 8 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 9 | # See the License for the specific language governing permissions and 10 | # limitations under the License. 11 | 12 | ## 13 | # \file monai_dynunet_training.py 14 | # \brief Script to train a dynUNet model in MONAI for automated segmentation 15 | # Example config file required by the main function is shown in 16 | # monaifbs/config/monai_dynUnet_training_config.yml 17 | # Example of model generated by this training function is stored in 18 | # monaifbs/models/checkpoint_dynUnet_DiceXent.pt 19 | # 20 | # \author Marta B M Ranzini (marta.ranzini@kcl.ac.uk) 21 | # \date November 2020 22 | # 23 | # This code was adapted from the dynUNet tutorial in MONAI 24 | # https://github.com/Project-MONAI/tutorials/blob/master/modules/dynunet_tutorial.ipynb 25 | 26 | import os 27 | import sys 28 | import logging 29 | import yaml 30 | from datetime import datetime 31 | import argparse 32 | from pathlib import Path 33 | 34 | import torch 35 | from torch.nn.functional import interpolate 36 | 37 | from torch.utils.tensorboard import SummaryWriter 38 | from monai.config import print_config 39 | from monai.data import DataLoader, PersistentDataset 40 | from monai.utils import misc, set_determinism 41 | from monai.engines import SupervisedTrainer 42 | from monai.networks.nets import DynUNet 43 | from monai.transforms import ( 44 | Compose, 45 | LoadNiftid, 46 | AddChanneld, 47 | CropForegroundd, 48 | SpatialPadd, 49 | NormalizeIntensityd, 50 | RandSpatialCropd, 51 | RandZoomd, 52 | RandGaussianNoised, 53 | RandGaussianSmoothd, 54 | RandScaleIntensityd, 55 | RandRotated, 56 | RandFlipd, 57 | SqueezeDimd, 58 | ToTensord, 59 | ) 60 | 61 | from monai.engines import SupervisedEvaluator 62 | from monai.handlers import ( 63 | LrScheduleHandler, 64 | StatsHandler, 65 | CheckpointSaver, 66 | MeanDice, 67 | TensorBoardImageHandler, 68 | TensorBoardStatsHandler, 69 | ValidationHandler, 70 | CheckpointLoader 71 | ) 72 | from monai.inferers import SimpleInferer 73 | 74 | from monaifbs.src.utils.custom_transform import InPlaneSpacingd 75 | from monaifbs.src.utils.custom_losses import DiceCELoss, DiceLossExtended 76 | from monaifbs.src.utils.custom_inferer import SlidingWindowInferer2D 77 | 78 | import monaifbs 79 | 80 | 81 | def create_data_list_of_dictionaries(input_file): 82 | """ 83 | Convert the list of input files to be processed in the dictionary format needed for MONAI 84 | Args: 85 | input_file: path to a .txt or .csv file (with no header) storing two-columns filenames: 86 | image filename in the first column and segmentation filename in the second column. 87 | The two columns should be separated by a comma. 88 | Return 89 | full_list: list of dicts, storing the filenames input to the MONAI training pipeline 90 | """ 91 | full_list = [] 92 | with open(input_file, 'r') as data: 93 | for line in data: 94 | # remove newline character if present 95 | line = line.rstrip('\n') if '\n' in line else line 96 | # split image and segmentation filenames 97 | try: 98 | current_f, current_s = line.split(',') 99 | except ValueError as ve: 100 | print('ValueError: {} in function create_data_list_of_dictionaries()'.format(ve)) 101 | print("Incorrect format for file {}. A two-column .txt or .csv file (with no header) is expected, " 102 | "storing the image filenames in the first column and respective segmentation in " 103 | "the second column, separated by a comma. Format of each line:" 104 | "/path/to/image.nii.gz,/path/to/seg.nii.gz".format(input_file)) 105 | exit() 106 | if os.path.isfile(current_f) and os.path.isfile(current_s): 107 | full_list.append({"image": current_f, "label": current_s}) 108 | else: 109 | raise FileNotFoundError('Expected image file: {} or segmentation file: {} not found'.format(current_f, 110 | current_s)) 111 | return full_list 112 | 113 | 114 | def choose_loss_function(number_out_channels, config_dict): 115 | """ 116 | Determine what loss function to use based on information in the configuration file. 117 | Current options are: 118 | - dynDiceCELoss = Dice + Xent. The Dice is computed per image and per channel in the batch and then average 119 | across the batch, using smooth terms at numerator and denominator = 1e-5 120 | - dynDiceCELoss_batch = Batch Dice + Xent. A single Dice value per channel is computed across the whole batch, 121 | using smooth terms at numerator and denominator = 1e-5 122 | - Batch_Dice = Batch Dice only, using smooth terms at numerator and denominator = 1e-5 123 | - Dice_Only = Dice only (per image and per channel, then average across the batch). The smooth term at the 124 | numerator is set to 0 as it provides greater training stability 125 | 126 | Args: 127 | number_out_channels: int, determines whether to use sigmoid or softmax as activation 128 | config_dict: dict, contains configuration parameters for sampling, network and training. 129 | See monaifbs/config/monai_dynUnet_training_config.yml for an example of the expected fields. 130 | 131 | Return: 132 | loss_function: callable, selected loss function type. 133 | """ 134 | 135 | # set some parameters for the Dice Loss 136 | do_sigmoid = True 137 | do_softmax = False 138 | if number_out_channels > 1: 139 | do_sigmoid = False 140 | do_softmax = True 141 | pow = 1.0 142 | if 'pow_dice' in config_dict['training']: 143 | pow = config_dict['training']['pow_dice'] 144 | 145 | # define the loss function based on the indications from the config file 146 | loss_type = config_dict['training']['loss_type'] 147 | if loss_type == "dynDiceCELoss": 148 | batch_version = False 149 | loss_fn = DiceCELoss(pow=pow) 150 | print("[LOSS] Using DiceCELoss with batch_version={} and Dice^{}\n".format(batch_version, pow)) 151 | elif loss_type == "dynDiceCELoss_batch": 152 | batch_version = True 153 | loss_fn = DiceCELoss(batch_version=batch_version, pow=pow) 154 | print("[LOSS] Using DiceCELoss with batch_version={} and Dice^{}\n".format(batch_version, pow)) 155 | elif loss_type == "Batch_Dice": 156 | smooth_num = 1e-5 157 | smooth_den = smooth_num 158 | batch_version = True 159 | squared_pred = False 160 | loss_fn = DiceLossExtended(sigmoid=do_sigmoid, softmax=do_softmax, 161 | smooth_num=smooth_num, smooth_den=smooth_den, squared_pred=squared_pred, 162 | batch_version=batch_version) 163 | print("[LOSS] Using Dice Loss - BATCH VERSION, " 164 | "Dice with {} at numerator and {} at denominator, " 165 | "do_sigmoid={}, do_softmax={}, squared_pred={}, " 166 | "batch_version={}\n".format(smooth_num, smooth_den, do_sigmoid, do_softmax, squared_pred, batch_version)) 167 | elif loss_type == "Dice_Only": 168 | smooth_num = 0 169 | smooth_den = smooth_num 170 | batch_version = False 171 | squared_pred = False 172 | loss_fn = DiceLossExtended(sigmoid=do_sigmoid, softmax=do_softmax, 173 | smooth_num=smooth_num, smooth_den=smooth_den, squared_pred=squared_pred, 174 | batch_version=batch_version) 175 | print("[LOSS] Using Dice Loss, " 176 | "Dice with {} at numerator and {} at denominator, " 177 | "do_sigmoid={}, do_softmax={}, squared_pred={}, " 178 | "batch_version={}\n".format(smooth_num, smooth_den, do_sigmoid, do_softmax, squared_pred, batch_version)) 179 | else: 180 | raise IOError("Unrecognized loss type") 181 | 182 | return loss_fn 183 | 184 | 185 | def run_training(train_file_list, valid_file_list, config_info): 186 | """ 187 | Pipeline to train a dynUNet segmentation model in MONAI. It is composed of the following main blocks: 188 | * Data Preparation: Extract the filenames and prepare the training/validation processing transforms 189 | * Load Data: Load training and validation data to PyTorch DataLoader 190 | * Network Preparation: Define the network, loss function, optimiser and learning rate scheduler 191 | * MONAI Evaluator: Initialise the dynUNet evaluator, i.e. the class providing utilities to perform validation 192 | during training. Attach handlers to save the best model on the validation set. A 2D sliding window approach 193 | on the 3D volume is used at evaluation. The mean 3D Dice is used as validation metric. 194 | * MONAI Trainer: Initialise the dynUNet trainer, i.e. the class providing utilities to perform the training loop. 195 | * Run training: The MONAI trainer is run, performing training and validation during training. 196 | Args: 197 | train_file_list: .txt or .csv file (with no header) storing two-columns filenames for training: 198 | image filename in the first column and segmentation filename in the second column. 199 | The two columns should be separated by a comma. 200 | See monaifbs/config/mock_train_file_list_for_dynUnet_training.txt for an example of the expected format. 201 | valid_file_list: .txt or .csv file (with no header) storing two-columns filenames for validation: 202 | image filename in the first column and segmentation filename in the second column. 203 | The two columns should be separated by a comma. 204 | See monaifbs/config/mock_valid_file_list_for_dynUnet_training.txt for an example of the expected format. 205 | config_info: dict, contains configuration parameters for sampling, network and training. 206 | See monaifbs/config/monai_dynUnet_training_config.yml for an example of the expected fields. 207 | """ 208 | 209 | """ 210 | Read input and configuration parameters 211 | """ 212 | # print MONAI config information 213 | logging.basicConfig(stream=sys.stdout, level=logging.INFO) 214 | print_config() 215 | 216 | # print to log the parameter setups 217 | print(yaml.dump(config_info)) 218 | 219 | # extract network parameters, perform checks/set defaults if not present and print them to log 220 | if 'seg_labels' in config_info['training'].keys(): 221 | seg_labels = config_info['training']['seg_labels'] 222 | else: 223 | seg_labels = [1] 224 | nr_out_channels = len(seg_labels) 225 | print("Considering the following {} labels in the segmentation: {}".format(nr_out_channels, seg_labels)) 226 | patch_size = config_info["training"]["inplane_size"] + [1] 227 | print("Considering patch size = {}".format(patch_size)) 228 | 229 | spacing = config_info["training"]["spacing"] 230 | print("Bringing all images to spacing = {}".format(spacing)) 231 | 232 | if 'model_to_load' in config_info['training'].keys() and config_info['training']['model_to_load'] is not None: 233 | model_to_load = config_info['training']['model_to_load'] 234 | if not os.path.exists(model_to_load): 235 | raise FileNotFoundError("Cannot find model: {}".format(model_to_load)) 236 | else: 237 | print("Loading model from {}".format(model_to_load)) 238 | else: 239 | model_to_load = None 240 | 241 | # set up either GPU or CPU usage 242 | if torch.cuda.is_available(): 243 | print("\n#### GPU INFORMATION ###") 244 | print("Using device number: {}, name: {}\n".format(torch.cuda.current_device(), torch.cuda.get_device_name())) 245 | current_device = torch.device("cuda:0") 246 | else: 247 | current_device = torch.device("cpu") 248 | print("Using device: {}".format(current_device)) 249 | 250 | # set determinism if required 251 | if 'manual_seed' in config_info['training'].keys() and config_info['training']['manual_seed'] is not None: 252 | seed = config_info['training']['manual_seed'] 253 | else: 254 | seed = None 255 | if seed is not None: 256 | print("Using determinism with seed = {}\n".format(seed)) 257 | set_determinism(seed=seed) 258 | 259 | """ 260 | Setup data output directory 261 | """ 262 | out_model_dir = os.path.join(config_info['output']['out_dir'], 263 | datetime.now().strftime('%Y-%m-%d_%H-%M-%S') + '_' + 264 | config_info['output']['out_postfix']) 265 | print("Saving to directory {}\n".format(out_model_dir)) 266 | # create cache directory to store results for Persistent Dataset 267 | if 'cache_dir' in config_info['output'].keys(): 268 | out_cache_dir = config_info['output']['cache_dir'] 269 | else: 270 | out_cache_dir = os.path.join(out_model_dir, 'persistent_cache') 271 | persistent_cache: Path = Path(out_cache_dir) 272 | persistent_cache.mkdir(parents=True, exist_ok=True) 273 | 274 | """ 275 | Data preparation 276 | """ 277 | # Read the input files for training and validation 278 | print("*** Loading input data for training...") 279 | 280 | train_files = create_data_list_of_dictionaries(train_file_list) 281 | print("Number of inputs for training = {}".format(len(train_files))) 282 | 283 | val_files = create_data_list_of_dictionaries(valid_file_list) 284 | print("Number of inputs for validation = {}".format(len(val_files))) 285 | 286 | # Define MONAI processing transforms for the training data. This includes: 287 | # - Load Nifti files and convert to format Batch x Channel x Dim1 x Dim2 x Dim3 288 | # - CropForegroundd: Reduce the background from the MR image 289 | # - InPlaneSpacingd: Perform in-plane resampling to the desired spacing, but preserve the resolution along the 290 | # last direction (lowest resolution) to avoid introducing motion artefact resampling errors 291 | # - SpatialPadd: Pad the in-plane size to the defined network input patch size [N, M] if needed 292 | # - NormalizeIntensityd: Apply whitening 293 | # - RandSpatialCropd: Crop a random patch from the input with size [B, C, N, M, 1] 294 | # - SqueezeDimd: Convert the 3D patch to a 2D one as input to the network (i.e. bring it to size [B, C, N, M]) 295 | # - Apply data augmentation (RandZoomd, RandRotated, RandGaussianNoised, RandGaussianSmoothd, RandScaleIntensityd, 296 | # RandFlipd) 297 | # - ToTensor: convert to pytorch tensor 298 | train_transforms = Compose( 299 | [ 300 | LoadNiftid(keys=["image", "label"]), 301 | AddChanneld(keys=["image", "label"]), 302 | CropForegroundd(keys=["image", "label"], source_key="image"), 303 | InPlaneSpacingd( 304 | keys=["image", "label"], 305 | pixdim=spacing, 306 | mode=("bilinear", "nearest"), 307 | ), 308 | SpatialPadd(keys=["image", "label"], spatial_size=patch_size, 309 | mode=["constant", "edge"]), 310 | NormalizeIntensityd(keys=["image"], nonzero=False, channel_wise=True), 311 | RandSpatialCropd(keys=["image", "label"], roi_size=patch_size, random_size=False), 312 | SqueezeDimd(keys=["image", "label"], dim=-1), 313 | RandZoomd( 314 | keys=["image", "label"], 315 | min_zoom=0.9, 316 | max_zoom=1.2, 317 | mode=("bilinear", "nearest"), 318 | align_corners=(True, None), 319 | prob=0.16, 320 | ), 321 | RandRotated(keys=["image", "label"], range_x=90, range_y=90, prob=0.2, 322 | keep_size=True, mode=["bilinear", "nearest"], 323 | padding_mode=["zeros", "border"]), 324 | RandGaussianNoised(keys=["image"], std=0.01, prob=0.15), 325 | RandGaussianSmoothd( 326 | keys=["image"], 327 | sigma_x=(0.5, 1.15), 328 | sigma_y=(0.5, 1.15), 329 | sigma_z=(0.5, 1.15), 330 | prob=0.15, 331 | ), 332 | RandScaleIntensityd(keys=["image"], factors=0.3, prob=0.15), 333 | RandFlipd(["image", "label"], spatial_axis=[0, 1], prob=0.5), 334 | ToTensord(keys=["image", "label"]), 335 | ] 336 | ) 337 | 338 | # Define MONAI processing transforms for the validation data 339 | # - Load Nifti files and convert to format Batch x Channel x Dim1 x Dim2 x Dim3 340 | # - CropForegroundd: Reduce the background from the MR image 341 | # - InPlaneSpacingd: Perform in-plane resampling to the desired spacing, but preserve the resolution along the 342 | # last direction (lowest resolution) to avoid introducing motion artefact resampling errors 343 | # - SpatialPadd: Pad the in-plane size to the defined network input patch size [N, M] if needed 344 | # - NormalizeIntensityd: Apply whitening 345 | # - ToTensor: convert to pytorch tensor 346 | # NOTE: The validation data is kept 3D as a 2D sliding window approach is used throughout the volume at inference 347 | val_transforms = Compose( 348 | [ 349 | LoadNiftid(keys=["image", "label"]), 350 | AddChanneld(keys=["image", "label"]), 351 | CropForegroundd(keys=["image", "label"], source_key="image"), 352 | InPlaneSpacingd( 353 | keys=["image", "label"], 354 | pixdim=spacing, 355 | mode=("bilinear", "nearest"), 356 | ), 357 | SpatialPadd(keys=["image", "label"], spatial_size=patch_size, mode=["constant", "edge"]), 358 | NormalizeIntensityd(keys=["image"], nonzero=False, channel_wise=True), 359 | ToTensord(keys=["image", "label"]), 360 | ] 361 | ) 362 | 363 | """ 364 | Load data 365 | """ 366 | # create training data loader 367 | train_ds = PersistentDataset(data=train_files, transform=train_transforms, 368 | cache_dir=persistent_cache) 369 | train_loader = DataLoader(train_ds, 370 | batch_size=config_info['training']['batch_size_train'], 371 | shuffle=True, 372 | num_workers=config_info['device']['num_workers']) 373 | check_train_data = misc.first(train_loader) 374 | print("Training data tensor shapes:") 375 | print("Image = {}; Label = {}".format(check_train_data["image"].shape, check_train_data["label"].shape)) 376 | 377 | # create validation data loader 378 | if config_info['training']['batch_size_valid'] != 1: 379 | raise Exception("Batch size different from 1 at validation ar currently not supported") 380 | val_ds = PersistentDataset(data=val_files, transform=val_transforms, cache_dir=persistent_cache) 381 | val_loader = DataLoader(val_ds, 382 | batch_size=1, 383 | shuffle=False, 384 | num_workers=config_info['device']['num_workers']) 385 | check_valid_data = misc.first(val_loader) 386 | print("Validation data tensor shapes (Example):") 387 | print("Image = {}; Label = {}\n".format(check_valid_data["image"].shape, check_valid_data["label"].shape)) 388 | 389 | """ 390 | Network preparation 391 | """ 392 | print("*** Preparing the network ...") 393 | # automatically extracts the strides and kernels based on nnU-Net empirical rules 394 | spacings = spacing[:2] 395 | sizes = patch_size[:2] 396 | strides, kernels = [], [] 397 | while True: 398 | spacing_ratio = [sp / min(spacings) for sp in spacings] 399 | stride = [2 if ratio <= 2 and size >= 8 else 1 for (ratio, size) in zip(spacing_ratio, sizes)] 400 | kernel = [3 if ratio <= 2 else 1 for ratio in spacing_ratio] 401 | if all(s == 1 for s in stride): 402 | break 403 | sizes = [i / j for i, j in zip(sizes, stride)] 404 | spacings = [i * j for i, j in zip(spacings, stride)] 405 | kernels.append(kernel) 406 | strides.append(stride) 407 | strides.insert(0, len(spacings) * [1]) 408 | kernels.append(len(spacings) * [3]) 409 | 410 | # initialise the network 411 | net = DynUNet( 412 | spatial_dims=2, 413 | in_channels=1, 414 | out_channels=nr_out_channels, 415 | kernel_size=kernels, 416 | strides=strides, 417 | upsample_kernel_size=strides[1:], 418 | norm_name="instance", 419 | deep_supervision=True, 420 | deep_supr_num=2, 421 | res_block=False, 422 | ).to(current_device) 423 | print(net) 424 | 425 | # define the loss function 426 | loss_function = choose_loss_function(nr_out_channels, config_info) 427 | 428 | # define the optimiser and the learning rate scheduler 429 | opt = torch.optim.SGD(net.parameters(), lr=float(config_info['training']['lr']), momentum=0.95) 430 | scheduler = torch.optim.lr_scheduler.LambdaLR( 431 | opt, lr_lambda=lambda epoch: (1 - epoch / config_info['training']['nr_train_epochs']) ** 0.9 432 | ) 433 | 434 | """ 435 | MONAI evaluator 436 | """ 437 | print("*** Preparing the dynUNet evaluator engine...\n") 438 | # val_post_transforms = Compose( 439 | # [ 440 | # Activationsd(keys="pred", sigmoid=True), 441 | # ] 442 | # ) 443 | val_handlers = [ 444 | StatsHandler(output_transform=lambda x: None), 445 | TensorBoardStatsHandler(log_dir=os.path.join(out_model_dir, "valid"), 446 | output_transform=lambda x: None, 447 | global_epoch_transform=lambda x: trainer.state.iteration), 448 | CheckpointSaver(save_dir=out_model_dir, save_dict={"net": net, "opt": opt}, save_key_metric=True, 449 | file_prefix='best_valid'), 450 | ] 451 | if config_info['output']['val_image_to_tensorboad']: 452 | val_handlers.append(TensorBoardImageHandler(log_dir=os.path.join(out_model_dir, "valid"), 453 | batch_transform=lambda x: (x["image"], x["label"]), 454 | output_transform=lambda x: x["pred"], interval=2)) 455 | 456 | # Define customized evaluator 457 | class DynUNetEvaluator(SupervisedEvaluator): 458 | def _iteration(self, engine, batchdata): 459 | inputs, targets = self.prepare_batch(batchdata) 460 | inputs, targets = inputs.to(engine.state.device), targets.to(engine.state.device) 461 | flip_inputs_1 = torch.flip(inputs, dims=(2,)) 462 | flip_inputs_2 = torch.flip(inputs, dims=(3,)) 463 | flip_inputs_3 = torch.flip(inputs, dims=(2, 3)) 464 | 465 | def _compute_pred(): 466 | pred = self.inferer(inputs, self.network) 467 | # use random flipping as data augmentation at inference 468 | flip_pred_1 = torch.flip(self.inferer(flip_inputs_1, self.network), dims=(2,)) 469 | flip_pred_2 = torch.flip(self.inferer(flip_inputs_2, self.network), dims=(3,)) 470 | flip_pred_3 = torch.flip(self.inferer(flip_inputs_3, self.network), dims=(2, 3)) 471 | return (pred + flip_pred_1 + flip_pred_2 + flip_pred_3) / 4 472 | 473 | # execute forward computation 474 | self.network.eval() 475 | with torch.no_grad(): 476 | if self.amp: 477 | with torch.cuda.amp.autocast(): 478 | predictions = _compute_pred() 479 | else: 480 | predictions = _compute_pred() 481 | return {"image": inputs, "label": targets, "pred": predictions} 482 | 483 | evaluator = DynUNetEvaluator( 484 | device=current_device, 485 | val_data_loader=val_loader, 486 | network=net, 487 | inferer=SlidingWindowInferer2D(roi_size=patch_size, sw_batch_size=4, overlap=0.0), 488 | post_transform=None, 489 | key_val_metric={ 490 | "Mean_dice": MeanDice( 491 | include_background=False, 492 | to_onehot_y=True, 493 | mutually_exclusive=True, 494 | output_transform=lambda x: (x["pred"], x["label"]), 495 | ) 496 | }, 497 | val_handlers=val_handlers, 498 | amp=False, 499 | ) 500 | 501 | """ 502 | MONAI trainer 503 | """ 504 | print("*** Preparing the dynUNet trainer engine...\n") 505 | # train_post_transforms = Compose( 506 | # [ 507 | # Activationsd(keys="pred", sigmoid=True), 508 | # ] 509 | # ) 510 | 511 | validation_every_n_epochs = config_info['training']['validation_every_n_epochs'] 512 | epoch_len = len(train_ds) // train_loader.batch_size 513 | validation_every_n_iters = validation_every_n_epochs * epoch_len 514 | 515 | # define event handlers for the trainer 516 | writer_train = SummaryWriter(log_dir=os.path.join(out_model_dir, "train")) 517 | train_handlers = [ 518 | LrScheduleHandler(lr_scheduler=scheduler, print_lr=True), 519 | ValidationHandler(validator=evaluator, interval=validation_every_n_iters, epoch_level=False), 520 | StatsHandler(tag_name="train_loss", output_transform=lambda x: x["loss"]), 521 | TensorBoardStatsHandler(summary_writer=writer_train, 522 | log_dir=os.path.join(out_model_dir, "train"), tag_name="Loss", 523 | output_transform=lambda x: x["loss"], 524 | global_epoch_transform=lambda x: trainer.state.iteration), 525 | CheckpointSaver(save_dir=out_model_dir, save_dict={"net": net, "opt": opt}, 526 | save_final=True, 527 | save_interval=2, epoch_level=True, 528 | n_saved=config_info['output']['max_nr_models_saved']), 529 | ] 530 | if model_to_load is not None: 531 | train_handlers.append(CheckpointLoader(load_path=model_to_load, load_dict={"net": net, "opt": opt})) 532 | 533 | # define customized trainer 534 | class DynUNetTrainer(SupervisedTrainer): 535 | def _iteration(self, engine, batchdata): 536 | inputs, targets = self.prepare_batch(batchdata) 537 | inputs, targets = inputs.to(engine.state.device), targets.to(engine.state.device) 538 | 539 | def _compute_loss(preds, label): 540 | labels = [label] + [interpolate(label, pred.shape[2:]) for pred in preds[1:]] 541 | return sum([0.5 ** i * self.loss_function(p, l) for i, (p, l) in enumerate(zip(preds, labels))]) 542 | 543 | self.network.train() 544 | self.optimizer.zero_grad() 545 | if self.amp and self.scaler is not None: 546 | with torch.cuda.amp.autocast(): 547 | predictions = self.inferer(inputs, self.network) 548 | loss = _compute_loss(predictions, targets) 549 | self.scaler.scale(loss).backward() 550 | self.scaler.step(self.optimizer) 551 | self.scaler.update() 552 | else: 553 | predictions = self.inferer(inputs, self.network) 554 | loss = _compute_loss(predictions, targets).mean() 555 | loss.backward() 556 | self.optimizer.step() 557 | return {"image": inputs, "label": targets, "pred": predictions, "loss": loss.item()} 558 | 559 | trainer = DynUNetTrainer( 560 | device=current_device, 561 | max_epochs=config_info['training']['nr_train_epochs'], 562 | train_data_loader=train_loader, 563 | network=net, 564 | optimizer=opt, 565 | loss_function=loss_function, 566 | inferer=SimpleInferer(), 567 | post_transform=None, 568 | key_train_metric=None, 569 | train_handlers=train_handlers, 570 | amp=False, 571 | ) 572 | 573 | """ 574 | Run training 575 | """ 576 | print("*** Run training...") 577 | trainer.run() 578 | print("Done!") 579 | 580 | 581 | if __name__ == "__main__": 582 | 583 | parser = argparse.ArgumentParser(description='Run training with dynUnet with MONAI.') 584 | parser.add_argument('--train_files_list', 585 | dest='train_files_list', 586 | metavar='/path/to/train_files_list.txt', 587 | type=str, 588 | help='two-column .txt or .csv file (with no header) containing image filenames and associated ' 589 | 'label filenames for training (image-label filenames should be comma separated). ' 590 | 'Expected format of each line:' 591 | '/path/to/train_image.nii.gz,/path/to/train_label.nii.gz ' 592 | 'See monaifbs/config/mock_train_file_list_for_dynUnet_training.txt as an example', 593 | required=True) 594 | parser.add_argument('--validation_files_list', 595 | dest='valid_files_list', 596 | metavar='/path/to/valid_files_list.txt', 597 | type=str, 598 | help='two-column .txt or .csv file (with no header) containing image filenames and associated ' 599 | 'label filenames for validation (image-label filenames should be comma separated). ' 600 | 'Expected format of each line:' 601 | '/path/to/valid_image.nii.gz,/path/to/valid_label.nii.gz ' 602 | 'See monaifbs/config/mock_valid_file_list_for_dynUnet_training.txt as an example', 603 | required=True) 604 | parser.add_argument('--out_folder', 605 | dest='out_folder', 606 | metavar='/path/to/out_folder', 607 | type=str, 608 | help='directory where to store the outputs of the training', 609 | required=True) 610 | parser.add_argument('--out_postfix', 611 | dest='out_postfix', 612 | metavar='out_postfix', 613 | type=str, 614 | help='postfix to add to the output directory name after datetime stamp', 615 | default='monai_dynUnet_2D') 616 | parser.add_argument('--cache_dir', 617 | dest='cache_dir', 618 | metavar='/path/to/cache_dir', 619 | type=str, 620 | help='Directory where preprocessed data are/will be stored. ' 621 | 'See MONAI PersistentCacheDataset for more information' 622 | 'https://github.com/Project-MONAI/MONAI/blob/releases/0.3.0/monai/data/dataset.py ' 623 | 'If not provided, it will be created in /path/to/out_folder/persistent_cache', 624 | default=None) 625 | parser.add_argument('--config_file', 626 | dest='config_file', 627 | metavar='/path/to/config_file.yml', 628 | type=str, 629 | help='config file containing network information for training ' 630 | 'The file monaifbs/config/monai_dynUnet_training_config.yml is used by default. ' 631 | 'See that file as an example of the expected structure', 632 | default=None) 633 | args = parser.parse_args() 634 | 635 | # check existence of filenames listing the input data 636 | if not os.path.isfile(args.train_files_list) or os.path.getsize(args.train_files_list) == 0: 637 | raise FileNotFoundError('Expected training file {} not found or empty'.format(args.train_files_list)) 638 | if not os.path.isfile(args.valid_files_list) or os.path.getsize(args.valid_files_list) == 0: 639 | raise FileNotFoundError('Expected validation file {} not found or empty'.format(args.valid_files_list)) 640 | 641 | # check existence of config file and read it 642 | config_file = args.config_file 643 | if config_file is None: 644 | config_file = os.path.join(*[os.path.dirname(monaifbs.__file__), 645 | "config", "monai_dynUnet_training_config.yml"]) 646 | if not os.path.isfile(config_file): 647 | raise FileNotFoundError('Expected config file: {} not found'.format(config_file)) 648 | with open(config_file) as f: 649 | print("*** Config file") 650 | print(config_file) 651 | config = yaml.load(f, Loader=yaml.FullLoader) 652 | 653 | # add the output directory to the config dictionary 654 | config['output']['out_postfix'] = args.out_postfix 655 | config['output']['out_dir'] = args.out_folder 656 | if not os.path.exists(config['output']['out_dir']): 657 | os.makedirs(config['output']['out_dir']) 658 | if args.cache_dir is not None: 659 | config['output']['cache_dir'] = args.cache_dir 660 | 661 | # run training with MONAI dynUnet 662 | run_training(args.train_files_list, args.valid_files_list, config) 663 | 664 | --------------------------------------------------------------------------------