├── src ├── __init__.py ├── Utils │ ├── __init__.py │ ├── io.py │ ├── volume_utilities.py │ └── configuration_parser.py ├── Inference │ ├── __init__.py │ ├── predictions_reconstruction.py │ └── predictions.py ├── Models │ └── UNet │ │ ├── __init__.py │ │ ├── AttentionGatedUNet.py │ │ └── DualAttentionUNet.py ├── PreProcessing │ ├── __init__.py │ └── pre_processing.py └── fit.py ├── resources └── images │ └── Architecture.png ├── .gitignore ├── requirements.txt ├── LICENSE ├── setup.py ├── Dockerfile ├── main.py └── README.md /src/__init__.py: -------------------------------------------------------------------------------- 1 | from . import * 2 | 3 | -------------------------------------------------------------------------------- /src/Utils/__init__.py: -------------------------------------------------------------------------------- 1 | from . import * 2 | 3 | -------------------------------------------------------------------------------- /src/Inference/__init__.py: -------------------------------------------------------------------------------- 1 | from . import * 2 | 3 | -------------------------------------------------------------------------------- /src/Models/UNet/__init__.py: -------------------------------------------------------------------------------- 1 | from . import * 2 | 3 | -------------------------------------------------------------------------------- /src/PreProcessing/__init__.py: -------------------------------------------------------------------------------- 1 | from . import * 2 | 3 | -------------------------------------------------------------------------------- /resources/images/Architecture.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dbouget/ct_mediastinal_structures_segmentation/HEAD/resources/images/Architecture.png -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # use glob syntax. 2 | syntax: glob 3 | 4 | *.nii.gz 5 | *.nii 6 | *.elc 7 | *.pyc 8 | *~ 9 | .idea/ 10 | *.xml 11 | os 12 | numpy 13 | *.h5 14 | *.hd5 15 | *.hdf5 16 | *.mhd 17 | *.raw 18 | /resources 19 | *.ini 20 | *.sqlite3 21 | *.zip 22 | .bash_history 23 | /venv 24 | 25 | 26 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | absl-py==0.11.0 2 | astor==0.8.1 3 | cached-property==1.5.2 4 | certifi==2020.12.5 5 | chardet==4.0.0 6 | cycler==0.10.0 7 | decorator==4.4.2 8 | filelock==3.0.12 9 | future==0.18.2 10 | gast==0.4.0 11 | gdown==3.12.2 12 | google-pasta==0.2.0 13 | grpcio==1.34.0 14 | h5py==2.10.0 15 | idna==2.10 16 | imageio==2.9.0 17 | importlib-metadata==3.3.0 18 | Keras-Applications==1.0.8 19 | Keras-Preprocessing==1.1.2 20 | kiwisolver==1.3.1 21 | Markdown==3.3.3 22 | matplotlib==3.3.3 23 | networkx==2.5 24 | nibabel==3.0.1 25 | numpy==1.19.3 26 | Pillow==8.0.1 27 | protobuf==3.14.0 28 | pyparsing==2.4.7 29 | PySocks==1.7.1 30 | python-dateutil==2.8.1 31 | PyWavelets==1.1.1 32 | requests==2.25.1 33 | scikit-image==0.16.2 34 | scipy==1.5.4 35 | SimpleITK==1.2.4 36 | six==1.15.0 37 | tensorboard==1.14.0 38 | tensorflow==1.14.0 39 | tensorflow-estimator==1.14.0 40 | tensorflow-gpu==1.14.0 41 | termcolor==1.1.0 42 | tqdm==4.54.1 43 | typing-extensions==3.7.4.3 44 | urllib3==1.26.2 45 | Werkzeug==1.0.1 46 | wrapt==1.12.1 47 | zipp==3.4.0 48 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | BSD 2-Clause License 2 | 3 | Copyright (c) 2021, dbouget 4 | All rights reserved. 5 | 6 | Redistribution and use in source and binary forms, with or without 7 | modification, are permitted provided that the following conditions are met: 8 | 9 | 1. Redistributions of source code must retain the above copyright notice, this 10 | list of conditions and the following disclaimer. 11 | 12 | 2. Redistributions in binary form must reproduce the above copyright notice, 13 | this list of conditions and the following disclaimer in the documentation 14 | and/or other materials provided with the distribution. 15 | 16 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 17 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 18 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 19 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 20 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 21 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 22 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 23 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 24 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 25 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 26 | 27 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import os 2 | import gdown 3 | 4 | 5 | def setup_repository(): 6 | # Downloading, extracting models. 7 | models_url = 'https://drive.google.com/uc?id=1DBIl8JyXEo6YdM9uNyo3vrv5T2WsYSXT' 8 | models_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'resources', 'models') 9 | os.makedirs(models_path, exist_ok=True) 10 | md5 = '434775bebd64910e01f4198eab251666' 11 | models_archive_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'models.zip') 12 | gdown.cached_download(url=models_url, path=models_archive_path, md5=md5) 13 | gdown.extractall(path=models_archive_path, to=models_path) 14 | os.remove(models_archive_path) 15 | 16 | # Setting up the data folder with runtime_config.ini file 17 | data_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'resources', 'data') 18 | os.makedirs(data_path, exist_ok=True) 19 | runtime_config_path = os.path.join(data_path, 'runtime_config.ini') 20 | if os.path.exists(runtime_config_path): 21 | os.remove(runtime_config_path) 22 | pfile = open(runtime_config_path, 'w') 23 | pfile.write("[Predictions]\n") 24 | pfile.write("non_overlapping=true\n") 25 | pfile.write("reconstruction_method=probabilities #probabilities, thresholding\n") 26 | pfile.write("reconstruction_order=resample_first #resample_first, resample_second\n") 27 | pfile.write("probability_threshold=0.4\n") 28 | pfile.close() 29 | 30 | 31 | setup_repository() 32 | 33 | -------------------------------------------------------------------------------- /src/Utils/io.py: -------------------------------------------------------------------------------- 1 | import os 2 | import nibabel as nib 3 | from nibabel import four_to_three 4 | import SimpleITK as sitk 5 | 6 | 7 | def load_nifti_volume(volume_path): 8 | nib_volume = nib.load(volume_path) 9 | if len(nib_volume.shape) > 3: 10 | if len(nib_volume.shape) == 4: # Common problem 11 | nib_volume = four_to_three(nib_volume)[0] 12 | else: # DWI volumes 13 | nib_volume = nib.Nifti1Image(nib_volume.get_data()[:, :, :, 0, 0], affine=nib_volume.affine) 14 | 15 | return nib_volume 16 | 17 | 18 | def dump_predictions(predictions, training_parameters, runtime_parameters, nib_volume, storage_prefix): 19 | print("Writing predictions to files...") 20 | naming_suffix = 'pred' if runtime_parameters.predictions_reconstruction_method == 'probabilities' else 'labels' 21 | class_names = training_parameters.training_class_names 22 | 23 | if len(predictions.shape) == 4: 24 | for c in range(1, predictions.shape[-1]): 25 | img = nib.Nifti1Image(predictions[..., c], affine=nib_volume.affine) 26 | predictions_output_path = os.path.join(storage_prefix + '-' + naming_suffix + '_' + class_names[c] + '.nii.gz') 27 | os.makedirs(os.path.dirname(predictions_output_path), exist_ok=True) 28 | nib.save(img, predictions_output_path) 29 | else: 30 | img = nib.Nifti1Image(predictions, affine=nib_volume.affine) 31 | predictions_output_path = os.path.join(storage_prefix + '-' + naming_suffix + '_' + 'argmax' + '.nii.gz') 32 | os.makedirs(os.path.dirname(predictions_output_path), exist_ok=True) 33 | nib.save(img, predictions_output_path) 34 | 35 | 36 | def convert_and_export_to_nifti(input_filepath): 37 | input_sitk = sitk.ReadImage(input_filepath) 38 | output_filepath = input_filepath.split('.')[0] + '.nii.gz' 39 | sitk.WriteImage(input_sitk, output_filepath) 40 | 41 | return output_filepath -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | # creates virtual ubuntu in docker image 2 | FROM ubuntu:18.04 3 | 4 | # maintainer of docker file 5 | MAINTAINER David Bouget 6 | 7 | # set language, format and stuff 8 | ENV LANG=C.UTF-8 LC_ALL=C.UTF-8 9 | 10 | # installing python3 11 | RUN apt-get update -y && \ 12 | apt-get install python3-pip -y && \ 13 | apt-get -y install sudo && \ 14 | apt-get update && \ 15 | pip3 install bs4 && \ 16 | pip3 install requests && \ 17 | apt-get install python3-lxml -y && \ 18 | pip3 install Pillow && \ 19 | apt-get install libopenjp2-7 -y && \ 20 | apt-get install libtiff5 -y 21 | 22 | # install curl 23 | RUN apt-get install curl -y 24 | 25 | # install nano 26 | RUN apt-get install nano -y 27 | 28 | # install git (OBS: using -y is conveniently to automatically answer yes to all the questions) 29 | RUN apt-get update && apt-get install -y git 30 | 31 | # give user sudo access and access to python directories 32 | RUN useradd -m ubuntu && echo "ubuntu:ubuntu" | chpasswd && adduser ubuntu sudo 33 | ENV PYTHON_DIR /usr/bin/python3 34 | RUN chown ubuntu $PYTHON_DIR -R 35 | USER ubuntu 36 | 37 | # Python 38 | RUN pip3 install tensorflow==1.14.0 39 | RUN pip3 install tensorflow-gpu==1.14.0 40 | RUN pip3 install progressbar2 41 | RUN pip3 install nibabel 42 | RUN pip3 install h5py==2.10.0 43 | RUN pip3 install scipy 44 | RUN pip3 install scikit-image==0.16.2 45 | RUN pip3 install progressbar2 46 | RUN pip3 install tqdm 47 | RUN pip3 install SimpleITK==1.2.4 48 | RUN pip3 install numpy==1.19.3 49 | 50 | RUN mkdir /home/ubuntu/src 51 | WORKDIR "/home/ubuntu/src" 52 | COPY src/ $WORKDIR 53 | WORKDIR "/home/ubuntu" 54 | COPY Dockerfile $WORKDIR 55 | COPY main.py $WORKDIR 56 | 57 | RUN mkdir /home/ubuntu/resources 58 | USER root 59 | RUN chown -R ubuntu:ubuntu /home/ubuntu/resources 60 | RUN chmod -R 777 /home/ubuntu/resources 61 | USER ubuntu 62 | EXPOSE 8888 63 | 64 | #RUN echo 'alias python=python3' >> ~/.bashrc 65 | 66 | # CMD ["/bin/bash"] 67 | ENTRYPOINT ["python3","/home/ubuntu/main.py"] 68 | 69 | 70 | 71 | 72 | -------------------------------------------------------------------------------- /src/PreProcessing/pre_processing.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import nibabel as nib 4 | from nibabel.processing import resample_to_output 5 | from src.Utils.volume_utilities import intensity_normalization, resize_volume, crop_CT 6 | from src.Utils.io import load_nifti_volume, convert_and_export_to_nifti 7 | 8 | 9 | def run_pre_processing(filename, pre_processing_parameters, lungs_mask_filename, storage_prefix): 10 | print("Extracting data...") 11 | ext_split = filename.split('.') 12 | extension = '.'.join(ext_split[1:]) 13 | 14 | if extension != 'nii' or extension != 'nii.gz': 15 | filename = convert_and_export_to_nifti(input_filepath=filename) 16 | pass 17 | 18 | nib_volume = load_nifti_volume(filename) 19 | 20 | print("Pre-processing...") 21 | # Normalize spacing 22 | new_spacing = pre_processing_parameters.output_spacing 23 | if pre_processing_parameters.output_spacing is None: 24 | tmp = np.min(nib_volume.header.get_zooms()) 25 | new_spacing = [tmp, tmp, tmp] 26 | 27 | library = pre_processing_parameters.preprocessing_library 28 | if library == 'nibabel': 29 | resampled_volume = resample_to_output(nib_volume, new_spacing, order=1) 30 | data = resampled_volume.get_data().astype('float32') 31 | 32 | crop_bbox = None 33 | # Exclude background 34 | if pre_processing_parameters.crop_background: 35 | data, crop_bbox = crop_CT(filename, data, lungs_mask_filename, new_spacing, storage_prefix) 36 | 37 | # Resize to network input size 38 | data = resize_volume(data, pre_processing_parameters.new_axial_size, pre_processing_parameters.slicing_plane, 39 | order=1) 40 | # Normalize values 41 | data = intensity_normalization(volume=data, parameters=pre_processing_parameters) 42 | 43 | return nib_volume, resampled_volume, data, crop_bbox 44 | 45 | 46 | def run_pre_processing_guided(filename, pre_processing_parameters, lungs_mask_filename, storage_prefix, anatomical_priors_filename): 47 | nib_volume, resampled_volume, data, crop_bbox = run_pre_processing(filename, pre_processing_parameters, 48 | lungs_mask_filename, 49 | storage_prefix) 50 | 51 | if os.path.exists(anatomical_priors_filename): 52 | nib_volume, resampled_volume, data_apg, crop_bbox = run_pre_processing(anatomical_priors_filename, 53 | pre_processing_parameters, 54 | lungs_mask_filename, 55 | storage_prefix) 56 | else: 57 | data_apg = np.zeros((data.shape)) 58 | 59 | final_pre_proc = np.zeros((data.shape) + (2,)) 60 | final_pre_proc[..., 0] = data 61 | final_pre_proc[..., 1] = data_apg 62 | 63 | return nib_volume, resampled_volume, final_pre_proc, crop_bbox -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import getopt 2 | import os 3 | import sys 4 | from src.fit import fit, fit_ensemble 5 | 6 | 7 | def main(argv): 8 | input_filename = '' 9 | output_prefix = '' 10 | model_list = '' 11 | lungs_mask_filename = None 12 | apg_mask_filename = None 13 | gpu_id = '-1' 14 | try: 15 | opts, args = getopt.getopt(argv, "hi:o:m:l:a:g:", ["Input=", "Output=", "Model=", "Lungs=", "APG=", "GPU="]) 16 | except getopt.GetoptError: 17 | print('usage: main.py --Input --Output --Model ' 18 | ' --Lungs --APG --GPU ') 19 | sys.exit(2) 20 | for opt, arg in opts: 21 | if opt == '-h': 22 | print('main.py --Input --Output --Model ' 23 | ' --Lungs --APG --GPU ') 24 | sys.exit() 25 | elif opt in ("-i", "--Input"): 26 | input_filename = arg 27 | elif opt in ("-o", "--Output"): 28 | output_prefix = arg 29 | elif opt in ("-m", "--Model"): 30 | model_list = arg 31 | elif opt in ("-l", "--Lungs"): 32 | lungs_mask_filename = arg 33 | elif opt in ("-a", "--APG"): 34 | apg_mask_filename = arg 35 | elif opt in ("-g", "--GPU"): 36 | if arg.isnumeric(): 37 | gpu_id = arg 38 | if input_filename == '': 39 | print('usage: main.py --Input --Output --Model ' 40 | ' --Lungs --APG --GPU ') 41 | sys.exit() 42 | 43 | os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" 44 | os.environ["CUDA_VISIBLE_DEVICES"] = gpu_id 45 | 46 | if os.path.exists(input_filename): 47 | real_path = os.path.realpath(os.path.dirname(input_filename)) 48 | input_filename = os.path.join(real_path, os.path.basename(input_filename)) 49 | else: 50 | print('Input filename does not exist on disk, with argument: {}'.format(input_filename)) 51 | sys.exit(2) 52 | 53 | if os.path.exists(os.path.dirname(output_prefix)): 54 | real_path = os.path.realpath(os.path.dirname(output_prefix)) 55 | output_prefix = os.path.join(real_path, os.path.basename(output_prefix)) 56 | else: 57 | print('Directory name for the output prefix does not exist on disk, with argument: {}'.format(input_filename)) 58 | sys.exit(2) 59 | 60 | model_list = model_list.split(',') 61 | if len(model_list) == 1: 62 | fit(input_filename=input_filename, output_path=output_prefix, selected_model=model_list[0], 63 | lungs_mask_filename=lungs_mask_filename, anatomical_priors_filename=apg_mask_filename) 64 | else: 65 | fit_ensemble(input_filename=input_filename, output_path=output_prefix, model_list=model_list, 66 | lungs_mask_filename=lungs_mask_filename, anatomical_priors_filename=apg_mask_filename) 67 | 68 | 69 | if __name__ == "__main__": 70 | main(sys.argv[1:]) 71 | 72 | -------------------------------------------------------------------------------- /src/fit.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import time 4 | import numpy as np 5 | from copy import deepcopy 6 | from src.Utils.configuration_parser import * 7 | from src.PreProcessing.pre_processing import run_pre_processing, run_pre_processing_guided 8 | from src.Inference.predictions import run_predictions 9 | from src.Inference.predictions_reconstruction import reconstruct_post_predictions, perform_ensemble 10 | from src.Utils.io import dump_predictions 11 | 12 | MODELS_PATH = os.path.join(os.path.dirname(os.path.realpath(__file__)), '../', 'resources/models') 13 | print(MODELS_PATH) 14 | sys.path.insert(1, MODELS_PATH) 15 | 16 | 17 | def fit(input_filename, output_path, selected_model, lungs_mask_filename, anatomical_priors_filename, 18 | user_runtime=None): 19 | """ 20 | 21 | """ 22 | print("Starting inference for file: {}, with model: {}.\n".format(input_filename, selected_model)) 23 | overall_start = time.time() 24 | pre_processing_parameters = PreProcessingParser(model_name=selected_model) 25 | if user_runtime is None: 26 | user_runtime = RuntimeConfigParser() 27 | valid_extensions = ['.h5', '.hd5', '.hdf5', '.hdf', '.ckpt'] 28 | model_path = '' 29 | for e, ext in enumerate(valid_extensions): 30 | model_path = os.path.join(MODELS_PATH, selected_model, 'model' + ext) 31 | if os.path.exists(model_path): 32 | break 33 | 34 | if not os.path.exists(model_path): 35 | raise ValueError('Could not find any model matching the requested type \'{}\'.'.format(selected_model)) 36 | 37 | if 'APG' in selected_model: 38 | nib_volume, resampled_volume, data, crop_bbox = run_pre_processing_guided(filename=input_filename, 39 | pre_processing_parameters=pre_processing_parameters, 40 | lungs_mask_filename=lungs_mask_filename, 41 | storage_prefix=output_path, 42 | anatomical_priors_filename=anatomical_priors_filename) 43 | else: 44 | nib_volume, resampled_volume, data, crop_bbox = run_pre_processing(filename=input_filename, 45 | pre_processing_parameters=pre_processing_parameters, 46 | lungs_mask_filename=lungs_mask_filename, 47 | storage_prefix=output_path) 48 | data = np.expand_dims(data, axis=-1) 49 | start = time.time() 50 | predictions = run_predictions(data=data, model_path=model_path, training_parameters=pre_processing_parameters, 51 | runtime_parameters=user_runtime) 52 | print('Model loading + inference time: {} seconds.'.format(time.time() - start)) 53 | 54 | final_predictions = reconstruct_post_predictions(predictions=predictions, parameters=user_runtime, 55 | crop_bbox=crop_bbox, nib_volume=nib_volume, 56 | resampled_volume=resampled_volume) 57 | 58 | dump_predictions(predictions=final_predictions, training_parameters=pre_processing_parameters, 59 | runtime_parameters=user_runtime, nib_volume=nib_volume, 60 | storage_prefix=output_path) 61 | print('Total processing time: {:.2f} seconds.\n'.format(time.time() - overall_start)) 62 | 63 | 64 | def fit_ensemble(input_filename, output_path, model_list, lungs_mask_filename, anatomical_priors_filename): 65 | overall_start = time.time() 66 | user_runtime = RuntimeConfigParser() 67 | 68 | for model in model_list: 69 | ensemble_runtime = RuntimeConfigParser() 70 | ensemble_runtime.set_default_runtime() 71 | outpath = output_path + model 72 | fit(input_filename, outpath, model, lungs_mask_filename, anatomical_priors_filename=anatomical_priors_filename, 73 | user_runtime=user_runtime) 74 | 75 | perform_ensemble(input_filename, output_path, model_list, user_runtime) 76 | print('Total ensemble processing time: {:.2f} seconds.\n'.format(time.time() - overall_start)) 77 | -------------------------------------------------------------------------------- /src/Utils/volume_utilities.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from copy import deepcopy 3 | from skimage.transform import resize 4 | from scipy.ndimage import binary_fill_holes 5 | from skimage.measure import regionprops 6 | from src.Utils.configuration_parser import * 7 | import subprocess 8 | from nibabel.processing import resample_to_output 9 | from src.Utils.io import load_nifti_volume, convert_and_export_to_nifti 10 | 11 | 12 | def crop_CT(filepath, volume, lungs_mask_filename, new_spacing, storage_prefix): 13 | # @TODO. Should name the lungs mask file with a specific name, otherwise will be multiple instances when ensemble... 14 | if lungs_mask_filename is None or not os.path.exists(lungs_mask_filename): 15 | script_path = '/'.join(os.path.dirname(os.path.realpath(__file__)).split('/')[:-2]) + '/main.py' 16 | #output_prefix = '/'.join(os.path.dirname(os.path.realpath(__file__)).split('/')[:-2]) + '/tmp' 17 | subprocess.call(['python3', '{script}'.format(script=script_path), 18 | '-i{input}'.format(input=filepath), 19 | '-o{output}'.format(output=storage_prefix), 20 | '-m{model}'.format(model='CT_Lungs'), 21 | '-g{gpu}'.format(gpu=os.environ["CUDA_VISIBLE_DEVICES"])]) 22 | lungs_mask_filename = storage_prefix + '-pred_Lungs.nii.gz' 23 | else: 24 | ext_split = lungs_mask_filename.split('.') 25 | extension = '.'.join(ext_split[1:]) 26 | 27 | if extension != 'nii' or extension != 'nii.gz': 28 | lungs_mask_filename = convert_and_export_to_nifti(input_filepath=lungs_mask_filename) 29 | 30 | lungs_mask_ni = load_nifti_volume(lungs_mask_filename) 31 | resampled_volume = resample_to_output(lungs_mask_ni, new_spacing, order=0) 32 | lungs_mask = resampled_volume.get_data().astype('float32') 33 | lungs_mask[lungs_mask < 0.5] = 0 34 | lungs_mask[lungs_mask >= 0.5] = 1 35 | lungs_mask = lungs_mask.astype('uint8') 36 | 37 | 38 | lung_region = regionprops(lungs_mask) 39 | min_row, min_col, min_depth, max_row, max_col, max_depth = lung_region[0].bbox 40 | print('cropping params', min_row, min_col, min_depth, max_row, max_col, max_depth) 41 | 42 | cropped_volume = volume[min_row:max_row, min_col:max_col, min_depth:max_depth] 43 | bbox = [min_row, min_col, min_depth, max_row, max_col, max_depth] 44 | 45 | return cropped_volume, bbox 46 | 47 | 48 | def resize_volume(volume, new_slice_size, slicing_plane, order=1): 49 | new_volume = None 50 | if len(new_slice_size) == 2: 51 | if slicing_plane == 'axial': 52 | new_val = int(volume.shape[2] * (new_slice_size[1] / volume.shape[1])) 53 | new_volume = resize(volume, (new_slice_size[0], new_slice_size[1], new_val), order=order) 54 | elif slicing_plane == 'sagittal': 55 | new_val = new_slice_size[0] 56 | new_volume = resize(volume, (new_val, new_slice_size[0], new_slice_size[1]), order=order) 57 | elif slicing_plane == 'coronal': 58 | new_val = new_slice_size[0] 59 | new_volume = resize(volume, (new_slice_size[0], new_val, new_slice_size[1]), order=order) 60 | elif len(new_slice_size) == 3: 61 | new_volume = resize(volume, new_slice_size, order=order) 62 | return new_volume 63 | 64 | 65 | def intensity_normalization_CT(volume, parameters): 66 | result = deepcopy(volume).astype('float32') 67 | 68 | result[volume < parameters.intensity_clipping_values[0]] = parameters.intensity_clipping_values[0] 69 | result[volume > parameters.intensity_clipping_values[1]] = parameters.intensity_clipping_values[1] 70 | 71 | if parameters.normalization_method == 'zeromean': 72 | mean_val = np.mean(result) 73 | var_val = np.std(result) 74 | tmp = (result - mean_val) / var_val 75 | result = tmp 76 | else: 77 | min_val = np.min(result) 78 | max_val = np.max(result) 79 | if (max_val - min_val) != 0: 80 | tmp = (result - min_val) / (max_val - min_val) 81 | result = tmp 82 | 83 | return result 84 | 85 | 86 | def intensity_normalization(volume, parameters): 87 | return intensity_normalization_CT(volume, parameters) 88 | 89 | 90 | def padding_for_inference(data, slab_size, slicing_plane): 91 | new_data = data 92 | if slicing_plane == 'axial': 93 | missing_dimension = (slab_size - (data.shape[2] % slab_size)) % slab_size 94 | if missing_dimension != 0: 95 | new_data = np.pad(data, ((0, 0), (0, 0), (0, missing_dimension), (0, 0)), mode='edge') 96 | elif slicing_plane == 'sagittal': 97 | missing_dimension = (slab_size - (data.shape[0] % slab_size)) % slab_size 98 | if missing_dimension != 0: 99 | new_data = np.pad(data, ((0, missing_dimension), (0, 0), (0, 0), (0, 0)), mode='edge') 100 | elif slicing_plane == 'coronal': 101 | missing_dimension = (slab_size - (data.shape[1] % slab_size)) % slab_size 102 | if missing_dimension != 0: 103 | new_data = np.pad(data, ((0, 0), (0, missing_dimension), (0, 0), (0, 0)), mode='edge') 104 | 105 | return new_data, missing_dimension 106 | 107 | 108 | def padding_for_inference_both_ends(data, slab_size, slicing_plane): 109 | new_data = data 110 | padding_val = int(slab_size / 2) 111 | if slicing_plane == 'axial': 112 | new_data = np.pad(data, ((0, 0), (0, 0), (padding_val, padding_val), (0, 0)), mode='edge') 113 | elif slicing_plane == 'sagittal': 114 | new_data = np.pad(data, ((padding_val, padding_val), (0, 0), (0, 0), (0, 0)), mode='edge') 115 | elif slicing_plane == 'coronal': 116 | new_data = np.pad(data, ((0, 0), (padding_val, padding_val), (0, 0), (0, 0)), mode='edge') 117 | 118 | return new_data 119 | -------------------------------------------------------------------------------- /src/Inference/predictions_reconstruction.py: -------------------------------------------------------------------------------- 1 | from nibabel import four_to_three 2 | from nibabel.processing import resample_to_output, resample_from_to 3 | from skimage.measure import regionprops, label 4 | from skimage.transform import resize 5 | from tensorflow.python.keras.models import load_model 6 | import matplotlib.pyplot as plt 7 | from scipy.ndimage import zoom 8 | import os 9 | import nibabel as nib 10 | from os.path import join 11 | import numpy as np 12 | import sys 13 | from shutil import copy 14 | from math import ceil, floor 15 | from copy import deepcopy 16 | from src.Utils.io import load_nifti_volume 17 | 18 | 19 | def reconstruct_post_predictions(predictions, parameters, crop_bbox, nib_volume, resampled_volume): 20 | print("Resampling predictions...") 21 | reconstruction_method = parameters.predictions_reconstruction_method 22 | probability_thresholds = parameters.predictions_probability_thresholds 23 | 24 | if parameters.predictions_reconstruction_order == 'resample_first': 25 | resampled_predictions = __resample_predictions(predictions=predictions, crop_bbox=crop_bbox, 26 | nib_volume=nib_volume, 27 | resampled_volume=resampled_volume, 28 | reconstruction_method=reconstruction_method) 29 | 30 | final_predictions = __cut_predictions(predictions=resampled_predictions, 31 | reconstruction_method=reconstruction_method, 32 | probability_threshold=probability_thresholds) 33 | else: 34 | thresh_predictions = __cut_predictions(predictions=predictions, reconstruction_method=reconstruction_method, 35 | probability_threshold=probability_thresholds) 36 | final_predictions = __resample_predictions(predictions=thresh_predictions, crop_bbox=crop_bbox, 37 | nib_volume=nib_volume, 38 | resampled_volume=resampled_volume, 39 | reconstruction_method=reconstruction_method) 40 | 41 | return final_predictions 42 | 43 | 44 | def __cut_predictions(predictions, probability_threshold, reconstruction_method): 45 | if reconstruction_method == 'probabilities': 46 | return predictions 47 | elif reconstruction_method == 'thresholding': 48 | final_predictions = np.zeros(predictions.shape).astype('uint8') 49 | if len(probability_threshold) != predictions.shape[-1]: 50 | probability_threshold = np.full(shape=(predictions.shape[-1]), fill_value=probability_threshold[0]) 51 | 52 | for c in range(0, predictions.shape[-1]): 53 | channel = deepcopy(predictions[:, :, :, c]) 54 | channel[channel < probability_threshold[c]] = 0 55 | channel[channel >= probability_threshold[c]] = 1 56 | final_predictions[:, :, :, c] = channel.astype('uint8') 57 | elif reconstruction_method == 'argmax': 58 | final_predictions = np.argmax(predictions, axis=-1).astype('uint8') 59 | else: 60 | raise ValueError('Unknown reconstruction_method!') 61 | 62 | return final_predictions 63 | 64 | 65 | def __resample_predictions(predictions, crop_bbox, nib_volume, resampled_volume, reconstruction_method): 66 | labels_type = predictions.dtype 67 | order = 0 if labels_type == np.uint8 else 1 68 | data = deepcopy(predictions).astype(labels_type) 69 | nb_classes = predictions.shape[-1] 70 | 71 | # Undo resizing (which is performed in function crop()) 72 | if crop_bbox is not None: 73 | resize_ratio = (crop_bbox[3] - crop_bbox[0], crop_bbox[4] - crop_bbox[1], crop_bbox[5] - crop_bbox[2]) / np.asarray(data.shape[0:3]) 74 | if len(data.shape) == 4: 75 | resize_ratio = list(resize_ratio) + [1.] 76 | if list(resize_ratio)[0:3] != [1., 1., 1.]: 77 | data = zoom(data, resize_ratio, order=order) 78 | 79 | # Undo cropping (which is performed in function crop()) 80 | if reconstruction_method == 'probabilities' or reconstruction_method == 'thresholding': 81 | new_data = np.zeros((resampled_volume.get_data().shape) + (nb_classes,), dtype=labels_type) 82 | else: 83 | new_data = np.zeros((resampled_volume.get_data().shape), dtype=labels_type) 84 | new_data[crop_bbox[0]:crop_bbox[3], crop_bbox[1]:crop_bbox[4], crop_bbox[2]:crop_bbox[5]] = data 85 | else: 86 | resize_ratio = resampled_volume.get_data().shape / np.asarray(data.shape)[0:3] 87 | if len(data.shape) == 4: 88 | resize_ratio = list(resize_ratio) + [1.] 89 | if list(resize_ratio)[0:3] != [1., 1., 1.]: 90 | new_data = zoom(data, resize_ratio, order=order) 91 | else: 92 | new_data = data 93 | 94 | # Resampling to the size and spacing of the original input volume 95 | if reconstruction_method == 'probabilities' or reconstruction_method == 'thresholding': 96 | resampled_predictions = np.zeros(nib_volume.get_data().shape + (nb_classes,)).astype(labels_type) 97 | for c in range(0, nb_classes): 98 | img = nib.Nifti1Image(new_data[..., c].astype(labels_type), affine=resampled_volume.affine) 99 | resampled_channel = resample_from_to(img, nib_volume, order=order) 100 | resampled_predictions[..., c] = resampled_channel.get_data() 101 | else: 102 | resampled_predictions = np.zeros(nib_volume.get_data().shape).astype(labels_type) 103 | img = nib.Nifti1Image(new_data.astype(labels_type), affine=resampled_volume.affine) 104 | resampled_channel = resample_from_to(img, nib_volume, order=order) 105 | resampled_predictions = resampled_channel.get_data() 106 | 107 | # Range has to be set to [0, 1] again after resampling with order 0 108 | if order == 3: 109 | for c in range(0, nb_classes): 110 | min_val = np.min(resampled_predictions[..., c]) 111 | max_val = np.max(resampled_predictions[..., c]) 112 | 113 | if (max_val - min_val) != 0: 114 | resampled_predictions[..., c] = (resampled_predictions[..., c] - min_val) / (max_val - min_val) 115 | 116 | return resampled_predictions 117 | 118 | 119 | def perform_ensemble(input_filename, output_path, model_list, user_runtime): 120 | input_ni = load_nifti_volume(input_filename) 121 | ensemble_pred = np.zeros(shape=(input_ni.shape)) 122 | for i, model in enumerate(model_list): 123 | pred_filename = output_path + model + '-pred_LymphNodes.nii.gz' 124 | predictions_ni = load_nifti_volume(pred_filename) 125 | if i == 0: 126 | ensemble_pred = predictions_ni.get_data()[:] 127 | else: 128 | ensemble_pred = np.maximum(ensemble_pred, predictions_ni.get_data()[:]) 129 | 130 | if user_runtime.predictions_reconstruction_method == 'thresholding': 131 | ensemble_pred_bin = np.zeros(ensemble_pred.shape).astype('uint8') 132 | ensemble_pred_bin[ensemble_pred>=user_runtime.predictions_probability_thresholds[0]] = 1 133 | nib.save(nib.Nifti1Image(ensemble_pred_bin, input_ni.affine), output_path + '-labels_Ensemble_LymphNodes.nii.gz') 134 | else: 135 | nib.save(nib.Nifti1Image(ensemble_pred, input_ni.affine), output_path + '-pred_Ensemble_LymphNodes.nii.gz') 136 | -------------------------------------------------------------------------------- /src/Inference/predictions.py: -------------------------------------------------------------------------------- 1 | from tensorflow.python.keras.models import load_model 2 | import matplotlib.pyplot as plt 3 | from scipy.ndimage import zoom 4 | import os 5 | from os.path import join 6 | import numpy as np 7 | import sys 8 | from shutil import copy 9 | from math import ceil, floor 10 | from copy import deepcopy 11 | from src.Utils.volume_utilities import padding_for_inference, padding_for_inference_both_ends 12 | from src.Models.UNet.DualAttentionUNet import PAM, CAM 13 | 14 | 15 | def run_predictions(data, model_path, training_parameters, runtime_parameters): 16 | """ 17 | Only the prediction is done in this function, possible thresholdings and re-sampling are not included here. 18 | :param data: 19 | :return: 20 | """ 21 | print("Loading model...") 22 | model = load_model(model_path, custom_objects={'PAM': PAM, 'CAM': CAM}, compile=False) 23 | 24 | whole_input_at_once = False 25 | if len(training_parameters.new_axial_size) == 3: 26 | whole_input_at_once = True 27 | 28 | final_result = None 29 | 30 | print("Predicting...") 31 | if whole_input_at_once: 32 | final_result = __run_predictions_whole(data=data, model=model, 33 | deep_supervision=training_parameters.training_deep_supervision) 34 | else: 35 | final_result = __run_predictions_slabbed(data=data, model=model, training_parameters=training_parameters, 36 | runtime_parameters=runtime_parameters) 37 | 38 | return final_result.astype('float32') 39 | 40 | 41 | def __run_predictions_whole(data, model, deep_supervision=False): 42 | data_prep = np.expand_dims(data, axis=0) 43 | 44 | predictions = model.predict(data_prep) 45 | 46 | if deep_supervision: 47 | return predictions[0][0] 48 | else: 49 | return predictions[0] 50 | 51 | 52 | def __run_predictions_slabbed(data, model, training_parameters, runtime_parameters): 53 | """ 54 | Working/tested for the axial and sagittal planes. 55 | """ 56 | slicing_plane = training_parameters.slicing_plane 57 | slab_size = training_parameters.training_slab_size 58 | new_axial_size = training_parameters.new_axial_size 59 | 60 | upper_boundary = data.shape[2] 61 | if slicing_plane == 'sagittal': 62 | upper_boundary = data.shape[0] 63 | elif slicing_plane == 'coronal': 64 | upper_boundary = data.shape[1] 65 | 66 | # Placeholder for the final predictions 67 | final_result = np.zeros(data.shape[:-1] + (training_parameters.training_nb_classes,)) 68 | count = 0 69 | 70 | if runtime_parameters.predictions_non_overlapping: 71 | data, pad_value = padding_for_inference(data=data, slab_size=slab_size, slicing_plane=slicing_plane) 72 | scale = ceil(upper_boundary / slab_size) 73 | unpad = False 74 | for chunk in range(scale): 75 | if chunk == scale-1 and pad_value != 0: 76 | unpad = True 77 | 78 | if slicing_plane == 'axial': 79 | slab_CT = data[:, :, int(chunk * slab_size):int((chunk + 1) * slab_size), 0] 80 | elif slicing_plane == 'sagittal': 81 | tmp = data[int(chunk * slab_size):int((chunk + 1) * slab_size), :, :, 0] 82 | slab_CT = tmp.transpose((1, 2, 0)) 83 | elif slicing_plane == 'coronal': 84 | tmp = data[:, int(chunk * slab_size):int((chunk + 1) * slab_size), :, 0] 85 | slab_CT = tmp.transpose((0, 2, 1)) 86 | 87 | slab_CT = np.expand_dims(np.expand_dims(slab_CT, axis=0), axis=-1) 88 | slab_CT_pred = model.predict(slab_CT) 89 | 90 | if not unpad: 91 | for c in range(0, slab_CT_pred.shape[-1]): 92 | if slicing_plane == 'axial': 93 | final_result[:, :, int(chunk * slab_size):int((chunk + 1) * slab_size), c] = \ 94 | slab_CT_pred[0][:, :, :slab_size, c] 95 | elif slicing_plane == 'sagittal': 96 | final_result[int(chunk * slab_size):int((chunk + 1) * slab_size), :, :, c] = \ 97 | slab_CT_pred[0][:, :, :slab_size, c].transpose((2, 0, 1)) 98 | elif slicing_plane == 'coronal': 99 | final_result[:, int(chunk * slab_size):int((chunk + 1) * slab_size), :, c] = \ 100 | slab_CT_pred[0][:, :, :slab_size, c].transpose((0, 2, 1)) 101 | else: 102 | for c in range(0, slab_CT_pred.shape[-1]): 103 | if slicing_plane == 'axial': 104 | final_result[:, :, int(chunk * slab_size):, c] = \ 105 | slab_CT_pred[0][:, :, :slab_size-pad_value, c] 106 | elif slicing_plane == 'sagittal': 107 | final_result[int(chunk * slab_size):, :, :, c] = \ 108 | slab_CT_pred[0][:, :, :slab_size-pad_value, c].transpose((2, 0, 1)) 109 | elif slicing_plane == 'coronal': 110 | final_result[:, int(chunk * slab_size):, :, c] = \ 111 | slab_CT_pred[0][:, :, :slab_size-pad_value, c].transpose((0, 2, 1)) 112 | 113 | print(count) 114 | count = count + 1 115 | else: 116 | if slab_size == 1: 117 | for slice in range(0, data.shape[2]): 118 | slab_CT = data[:, :, slice, 0] 119 | if np.sum(slab_CT > 0.1) == 0: 120 | continue 121 | slab_CT_pred = model.predict(np.reshape(slab_CT, (1, new_axial_size[0], new_axial_size[1], 1))) 122 | for c in range(0, slab_CT_pred.shape[-1]): 123 | final_result[:, :, slice, c] = slab_CT_pred[:, :, c] 124 | else: 125 | data = padding_for_inference_both_ends(data=data, slab_size=slab_size, slicing_plane=slicing_plane) 126 | half_slab_size = int(slab_size / 2) 127 | for slice in range(half_slab_size, upper_boundary): 128 | if slicing_plane == 'axial': 129 | slab_CT = data[:, :, slice - half_slab_size:slice + half_slab_size, 0] 130 | elif slicing_plane == 'sagittal': 131 | slab_CT = data[slice - half_slab_size:slice + half_slab_size, :, :, 0] 132 | slab_CT = slab_CT.transpose((1, 2, 0)) 133 | elif slicing_plane == 'coronal': 134 | slab_CT = data[:, slice - half_slab_size:slice + half_slab_size, :, 0] 135 | slab_CT = slab_CT.transpose((0, 2, 1)) 136 | 137 | slab_CT = np.reshape(slab_CT, (1, new_axial_size[0], new_axial_size[1], slab_size, 1)) 138 | if np.sum(slab_CT > 0.1) == 0: 139 | continue 140 | 141 | slab_CT_pred = model.predict(slab_CT) 142 | 143 | for c in range(0, slab_CT_pred.shape[-1]): 144 | if slicing_plane == 'axial': 145 | final_result[:, :, slice - half_slab_size, c] = slab_CT_pred[0][:, :, half_slab_size, c] 146 | elif slicing_plane == 'sagittal': 147 | final_result[slice, :, :, c] = slab_CT_pred[0][:, :, half_slab_size, c] 148 | elif slicing_plane == 'coronal': 149 | final_result[:, slice, :, c] = slab_CT_pred[0][:, :, half_slab_size, c] 150 | 151 | print(count) 152 | count = count + 1 153 | 154 | return final_result 155 | -------------------------------------------------------------------------------- /src/Utils/configuration_parser.py: -------------------------------------------------------------------------------- 1 | import configparser 2 | import os 3 | import sys 4 | 5 | 6 | class PreProcessingParser: 7 | def __init__(self, model_name): 8 | self.preprocessing_filename = os.path.join(os.path.dirname(os.path.realpath(__file__)), '../../', 9 | 'resources/models', model_name, 'pre_processing.ini') 10 | if not os.path.exists(self.preprocessing_filename): 11 | raise ValueError('Missing configuration file with pre-processing parameters: {}'. 12 | format(self.preprocessing_filename)) 13 | 14 | self.pre_processing_config = configparser.ConfigParser() 15 | self.pre_processing_config.read(self.preprocessing_filename) 16 | self.__parse_content() 17 | 18 | def __parse_content(self): 19 | self.__parse_pre_processing_content() 20 | self.__parse_training_content() 21 | 22 | def __parse_training_content(self): 23 | self.training_nb_classes = None 24 | self.training_class_names = None 25 | self.training_slab_size = None 26 | self.training_optimal_thresholds = None 27 | self.training_deep_supervision = False 28 | 29 | if self.pre_processing_config.has_option('Training', 'nb_classes'): 30 | self.training_nb_classes = int(self.pre_processing_config['Training']['nb_classes'].split('#')[0]) 31 | 32 | if self.pre_processing_config.has_option('Training', 'classes'): 33 | if self.pre_processing_config['Training']['classes'].split('#')[0].strip() != '': 34 | self.training_class_names = [x.strip() for x in self.pre_processing_config['Training']['classes'].split('#')[0].split(',')] 35 | 36 | if self.pre_processing_config.has_option('Training', 'slab_size'): 37 | self.training_slab_size = int(self.pre_processing_config['Training']['slab_size'].split('#')[0]) 38 | 39 | if self.pre_processing_config.has_option('Training', 'optimal_thresholds'): 40 | if self.pre_processing_config['Training']['optimal_thresholds'].split('#')[0].strip() != '': 41 | self.training_optimal_thresholds = [float(x.strip()) for x in self.pre_processing_config['Training']['optimal_thresholds'].split('#')[0].split(',')] 42 | 43 | if self.pre_processing_config.has_option('Training', 'deep_supervision'): 44 | if self.pre_processing_config['Training']['deep_supervision'].split('#')[0].strip() != '': 45 | self.training_deep_supervision = True if self.pre_processing_config['Training']['deep_supervision'].split('#')[0].strip().lower() == 'true' else False 46 | 47 | def __parse_pre_processing_content(self): 48 | self.preprocessing_library = 'nibabel' 49 | self.output_spacing = None 50 | self.crop_background = False 51 | self.intensity_clipping_values = None 52 | self.intensity_clipping_range = [0.0, 100.0] 53 | self.intensity_target_range = [0.0, 1.0] 54 | self.new_axial_size = None 55 | self.slicing_plane = 'axial' 56 | self.normalization_method = None 57 | 58 | if self.pre_processing_config.has_option('PreProcessing', 'library'): 59 | if self.pre_processing_config['PreProcessing']['library'].split('#')[0].strip() == 'dipy': 60 | self.preprocessing_library = 'dipy' 61 | 62 | if self.pre_processing_config.has_option('PreProcessing', 'output_spacing'): 63 | if self.pre_processing_config['PreProcessing']['output_spacing'].split('#')[0].strip() != '': 64 | self.output_spacing = [float(x) for x in self.pre_processing_config['PreProcessing']['output_spacing'].split('#')[0].split(',')] 65 | 66 | if self.pre_processing_config.has_option('PreProcessing', 'intensity_clipping_values'): 67 | if self.pre_processing_config['PreProcessing']['intensity_clipping_values'].split('#')[0].strip() != '': 68 | self.intensity_clipping_values = [float(x) for x in self.pre_processing_config['PreProcessing']['intensity_clipping_values'].split('#')[0].split(',')] 69 | 70 | if self.pre_processing_config.has_option('PreProcessing', 'intensity_clipping_range'): 71 | if self.pre_processing_config['PreProcessing']['intensity_clipping_range'].split('#')[0].strip() != '': 72 | self.intensity_clipping_range = [float(x) for x in self.pre_processing_config['PreProcessing']['intensity_clipping_range'].split('#')[0].split(',')] 73 | 74 | if self.pre_processing_config.has_option('PreProcessing', 'intensity_final_range'): 75 | if self.pre_processing_config['PreProcessing']['intensity_final_range'].split('#')[0].strip() != '': 76 | self.intensity_target_range = [float(x) for x in self.pre_processing_config['PreProcessing']['intensity_final_range'].split('#')[0].split(',')] 77 | 78 | if self.pre_processing_config.has_option('PreProcessing', 'background_cropping'): 79 | if self.pre_processing_config['PreProcessing']['background_cropping'].split('#')[0].strip() != '': 80 | self.crop_background = True if self.pre_processing_config['PreProcessing']['background_cropping'].split('#')[0].lower()\ 81 | == 'true' else False 82 | 83 | if self.pre_processing_config.has_option('PreProcessing', 'new_axial_size'): 84 | if self.pre_processing_config['PreProcessing']['new_axial_size'].split('#')[0].strip() != '': 85 | self.new_axial_size = [int(x) for x in self.pre_processing_config['PreProcessing']['new_axial_size'].split('#')[0].split(',')] 86 | 87 | if self.pre_processing_config.has_option('PreProcessing', 'slicing_plane'): 88 | if self.pre_processing_config['PreProcessing']['slicing_plane'].split('#')[0].strip() != '': 89 | self.slicing_plane = self.pre_processing_config['PreProcessing']['slicing_plane'].split('#')[0] 90 | 91 | if self.pre_processing_config.has_option('PreProcessing', 'normalization_method'): 92 | self.normalization_method = self.pre_processing_config['PreProcessing']['normalization_method'].split('#')[0] 93 | 94 | 95 | class RuntimeConfigParser: 96 | def __init__(self): 97 | self.runtime_filename = os.path.join(os.path.dirname(os.path.realpath(__file__)), '../../', 'resources/data', 98 | 'runtime_config.ini') 99 | if not os.path.exists(self.runtime_filename): 100 | raise ValueError('Missing configuration file with runtime parameters: {}'. 101 | format(self.runtime_filename)) 102 | 103 | self.runtime_config = configparser.ConfigParser() 104 | self.runtime_config.read(self.runtime_filename) 105 | self.__parse_content() 106 | 107 | def __parse_content(self): 108 | self.predictions_non_overlapping = True 109 | self.predictions_reconstruction_method = None 110 | self.predictions_reconstruction_order = None 111 | self.predictions_probability_thresholds = [0.5] 112 | 113 | if self.runtime_config.has_option('Predictions', 'non_overlapping'): 114 | self.predictions_non_overlapping = True if self.runtime_config['Predictions']['non_overlapping'].split('#')[0].lower().strip()\ 115 | == 'true' else False 116 | 117 | if self.runtime_config.has_option('Predictions', 'reconstruction_method'): 118 | self.predictions_reconstruction_method = self.runtime_config['Predictions']['reconstruction_method'].split('#')[0].strip() 119 | 120 | if self.runtime_config.has_option('Predictions', 'reconstruction_order'): 121 | self.predictions_reconstruction_order = self.runtime_config['Predictions']['reconstruction_order'].split('#')[0].strip() 122 | 123 | if self.runtime_config.has_option('Predictions', 'probability_threshold'): 124 | if self.runtime_config['Predictions']['probability_threshold'].split('#')[0].strip() != '': 125 | self.predictions_probability_thresholds = [float(x.strip()) for x in self.runtime_config['Predictions']['probability_threshold'].split('#')[0].split(',')] 126 | 127 | def set_default_runtime(self): 128 | self.predictions_non_overlapping = True 129 | self.predictions_reconstruction_method = 'probabilities' 130 | self.predictions_reconstruction_order = 'resample_first' 131 | self.predictions_probability_thresholds = [0.5] 132 | -------------------------------------------------------------------------------- /src/Models/UNet/AttentionGatedUNet.py: -------------------------------------------------------------------------------- 1 | from tensorflow.python.keras.layers import Input, Dense, Convolution2D, MaxPooling2D, Dropout, Flatten, SpatialDropout2D, \ 2 | ZeroPadding2D, Activation, AveragePooling2D, UpSampling2D, BatchNormalization, ConvLSTM2D, \ 3 | TimeDistributed, Concatenate, Lambda, Reshape, UpSampling3D, Convolution3D, MaxPooling3D, SpatialDropout3D,\ 4 | Conv2DTranspose, Conv3DTranspose, add, multiply, Reshape, Softmax, AveragePooling3D, Add, Layer 5 | from tensorflow.python.keras.models import Model 6 | import tensorflow as tf 7 | import numpy as np 8 | 9 | 10 | def convolution_block(x, nr_of_convolutions, use_bn=False, spatial_dropout=None): 11 | for i in range(2): 12 | x = Convolution3D(nr_of_convolutions, 3, padding='same')(x) 13 | if use_bn: 14 | x = BatchNormalization()(x) 15 | x = Activation('relu')(x) 16 | if spatial_dropout: 17 | x = SpatialDropout3D(spatial_dropout)(x) 18 | 19 | return x 20 | 21 | 22 | def attention_block(g, x, nr_of_convolutions): 23 | """ 24 | Taken from https://github.com/LeeJunHyun/Image_Segmentation 25 | """ 26 | g1 = Convolution3D(nr_of_convolutions, kernel_size=1, strides=1, padding='same', use_bias=True)(g) 27 | g1 = BatchNormalization()(g1) 28 | 29 | x1 = Convolution3D(nr_of_convolutions, kernel_size=1, strides=1, padding='same', use_bias=True)(x) 30 | x1 = BatchNormalization()(x1) 31 | 32 | psi = Concatenate()([g1, x1]) 33 | psi = Activation(activation='relu')(psi) 34 | psi = Convolution3D(1, kernel_size=1, strides=1, padding='same', use_bias=True)(psi) 35 | psi = BatchNormalization()(psi) 36 | psi = Activation(activation='sigmoid')(psi) 37 | 38 | return multiply([x, psi]) 39 | 40 | 41 | def attention_block_oktay(g, x, nr_of_convolutions): 42 | """ 43 | Following the original paper and implementation at https://github.com/ozan-oktay/Attention-Gated-Networks 44 | """ 45 | g1 = Convolution3D(nr_of_convolutions, kernel_size=1, strides=1, padding='same', use_bias=True)(g) 46 | g1 = BatchNormalization()(g1) 47 | 48 | x1 = MaxPooling3D([2, 2, 2])(x) 49 | x1 = Convolution3D(nr_of_convolutions, kernel_size=1, strides=1, padding='same', use_bias=True)(x1) 50 | x1 = BatchNormalization()(x1) 51 | 52 | psi = Concatenate()([g1, x1]) 53 | psi = Activation(activation='relu')(psi) 54 | psi = Convolution3D(1, kernel_size=1, strides=1, padding='same', use_bias=True)(psi) 55 | psi = BatchNormalization()(psi) 56 | psi = Activation(activation='sigmoid')(psi) 57 | 58 | return multiply([x, psi]) 59 | 60 | 61 | def encoder_block(x, nr_of_convolutions, use_bn=False, spatial_dropout=None): 62 | x_before_downsampling = convolution_block(x, nr_of_convolutions, use_bn, spatial_dropout) 63 | downsample = [2, 2, 2] 64 | for i in range(1, 4): 65 | if x.shape[i] <= 4: 66 | downsample[i-1] = 1 67 | 68 | x = MaxPooling3D(downsample)(x_before_downsampling) 69 | 70 | return x, x_before_downsampling 71 | 72 | 73 | def encoder_block_pyramid(x, input_ds, nr_of_convolutions, use_bn=False, spatial_dropout=None): 74 | pyramid_conv = Convolution3D(filters=nr_of_convolutions, kernel_size=(3, 3, 3), padding='same', activation='relu')(input_ds) 75 | x = Concatenate(axis=-1)([pyramid_conv, x]) 76 | x_before_downsampling = convolution_block(x, nr_of_convolutions, use_bn, spatial_dropout) 77 | downsample = [2, 2, 2] 78 | for i in range(1, 4): 79 | if x.shape[i] <= 4: 80 | downsample[i-1] = 1 81 | 82 | x = MaxPooling3D(downsample)(x_before_downsampling) 83 | 84 | return x, x_before_downsampling 85 | 86 | 87 | def decoder_block(x, cross_over_connection, nr_of_convolutions, use_bn=False, spatial_dropout=None): 88 | x = Conv3DTranspose(nr_of_convolutions, kernel_size=3, padding='same', strides=2)(x) 89 | if use_bn: 90 | x = BatchNormalization()(x) 91 | x = Activation('relu')(x) 92 | attention = attention_block(g=x, x=cross_over_connection, nr_of_convolutions=int(nr_of_convolutions/2)) 93 | x = Concatenate()([x, attention]) 94 | x = convolution_block(x, nr_of_convolutions, use_bn, spatial_dropout) 95 | 96 | return x 97 | 98 | 99 | class AttentionGatedUnet(): 100 | def __init__(self, input_shape, nb_classes, deep_supervision=False, input_pyramid=False): 101 | if len(input_shape) != 3 and len(input_shape) != 4: 102 | raise ValueError('Input shape must have 3 or 4 dimensions') 103 | if nb_classes <= 1: 104 | raise ValueError('Segmentation classes must be > 1') 105 | self.dims = 3 106 | self.input_shape = input_shape 107 | self.nb_classes = nb_classes 108 | self.deep_supervision = deep_supervision 109 | self.input_pyramid = input_pyramid 110 | self.convolutions = None 111 | self.encoder_use_bn = True 112 | self.decoder_use_bn = True 113 | self.encoder_spatial_dropout = None 114 | self.decoder_spatial_dropout = None 115 | 116 | def set_convolutions(self, convolutions): 117 | self.convolutions = convolutions 118 | 119 | def get_dice_loss(self): 120 | def dice_loss(target, output, epsilon=1e-10): 121 | smooth = 1. 122 | dice = 0 123 | 124 | for object in range(0, self.nb_classes): 125 | if self.dims == 2: 126 | output1 = output[:, :, :, object] 127 | target1 = target[:, :, :, object] 128 | else: 129 | output1 = output[:, :, :, :, object] 130 | target1 = target[:, :, :, :, object] 131 | intersection1 = tf.reduce_sum(output1 * target1) 132 | union1 = tf.reduce_sum(output1 * output1) + tf.reduce_sum(target1 * target1) 133 | dice += (2. * intersection1 + smooth) / (union1 + smooth) 134 | 135 | dice /= (self.nb_classes - 1) 136 | 137 | return tf.clip_by_value(1. - dice, 0., 1. - epsilon) 138 | 139 | return dice_loss 140 | 141 | def create(self): 142 | """ 143 | Create model and return it 144 | 145 | :return: keras model 146 | """ 147 | 148 | input_layer = Input(shape=self.input_shape) 149 | x = input_layer 150 | 151 | init_size = max(self.input_shape[:-1]) 152 | size = init_size 153 | 154 | convolutions = self.convolutions 155 | connection = [] 156 | i = 0 157 | 158 | if self.input_pyramid: 159 | scaled_input = [] 160 | scaled_input.append(x) 161 | for i, nbc in enumerate(self.convolutions[:-1]): 162 | ds_input = AveragePooling3D(pool_size=(2, 2, 2))(scaled_input[i]) 163 | scaled_input.append(ds_input) 164 | 165 | for i, nbc in enumerate(self.convolutions[:-1]): 166 | if not self.input_pyramid or (i == 0): 167 | x, x_before_ds = encoder_block(x, nbc, use_bn=self.encoder_use_bn, 168 | spatial_dropout=self.encoder_spatial_dropout) 169 | else: 170 | x, x_before_ds = encoder_block_pyramid(x, scaled_input[i], nbc, use_bn=self.encoder_use_bn, 171 | spatial_dropout=self.encoder_spatial_dropout) 172 | connection.insert(0, x_before_ds) # Append in reverse order for easier use in the next block 173 | 174 | x = convolution_block(x, self.convolutions[-1], self.encoder_use_bn, self.encoder_spatial_dropout) 175 | connection.insert(0, x) 176 | 177 | inverse_conv = self.convolutions[::-1] 178 | inverse_conv = inverse_conv[1:] 179 | decoded_layers = [] 180 | 181 | for i, nbc in enumerate(inverse_conv): 182 | x = decoder_block(x, connection[i+1], nbc, use_bn=self.decoder_use_bn, 183 | spatial_dropout=self.decoder_spatial_dropout) 184 | decoded_layers.append(x) 185 | 186 | if not self.deep_supervision: 187 | # Final activation layer 188 | x = Convolution3D(self.nb_classes, 1, activation='softmax')(x) 189 | else: 190 | recons_list = [] 191 | for i, lay in enumerate(decoded_layers): 192 | x = Convolution3D(self.nb_classes, 1, activation='softmax')(lay) 193 | recons_list.append(x) 194 | x = recons_list[::-1] 195 | 196 | return Model(inputs=input_layer, outputs=x) 197 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Mediastinal lymph nodes segmentation using 3D convolutional neural network ensembles and anatomical priors guiding 2 | - - - 3 | 4 | The repository contains the architectures, inference code, 5 | and trained models for lymph nodes segmentation in contrast-enhanced CT volumes. 6 | 7 | ![Dual attention guided U-Net architecture](resources/images/Architecture.png) 8 | 9 | ## Description 10 | The following models are provided, all trained on one fold of our 5-fold cross-validation (i.e., 72 patients): 11 | * CT_Lungs: Model trained using a simple 3D U-Net architecture, using the entire input volume. 12 | * UNet-SW32: Model trained using a simple 3D U-Net architecture, on slabs of 32 slices. 13 | * AGUNet: Model trained using the 3D Attention-Gated U-Net architecture, using the entire input volume. 14 | * AGUNet-APG: Model trained using the 3D Attention-Gated U-Net architecture, using the entire input volume, 15 | and a mask with anatomical priors (esophagus, azygos vein, subclavian artieries and brachiocephalic veins). 16 | :warning: the mask must be fed manually by the user, a trained model is not provided for this. 17 | 18 | ## Citing 19 | Please cite the following article if you re-use any part of the code, the trained models, or suggested ground truth files: 20 |
 21 |   @article{bouget2021mediastinal,
 22 |   author = {David Bouget and André Pedersen and Johanna Vanel and Haakon O. Leira and Thomas Langø},
 23 |   title = {Mediastinal lymph nodes segmentation using 3D convolutional neural network ensembles and anatomical priors guiding},
 24 |   journal = {Computer Methods in Biomechanics and Biomedical Engineering: Imaging \& Visualization},
 25 |   volume = {0},
 26 |   number = {0},
 27 |   pages = {1-15},
 28 |   year  = {2022},
 29 |   publisher = {Taylor & Francis},
 30 |   doi = {10.1080/21681163.2022.2043778},
 31 |   URL = {https://doi.org/10.1080/21681163.2022.2043778},
 32 |   eprint = {https://doi.org/10.1080/21681163.2022.2043778}}
 33 | 
34 | 35 | ## Stand-alone data and annotation access 36 | The annotations for the benchmark subset have been proofed by an expert radiologist. 37 | For the NIH dataset, the annotations were performed by a medical trainee under supervision of the expert. 38 | **Benchmark subset:** Fifteen contrast-enhanced CT volumes from St. Olavs University Hospital patients. 39 | Stations and segmentations for all lymph nodes are provided, in addition to the segmentation for the esophagus, 40 | azygos vein, subclavian arteries, and brachiocephalic veins. 41 | Available for direct download [here](https://drive.google.com/uc?id=1ZsFq7PslqQ5ow_dXB01kDkaKPqYDXD5d). 42 | **NIH dataset**: Stations and refined segmentations for all lymph nodes in 89 patients from the open-source NIH dataset, available [here](https://drive.google.com/uc?id=1iVCnZc1GHwtx9scyAXdANqz2HdQArTHn). 43 | The CT volumes are available for download on the official [web-page](https://wiki.cancerimagingarchive.net/display/Public/CT+Lymph+Nodes). 44 | 45 | **Mediastinal CT dataset**: The dataset described in our previous article is available for download [here](https://drive.google.com/uc?id=1YqCRcBpsFoE4JsBq5NROqIpeijnITpe1). 46 |
 47 |   @article{bouget2019semantic,  
 48 |   title={Semantic segmentation and detection of mediastinal lymph nodes and anatomical structures in CT data for lung cancer staging},  
 49 |   author={Bouget, David and Jørgensen, Arve and Kiss, Gabriel and Leira, Haakon Olav and Langø, Thomas},  
 50 |   journal={International journal of computer assisted radiology and surgery},  
 51 |   volume={14},  
 52 |   number={6},  
 53 |   pages={977--986},  
 54 |   year={2019},  
 55 |   publisher={Springer}}
 56 | 
57 | 58 | ## Installation 59 | The following steps have been tested on both Ubuntu and Windows. The details below are for Linux. See the troubleshooting section below for Windows-specific details. 60 | ### a. Python 61 | The Python virtual environment can be setup using the following commands: 62 | 63 | > `virtualenv -p python3 venv` 64 | `source venv/bin/activate` 65 | `pip install -r requirements.txt` 66 | 67 | ### b. Docker 68 | Simply download the corresponding Docker image: 69 | 70 | > `docker pull dbouget/ct_mediastinal_structures_segmentation:v1` 71 | 72 | ### c. Models 73 | In order to download the models locally and prepare the folders, simply run the following: 74 | 75 | > `source venv/bin/activate` 76 | `python setup.py` 77 | `deactivate` 78 | 79 | ## Usage 80 | ### a. Command line parameters 81 | The command line input parameters are the following: 82 | - i [Input]: Path to the MRI volume file, preferably in nifti format. Other formats will 83 | be converted to nifti before being processed, using SimpleITK. 84 | - o [Output]: Path and prefix for the predictions file. The base name must already exist 85 | on disk. 86 | - m [Model]: Name of the model to use for inference, in the list [UNet-SW32, AGUNet, AGUNet-APG]. 87 | To run ensembles, provide a coma-separated list of model names. 88 | - g [GPU]: (Optionnal )Id of the GPU to use for inference. The CPU is used if no eligible number is provided. 89 | - l [Lungs]: (Optionnal) Path to a pre-generated lungs mask, with the same parameters as the main CT volume 90 | (i.e., dimensions, spacings, origin). If none is provided the mask will be generated. 91 | - a [APG]: (Optionnal) Path to a pre-generated mask containing the anatomical priors for guiding, with the same 92 | parameters as the main CT volume (i.e., dimensions, spacings, origin). If none is provided the mask will be set to 0. 93 | 94 | ### b. Extra-configuration parameters 95 | A runtime configuration file also exists in resources/data/runtime_config.ini, 96 | where additional variables can be modified: 97 | - non_overlapping: [true, false], only in effect for the UNet-SW32 model. 98 | True indicates no overlapping in predictions while false indicates stride 1 overlap. 99 | - reconstruction_method: [thresholding, probabilities]. In the latter, raw prediction maps 100 | in range [0, 1] are dumped while in the former a binary mask is dumped using a pre-set 101 | probability threshold value. 102 | - reconstruction_order: [resample_first, resample_second]. In the former, the raw probability map 103 | is resampled to the original patient's space before the reconstruction happens (slower) while 104 | in the former the opposite is performed (faster). 105 | - probability_threshold: threshold value to be used when the reconstruction method is set to thresholding 106 | (optimal values for each model can be found in the paper). 107 | 108 | ### c. Python execution 109 | To run inference with the attention-gated U-Net model, using GPU 0, execute the following in the project root directory: 110 | > `source venv/bin/activate` 111 | `python main.py -i /path/to/file.nii.gz -o /output/path/to/output_prefix -m AGUNet -g 0` 112 | `deactivate` 113 | 114 | ### d. Docker execution 115 | The local resources sub-folder is mapped to the resources sub-folder within the docker container. 116 | As such, input CT volumes have to be copied inside resources/data to be processed and the output folder 117 | for the predictions has to be set within the resources sub-folder to be accessible locally. 118 | :warning: The docker container does not have gpu support so all inferences are performed on CPU only. 119 | 120 | > `cp /path/to/ct.nii.gz /path/to/ct_mediastinal_structures_segmentation/resources/data/ct.nii.gz` 121 | `docker run --entrypoint /bin/bash -v /path/to/ct_mediastinal_structures/resources:/home/ubuntu/resources -t -i dbouget/ct_mediastinal_structures_segmentation:v1` 122 | `python3 main.py -i ./resources/data/ct.nii.gz -o ./resources/output_prefix -m AGUNet` 123 | 124 | 125 | ## Acknowledgements 126 | Parts of the models' architectures were collected from the following repositories: 127 | - https://github.com/niecongchong/DANet-keras/ 128 | - https://github.com/ozan-oktay/Attention-Gated-Networks 129 | 130 | For more detailed information about attention mechanisms, please read the corresponding publications: 131 | 132 |
133 |   @inproceedings{fu2019dual,
134 |   title={Dual attention network for scene segmentation},
135 |   author={Fu, Jun and Liu, Jing and Tian, Haijie and Li, Yong and Bao, Yongjun and Fang, Zhiwei and Lu, Hanqing},
136 |   booktitle={Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition},
137 |   pages={3146--3154},
138 |   year={2019}}
139 | 
140 | 141 |
142 |   @article{oktay2018attention,
143 |   title={Attention u-net: Learning where to look for the pancreas},
144 |   author={Oktay, Ozan and Schlemper, Jo and Folgoc, Loic Le and Lee, Matthew and Heinrich, Mattias and Misawa, Kazunari and Mori, Kensaku and McDonagh, Steven and Hammerla, Nils Y and Kainz, Bernhard and others},
145 |   journal={arXiv preprint arXiv:1804.03999},
146 |   year={2018}}
147 | 
148 | 149 | ## Troubleshooting 150 | On windows, to activate the virtual environment, run: 151 | > `.\venv\Scripts\activate` 152 | 153 | This assumes that one is using [virtualenv](https://pypi.org/project/virtualenv/) to make virtual environments. This can be easily installed using pip by: 154 | > `pip install virtualenv` 155 | -------------------------------------------------------------------------------- /src/Models/UNet/DualAttentionUNet.py: -------------------------------------------------------------------------------- 1 | from tensorflow.python.keras.layers import Input, Dense, Convolution2D, MaxPooling2D, Dropout, Flatten, SpatialDropout2D, \ 2 | ZeroPadding2D, Activation, AveragePooling2D, UpSampling2D, BatchNormalization, ConvLSTM2D, \ 3 | TimeDistributed, Concatenate, Lambda, Reshape, UpSampling3D, Convolution3D, MaxPooling3D, SpatialDropout3D,\ 4 | Conv2DTranspose, Conv3DTranspose, add, multiply, Reshape, Softmax, AveragePooling3D, Add, Layer 5 | from tensorflow.python.keras.models import Model 6 | import tensorflow as tf 7 | import numpy as np 8 | import math 9 | 10 | 11 | class CAM(Layer): 12 | """ 13 | Implementation from https://github.com/niecongchong/DANet-keras/ 14 | """ 15 | def __init__(self, 16 | gamma_initializer=tf.zeros_initializer(), 17 | gamma_regularizer=None, 18 | gamma_constraint=None, 19 | **kwargs): 20 | super(CAM, self).__init__(**kwargs) 21 | self.gamma_initializer = gamma_initializer 22 | self.gamma_regularizer = gamma_regularizer 23 | self.gamma_constraint = gamma_constraint 24 | 25 | def build(self, input_shape): 26 | self.gamma = self.add_weight(shape=(1, ), 27 | initializer=self.gamma_initializer, 28 | name='gamma', 29 | regularizer=self.gamma_regularizer, 30 | constraint=self.gamma_constraint) 31 | 32 | self.built = True 33 | 34 | def compute_output_shape(self, input_shape): 35 | return input_shape 36 | 37 | def call(self, x): 38 | input_shape = x.get_shape().as_list() 39 | _, h, w, d, filters = input_shape 40 | 41 | vec_a = Reshape(target_shape=(h * w * d, filters))(x) 42 | vec_aT = tf.transpose(vec_a, perm=[0, 2, 1]) 43 | aTa = tf.linalg.matmul(vec_aT, vec_a) 44 | softmax_aTa = Activation('softmax')(aTa) 45 | aaTa = tf.linalg.matmul(vec_a, softmax_aTa) 46 | aaTa = Reshape(target_shape=(h, w, d, filters))(aaTa) 47 | out = (self.gamma * aaTa) + x 48 | return out 49 | 50 | 51 | class PAM(Layer): 52 | """ 53 | Implementation from https://github.com/niecongchong/DANet-keras/ 54 | """ 55 | def __init__(self, 56 | gamma_initializer=tf.zeros_initializer(), 57 | gamma_regularizer=None, 58 | gamma_constraint=None, 59 | **kwargs): 60 | super(PAM, self).__init__(**kwargs) 61 | self.gamma_initializer = gamma_initializer 62 | self.gamma_regularizer = gamma_regularizer 63 | self.gamma_constraint = gamma_constraint 64 | 65 | def build(self, input_shape): 66 | self.gamma = self.add_weight(shape=(1, ), 67 | initializer=self.gamma_initializer, 68 | name='gamma', 69 | regularizer=self.gamma_regularizer, 70 | constraint=self.gamma_constraint) 71 | 72 | self.built = True 73 | 74 | def compute_output_shape(self, input_shape): 75 | return input_shape 76 | 77 | def call(self, x): 78 | input_shape = x.get_shape().as_list() 79 | _, h, w, d, filters = input_shape 80 | b_layer = Convolution3D(filters // 8, 1, use_bias=False)(x) 81 | c_layer = Convolution3D(filters // 8, 1, use_bias=False)(x) 82 | d_layer = Convolution3D(filters, 1, use_bias=False)(x) 83 | 84 | b_layer = tf.transpose(Reshape(target_shape=(h * w * d, filters // 8))(b_layer), perm=[0, 2, 1]) 85 | c_layer = Reshape(target_shape=(h * w * d, filters // 8))(c_layer) 86 | d_layer = Reshape(target_shape=(h * w * d, filters))(d_layer) 87 | 88 | # The bc_mul matrix should be of size (H*W*D) * (H*W*D) 89 | bc_mul = tf.linalg.matmul(c_layer, b_layer) 90 | activation_bc_mul = Activation(activation='softmax')(bc_mul) 91 | bcd_mul = tf.linalg.matmul(activation_bc_mul, d_layer) 92 | bcd_mul = Reshape(target_shape=(h, w, d, filters))(bcd_mul) 93 | out = (self.gamma * bcd_mul) + x 94 | return out 95 | 96 | 97 | def convolution_block(x, nr_of_convolutions, use_bn=False, spatial_dropout=None): 98 | for i in range(2): 99 | x = Convolution3D(nr_of_convolutions, 3, padding='same')(x) 100 | if use_bn: 101 | x = BatchNormalization()(x) 102 | x = Activation('relu')(x) 103 | if spatial_dropout: 104 | x = SpatialDropout3D(spatial_dropout)(x) 105 | 106 | return x 107 | 108 | 109 | def encoder_block(x, nr_of_convolutions, use_bn=False, spatial_dropout=None): 110 | x_before_downsampling = convolution_block(x, nr_of_convolutions, use_bn, spatial_dropout) 111 | downsample = [2, 2, 2] 112 | for i in range(1, 4): 113 | if x.shape[i] <= 4: 114 | downsample[i-1] = 1 115 | 116 | x = MaxPooling3D(downsample)(x_before_downsampling) 117 | 118 | return x, x_before_downsampling 119 | 120 | 121 | def encoder_block_pyramid(x, input_ds, nr_of_convolutions, use_bn=False, spatial_dropout=None): 122 | pyramid_conv = Convolution3D(filters=nr_of_convolutions, kernel_size=(3, 3, 3), padding='same', activation='relu')(input_ds) 123 | x = Concatenate(axis=-1)([pyramid_conv, x]) 124 | x_before_downsampling = convolution_block(x, nr_of_convolutions, use_bn, spatial_dropout) 125 | downsample = [2, 2, 2] 126 | for i in range(1, 4): 127 | if x.shape[i] <= 4: 128 | downsample[i-1] = 1 129 | 130 | x = MaxPooling3D(downsample)(x_before_downsampling) 131 | 132 | return x, x_before_downsampling 133 | 134 | 135 | def decoder_block(x, cross_over_connection, nr_of_convolutions, use_bn=False, spatial_dropout=None): 136 | x = Conv3DTranspose(nr_of_convolutions, kernel_size=3, padding='same', strides=2)(x) 137 | x = Concatenate()([cross_over_connection, x]) 138 | if use_bn: 139 | x = BatchNormalization()(x) 140 | x = Activation('relu')(x) 141 | x = convolution_block(x, nr_of_convolutions, use_bn, spatial_dropout) 142 | 143 | return x 144 | 145 | 146 | def decoder_block_guided(x, cross_over_connection, nr_of_convolutions, iteration, attention_layer, use_bn=False, spatial_dropout=None): 147 | x = Conv3DTranspose(nr_of_convolutions, kernel_size=3, padding='same', strides=2)(x) 148 | upsampling_factor = int(math.pow(2, iteration)) 149 | attention_layer_up = Conv3DTranspose(nr_of_convolutions, kernel_size=3, padding='same', strides=upsampling_factor)(attention_layer) 150 | x = Concatenate()([attention_layer_up, cross_over_connection, x]) 151 | if use_bn: 152 | x = BatchNormalization()(x) 153 | x = Activation('relu')(x) 154 | x = convolution_block(x, nr_of_convolutions, use_bn, spatial_dropout) 155 | 156 | return x 157 | 158 | 159 | class DualAttentionUnet(): 160 | def __init__(self, input_shape, nb_classes, deep_supervision=False, input_pyramid=False, attention_guiding=False): 161 | if len(input_shape) != 3 and len(input_shape) != 4: 162 | raise ValueError('Input shape must have 3 or 4 dimensions') 163 | if nb_classes <= 1: 164 | raise ValueError('Segmentation classes must be > 1') 165 | self.dims = 3 166 | self.input_shape = input_shape 167 | self.nb_classes = nb_classes 168 | self.deep_supervision = deep_supervision 169 | self.input_pyramid = input_pyramid 170 | self.attention_guided = attention_guiding 171 | self.convolutions = None 172 | self.encoder_use_bn = True 173 | self.decoder_use_bn = True 174 | self.encoder_spatial_dropout = None 175 | self.decoder_spatial_dropout = None 176 | 177 | def set_convolutions(self, convolutions): 178 | self.convolutions = convolutions 179 | 180 | def get_dice_loss(self): 181 | def dice_loss(target, output, epsilon=1e-10): 182 | smooth = 1. 183 | dice = 0 184 | 185 | for object in range(0, self.nb_classes): 186 | if self.dims == 2: 187 | output1 = output[:, :, :, object] 188 | target1 = target[:, :, :, object] 189 | else: 190 | output1 = output[:, :, :, :, object] 191 | target1 = target[:, :, :, :, object] 192 | intersection1 = tf.reduce_sum(output1 * target1) 193 | union1 = tf.reduce_sum(output1 * output1) + tf.reduce_sum(target1 * target1) 194 | dice += (2. * intersection1 + smooth) / (union1 + smooth) 195 | 196 | dice /= (self.nb_classes - 1) 197 | 198 | return tf.clip_by_value(1. - dice, 0., 1. - epsilon) 199 | 200 | return dice_loss 201 | 202 | def create(self): 203 | """ 204 | Create model and return it 205 | 206 | :return: keras model 207 | """ 208 | 209 | input_layer = Input(shape=self.input_shape) 210 | x = input_layer 211 | 212 | init_size = max(self.input_shape[:-1]) 213 | size = init_size 214 | 215 | convolutions = self.convolutions 216 | connection = [] 217 | i = 0 218 | 219 | if self.input_pyramid: 220 | scaled_input = [] 221 | scaled_input.append(x) 222 | for i, nbc in enumerate(self.convolutions[:-1]): 223 | ds_input = AveragePooling3D(pool_size=(2, 2, 2))(scaled_input[i]) 224 | scaled_input.append(ds_input) 225 | 226 | for i, nbc in enumerate(self.convolutions[:-1]): 227 | if not self.input_pyramid or (i == 0): 228 | x, x_before_ds = encoder_block(x, nbc, use_bn=self.encoder_use_bn, 229 | spatial_dropout=self.encoder_spatial_dropout) 230 | else: 231 | x, x_before_ds = encoder_block_pyramid(x, scaled_input[i], nbc, use_bn=self.encoder_use_bn, 232 | spatial_dropout=self.encoder_spatial_dropout) 233 | connection.insert(0, x_before_ds) # Append in reverse order for easier use in the next block 234 | 235 | x = convolution_block(x, self.convolutions[-1], self.encoder_use_bn, self.encoder_spatial_dropout) 236 | connection.insert(0, x) 237 | 238 | pam = PAM()(x) 239 | pam = Convolution3D(self.convolutions[-1], 3, padding='same')(pam) 240 | pam = BatchNormalization()(pam) 241 | pam = Activation('relu')(pam) 242 | pam = SpatialDropout3D(0.5)(pam) 243 | pam = Convolution3D(self.convolutions[-1], 3, padding='same')(pam) 244 | 245 | cam = CAM()(x) 246 | cam = Convolution3D(self.convolutions[-1], 3, padding='same')(cam) 247 | cam = BatchNormalization()(cam) 248 | cam = Activation('relu')(cam) 249 | cam = SpatialDropout3D(0.5)(cam) 250 | cam = Convolution3D(self.convolutions[-1], 3, padding='same')(cam) 251 | 252 | x = add([pam, cam]) 253 | x = SpatialDropout3D(0.5)(x) 254 | x = Convolution3D(self.convolutions[-1], 1, padding='same')(x) 255 | x_bottom = x = BatchNormalization()(x) 256 | 257 | inverse_conv = self.convolutions[::-1] 258 | inverse_conv = inverse_conv[1:] 259 | decoded_layers = [] 260 | for i, nbc in enumerate(inverse_conv): 261 | if not self.attention_guided: 262 | x = decoder_block(x, connection[i+1], nbc, use_bn=self.decoder_use_bn, 263 | spatial_dropout=self.decoder_spatial_dropout) 264 | else: 265 | x = decoder_block_guided(x, connection[i + 1], nbc, iteration=i+1, attention_layer=x_bottom, 266 | use_bn=self.decoder_use_bn, spatial_dropout=self.decoder_spatial_dropout) 267 | decoded_layers.append(x) 268 | 269 | if not self.deep_supervision: 270 | # Final activation layer 271 | x = Convolution3D(self.nb_classes, 1, activation='softmax')(x) 272 | else: 273 | recons_list = [] 274 | for i, lay in enumerate(decoded_layers): 275 | x = Convolution3D(self.nb_classes, 1, activation='softmax')(lay) 276 | recons_list.append(x) 277 | x = recons_list[::-1] 278 | 279 | return Model(inputs=input_layer, outputs=x) 280 | --------------------------------------------------------------------------------