├── data ├── OCTA-Data.tif └── OCTA-Labels.tif ├── requirements.txt ├── LICENSE ├── tubenet_env.yml ├── CITATION.cff ├── test.py ├── predict.py ├── preprocessing.py ├── tUbeNet_classes.py ├── README.md ├── train.py ├── model.py └── tUbeNet_functions.py /data/OCTA-Data.tif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/natalie11/tUbeNet/HEAD/data/OCTA-Data.tif -------------------------------------------------------------------------------- /data/OCTA-Labels.tif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/natalie11/tUbeNet/HEAD/data/OCTA-Labels.tif -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | tensorflow>=2.20.0 2 | tensorboard>=2.20.0 3 | numpy>=2.3.0 4 | scipy>=1.16.0 5 | scikit-learn>=1.7.0 6 | scikit-image>=0.25.2 7 | simpleitk>=2.5.0 8 | matplotlib>=3.10.0 9 | nibabel>=5.3.0 10 | dask>=2025.7.0 11 | tqdm>=4.67.0 12 | tifffile>=2025.6.11 13 | zarr>=3.1.0 14 | 15 | # Tested NVIDIA GPU libraries (Linux only - DELETE before installing on Windows) 16 | nvidia-cublas-cu12==12.9.1.4 17 | nvidia-cuda-cupti-cu12==12.9.79 18 | nvidia-cuda-nvcc-cu12==12.9.86 19 | nvidia-cuda-nvrtc-cu12==12.9.86 20 | nvidia-cuda-runtime-cu12==12.9.79 21 | nvidia-cudnn-cu12==9.3.0.75 22 | nvidia-nvjitlink-cu12==12.9.86 23 | nvidia-cufft-cu12==11.4.1.4 24 | nvidia-curand-cu12==10.3.10.19 25 | nvidia-cusparse-cu12==12.5.10.65 26 | nvidia-cusolver-cu12==11.7.5.82 27 | nvidia-nccl-cu12==2.27.7 28 | 29 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2025 Natalie Holroyd 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /tubenet_env.yml: -------------------------------------------------------------------------------- 1 | name: tubenet 2 | channels: 3 | - conda-forge 4 | - defaults 5 | dependencies: 6 | - python=3.11 7 | - pip 8 | - pip: 9 | - tensorflow>=2.20.0 10 | - tensorboard>=2.20.0 11 | - numpy>=2.3.0 12 | - scipy>=1.16.0 13 | - scikit-learn>=1.7.0 14 | - scikit-image>=0.25.2 15 | - simpleitk>=2.5.0 16 | - matplotlib>=3.10.0 17 | - nibabel>=5.3.0 18 | - dask>=2025.7.0 19 | - tqdm>=4.67.0 20 | - tifffile>=2025.6.11 21 | - zarr>=3.1.0 22 | 23 | 24 | # Tested NVIDIA GPU libraries (Linux only - DELETE before installing on Windows) 25 | - nvidia-cublas-cu12==12.9.1.4 26 | - nvidia-cuda-cupti-cu12==12.9.79 27 | - nvidia-cuda-nvcc-cu12==12.9.86 28 | - nvidia-cuda-nvrtc-cu12==12.9.86 29 | - nvidia-cuda-runtime-cu12==12.9.79 30 | - nvidia-cudnn-cu12==9.3.0.75 31 | - nvidia-nvjitlink-cu12==12.9.86 32 | - nvidia-cufft-cu12==11.4.1.4 33 | - nvidia-curand-cu12==10.3.10.19 34 | - nvidia-cusparse-cu12==12.5.10.65 35 | - nvidia-cusolver-cu12==11.7.5.82 36 | - nvidia-nccl-cu12==2.27.7 37 | -------------------------------------------------------------------------------- /CITATION.cff: -------------------------------------------------------------------------------- 1 | # This CITATION.cff file was generated with cffinit. 2 | # Visit https://bit.ly/cffinit to generate yours today! 3 | 4 | cff-version: 1.2.0 5 | title: tUbeNet 6 | message: Use this file to cite any uses of this ML mode 7 | type: software 8 | authors: 9 | - given-names: Natalie Aroha 10 | family-names: Holroyd 11 | email: natalie.holroyd.16@ucl.ac.uk 12 | affiliation: University College London 13 | orcid: 'https://orcid.org/0000-0001-9174-1346' 14 | - given-names: Zhongwang 15 | family-names: Li 16 | affiliation: University College London 17 | - given-names: Claire 18 | family-names: Walsh 19 | affiliation: University College London 20 | orcid: 'https://orcid.org/0000-0003-3769-3392' 21 | - given-names: Emmeline 22 | family-names: Brown 23 | affiliation: University College London 24 | orcid: 'https://orcid.org/0000-0001-6222-0146' 25 | - given-names: Rebecca 26 | family-names: Shipley 27 | affiliation: University College London 28 | orcid: 'https://orcid.org/0000-0002-2818-6228' 29 | - given-names: Simon 30 | family-names: Walker-Samuel 31 | affiliation: University College London 32 | orcid: 'https://orcid.org/0000-0003-3530-9166' 33 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | """ 4 | Created on Fri Sep 5 09:56:27 2025 5 | 6 | @author: natalie 7 | """ 8 | 9 | #Import libraries 10 | import os 11 | os.environ['TF_CPP_MIN_LOG_LEVEL'] = '1' #Suppress info logs from tf 12 | import pickle 13 | from model import tUbeNet 14 | import tUbeNet_functions as tube 15 | from tUbeNet_classes import DataDir 16 | import argparse 17 | 18 | def main(args): 19 | """Set arameters and file paths:""" 20 | # Paramters 21 | volume_dims = args.volume_dims 22 | overlap = args.overlap 23 | n_classes = 2 #TO DO expand to handel multi-class case 24 | 25 | binary_output = args.binary_output 26 | attention = args.attention 27 | 28 | data_headers = args.data_headers 29 | model_path = args.model_path 30 | output_path = args.output_path 31 | 32 | #---------------------------------------------------------------------------------------------------------------------------------------------- 33 | """ Create Data Directory""" 34 | # Load data headers into a list 35 | header_filenames=[f for f in os.listdir(data_headers) if os.path.isfile(os.path.join(data_headers, f))] 36 | headers = [] 37 | try: 38 | for file in header_filenames: #Iterate through header files 39 | file=os.path.join(data_headers,file) 40 | with open(file, "rb") as f: 41 | data_header = pickle.load(f) # Unpickle DataHeader object 42 | headers.append(data_header) # Add to list of headers 43 | except IndexError: print("Unable to load data header files from {data_headers}") 44 | 45 | # Create empty data directory 46 | data_dir = DataDir([], image_dims=[], 47 | image_filenames=[], 48 | label_filenames=[], 49 | data_type=[], exclude_region=[]) 50 | 51 | # Fill directory from headers 52 | for header in headers: 53 | data_dir.list_IDs.append(header.ID) 54 | data_dir.image_dims.append(header.image_dims) 55 | data_dir.image_filenames.append(header.image_filename) 56 | data_dir.label_filenames.append(header.label_filename) 57 | data_dir.data_type.append('float32') 58 | data_dir.exclude_region.append((None,None,None)) #region to be left out of training for use as validation data (under development) 59 | 60 | 61 | """ Load Model """ 62 | tubenet = tUbeNet(n_classes=n_classes, input_dims=volume_dims, attention=attention) 63 | 64 | # Load exisiting model with or without fine tuning adjustment (fine tuning -> classifier replaced and first 10 layers frozen) 65 | model = tubenet.load_weights(filename=model_path, 66 | loss='DICE BCE', 67 | metrics=['accuracy', 'recall', 'precision', tube.dice]) 68 | 69 | """ Plot ROC """ 70 | # Evaluate model on data 71 | validation_metrics = tube.roc_analysis(model, data_dir, 72 | volume_dims=volume_dims, 73 | n_classes=n_classes, 74 | overlap=overlap, 75 | output_path=output_path, 76 | binary_output=binary_output) 77 | 78 | def parse_dims(values): 79 | """Parse volume dimensions: allow either one int (isotropic) or three ints (anisotropic).""" 80 | if len(values) == 1: 81 | return (values[0], values[0], values[0]) 82 | elif len(values) == 3: 83 | return tuple(values) 84 | else: 85 | raise argparse.ArgumentTypeError( 86 | "volume_dims must be either a single value (e.g. --volume_dims 64) " 87 | "or three values (e.g. --volume_dims 64 64 32).") 88 | 89 | if __name__ == "__main__": 90 | parser = argparse.ArgumentParser(description="Evaluate TubeNet model on paired data.") 91 | 92 | parser.add_argument("--data_headers", type=str, required=True, 93 | help="Path to directory containing preprocessed header files.") 94 | parser.add_argument("--model_path", type=str, required=True, 95 | help="Path to trained model (.h5 file).") 96 | parser.add_argument("--output_path", type=str, required=True, 97 | help="Directory where predictions will be saved.") 98 | 99 | parser.add_argument("--volume_dims", type=int, nargs="+", default=[64, 64, 64], 100 | help="Volume dimensions passed to CNN. Provide 1 value (isotropic) " 101 | "or 3 values (anisotropic). E.g. --volume_dims 64 OR --volume_dims 32 64 64") 102 | parser.add_argument("--overlap", type=int, nargs="+", default=None, 103 | help="Overlap between patches during inference. Provide 1 value (isotropic) " 104 | "or 3 values (anisotropic). E.g. --overlap 32 OR --volume_dims 16 32 32. " 105 | "Defaults to half of volume_dims.") 106 | parser.add_argument("--binary_output", action="store_true", 107 | help="Save predictions as binary (True) or softmax (False).") 108 | parser.add_argument("--attention", action="store_true", 109 | help="Use this flag if loading a tubenet model built with attention blocks") 110 | 111 | args = parser.parse_args() 112 | args.volume_dims = parse_dims(args.volume_dims) 113 | if args.overlap: args.overlap = parse_dims(args.overlap) #Parse if not None 114 | main(args) -------------------------------------------------------------------------------- /predict.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | """ 4 | Created on Wed Sep 3 16:22:26 2025 5 | 6 | @author: natalie 7 | """ 8 | 9 | #Import libraries 10 | import os 11 | os.environ['TF_CPP_MIN_LOG_LEVEL'] = '1' #Suppress info logs from tf 12 | import pickle 13 | from model import tUbeNet 14 | import tUbeNet_functions as tube 15 | from tUbeNet_classes import DataDir 16 | import argparse 17 | 18 | def main(args): 19 | """Set parameters and file paths:""" 20 | # Paramters 21 | volume_dims = args.volume_dims 22 | overlap = args.overlap 23 | n_classes = 2 #TO DO expand to handel multi-class case 24 | 25 | binary_output = args.binary_output 26 | preview = args.preview 27 | attention = args.attention 28 | 29 | data_headers = args.data_headers 30 | model_path = args.model_path 31 | output_path = args.output_path 32 | tiff_path = args.tiff_path 33 | 34 | #---------------------------------------------------------------------------------------------------------------------------------------------- 35 | """ Create Data Directory""" 36 | # Load data headers into a list 37 | header_filenames=[f for f in os.listdir(data_headers) if os.path.isfile(os.path.join(data_headers, f))] 38 | headers = [] 39 | try: 40 | for file in header_filenames: #Iterate through header files 41 | file=os.path.join(data_headers,file) 42 | with open(file, "rb") as f: 43 | data_header = pickle.load(f) # Unpickle DataHeader object 44 | headers.append(data_header) # Add to list of headers 45 | except IndexError: print("Unable to load data header files from {}".format(data_headers)) 46 | 47 | # Create empty data directory 48 | data_dir = DataDir([], image_dims=[], 49 | image_filenames=[], 50 | label_filenames=[], 51 | data_type=[], exclude_region=[]) 52 | 53 | # Fill directory from headers 54 | for header in headers: 55 | data_dir.list_IDs.append(header.ID) 56 | data_dir.image_dims.append(header.image_dims) 57 | data_dir.image_filenames.append(header.image_filename) 58 | data_dir.label_filenames.append(None) #Labels not required for prediction 59 | data_dir.data_type.append('float32') 60 | data_dir.exclude_region.append((None,None,None)) #region to be left out of training for use as validation data (under development) 61 | 62 | 63 | """ Load Model """ 64 | # Initialise model 65 | tubenet = tUbeNet(n_classes=n_classes, input_dims=volume_dims, attention=attention) 66 | # Load weights 67 | model = tubenet.load_weights(filename=model_path, loss='DICE BCE') 68 | 69 | # If undefined set overlap to half volume_dims 70 | if not overlap: 71 | overlap = (volume_dims[0]//2,volume_dims[1]//2,volume_dims[2]//2) 72 | 73 | """Predict segmentation""" 74 | for i in data_dir.image_filenames: 75 | # Isolate image filename 76 | image_directory, image_filename = os.path.split(i.replace('\\','/')) 77 | print("Begining Inference on {}".format(image_filename)) 78 | 79 | # Create output filenames 80 | dask_name = os.path.join(output_path,str(image_filename)+"_segmentation") 81 | if tiff_path: tiff_name=os.path.join(tiff_path,str(image_filename)+"_segmentation.tiff") 82 | else: tiff_name = None 83 | tube.predict_segmentation_dask( 84 | model, 85 | i, 86 | dask_name, 87 | volume_dims=volume_dims, 88 | overlap=overlap, 89 | n_classes=n_classes, 90 | export_bigtiff=tiff_name, 91 | preview=preview, 92 | binary_output=binary_output, 93 | prob_channel=1, 94 | ) 95 | 96 | def parse_dims(values): 97 | """Parse volume dimensions: allow either one int (isotropic) or three ints (anisotropic).""" 98 | if len(values) == 1: 99 | return (values[0], values[0], values[0]) 100 | elif len(values) == 3: 101 | return tuple(values) 102 | else: 103 | raise argparse.ArgumentTypeError( 104 | "volume_dims must be either a single value (e.g. --volume_dims 64) " 105 | "or three values (e.g. --volume_dims 64 64 32).") 106 | 107 | if __name__ == "__main__": 108 | parser = argparse.ArgumentParser(description="Run prediction using TubeNet model.") 109 | 110 | parser.add_argument("--data_headers", type=str, required=True, 111 | help="Path to directory containing preprocessed header files.") 112 | parser.add_argument("--model_path", type=str, required=True, 113 | help="Path to trained model (.h5 file).") 114 | parser.add_argument("--output_path", type=str, required=True, 115 | help="Directory where predictions will be saved.") 116 | parser.add_argument("--tiff_path", type=str, default=None, 117 | help="Optional path to save TIFF output in addition to Zarr.") 118 | 119 | parser.add_argument("--volume_dims", type=int, nargs="+", default=[64, 64, 64], 120 | help="Volume dimensions passed to CNN. Provide 1 value (isotropic) " 121 | "or 3 values (anisotropic). E.g. --volume_dims 64 OR --volume_dims 32 64 64") 122 | parser.add_argument("--overlap", type=int, nargs="+", default=None, 123 | help="Overlap between patches during inference. Provide 1 value (isotropic) " 124 | "or 3 values (anisotropic). E.g. --overlap 32 OR --volume_dims 16 32 32. " 125 | "Defaults to half of volume_dims.") 126 | parser.add_argument("--binary_output", action="store_true", 127 | help="Save predictions as binary image. Otherwise, the softmax output will be saved.") 128 | parser.add_argument("--preview", action="store_true", 129 | help="Display preview of predicted segmentation during inference.") 130 | parser.add_argument("--attention", action="store_true", 131 | help="Use this flag if loading a tubenet model built with attention blocks") 132 | 133 | args = parser.parse_args() 134 | args.volume_dims = parse_dims(args.volume_dims) 135 | if args.overlap: args.overlap = parse_dims(args.overlap) #Parse if not None 136 | main(args) -------------------------------------------------------------------------------- /preprocessing.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | """ 4 | Created on Wed Sep 3 16:04:47 2025 5 | 6 | @author: natalie 7 | """ 8 | 9 | # -*- coding: utf-8 -*- 10 | """tUbeNet 3D 11 | Data Preprocessing script: load image data (and optional labels) and convert into zarr format with data header 12 | 13 | 14 | Developed by Natalie Holroyd (UCL) 15 | """ 16 | 17 | #Import libraries 18 | import os 19 | import numpy as np 20 | import tUbeNet_functions as tube 21 | import argparse 22 | 23 | def main(args): 24 | #---------------------------------------------------------------------------------------------------------------------------------------------- 25 | """Set hard-coded parameters and file paths:""" 26 | 27 | # Paramters 28 | chunks = tuple(args.chunks) # chunk size for saving zarr files equal to chunk size used by model 29 | val_fraction = args.val_fraction # fraction of data to use for validation 30 | crop = args.crop # crop images if there are large sections of background containing no vessels 31 | 32 | image_directory = args.image_directory 33 | label_directory = args.label_directory 34 | output_path = args.output_path 35 | 36 | #---------------------------------------------------------------------------------------------------------------------------------------------- 37 | # Create list of image files 38 | if os.path.isdir(image_directory): 39 | # Add all file paths of image_paths 40 | image_filenames = [f for f in os.listdir(image_directory) if os.path.isfile(os.path.join(image_directory, f))] 41 | elif os.path.isfile(image_directory): 42 | # If file is given, process this file only 43 | image_directory, image_filenames = os.path.split(image_directory.replace('\\','/')) 44 | image_filenames = [image_filenames] 45 | else: raise ValueError('Image directory could not be found') 46 | 47 | # Create list of label files 48 | if label_directory is not None: 49 | if os.path.isdir(label_directory): 50 | # Add all file paths of image_paths 51 | label_filenames = [f for f in os.listdir(label_directory) if os.path.isfile(os.path.join(label_directory, f))] 52 | elif os.path.isfile(label_directory): 53 | # If file is given, process this file only (split filename from rest of path) 54 | label_directory, label_filenames = os.path.split(label_directory.replace('\\','/')) 55 | label_filenames = [label_filenames] 56 | else: raise ValueError('A label directory was provided but could not be found. Set label_directory to None if not using.') 57 | 58 | assert len(image_filenames)==len(label_filenames), "Expected same number of image and label files. Set label_directory to None if not using." 59 | else: 60 | label_filenames = [None]*len(image_filenames) 61 | 62 | 63 | # Process and save each dataset in directory 64 | for image_filename, label_filename in zip(image_filenames, label_filenames): 65 | # Set names and paths 66 | output_name = os.path.splitext(image_filename)[0] 67 | image_path = os.path.join(image_directory, image_filename) 68 | if label_filename is not None: 69 | label_path = os.path.join(label_directory, label_filename) 70 | else: label_path = None 71 | 72 | # Run preprocessing 73 | data, labels, classes = tube.data_preprocessing(image_path=image_path, 74 | label_path=label_path) 75 | 76 | # Set data type 77 | data = data.astype('float32') 78 | if labels is not None: 79 | labels = labels.astype('int8') 80 | 81 | # Crop 82 | if crop and labels is not None: 83 | labels, data = tube.crop_from_labels(labels, data) 84 | 85 | # Split into test and train 86 | if val_fraction > 0 and labels is not None: 87 | 88 | n_training_imgs = int(data.shape[0]-np.floor(data.shape[0]*val_fraction)) 89 | 90 | train_data = data[0:n_training_imgs,...] 91 | train_labels = labels[0:n_training_imgs,...] 92 | 93 | test_data = data[n_training_imgs:,...] 94 | test_labels = labels[n_training_imgs:,...] 95 | 96 | # Create folders 97 | train_folder = os.path.join(output_path,"train") 98 | if not os.path.exists(train_folder): 99 | os.makedirs(train_folder) 100 | train_name = str(output_name)+"_train" 101 | 102 | test_folder = os.path.join(output_path,"test") 103 | if not os.path.exists(test_folder): 104 | os.makedirs(test_folder) 105 | test_name = str(output_name)+"_test" 106 | 107 | # Save train data 108 | train_name = str(output_name)+"_train" 109 | 110 | train_path, train_header = tube.save_as_dask_array(train_data, labels=train_labels, 111 | output_path=train_folder, 112 | output_name=train_name, 113 | chunks=chunks) 114 | print("Processed training data and header files saved to "+str(train_path)) 115 | 116 | # Save test data 117 | test_path, test_header = tube.save_as_dask_array(test_data, labels=test_labels, 118 | output_path=test_folder, 119 | output_name=test_name, 120 | chunks=chunks) 121 | print("Processed test data and header files saved to "+str(test_path)) 122 | 123 | else: 124 | save_path, save_header = tube.save_as_dask_array(data, labels=labels, 125 | output_path=output_path, 126 | output_name=output_name, 127 | chunks=chunks) 128 | print("Processed data and header files saved to "+str(save_path)) 129 | 130 | def parse_chunks(values): 131 | if len(values) == 1: 132 | return (values[0], values[0], values[0]) 133 | elif len(values) == 3: 134 | return tuple(values) 135 | else: 136 | raise argparse.ArgumentTypeError( 137 | "Chunks must be either a single value (e.g. --chunks 64) " 138 | "or three values (e.g. --chunks 64 64 32).") 139 | 140 | if __name__ == "__main__": 141 | parser = argparse.ArgumentParser(description="Preprocess image and label datasets for TubeNet.") 142 | parser.add_argument("--image_directory", type=str, required=True, 143 | help="Path to image file or directory") 144 | parser.add_argument("--label_directory", type=str, default=None, 145 | help="Path to label file or directory. Set to None if not using labels.") 146 | parser.add_argument("--output_path", type=str, required=True, 147 | help="Directory where processed data will be saved") 148 | parser.add_argument("--chunks", type=int, nargs="+", default=[64, 64, 64], 149 | help="Chunk size for saving zarr files. " 150 | "Provide 1 value (isotropic) or 3 values (anisotropic). " 151 | "E.g. --chunks 64 OR --chunks 64 64 32") 152 | parser.add_argument("--val_fraction", type=float, default=0.0, 153 | help="Fraction of data to use for validation (0-1)") 154 | parser.add_argument("--crop", action='store_true', 155 | help="Enable cropping if there are large background sections with no vessels") 156 | 157 | 158 | args = parser.parse_args() 159 | args.chunks = parse_chunks(args.chunks) #create tuple of values for chunk dimensions 160 | 161 | main(args) -------------------------------------------------------------------------------- /tUbeNet_classes.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """tUbeNet 3D 3 | U-Net based CNN for vessel segmentation 4 | 5 | Developed by Natalie Holroyd (UCL) 6 | """ 7 | 8 | #Import libraries 9 | import numpy as np 10 | import math 11 | import random 12 | import pickle 13 | import os 14 | join = os.path.join 15 | 16 | import io 17 | from matplotlib import pyplot as plt 18 | from scipy.ndimage import rotate, zoom 19 | import tensorflow as tf 20 | import dask.array as da 21 | 22 | from tensorflow.keras.utils import Sequence, to_categorical #np_utils 23 | #--------------------------------------------------------------------------------------------------------------------------------------------- 24 | class DataHeader: 25 | def __init__(self, ID=None, image_dims=(1024,1024,1024), image_filename=None, label_filename=None): 26 | 'Initialization' 27 | self.ID = ID 28 | self.image_dims = image_dims 29 | self.image_filename = image_filename 30 | self.label_filename = label_filename 31 | def save(self, filename): 32 | with open(filename, 'wb') as f: 33 | pickle.dump(self, f, protocol=pickle.HIGHEST_PROTOCOL) 34 | 35 | class DataDir: 36 | def __init__(self, list_IDs, image_dims=(1024,1024,1024), image_filenames=None, label_filenames=None, data_type='float64', exclude_region=None): 37 | 'Initialization' 38 | self.image_dims = image_dims 39 | self.image_filenames = image_filenames 40 | self.label_filenames = label_filenames 41 | self.list_IDs = list_IDs 42 | self.data_type = data_type 43 | self.exclude_region = exclude_region 44 | 45 | class DataGenerator(Sequence): 46 | def __init__(self, data_dir, batch_size=32, volume_dims=(64,64,64), shuffle=True, n_classes=2, 47 | dataset_weighting=None, augment=False, vessel_threshold=0.001, **kwargs): 48 | 'Initialization' 49 | super().__init__(**kwargs) 50 | 51 | self.volume_dims = volume_dims 52 | self.batch_size = batch_size 53 | self.shuffle = shuffle 54 | self.data_dir = data_dir 55 | self.on_epoch_end() 56 | self.n_classes = n_classes 57 | self.dataset_weighting = dataset_weighting 58 | self.augment = augment 59 | self.vessel_threshold = vessel_threshold 60 | 61 | # Open zarr arrays 62 | self._images = [da.from_zarr(p) for p in self.data_dir.image_filenames] 63 | self._labels = [da.from_zarr(p) for p in self.data_dir.label_filenames] 64 | 65 | def __len__(self): 66 | 'Denotes the max number of batches per epoch' 67 | batches=0 68 | for i in range(len(self.data_dir.list_IDs)): 69 | batches_per_dataset = int(np.floor(np.prod(self.data_dir.image_dims[i])/np.prod(self.volume_dims))) 70 | batches += batches_per_dataset 71 | return batches 72 | 73 | def __getitem__(self, index): 74 | 'Generate one batch of data' 75 | # randomly generate list of IDs for batch, weighted according to given 'dataset_weighting' if not None 76 | if len(self.data_dir.list_IDs)>2: 77 | list_IDs_temp = random.choices(self.data_dir.list_IDs, weights=self.dataset_weighting, k=self.batch_size) 78 | else: list_IDs_temp=[self.data_dir.list_IDs[0]]*self.batch_size 79 | # Generate data 80 | X, y = self.__data_generation(list_IDs_temp) 81 | if self.augment: 82 | self._augmentation(X,y) 83 | 84 | # Reshape to add depth of 1, one hot encode labels 85 | X = X.reshape(*X.shape, 1) 86 | y = to_categorical(y, num_classes=self.n_classes) 87 | return X, y 88 | 89 | def on_epoch_end(self): 90 | 'Updates indexes after each epoch' 91 | self.indexes = np.arange(len(self.data_dir.list_IDs)) 92 | if self.shuffle == True: 93 | np.random.shuffle(self.indexes) 94 | 95 | 96 | def random_coordinates(self, image_dims, exclude_region): 97 | coords=np.zeros(3) 98 | for ax in range(3): 99 | coords[ax] = random.randint(0,(image_dims[ax]-self.volume_dims[ax])) 100 | if exclude_region[ax] is not None: 101 | exclude = range(exclude_region[ax][0]-self.volume_dims[ax], exclude_region[ax][1]) 102 | while coords[ax] in exclude: # if coordinate falls in excluded region, generate new coordinate 103 | coords[ax] = random.randint(0,(image_dims[ax]-self.volume_dims[ax])) 104 | 105 | return coords 106 | 107 | def __data_generation(self, list_IDs_temp): 108 | 'Generates data containing batch_size samples' # X : (n_samples, *dim, n_channels) 109 | # Initialization 110 | X = np.empty((self.batch_size, *self.volume_dims)) 111 | y = np.empty((self.batch_size, *self.volume_dims)) 112 | for i, ID_temp in enumerate(list_IDs_temp): 113 | index=self.data_dir.list_IDs.index(ID_temp) 114 | 115 | X_da = self._images[index] 116 | y_da = self._labels[index] 117 | 118 | vessels_present=False 119 | count=0 120 | while not vessels_present: 121 | #Generate random coordinates within dataset 122 | count+=1 123 | z0, x0, y0 = self.random_coordinates(self.data_dir.image_dims[index], 124 | self.data_dir.exclude_region[index]) 125 | dz, dx, dy = self.volume_dims 126 | #Load labels at coordinates 127 | y_slice = y_da[z0:z0+dz, x0:x0+dx, y0:y0+dy] 128 | y_slice = y_slice.compute() # brings just this sub-volume to RAM as np.array 129 | 130 | #Check fraction of pixels classed as vessel in labels before loading in image data 131 | frac = y_slice.astype(bool).mean() 132 | if frac>self.vessel_threshold or count>5: vessels_present=True 133 | if vessels_present: 134 | X_slice = X_da[z0:z0+dz, x0:x0+dx, y0:y0+dy] 135 | X_slice = X_slice.compute() 136 | 137 | X[i]=X_slice.astype(np.float32) 138 | y[i]=y_slice.astype(np.int32) 139 | return X, y 140 | 141 | def _augmentation(self, X, y): 142 | # Apply data augmentations to each image/label pair in batch 143 | for i in range(self.batch_size): 144 | #Rotate 145 | angle = np.random.uniform(-30,30, size=1) 146 | X[i] = rotate(X[i], float(angle), reshape=False, order=3, mode='reflect') 147 | y[i] = rotate(y[i], float(angle), reshape=False, order=0, mode='reflect') 148 | #Zoom and crop 149 | scale = np.random.uniform(1.0,1.25, size=1) 150 | Xzoom = zoom(X[i], float(scale), order=3, mode='reflect') 151 | yzoom = zoom(y[i], float(scale), order=0, mode='reflect') 152 | (d,h,w)=X[i].shape 153 | (dz,hz,wz)=Xzoom.shape 154 | dz=int((dz-d)//2) 155 | hz=int((hz-h)//2) 156 | wz=int((wz-w)//2) 157 | X[i]=Xzoom[dz:int(dz+d), hz:int(hz+h), wz:int(wz+w)] 158 | y[i]=yzoom[dz:int(dz+d), hz:int(hz+h), wz:int(wz+w)] 159 | #Flip 160 | #NB: do not flip in z axis due to asymmetric PSF in HREM data 161 | axes = np.random.randint(4, size=1) 162 | if axes==0: 163 | #flip in x axis 164 | X[i] = np.flip(X[i],1) 165 | y[i] = np.flip(y[i],1) 166 | elif axes==1: 167 | #flip in y axis 168 | X[i] = np.flip(X[i],2) 169 | y[i] = np.flip(y[i],2) 170 | elif axes==2: 171 | #flip in x and y axis 172 | X[i] = np.flip(X[i],(1,2)) 173 | y[i] = np.flip(y[i],(1,2)) 174 | #if axes==3, no flip 175 | return X, y 176 | 177 | 178 | class MetricDisplayCallback(tf.keras.callbacks.Callback): 179 | 180 | def __init__(self,log_dir=None): 181 | super().__init__() 182 | self.log_dir = log_dir # directory where logs are saved 183 | self.file_writer = tf.summary.create_file_writer(log_dir) 184 | 185 | def on_epoch_end(self, epoch, logs={}): 186 | # have tf log custom metrics and save to file 187 | with self.file_writer.as_default(): 188 | for k,v in zip(logs.keys(),logs.values()): 189 | # iterate through monitored metrics (k) and values (v) 190 | tf.summary.scalar(k, v, step=epoch) 191 | 192 | class ImageDisplayCallback(tf.keras.callbacks.Callback): 193 | 194 | def __init__(self, generator, log_dir=None, index=0): 195 | super().__init__() 196 | self.x = None 197 | self.y = None 198 | self.pred = None 199 | self.data_generator = generator #data generator 200 | self.log_dir = log_dir # directory where logs are saved 201 | self.file_writer = tf.summary.create_file_writer(log_dir) 202 | self.index=index 203 | 204 | def on_epoch_end(self, epoch, logs={}): 205 | self.x, self.y = self.data_generator.__getitem__(self.index) 206 | self.pred = self.model.predict(self.x) 207 | 208 | x_shape=self.x.shape 209 | z_centre = int(x_shape[1]/2) 210 | img = self.x[0,z_centre,:,:,:] #take centre slice in z-stack 211 | labels = np.reshape(np.argmax(self.y[0,z_centre,:,:,:], axis=-1),(x_shape[1],x_shape[2],1)) #reverse one hot encoding 212 | pred = np.reshape(np.argmax(self.pred[0,z_centre,:,:,:], axis=-1),(x_shape[1],x_shape[2],1)) #reverse one hot encoding 213 | img = tf.convert_to_tensor(img,dtype=tf.float32) 214 | labels = tf.convert_to_tensor(labels,dtype=tf.float32) 215 | pred = tf.convert_to_tensor(pred,dtype=tf.float32) 216 | with self.file_writer.as_default(): 217 | tf.summary.image("Example output", [img, labels, pred], step=epoch) 218 | 219 | 220 | class FilterDisplayCallback(tf.keras.callbacks.Callback): 221 | 222 | def __init__(self,log_dir=None): 223 | super().__init__() 224 | self.log_dir = log_dir # directory where logs are saved 225 | self.file_writer = tf.summary.create_file_writer(log_dir) 226 | 227 | def find_grid_dims(self, n): 228 | # n = number of filters 229 | # Starting at sqrt(n), check if i is a factor, if yes use i and n/i as rows/columns respectively 230 | for i in range(int(np.sqrt(float(n))), 0, -1): 231 | if n % i == 0: #NB is i==1, you have a prime number of filters.. 232 | return (i, int(n / i)) 233 | 234 | def make_grid(self, n, filters): 235 | # n = number of filters 236 | (rows, columns)=self.find_grid_dims(n) 237 | # normalize filter between 0-1 238 | f_min, f_max = filters.min(), filters.max() 239 | filters = (filters - f_min) / (f_max - f_min) 240 | cz = int(math.ceil(filters.shape[0]/2)) 241 | fig = plt.figure() 242 | index=1 243 | for i in range(rows): 244 | for j in range(columns): 245 | plt.subplot(rows, columns, index) 246 | plt.xticks([]) #no ticks 247 | plt.yticks([]) 248 | plt.grid(False) 249 | plt.imshow(filters[cz,:,:,0,index-1]) #plot central slice of 3D filter 250 | index=index+1 251 | 252 | return fig 253 | 254 | def plot_to_img(self, plot): 255 | buf = io.BytesIO() 256 | plt.savefig(buf, format='png') 257 | plt.close(plot) 258 | buf.seek(0) 259 | image = tf.image.decode_png(buf.getvalue(), channels=4) 260 | image = tf.expand_dims(image, 0) 261 | return image 262 | 263 | def on_epoch_end(self, epoch, logs={}): 264 | # visualise filters for conv1 layer 265 | layer=self.model.layers[1] #first block after input layer 266 | # get filter weights 267 | filters = layer.get_weights()[0] #first conv layer only 268 | n = filters.shape[-1] #number of filters 269 | plot = self.make_grid(n, filters) 270 | image = self.plot_to_img(plot) 271 | 272 | with self.file_writer.as_default(): 273 | tf.summary.image("Convolution 1 filters from layer "+str(layer.name), image, step=epoch) 274 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | [![DOI](https://zenodo.org/badge/187050295.svg)](https://doi.org/10.5281/zenodo.15683547) [![DOI](https://img.shields.io/badge/DOI-10.1093%2Fbiomethods%2Fbpaf087-blue)](https://doi.org/10.1093/biomethods/bpaf087) [![License: MIT](https://img.shields.io/badge/License-MIT-yellow.svg)](https://opensource.org/licenses/MIT) 2 | # tUbeNet 3 | tUbeNet is a 3D convolutional neural network (CNN) for semantic segmentation of vasculature from 3D grayscale medical images. It was trained on varied data across different modalities, scales and pathologies, to create a generalisable foundation model, which can be fine-tuned to new images with a minimal additional training ([Paper here](https://doi.org/10.1093/biomethods/bpaf087)). 4 | 5 | * Download pretrained weights [here](https://doi.org/10.5522/04/25498603.v2). 6 | * The original training/test data can be found [here](https://doi.org/10.5522/04/25715604.v1). 7 | * Contact: natalie.holroyd.16@ucl.ac.uk for questions, troubleshooting and tips! 8 | 9 | ![github_fig](https://github.com/natalie11/tUbeNet/assets/30265332/49dde486-2e54-41e1-98cc-f83f6f910688) 10 | 11 | ## Installation 12 | tUbetnet uses Python 3.11 and Tensorflow 2.20. You can create an environment for running tUbenet using **pip** or **conda**. 13 | 14 | ### Option 1: Conda 15 | First install anaconda or miniconda following the instructions [here](https://www.anaconda.com/docs/getting-started/anaconda/install). 16 | You can then create a new virtual environment using the .yml file included in this repository by running this command in your command prompt. 17 | ``` 18 | # Create environment from YAML 19 | conda env create -f tubenet_env.yml 20 | # Activate environment 21 | conda activate tubenet 22 | ``` 23 | 24 | ### Option 2: Pip 25 | Create a virtual environment using venv and then install all the required libraries using pip. 26 | ``` 27 | # Create Environment 28 | python -m venv '\path\to\environment' 29 | # Activate Environment 30 | '\path\to\environment\Scripts\activate.bat' 31 | # Install requirements 32 | pip install -r requirements.txt 33 | ``` 34 | 35 | **Note on GPU usage:** tUbenet has been tested with CUDA 12.9 and cudnn 9.3 (pinned in requirements.txt and tubenet_env.yml). These versions are compatible with Nvidia GPUs with the Pascal microachritecture (e.g. GeForce GTX 10 series) and newer. GPU users will need a Nvidia driver >=525.60.13. GPU running is not supported on Windows or MacOS. On memory usage: the pre-trained model was trained on two 8 GB Nvidia GeForce GTX 1080 GPUs, but tUbeNet is also compatible with single GPU training. Peak memory usage was measured at 5.49 GB when training on a single GPU. Inference time was 222 ms per \numproduct{64x64x64} volume when run on a single GPU. 36 | 37 | **Note for Windows/ MacOS users:** TensorFlow no longer supports GPU usage on Windows or Mac. You can still run tUbnet with CPU only - just make sure you DELETE the Nvidia packaged from requirements.txt / tubenet_env.yml before installing. Or see [tensorflow's website](https://www.tensorflow.org/install/pip#windows-wsl2) for instructions on using Windows Subsystem for Linux (WSL) to allow GPU utilization on a Windows machine. 38 | 39 | ## How to use 40 | 41 | ### Workflow 42 | 43 | tUbeNet is organized into four callable scripts: 44 | 45 | * preprocessing.py → Prepares raw data into model-ready Zarr format. 46 | * train.py → Train a new model or fine-tune the pretrained one. 47 | * test.py → Evaluate a trained model on labeled data (with ROC analysis). 48 | * predict.py → Run inference on unlabeled data and save segmentations. 49 | 50 | Small volumes of OCT-A imaging data (OCTA-Data.tif) and paired manual labels (OCTA-Labels.tif) are provided to enable quick testing of the model to confirm successful installation. 51 | 52 | ### Preparing data 53 | This step converts raw image volumes (.tif/.nii) (and optional binary labels) into Zarr format with header files that can be read by the train/test/predict scripts. You can run this script on an individual image or a folder of images. 54 | 55 | The zarr format allows individual chunks of an image to be read from the disk, making processing training and inference much more memory efficient. You can set the chunck size yourself (as below) or use the default size of 64 x 64 x 64 pixels. This script can also optionally crop your images based on the labels provided - creating a subvolume that contains all the labelled vessels while trimming image regions devoid of vessels. Finally, using 'val_fraction' you can optionally chose a proportion of each image volume to reserve for validation. 56 | 57 | With labels (and optional validation data split, cropping): 58 | ``` 59 | python preprocessing.py \ 60 | --image_directory '\path\to\images' \ 61 | --label_directory '\path\to\labels' \ 62 | --output_path 'path\to\processed' \ 63 | --chunks 64 \ 64 | --val_fraction 0.2 \ 65 | --crop 66 | ``` 67 | 68 | Without labels (prediction only): 69 | ``` 70 | python preprocessing.py \ 71 | --image_directory '\path\to\images' \ 72 | --output_path 'path\to\processed' \ 73 | --chunks 64 74 | ``` 75 | 76 | Key arguments: 77 | 78 | ```--image_directory``` → Path to raw images (file or folder). 79 | 80 | ```--label_directory``` → Path to labels (optional). 81 | 82 | ```--output_path``` → Where processed Zarr + header files are saved. 83 | 84 | ```--chunks``` → Patch size for saving (default: 64 64 64). 85 | 86 | ```--val_fraction``` → Split fraction for validation (0–1). 87 | 88 | ```--crop``` → Crop background regions without vessels. 89 | 90 | ### Training and Fine-tuning 91 | Run train.py to train from scratch or fine-tune a pretrained model. Training can be run with out without validation data. During training, batches of image subvolumes (64x64x64 pixels) with be generated - the steps_per_epoch argument sets the number of batches generated per training epoch. By providing pre-trained model weights and using the '--fine_tuning' flag, you can fine tune our existing model to your own data. Updated model weights will be saved to the model path provided. Predicted labels and evaluation metrics for the validation data (Receiver Operating Characteristic Curve and Precision Recall Curve - only if validation data was provided) will be saved to the provided output path. 92 | 93 | Train from scratch: 94 | ``` 95 | python train.py \ 96 | --data_headers 'path\to\train\headers' \ 97 | --val_headers 'path\to\test\headers' \ 98 | --model_path 'path\to\model_output' \ 99 | --output_path 'path\to\prediction' \ 100 | --n_epochs 100 \ 101 | --steps_per_epoch 200 \ 102 | --batch_size 6 \ 103 | --loss "DICE BCE" \ 104 | --lr0 0.0005 105 | ``` 106 | 107 | Fine-tune a pretrained model: 108 | ``` 109 | python train.py \ 110 | --data_headers 'path\to\train\headers' \ 111 | --val_headers 'path\to\test\headers' \ 112 | --model_path 'path\to\model_output' \ 113 | --model_weights_file 'path\pretrained_model.weights.h5' \ 114 | --output_path 'path\to\prediction' \ 115 | --n_epochs 50 \ 116 | --steps_per_epoch 200 \ 117 | --fine_tune 118 | ``` 119 | 120 | Key arguments: 121 | 122 | ```--data_headers``` → Folder containing headers for training data (generated from preprocessing). 123 | 124 | ```--val_headers``` → Folder containing headers for validation data (optional). 125 | 126 | ```--model_path``` → Where models and logs are saved. 127 | 128 | ```--model_weights_file``` → Pretrained model weights (optional). 129 | 130 | ```--n_epochs```, ```--steps_per_epoch```, ```--batch_size``` → Epochs, batches per epoch and batch size respectively. 131 | 132 | ```--loss``` → Loss function - chose from "DICE BCE" (recommended), "focal", "WCCE" (Weighted Categorical CrossEntropy). See the tubenet preprint for details on loss functions. 133 | 134 | ```--lr0``` → Initial learning rate. 135 | 136 | ```--class_weights``` → Weights of background to vessels - only relevant when using Weighted Categorical CrossEntropy loss (WCCE). 137 | 138 | ```--fine_tune``` → Enables fine-tuning by frezzing the first 2 encoding blocks and replacing the classifier layer. 139 | 140 | ```--volume_dims``` → Input patch size (default: 64 64 64). 141 | 142 | ```--attention``` → Enable attention blocks in place of skips (experimental). 143 | 144 | #### Monitoring training 145 | Training logs can be viewed in TensorBoard using ```tensorboard --logdir path\to\model_output\logs```. 146 | 147 | ### Testing 148 | 149 | Use test.py to evaluate a trained model on labeled test data. This will generate ROC and Precision Recall Curve graphs, as well as labelled images in tiff and Zarr format. 150 | 151 | ``` 152 | python test.py \ 153 | --data_headers 'path\to\data\headers' \ 154 | --model_path 'path\pretrained_model.weights.h5' \ 155 | --output_path 'path\to\prediction' \ 156 | --volume_dims 64 \ 157 | --overlap 32 \ 158 | --binary_output 159 | ``` 160 | 161 | Key arguments: 162 | 163 | ```--data_headers``` → Folder containing headers for test data (generated from preprocessing). 164 | 165 | ```--model_weights_file``` → Trained model weights 166 | 167 | ```--output_path``` → Folder where predictions and evaluation outputs will be saved 168 | 169 | ```--volume_dims``` → The size of image chunks passed to the model for inference (default: 64 64 64) Note: this sould match the chunk size passed to the model during training 170 | 171 | ```--overlap``` → Labels are predicted on overlapping image chunks and averaged (to avoid boundary artefacts). The overlap should be approximately half of the chunk volume, but can be reduced (to speed up inference time) or increased as desired. 172 | 173 | ```--binary_output``` → Use this flag to save label predictions as binary images. Otherwise, the softmax output from the final model layer with be saved. 174 | 175 | ### Predicting on Unlabelled Data 176 | 177 | Use predict.py for running inference (label predicition) on new data without labels. Predicted labels will be saved in zarr format, and optionally as 3D tiff images. Use the --binary_output flag to save label predictions as binary images. Otherwise, the softmax output from the final model layer with be saved (values between 0 and 1, with values closer to 1 implying higher likelyhood of the pixel belonging to a vessel). The softmax output is often be useful for identifying areas of the image that the model is struggling to classify, and allows you to set your own threshold for classifying vessles. 178 | 179 | ``` 180 | python predict.py \ 181 | --data_headers 'path\to\data\headers' \ 182 | --model_path 'path\pretrained_model.weights.h5' \ 183 | --output_path 'path\to\prediction' \ 184 | --tiff_path 'path\to\tiff_outputs' \ 185 | --volume_dims 64 64 64 \ 186 | --overlap 32 32 32 \ 187 | --binary_output \ 188 | --preview 189 | ``` 190 | 191 | Key arguments: 192 | 193 | ```--data_headers``` → Folder containing headers for data (generated from preprocessing). 194 | 195 | ```--model_weights_file``` → Trained model weights 196 | 197 | ```--output_path``` → Folder where predictions will be saved in zarr format 198 | 199 | ```--tiff_path``` → Folder where predictions will be saved as 3D tiff images (optional) 200 | 201 | ```--volume_dims``` → The size of image chunks passed to the model for inference (default: 64 64 64) Note: this sould match the chunk size passed to the model during training 202 | 203 | ```--overlap``` → Labels are predicted on overlapping image chunks and averaged (to avoid boundary artefacts). The overlap should be approximately half of the chunk volume, but can be reduced (to speed up inference time) or increased as desired. 204 | 205 | ```--binary_output``` → Use this flag to save label predictions as binary images. Otherwise, the softmax output from the final model layer with be saved. 206 | Zarr segmentations in --output_path. 207 | 208 | ```--preview``` → Use this flag to save prediction previews at regular intervals throughout inference. This is useful for checking the the model prediction is sensible without having to wait for the entire image to be processed. 209 | 210 | ## Citing 211 | If you use this model in any published work, please cite our [paper](https://doi.org/10.1093/biomethods/bpaf087). 212 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | """ 4 | Created on Wed Sep 3 16:21:26 2025 5 | 6 | @author: natalie 7 | """ 8 | 9 | #Import libraries 10 | import os 11 | os.environ['TF_CPP_MIN_LOG_LEVEL'] = '1' #Suppress info logs from tf 12 | import pickle 13 | import datetime 14 | import argparse 15 | from model import tUbeNet 16 | import tUbeNet_functions as tube 17 | from tUbeNet_classes import DataDir, DataGenerator, ImageDisplayCallback, MetricDisplayCallback, FilterDisplayCallback 18 | from tensorflow.keras.callbacks import ModelCheckpoint, TensorBoard 19 | 20 | def main(args): 21 | """Set parameters and file paths:""" 22 | # Model paramters 23 | volume_dims = args.volume_dims 24 | n_epochs = args.n_epochs 25 | steps_per_epoch = args.steps_per_epoch 26 | batch_size = args.batch_size 27 | dataset_weighting = args.dataset_weighting 28 | loss = args.loss 29 | lr0 = args.lr0 30 | class_weights = args.class_weights 31 | n_classes = 2 #TO DO expand to handel multi-class case 32 | 33 | # Training and prediction options 34 | fine_tune = args.fine_tune 35 | binary_output = args.binary_output 36 | augment = args.no_augment 37 | attention = args.attention 38 | 39 | """ Paths and filenames """ 40 | # Training data 41 | data_headers = args.data_headers 42 | 43 | # Validation data 44 | val_headers = args.val_headers 45 | 46 | # Model 47 | model_path = args.model_path 48 | model_weights_file = args.model_weights_file 49 | 50 | # Image output 51 | output_path = args.output_path 52 | 53 | #---------------------------------------------------------------------------------------------------------------------------------------------- 54 | """ Create Data Directory""" 55 | # Load data headers into a list 56 | header_filenames=[f for f in os.listdir(data_headers) if os.path.isfile(os.path.join(data_headers, f))] 57 | headers = [] 58 | try: 59 | for file in header_filenames: #Iterate through header files 60 | file=os.path.join(data_headers,file) 61 | with open(file, "rb") as f: 62 | data_header = pickle.load(f) # Unpickle DataHeader object 63 | headers.append(data_header) # Add to list of headers 64 | except FileNotFoundError: print("Unable to load data header files from {data_headers}") 65 | 66 | # Create empty data directory 67 | data_dir = DataDir([], image_dims=[], 68 | image_filenames=[], 69 | label_filenames=[], 70 | data_type=[], exclude_region=[]) 71 | 72 | # Fill directory from headers 73 | for header in headers: 74 | data_dir.list_IDs.append(header.ID) 75 | data_dir.image_dims.append(header.image_dims) 76 | data_dir.image_filenames.append(header.image_filename) 77 | data_dir.label_filenames.append(header.label_filename) 78 | data_dir.data_type.append('float32') 79 | data_dir.exclude_region.append((None,None,None)) #region to be left out of training for use as validation data (under development) 80 | 81 | """ Create Data Generator """ 82 | params = {'batch_size': batch_size, 83 | 'volume_dims': volume_dims, 84 | 'n_classes': n_classes, 85 | 'dataset_weighting': dataset_weighting, 86 | 'augment':augment, 87 | 'shuffle': False} 88 | 89 | data_generator=DataGenerator(data_dir, **params) 90 | 91 | """ Load or Build Model """ 92 | tubenet = tUbeNet(n_classes=n_classes, input_dims=volume_dims, attention=attention) 93 | 94 | if model_weights_file is not None: 95 | # Load exisiting model with or without fine tuning adjustment (fine tuning -> classifier replaced and first 2 blocks frozen) 96 | if not os.path.isfile(model_weights_file): 97 | if os.path.isfile(os.path.join(model_path, model_weights_file)): 98 | model_weights_file=os.path.join(model_path, model_weights_file) 99 | else: 100 | raise FileNotFoundError("Could not locate model weights file at {}".format(model_weights_file)) 101 | 102 | model = tubenet.load_weights(filename=os.path.join(model_path,model_weights_file), 103 | loss=loss, 104 | class_weights=class_weights, 105 | learning_rate=lr0, 106 | metrics=['accuracy', 'recall', 'precision', tube.dice], 107 | freeze_layers=2, fine_tune=fine_tune) 108 | 109 | else: 110 | model = tubenet.create(learning_rate=lr0, 111 | loss=loss, 112 | class_weights=class_weights, 113 | metrics=['accuracy', 'recall', 'precision', tube.dice]) 114 | 115 | 116 | """ Train and save model """ 117 | 118 | # Create folder for log files 119 | date = datetime.datetime.now() 120 | filepath = os.path.join(model_path,"{}_model_checkpoint.weights.h5".format(date.strftime("%d%m%y"))) 121 | log_dir = os.path.join(model_path,'logs') 122 | if not os.path.exists(log_dir): 123 | os.makedirs(log_dir) 124 | 125 | # Define callbacks 126 | if val_headers is not None: 127 | monitored_metric='val_loss' 128 | else: 129 | monitored_metric='loss' 130 | checkpoint = ModelCheckpoint(filepath, monitor=monitored_metric, verbose=1, save_weights_only=True, save_best_only=True, mode='max') 131 | tbCallback = TensorBoard(log_dir=log_dir, histogram_freq=1, write_graph=False, write_images=True) 132 | imageCallback = ImageDisplayCallback(data_generator,log_dir=os.path.join(log_dir,'images')) 133 | filterCallback = FilterDisplayCallback(log_dir=os.path.join(log_dir,'filters')) #experimental 134 | metricCallback = MetricDisplayCallback(log_dir=log_dir) 135 | 136 | # Create directory of validation data 137 | if val_headers is not None: 138 | # Import data header 139 | header_filenames=[f for f in os.listdir(val_headers) if os.path.isfile(os.path.join(val_headers, f))] 140 | headers = [] 141 | try: 142 | for file in header_filenames: #Iterate through header files 143 | file=os.path.join(val_headers,file) 144 | with open(file, "rb") as f: 145 | val_header = pickle.load(f) # Unpickle DataHeader object 146 | headers.append(val_header) # Add to list of headers 147 | except FileNotFoundError: print("Unable to load data header files from {val_headers}") 148 | 149 | # Create empty data directory 150 | val_dir = DataDir([], image_dims=[], 151 | image_filenames=[], 152 | label_filenames=[], 153 | data_type=[], exclude_region=[]) 154 | 155 | # Fill directory from headers 156 | for header in headers: 157 | val_dir.list_IDs.append(header.ID) 158 | val_dir.image_dims.append(header.image_dims) 159 | val_dir.image_filenames.append(header.image_filename) 160 | val_dir.label_filenames.append(header.label_filename) 161 | val_dir.data_type.append('float32') 162 | val_dir.exclude_region.append((None,None,None)) 163 | 164 | 165 | vparams = {'batch_size': batch_size, 166 | 'volume_dims': volume_dims, 167 | 'n_classes': n_classes, 168 | 'dataset_weighting': None, 169 | 'augment': False, 170 | 'shuffle': False} 171 | 172 | val_generator=DataGenerator(val_dir, **vparams) 173 | 174 | # TRAIN with validation 175 | history=model.fit(data_generator, validation_data=val_generator, 176 | validation_steps=5, epochs=n_epochs, steps_per_epoch=steps_per_epoch, 177 | callbacks=[checkpoint, tbCallback, imageCallback, filterCallback, metricCallback]) 178 | 179 | else: 180 | # TRAIN without validation 181 | history=model.fit(data_generator, epochs=n_epochs, 182 | steps_per_epoch=steps_per_epoch, 183 | callbacks=[checkpoint, tbCallback, imageCallback, filterCallback, metricCallback]) 184 | 185 | # SAVE MODEL 186 | model.save_weights(os.path.join(model_path,"{}_trained_model.weights.h5".format(date.strftime("%d%m%y")))) 187 | 188 | """ Plot ROC """ 189 | # Evaluate model on validation data 190 | if val_headers is not None: 191 | validation_metrics = tube.roc_analysis(model, val_dir, 192 | volume_dims=volume_dims, 193 | n_classes=n_classes, 194 | output_path=output_path, 195 | binary_output=binary_output) 196 | 197 | def parse_dims(values): 198 | """Parse volume dimensions: allow either one int (isotropic) or three ints (anisotropic).""" 199 | if len(values) == 1: 200 | return (values[0], values[0], values[0]) 201 | elif len(values) == 3: 202 | return tuple(values) 203 | else: 204 | raise argparse.ArgumentTypeError( 205 | "volume_dims must be either a single value (e.g. --volume_dims 64) " 206 | "or three values (e.g. --volume_dims 64 64 32).") 207 | 208 | if __name__ == "__main__": 209 | parser = argparse.ArgumentParser(description="Train TubeNet model.") 210 | 211 | # Data paths 212 | parser.add_argument("--data_headers", type=str, required=True, 213 | help="Path to directory containing training header files.") 214 | parser.add_argument("--val_headers", type=str, default=None, 215 | help="Path to directory containing validation header files (optional).") 216 | parser.add_argument("--model_path", type=str, required=True, 217 | help="Directory where trained model will be saved.") 218 | parser.add_argument("--model_weights_file", type=str, default=None, 219 | help="Filename for pre-trained model weights (ending .h5 or .weights.h5). " 220 | "If unset, the model will be trained from scratch.") 221 | parser.add_argument("--output_path", type=str, required=True, 222 | help="Directory where predictions/analysis outputs will be saved.") 223 | 224 | # Model parameters 225 | parser.add_argument("--volume_dims", type=int, nargs="+", default=[64, 64, 64], 226 | help="Volume dimensions passed to CNN. Provide 1 value (isotropic) or 3 values (anisotropic).") 227 | parser.add_argument("--n_epochs", type=int, default=100, 228 | help="Number of epochs for training.") 229 | parser.add_argument("--steps_per_epoch", type=int, default=100, 230 | help="Number of batches generated per epoch.") 231 | parser.add_argument("--batch_size", type=int, default=6, 232 | help="Batch size.") 233 | parser.add_argument("--dataset_weighting", type=float, nargs="+", default=None, 234 | help="Relative weighting when pulling training data from multiple datasets.") 235 | parser.add_argument("--loss", type=str, default="DICE BCE", 236 | choices=["DICE BCE", "focal", "WCCE"], 237 | help="Loss function.") 238 | parser.add_argument("--lr0", type=float, default=1e-3, 239 | help="Initial learning rate.") 240 | parser.add_argument("--class_weights", type=float, nargs=2, default=[1.0, 1.0], 241 | help="Relative class weights (background, vessels).") 242 | parser.add_argument("--no_augment", action="store_false", 243 | help="Disable data augmentation.") 244 | parser.add_argument("--attention", action="store_true", 245 | help="Enable attention mechanism in model (experimental).") 246 | 247 | # Training options 248 | parser.add_argument("--fine_tune", action="store_true", 249 | help="Enable fine-tuning by freezing shallow layers.") 250 | parser.add_argument("--binary_output", action="store_true", 251 | help="Save predictions as binary image instead of softmax.") 252 | 253 | args = parser.parse_args() 254 | args.volume_dims = parse_dims(args.volume_dims) 255 | main(args) 256 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on Wed Jun 8 13:21:04 2022 4 | 5 | @author: Natal 6 | """ 7 | #Import libraries 8 | import os 9 | from functools import partial 10 | import tUbeNet_functions as tube 11 | 12 | # import required objects and fuctions from keras 13 | from tensorflow.keras.models import Model 14 | # CNN layers 15 | from tensorflow.keras.layers import ( 16 | Input, concatenate, Conv3D, MaxPooling3D, 17 | Conv3DTranspose, LeakyReLU, Dropout, Dense, Flatten, GroupNormalization) 18 | # opimiser 19 | from tensorflow.keras.optimizers import Adam 20 | 21 | # import tensor flow 22 | import tensorflow as tf 23 | 24 | # set backend and dim ordering 25 | K=tf.keras.backend 26 | K.set_image_data_format('channels_last') 27 | 28 | # set memory limit on gpu 29 | physical_devices = tf.config.list_physical_devices('GPU') 30 | try: 31 | for gpu in physical_devices: 32 | tf.config.set_memory_growth(gpu, True) 33 | except: 34 | pass 35 | 36 | 37 | """Model blocks""" 38 | class AttnBlock(tf.keras.layers.Layer): 39 | def __init__(self, channels=32): 40 | super(AttnBlock,self).__init__() 41 | self.Wq = Conv3D(channels, (3, 3, 3), padding='same', kernel_initializer='he_uniform') 42 | self.Wk = Conv3D(channels, (3, 3, 3), padding='same', kernel_initializer='he_uniform') 43 | self.map = Conv3D(channels, (1, 1, 1), activation= 'sigmoid', padding='same', kernel_initializer='he_uniform') 44 | def call (self, query, key): 45 | w_query=self.Wq(query) 46 | w_key=self.Wk(key) 47 | dot_prod=tf.matmul(w_query, w_key, transpose_b=True) 48 | attn_map=self.map(dot_prod) 49 | return attn_map*query 50 | 51 | class EncodeBlock(tf.keras.layers.Layer): 52 | def __init__(self, channels=32, alpha=0.2, dropout=0.3): 53 | super(EncodeBlock,self).__init__() 54 | self.conv1 = Conv3D(channels, (3, 3, 3), activation= 'linear', padding='same', kernel_initializer='he_uniform') 55 | self.conv2 = Conv3D(channels, (3, 3, 3), activation= 'linear', padding='same', kernel_initializer='he_uniform') 56 | self.norm = GroupNormalization(groups=int(channels/4), axis=4) 57 | self.lrelu = LeakyReLU(negative_slope=alpha) 58 | self.pool = MaxPooling3D(pool_size=(2, 2, 2)) 59 | self.dropout = Dropout(dropout) 60 | def call (self, x): 61 | conv1 = self.conv1(x) 62 | activ1 = self.lrelu(conv1) 63 | norm1 = self.norm(activ1) 64 | conv2 = self.conv2(norm1) 65 | activ2 = self.lrelu(conv2) 66 | norm2 = self.norm(activ2) 67 | pool = self.pool(norm2) 68 | drop = self.dropout(pool) 69 | return drop 70 | 71 | class DecodeBlock(tf.keras.layers.Layer): 72 | def __init__(self, channels=32, alpha=0.2): 73 | super(DecodeBlock,self).__init__() 74 | self.transpose = Conv3DTranspose(channels, (2, 2, 2), strides=(2, 2, 2), padding='same', kernel_initializer='he_uniform') 75 | self.conv = Conv3D(channels, (3, 3, 3), activation= 'linear', padding='same', kernel_initializer='he_uniform') 76 | self.norm = GroupNormalization(groups=int(channels/4), axis=4) 77 | self.lrelu = LeakyReLU(negative_slope=alpha) 78 | self.channels = channels 79 | def call (self, skip, x, attention=False): 80 | if attention: 81 | attn = AttnBlock(channels=self.channels)(skip, x) 82 | else: 83 | attn = concatenate([skip, x], axis=4) 84 | transpose = self.transpose(attn) 85 | activ1 = self.lrelu(transpose) 86 | norm1 = self.norm(activ1) 87 | conv = self.conv(norm1) 88 | activ2 = self.lrelu(conv) 89 | norm2 = self.norm(activ2) 90 | return norm2 91 | 92 | class UBlock(tf.keras.layers.Layer): 93 | def __init__(self, channels=32, alpha=0.2): 94 | super(UBlock,self).__init__() 95 | self.conv1 = Conv3D(channels, (3, 3, 3), activation= 'linear', padding='same', kernel_initializer='he_uniform') 96 | self.conv2 = Conv3D(int(channels/2), (3, 3, 3), activation= 'linear', padding='same', kernel_initializer='he_uniform') 97 | self.norm = GroupNormalization(groups=int(channels/4), axis=4) 98 | self.lrelu = LeakyReLU(negative_slope=alpha) 99 | def call (self, x): 100 | conv1 = self.conv1(x) 101 | activ1 = self.lrelu(conv1) 102 | norm1 = self.norm(activ1) 103 | conv2 = self.conv2(norm1) 104 | activ2 = self.lrelu(conv2) 105 | return activ2 106 | 107 | class EncoderOnlyOutput(tf.keras.layers.Layer): 108 | def __init__(self, channels=64, alpha=0.2): 109 | super(EncoderOnlyOutput,self).__init__() 110 | self.flatten = Flatten() 111 | self.dense1 = Dense(channels, activation='linear', kernel_initializer='he_uniform') 112 | self.dense2 = Dense(2, activation='softmax') #classifier 113 | self.lrelu = LeakyReLU(negative_slope=alpha) 114 | def call (self, x): 115 | flatten = self.flatten(x) 116 | dense1 = self.dense1(flatten) 117 | activ1 = self.lrelu(dense1) 118 | dense2 = self.dense2(activ1) 119 | return dense2 120 | 121 | """Build Model""" 122 | class tUbeNet(tf.keras.Model): 123 | def __init__(self, n_classes=2, input_dims=(64,64,64), dropout=0.3, alpha=0.2, attention=False): 124 | super(tUbeNet,self).__init__() 125 | self.n_classes=n_classes 126 | self.input_dims=input_dims 127 | self.dropout=dropout 128 | self.alpha=alpha 129 | self.attention=attention 130 | 131 | def build_model(self, encoder_only=False): 132 | inputs = Input((*self.input_dims, 1)) 133 | 134 | block1 = EncodeBlock(channels=32, alpha=self.alpha, dropout=self.dropout)(inputs) 135 | block2 = EncodeBlock(channels=64, alpha=self.alpha, dropout=self.dropout)(block1) 136 | block3 = EncodeBlock(channels=128, alpha=self.alpha, dropout=self.dropout)(block2) 137 | block4 = EncodeBlock(channels=256, alpha=self.alpha, dropout=self.dropout)(block3) 138 | block5 = EncodeBlock(channels=512, alpha=self.alpha, dropout=self.dropout)(block4) 139 | 140 | block6 = UBlock(channels=1024, alpha=self.alpha)(block5) 141 | 142 | if encoder_only: 143 | output = EncoderOnlyOutput(channels=64, alpha=self.alpha)(block6) 144 | 145 | else: 146 | upblock1 = DecodeBlock(channels=512, alpha=self.alpha)(block5, block6, attention=self.attention) 147 | upblock2 = DecodeBlock(channels=256, alpha=self.alpha)(block4, upblock1, attention=self.attention) 148 | upblock3 = DecodeBlock(channels=128, alpha=self.alpha)(block3, upblock2, attention=self.attention) 149 | upblock4 = DecodeBlock(channels=64, alpha=self.alpha)(block2, upblock3, attention=self.attention) 150 | upblock5 = DecodeBlock(channels=32, alpha=self.alpha)(block1, upblock4, attention=self.attention) 151 | 152 | output = Conv3D(self.n_classes, (1, 1, 1), activation='softmax')(upblock5) 153 | 154 | model = Model(inputs=inputs, outputs=output) 155 | return model 156 | 157 | def selectLoss(self, loss_name, class_weights=None): 158 | """select loss from custom losses""" 159 | if loss_name == 'WCCE': 160 | custom_loss=partial(tube.weighted_crossentropy, weights=class_weights) 161 | custom_loss.__name__ = "custom_loss" #partial doesn't cope name or module attribute from function 162 | custom_loss.__module__ = tube.weighted_crossentropy.__module__ 163 | elif loss_name == 'DICE BCE': 164 | custom_loss=partial(tube.DiceBCELoss,smooth=1e-6) 165 | custom_loss.__name__ = "custom_loss" #partial doesn't cope name or module attribute from function 166 | custom_loss.__module__ = tube.DiceBCELoss.__module__ 167 | elif loss_name == 'focal': 168 | custom_loss=tf.keras.losses.CategoricalFocalCrossentropy(alpha=0.2, gamma=5) 169 | else: 170 | print('Loss not recognised, using categorical crossentropy') 171 | custom_loss='categorical_crossentropy' 172 | return custom_loss 173 | 174 | def create(self, loss=None, class_weights=(1,1), learning_rate=1e-3, metrics=['accuracy'], encoder_only=False): 175 | custom_loss = self.selectLoss(loss,class_weights) 176 | 177 | #Check for multiple GPUs 178 | physical_devices = tf.config.list_physical_devices('GPU') 179 | n_gpus=len(physical_devices) 180 | if n_gpus >1: 181 | strategy = tf.distribute.MirroredStrategy() 182 | print("Creating model on {} GPUs".format(n_gpus)) 183 | with strategy.scope(): 184 | model = self.build_model(encoder_only=encoder_only) 185 | model.compile(optimizer=Adam(learning_rate=learning_rate), loss=custom_loss, metrics=metrics) 186 | else: 187 | model = self.build_model(encoder_only=encoder_only) 188 | model.compile(optimizer=Adam(learning_rate=learning_rate), loss=custom_loss, metrics=metrics) 189 | 190 | print('Model Summary') 191 | model.summary() 192 | return model 193 | 194 | def load_weights(self, filename=None, loss=None, class_weights=(1,1), 195 | learning_rate=1e-5, metrics=['accuracy'], freeze_layers=0, fine_tune=False): 196 | """ Fine Tuning 197 | Replaces classifer layer and freezes shallow layers for fine tuning 198 | Inputs: 199 | filename = path to file containing model weights 200 | freeze_layers = number of layers to freeze for training (int, default 0) 201 | learning_rate = learning rate (float, default 1e-5) 202 | loss = loss function, function or string 203 | metrics = training metrics, list of functions or strings 204 | Outputs: 205 | model = compiled model 206 | """ 207 | 208 | physical_devices = tf.config.list_physical_devices('GPU') 209 | n_gpus=len(physical_devices) 210 | 211 | # create path for file containing weights 212 | if filename is None: 213 | raise ValueError("model weights filename must be provided") 214 | if os.path.isfile(filename+'.h5'): 215 | mfile = filename+'.h5' 216 | elif os.path.isfile(filename+'.hdf5'): 217 | mfile = (filename+'.hdf5') 218 | else: mfile=filename 219 | 220 | custom_loss=self.selectLoss(loss,class_weights) 221 | 222 | if fine_tune: 223 | if n_gpus>1: 224 | strategy = tf.distribute.MirroredStrategy() 225 | print("Creating model on {} GPUs".format(n_gpus)) 226 | with strategy.scope(): 227 | model = self.build_model() 228 | # load weights into new model 229 | model.load_weights(mfile) 230 | 231 | # recover the output from the last layer in the model and use as input to new Classifer 232 | last = model.layers[-2].output 233 | classifier = Conv3D(self.n_classes, (1, 1, 1), activation='softmax', name='newClassifier')(last) 234 | model = Model(inputs=[model.input], outputs=[classifier]) 235 | # freeze weights for selected layers 236 | for layer in model.layers[:freeze_layers]: layer.trainable = False 237 | 238 | model.compile(optimizer=Adam(learning_rate=learning_rate), loss=custom_loss, metrics=metrics) 239 | else: 240 | model = self.build_model() 241 | # load weights into new model 242 | model.load_weights(mfile) 243 | 244 | # recover the output from the last layer in the model and use as input to new Classifer 245 | last = model.layers[-2].output 246 | classifier = Conv3D(self.n_classes, (1, 1, 1), activation='softmax', name='newClassifier')(last) 247 | model = Model(inputs=[model.input], outputs=[classifier]) 248 | # freeze weights for selected layers 249 | for layer in model.layers[:freeze_layers]: layer.trainable = False 250 | 251 | model.compile(optimizer=Adam(learning_rate=learning_rate), loss=custom_loss, metrics=metrics) 252 | 253 | else: 254 | if n_gpus>1: 255 | strategy = tf.distribute.MirroredStrategy() 256 | print("Creating model on {} GPUs".format(n_gpus)) 257 | with strategy.scope(): 258 | model = self.build_model() 259 | # load weights into new model 260 | model.load_weights(mfile) 261 | model.compile(optimizer=Adam(learning_rate=learning_rate), loss=custom_loss, metrics=metrics) 262 | else: 263 | model = self.build_model() 264 | # load weights into new model 265 | model.load_weights(mfile) 266 | model.compile(optimizer=Adam(learning_rate=learning_rate), loss=custom_loss, metrics=metrics) 267 | 268 | print('Model Summary') 269 | model.summary() 270 | return model 271 | 272 | def load_encoder(self, filename=None, loss=None, class_weights=(1,1), 273 | learning_rate=1e-5, metrics=['accuracy']): 274 | """ Fine Tuning 275 | Replaces classifer layer and freezes shallow layers for fine tuning 276 | Inputs: 277 | filename = path to file containing model weights 278 | freeze_layers = number of layers to freeze for training (int, default 0) 279 | learning_rate = learning rate (float, default 1e-5) 280 | loss = loss function, function or string 281 | metrics = training metrics, list of functions or strings 282 | Outputs: 283 | model = compiled model 284 | """ 285 | 286 | physical_devices = tf.config.list_physical_devices('GPU') 287 | n_gpus=len(physical_devices) 288 | # create path for file containing weights 289 | if filename is None: 290 | raise ValueError("model weights filename must be provided") 291 | if os.path.isfile(filename+'.h5'): 292 | mfile = filename+'.h5' 293 | elif os.path.isfile(filename+'.hdf5'): 294 | mfile = (filename+'.hdf5') 295 | else: print("No model weights file found") 296 | 297 | custom_loss=self.selectLoss(loss,class_weights) 298 | 299 | if n_gpus>1: 300 | strategy = tf.distribute.MirroredStrategy() 301 | print("Creating model on {} GPUs".format(n_gpus)) 302 | with strategy.scope(): 303 | # build and load weights into encoder only model 304 | encoder_model = self.build_model(encoder_only=True) 305 | encoder_model.load_weights(mfile) 306 | 307 | # build full model 308 | model = self.build_model() 309 | 310 | # transfer weights for all encoder layers, except dense layer 311 | for i, layer in enumerate(encoder_model.layers[:-1]): 312 | print('Copying weights from', layer.name, 'to', model.layers[i].name) 313 | weights = layer.get_weights() 314 | model.layers[i].set_weights(weights) 315 | 316 | #Compile model 317 | model.compile(optimizer=Adam(learning_rate=learning_rate), loss=custom_loss, metrics=metrics) 318 | 319 | else: 320 | # build and load weights into encoder only model 321 | encoder_model = self.build_model(encoder_only=True) 322 | encoder_model.load_weights(mfile) 323 | 324 | # build full model 325 | model = self.build_model() 326 | 327 | # transfer weights for all encoder layers, except dense layer 328 | for i, layer in enumerate(encoder_model.layers[:-1]): 329 | print('Copying weights from', layer.name, 'to', model.layers[i].name) 330 | weights = layer.get_weights() 331 | model.layers[i].set_weights(weights) 332 | 333 | #Compile model 334 | model.compile(optimizer=Adam(learning_rate=learning_rate), loss=custom_loss, metrics=metrics) 335 | 336 | 337 | print('Model Summary') 338 | model.summary() 339 | return model 340 | 341 | 342 | 343 | 344 | 345 | 346 | 347 | 348 | 349 | 350 | 351 | 352 | 353 | 354 | 355 | 356 | -------------------------------------------------------------------------------- /tUbeNet_functions.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """tUbeNet 3D 3 | U-Net based CNN for vessel segmentation 4 | 5 | Developed by Natalie Holroyd (UCL) 6 | """ 7 | 8 | #Import libraries 9 | import os 10 | import numpy as np 11 | from skimage import io 12 | from sklearn.metrics import roc_curve, auc, average_precision_score, PrecisionRecallDisplay, precision_recall_curve, precision_score, recall_score 13 | import matplotlib.pyplot as plt 14 | import dask.array as da 15 | import zarr 16 | from tqdm import tqdm 17 | from scipy.signal.windows import general_hamming 18 | import tifffile as tiff 19 | 20 | # import tensor flow 21 | import tensorflow as tf 22 | 23 | # set backend and dim ordering 24 | K=tf.keras.backend 25 | K.set_image_data_format('channels_last') 26 | 27 | # set memory limit on gpu 28 | physical_devices = tf.config.list_physical_devices('GPU') 29 | try: 30 | for gpu in physical_devices: 31 | tf.config.experimental.set_memory_growth(gpu, True) 32 | except: 33 | pass 34 | 35 | 36 | #------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- 37 | """Custom metrics""" 38 | # Use when y_true/y_pred are np arrays rather than keras tensors 39 | def precision_logical(y_true, y_pred): 40 | #true positive 41 | TP = np.sum(np.logical_and(np.equal(y_true,1),np.equal(y_pred,1))) 42 | #false positive 43 | FP = np.sum(np.logical_and(np.equal(y_true,0),np.equal(y_pred,1))) 44 | precision1=TP/(TP+FP) 45 | return precision1 46 | 47 | def recall_logical(y_true, y_pred): 48 | #true positive 49 | TP = np.sum(np.logical_and(np.equal(y_true,1),np.equal(y_pred,1))) 50 | #false negative 51 | FN = np.sum(np.logical_and(np.equal(y_true,1),np.equal(y_pred,0))) 52 | recall1=TP/(TP+FN) 53 | return recall1 54 | 55 | # Use when y_treu/ y_pred are keras tensors - for passing to model 56 | def precision(y_true, y_pred): 57 | y_pred = tf.cast(y_pred, tf.float32) # Change tensor dtype 58 | y_true = tf.cast(y_true, tf.float32) 59 | true_positives = K.sum(K.round(K.clip(y_true[...,1] * y_pred[...,1], 0, 1))) 60 | predicted_positives = K.sum(K.round(K.clip(y_pred[...,1], 0, 1))) 61 | precision = true_positives / (predicted_positives + K.epsilon()) 62 | return precision 63 | 64 | def recall(y_true, y_pred): 65 | y_pred = tf.cast(y_pred, tf.float32) # Change tensor dtype 66 | y_true = tf.cast(y_true, tf.float32) 67 | true_positives = K.sum(K.round(K.clip(y_true[...,1] * y_pred[...,1], 0, 1))) 68 | possible_positives = K.sum(K.round(K.clip(y_true[...,1], 0, 1))) 69 | recall = true_positives / (possible_positives + K.epsilon()) 70 | return recall 71 | 72 | def dice(y_true, y_pred): 73 | P = precision(y_true, y_pred) 74 | R = recall(y_true, y_pred) 75 | dice = 2*(P*R)/(P+R+K.epsilon()) 76 | return dice 77 | 78 | """Custom Losses""" 79 | def weighted_crossentropy(y_true, y_pred, weights): 80 | """Custom loss function - weighted to address class imbalance""" 81 | weight_mask = y_true[...,0] * weights[0] + y_true[...,1] * weights[1] 82 | return K.categorical_crossentropy(y_true, y_pred,) * weight_mask 83 | 84 | def DiceBCELoss(y_true, y_pred, smooth=1e-6): 85 | BCE = tf.keras.losses.binary_crossentropy(y_true, y_pred) 86 | dice_loss = 1-dice(y_true, y_pred) 87 | Dice_BCE = (BCE + dice_loss)/2 88 | return Dice_BCE 89 | 90 | #---------------------------INFERENCE------------------------------------------------------------------------------------------------------------------------ 91 | 92 | def predict_segmentation_dask( 93 | model, 94 | image_path, # e.g. "/path/dataset.zarr/image" 95 | out_store, # e.g. "/path/dataset_pred.zarr" (folder will be created) 96 | volume_dims=(64, 64, 64), # (Z,X,Y) 97 | overlap=(16, 16, 16), # (Z,X,Y) overlap (must be < volume_dims) 98 | n_classes=2, # softmax classes produced by model 99 | export_bigtiff=None, # e.g. "/path/dataset_pred.tif" to export 3D TIFF (optional) 100 | preview=False, # Preview segmentation for every slab of subvolumes processed in the z axis 101 | binary_output=False, # if True: Reverrses one hot encoding to give pixel values = class index; else softmax output for foreground channel only 102 | prob_channel=1, # which channel to export if binary_output=False (for 2-class, 1 is foreground prob) 103 | ): 104 | """ 105 | Sliding-window inference with smooth blending. 106 | Writes two Zarr datasets on disk during accumulation: 'sum' and 'wsum'. 107 | Final result is written as 'seg' (either class indices or a single-channel probability). 108 | Optionally, writes a BigTIFF 3D volume without holding everything in RAM. 109 | """ 110 | 111 | # Open image using Dask array and check dimensions 112 | img = da.from_zarr(image_path) # shape (Z,X,Y) or (Z,X,Y,1) 113 | if img.ndim == 4 and img.shape[-1] == 1: 114 | img = img[..., 0] 115 | assert img.ndim == 3, "Expected (Z,X,Y) image" 116 | 117 | # Make all dimensions int 118 | Z, X, Y = map(int, img.shape) 119 | stride = np.array(volume_dims, dtype="int32")-np.array(overlap, dtype="int32") 120 | 121 | # Check for sensible overlap dimensions 122 | if any(stride<0): 123 | raise ValueError("overlap must be less than volume_dims on each axis") 124 | 125 | def auto_pad(img, volume_dims, stride): 126 | img_shape = np.array(img.shape) 127 | volume_dims = np.array(volume_dims) 128 | 129 | pad_widths = [] 130 | new_shape = [] 131 | 132 | for shape_i, dim_i, stride_i in zip(img_shape, volume_dims, stride): 133 | # Number of strides needed to cover full image volume 134 | target_size = int(np.ceil((shape_i-dim_i)/stride_i)*stride_i+dim_i) 135 | 136 | total_pad = target_size-shape_i 137 | 138 | # Pad must be at least half volume_dims to avoid boundary artefact 139 | half = dim_i//2 140 | before = max(half, total_pad//2) # pad on either side of image 141 | after = max(total_pad-before, half) 142 | 143 | pad_widths.append((before, after)) # (Before, After) in each dimension 144 | new_shape.append(shape_i + before + after) 145 | 146 | padded = da.pad(img, pad_widths, mode='reflect') 147 | #print(f"Padded from {img.shape} to {tuple(new_shape)}") #Debugging 148 | return padded, pad_widths 149 | 150 | # Pad image to avoid boundary effects and allow patches to cover whole image 151 | img, pad_widths = auto_pad(img, volume_dims, stride) 152 | 153 | # Prepare output Zarr stores 154 | # Accumulates weighted sum of softmax outputs, and summed weights from hann filter (for normalising) 155 | root = zarr.open(out_store, mode="w") 156 | sum_arr = root.create_dataset("sum", shape=(*img.shape, n_classes), chunks=(*volume_dims, n_classes), 157 | dtype="float32") 158 | wsum_arr = root.create_dataset("wsum", shape=(*img.shape, 1), chunks=(*volume_dims, 1), 159 | dtype="float32") 160 | 161 | # Compute Hann window for blending 162 | wz, wx, wy = general_hamming(volume_dims[0],0.75), general_hamming(volume_dims[1],0.75), general_hamming(volume_dims[2],0.75) 163 | w_patch = wz[:, None, None] * wx[None, :, None] * wy[None, None, :] 164 | w_patch /= (w_patch.max() + 1e-8) # Normalise to max 1 165 | w_patch = w_patch.astype(np.float32)[...,None] # (Z,X,Y,1) 166 | 167 | # Compute sliding window coordinates 168 | windows = da.lib.stride_tricks.sliding_window_view(img, volume_dims)[::stride[0], 169 | ::stride[1], ::stride[2]] 170 | #print("windows shape:", windows.shape) #debugging 171 | 172 | # Total patches for progress bar 173 | total_patches = windows.shape[0]*windows.shape[1]*windows.shape[2] 174 | 175 | # Preview 176 | def plot_preview(original, pred, z, out_store): 177 | 178 | fig, axs = plt.subplots(1, 2, figsize=(10, 5)) 179 | axs[0].imshow(original, cmap="gray") 180 | axs[0].set_title(f"Input z={z}") 181 | axs[0].axis("off") 182 | 183 | axs[1].imshow(pred, cmap="viridis") 184 | axs[1].set_title(f"Prediction z={z}") 185 | axs[1].axis("off") 186 | plt.tight_layout() 187 | fig.savefig(os.path.join(out_store,'preview_z'+str(z)+'.png')) 188 | 189 | 190 | # Inference step - iterate through windows and blend with weighted sum 191 | step_i = 0 192 | with tqdm(total=total_patches, desc="Inference", unit="patch") as pbar: 193 | for zi in range(windows.shape[0]): 194 | # Define position within image (z-axis) 195 | z0 = zi * stride[0] 196 | z1 = z0 + volume_dims[0] 197 | 198 | for xi in range(windows.shape[1]): 199 | # Define position within image (x-axis) 200 | x0 = xi * stride[1] 201 | x1 = x0 + volume_dims[1] 202 | 203 | for yi in range(windows.shape[2]): 204 | # Define position within image (y-axis) 205 | y0 = yi * stride[2] 206 | y1 = y0 + volume_dims[2] 207 | 208 | # Read patch (compute only this slice) 209 | patch = windows[zi, xi, yi].compute().astype(np.float32, copy=False) 210 | patch = patch[None,...,None] # Reshape to (1,Z,X,Y,C) 211 | 212 | # Predict softmax probability (batch of 1) 213 | pred = model.predict(patch, verbose=0) 214 | pred = pred[0] # pred shape: (1,Z,X,Y,C) -> (Z,X,Y,C) 215 | 216 | # Add weighted prediciton and weighs to accumlators in correct positions 217 | sum_arr[z0:z1, x0:x1, y0:y1, :] += pred * w_patch 218 | wsum_arr[z0:z1, x0:x1, y0:y1, :] += w_patch 219 | 220 | # Update progress bar 221 | step_i += 1 222 | pbar.update(1) 223 | 224 | if preview and z0>0 and z1 0, preview_sum[..., prob_channel] / np.maximum(preview_w[..., 0], 1e-8), 0.0) 232 | preview_pred = preview_pred[pad_widths[1][0]:img.shape[1]-pad_widths[1][1], 233 | pad_widths[2][0]:img.shape[2]-pad_widths[2][1]] # Remove padding (X,Y) 234 | # Also read the corresponding input slice 235 | orig_slice = img[z_mid_slice, :, :].compute() 236 | orig_slice = orig_slice[pad_widths[1][0]:img.shape[1]-pad_widths[1][1], 237 | pad_widths[2][0]:img.shape[2]-pad_widths[2][1]] # Remove padding 238 | plot_preview(orig_slice, preview_pred, z_mid_slice, out_store) 239 | 240 | 241 | # Crop outputs 242 | sum_arr = sum_arr[pad_widths[0][0]:img.shape[0]-pad_widths[0][1], 243 | pad_widths[1][0]:img.shape[1]-pad_widths[1][1], 244 | pad_widths[2][0]:img.shape[2]-pad_widths[2][1]] 245 | wsum_arr = wsum_arr[pad_widths[0][0]:img.shape[0]-pad_widths[0][1], 246 | pad_widths[1][0]:img.shape[1]-pad_widths[1][1], 247 | pad_widths[2][0]:img.shape[2]-pad_widths[2][1]] 248 | 249 | # Normalize and write final zarr 250 | # Create output 'seg' dataset 251 | if binary_output: 252 | # class map (argmax): store as uint8 (or change dtype if you have >255 classes) 253 | seg = root.create_dataset("seg", shape=(Z, X, Y), chunks=volume_dims, dtype="uint8") 254 | # Process chunk-by-chunk to avoid OOM 255 | for zi in tqdm(range(windows.shape[0]), desc="Normalising and saving"): 256 | z0 = zi * stride[0] 257 | z1 = z0 + volume_dims[0] 258 | 259 | # load a slab of weighted sums and weights 260 | slab_sum = sum_arr[z0:z1, :, :, :] # (vz,X,Y,C) 261 | slab_w = wsum_arr[z0:z1, :, :, 0:1] # (vz,X,Y,1) 262 | slab_sum = np.array(slab_sum) # bring to RAM slab only 263 | slab_w = np.array(slab_w) 264 | probs = np.where(slab_w > 0, slab_sum / np.maximum(slab_w, 1e-8), 0.0) # (vz,X,Y,C) 265 | seg[z0:z1, :, :] = np.argmax(probs, axis=-1).astype(np.uint8) 266 | else: 267 | # store forground probability as float32 268 | seg = root.create_dataset("seg", shape=(Z, X, Y), chunks=volume_dims, dtype="float32") 269 | for zi in tqdm(range(windows.shape[0]), desc="Normalising and saving"): 270 | z0 = zi * stride[0] 271 | z1 = z0 + volume_dims[0] 272 | 273 | # load a slab of weighted sums (foreground channel only) and weights 274 | slab_sum = sum_arr[z0:z1, :, :, prob_channel:prob_channel+1] # (vz,X,Y,1) 275 | slab_w = wsum_arr[z0:z1, :, :, 0:1] # (vz,X,Y,1) 276 | slab_sum = np.array(slab_sum) 277 | slab_w = np.array(slab_w) 278 | prob = np.where(slab_w > 0, slab_sum / np.maximum(slab_w, 1e-8), 0.0)[..., 0] # (vz,X,Y) 279 | seg[z0:z1, :, :] = prob 280 | 281 | # Delete sum_arr and wsum_arr now that we're finished with them 282 | del root["sum"] 283 | del root["wsum"] 284 | 285 | # Optional: export BigTIFF 3D, slice-by-slice to handle very large images 286 | if export_bigtiff: 287 | with tiff.TiffWriter(export_bigtiff, bigtiff=True) as tw: 288 | for z in tqdm(range(Z), desc="Export BigTIFF", unit="slice"): 289 | seg_slice = np.array(seg[z, :, :]) # bring one 2D slice to RAM 290 | tw.write(seg_slice, photometric="minisblack", metadata=None) 291 | 292 | return seg, os.path.join(out_store, "segmentation") 293 | 294 | #-----------------------PREPROCESSING FUNCTIONS-------------------------------------------------------------------------------------------------------------- 295 | def data_preprocessing(image_path=None, label_path=None): 296 | """# Pre-processing 297 | Load data, downsample if neccessary, normalise and pad. 298 | Inputs: 299 | image_path = path to image data (string) 300 | label_path = path to labels (string) 301 | Outputs: 302 | img_pad = image data as an np.array, scaled between 0 and 1 303 | seg_pad = label data as an np.array, scaled between 0 and 1 304 | classes = list of classes present in labels 305 | """ 306 | 307 | # Load image 308 | print('Loading images from '+str(image_path)) 309 | img=io.imread(image_path) 310 | print('Size '+str(img.shape)) 311 | 312 | if len(img.shape)==4: 313 | print('Image data has dimensions. Cropping to first 3 dimensions') 314 | img=img[:,:,:,0] 315 | assert img.ndim == 3, "Expected (Z,X,Y) image" 316 | 317 | # Normalise 318 | print('Rescaling data between 0 and 1') 319 | img_min = np.amin(img) 320 | denominator = np.amax(img)-img_min 321 | try:img = (img-img_min)/denominator # Rescale between 0 and 1 322 | except: 323 | try: 324 | # break image up into quarters and normalise one chunck at a time 325 | quarter=int(img.shape[0]/4) 326 | for i in range (3): 327 | img[i*quarter:(i+1)*quarter,:,:]=(img[i*quarter:(i+1)*quarter,:,:]-img_min)/denominator 328 | img[3*quarter:,:,:]=(img[3*quarter:,:,:]-img_min)/denominator 329 | except: 330 | sixteenth=int(img.shape[0]/16) 331 | for i in range (15): 332 | img[i*sixteenth:(i+1)*sixteenth,:,:]=(img[i*sixteenth:(i+1)*sixteenth,:,:]-img_min)/denominator 333 | img[15*sixteenth:,:,:]=(img[15*sixteenth:,:,:]-img_min)/denominator 334 | 335 | #Repeat for labels is present 336 | if label_path is not None: 337 | print('Loading labels from '+str(label_path)) 338 | seg=io.imread(label_path) 339 | 340 | # Normalise 341 | print('Rescaling data between 0 and 1') 342 | seg = (seg-np.amin(seg))/(np.amax(seg)-np.amin(seg)) 343 | 344 | # Find the number of unique classes in segmented training set 345 | classes = np.unique(seg) 346 | 347 | return img, seg, classes 348 | 349 | return img, None, None 350 | 351 | def crop_from_labels(labels, data): 352 | iz, ix, iy = np.where(labels[...]!=0) # find instances of non-zero values in X_test along axis 1 353 | labels = labels[min(iz):max(iz)+1, min(ix):max(ix)+1, min(iy):max(iy)+1] # use this to index data and labels 354 | data = data[min(iz):max(iz)+1, min(ix):max(ix)+1, min(iy):max(iy)+1] 355 | print("Cropped to {}".format(data.shape)) 356 | 357 | return labels, data 358 | 359 | def save_as_dask_array(data, labels=None, output_path=None, output_name=None, chunks=(64,64,64)): 360 | # Create header folder if does not exist 361 | header_folder=os.path.join(output_path, "headers") 362 | if not os.path.exists(header_folder): 363 | os.makedirs(header_folder) 364 | header_name=os.path.join(header_folder,str(output_name)+"_header") 365 | 366 | # Convert to dask array and save as zarr 367 | data = da.from_array(data, chunks=chunks) 368 | data.to_zarr(os.path.join(output_path, output_name)) 369 | 370 | from tUbeNet_classes import DataHeader 371 | 372 | # Repeat of labels if present 373 | if labels is not None: 374 | # Convert to dask array and save as zarr 375 | labels = da.from_array(labels, chunks=chunks) 376 | labels.to_zarr(os.path.join(output_path, str(output_name)+"_labels")) 377 | 378 | # Save data header for easy reading in 379 | header = DataHeader(ID=output_name, image_dims=labels.shape, 380 | image_filename=os.path.join(output_path, output_name), 381 | label_filename=os.path.join(output_path, str(output_name)+"_labels")) 382 | header.save(header_name) 383 | else: 384 | # Save data header for easy reading in, with label_filename=None 385 | header = DataHeader(ID=output_name, image_dims=data.shape, 386 | image_filename=os.path.join(output_path, output_name), 387 | label_filename=None) 388 | header.save(header_name) 389 | 390 | return output_path, header_name 391 | 392 | #---------------------------EVALUATION---------------------------------------------------------------------- 393 | 394 | def roc_analysis(model, data_dir, volume_dims=(64,64,64), 395 | overlap=None, n_classes=2, 396 | output_path=None, 397 | binary_output=False): 398 | 399 | optimal_thresholds = [] 400 | recall = [] 401 | precision = [] 402 | dice = [] 403 | average_precision = [] 404 | 405 | if not overlap: 406 | overlap = (volume_dims[0]//2,volume_dims[1]//2,volume_dims[2]//2) 407 | 408 | for index in range(0,len(data_dir.list_IDs)): 409 | print('Evaluating model on '+str(data_dir.list_IDs[index])+' data') 410 | 411 | # Build output name from image filename and output path 412 | dask_name = os.path.join(output_path, str(data_dir.list_IDs[index])+"_prediction") 413 | tiff_name = str(dask_name)+".tif" 414 | 415 | # Predict segmentation 416 | y_pred, zarr_path = predict_segmentation_dask( 417 | model, 418 | data_dir.image_filenames[index], 419 | dask_name, 420 | volume_dims=volume_dims, 421 | overlap=overlap, 422 | n_classes=n_classes, 423 | preview=False, 424 | binary_output=False, 425 | prob_channel=1, 426 | ) 427 | 428 | # Create 1D numpy array of predicted output (softmax) 429 | y_pred1D = da.ravel(y_pred).astype(np.float32) 430 | 431 | # Create 1D numpy array of true labels 432 | y_test = da.from_zarr(data_dir.label_filenames[index]) 433 | y_test1D = da.ravel(y_test).astype(np.float32) 434 | 435 | # ROC Curve and area under curve 436 | fpr, tpr, thresholds_roc = roc_curve(y_test1D, y_pred1D, pos_label=1) 437 | area_under_curve = auc(fpr, tpr) 438 | 439 | # Plot ROC 440 | fig = plt.figure() 441 | plt.plot(fpr, tpr, color='darkorange', 442 | lw=2, label='ROC curve (area = %0.5f)' % area_under_curve) 443 | plt.plot([0, 1], [0, 1], color='navy', lw=2, linestyle='--') 444 | plt.xlim([0.0, 1.0]) 445 | plt.ylim([0.0, 1.05]) 446 | plt.xlabel('False Positive Rate') 447 | plt.ylabel('True Positive Rate') 448 | plt.title('Receiver operating characteristic for '+str(data_dir.list_IDs[index])) 449 | plt.legend(loc="lower right") 450 | fig.savefig(os.path.join(output_path,'ROC_'+str(data_dir.list_IDs[index])+'.png')) 451 | 452 | # Precision-Recall Curve 453 | fig, ax = plt.subplots() 454 | disp = PrecisionRecallDisplay.from_predictions(np.asarray(y_test1D), np.asarray(y_pred1D), 455 | name='PR Curve', 456 | ax=ax, pos_label=1) 457 | ax.set_title("Precision-Recall Curve for "+str(data_dir.list_IDs[index])) 458 | ax.set_xlim([0.0, 1.0]) 459 | ax.set_ylim([0.0, 1.05]) 460 | fig.savefig(os.path.join(output_path,'PRCurve_'+str(data_dir.list_IDs[index])+'.png')) 461 | 462 | # Report and log DICE and average precision 463 | p, r, thresholds = precision_recall_curve(np.asarray(y_test1D), np.asarray(y_pred1D)) 464 | f1 = 2*p*r/(p+r) 465 | optimal_idx = np.argmax(f1) # Find threshold to maximise DICE 466 | 467 | print('Optimal threshold (ROC): {}'.format(thresholds[optimal_idx])) 468 | optimal_thresholds.append(thresholds[optimal_idx]) 469 | print('Recall at optimal threshold: {}'.format(r[optimal_idx])) 470 | recall.append(r[optimal_idx]) 471 | print('Precision at optimal threshold: {}'.format(p[optimal_idx])) 472 | precision.append(p[optimal_idx]) 473 | print('DICE Score: {}'.format(f1[optimal_idx])) 474 | 475 | average_precision.append(average_precision_score(np.asarray(y_test1D), np.asarray(y_pred1D))) 476 | print('Average Precision Score: {}'.format(average_precision[index])) 477 | 478 | # Convert to binary with optimal threshold 479 | if binary_output: 480 | y_pred = np.where(y_pred[...,1]>thresholds[optimal_idx],1,0) 481 | 482 | # Save as tiff 483 | tiff.imwrite(tiff_name, y_pred, photometric="minisblack", metadata=None) 484 | print('Predicted segmentation saved to {}'.format(tiff_name)) 485 | 486 | return optimal_thresholds, recall, precision, average_precision 487 | --------------------------------------------------------------------------------