├── kits19cnn ├── __init__.py ├── models │ ├── __init__.py │ ├── nnunet │ │ ├── __init__.py │ │ ├── initialization.py │ │ └── generic_UNet.py │ └── smp_models.py ├── inference │ ├── __init__.py │ ├── ensemble.py │ ├── utils.py │ ├── inference_class.py │ └── evaluate.py ├── experiments │ ├── __init__.py │ ├── infer_2d.py │ ├── train_3d.py │ ├── infer.py │ ├── train_2d.py │ ├── utils.py │ └── train.py ├── io │ ├── __init__.py │ ├── custom_transforms.py │ ├── dataset.py │ ├── resample.py │ ├── preprocess.py │ ├── custom_augmentations.py │ └── dataset_2d.py ├── metrics.py ├── utils.py ├── loss_functions.py └── visualize.py ├── images ├── label_case_00113.png └── pred2d_case_00113.png ├── script_configs ├── eval.yml ├── infer_tu_only │ ├── eval.yml │ └── pred.yml ├── pred.yml └── train.yml ├── .gitignore ├── scripts ├── evaluate.py ├── predict.py └── train_yaml.py ├── setup.py ├── notebooks └── Visualizing Volumes.ipynb ├── README.md └── LICENSE /kits19cnn/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /images/label_case_00113.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jchen42703/kits19-cnn/HEAD/images/label_case_00113.png -------------------------------------------------------------------------------- /images/pred2d_case_00113.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jchen42703/kits19-cnn/HEAD/images/pred2d_case_00113.png -------------------------------------------------------------------------------- /kits19cnn/models/__init__.py: -------------------------------------------------------------------------------- 1 | from .nnunet import SegmentationNetwork, Generic_UNet 2 | from .smp_models import wrap_smp_model 3 | -------------------------------------------------------------------------------- /kits19cnn/models/nnunet/__init__.py: -------------------------------------------------------------------------------- 1 | from .neural_network import SegmentationNetwork 2 | from .generic_UNet import Generic_UNet 3 | -------------------------------------------------------------------------------- /kits19cnn/inference/__init__.py: -------------------------------------------------------------------------------- 1 | from .inference_class import Predictor 2 | from .evaluate import Evaluator 3 | from .utils import create_submission, load_weights_infer 4 | from .ensemble import Ensembler 5 | -------------------------------------------------------------------------------- /script_configs/eval.yml: -------------------------------------------------------------------------------- 1 | orig_img_dir: /content/kits_preprocessed 2 | pred_dir: /content/kits19_predictions 3 | label_file_ending: .npy 4 | print_metrics: False 5 | binary_tumor: False 6 | -------------------------------------------------------------------------------- /script_configs/infer_tu_only/eval.yml: -------------------------------------------------------------------------------- 1 | orig_img_dir: /content/kits_preprocessed 2 | pred_dir: /content/kits19_predictions 3 | label_file_ending: .npy 4 | print_metrics: False 5 | binary_tumor: True 6 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | *.pyc 2 | *.ipynb_checkpoints 3 | __pycache__/ 4 | 5 | bin/ 6 | build/ 7 | develop-eggs/ 8 | dist/ 9 | eggs/ 10 | lib/ 11 | lib64/ 12 | parts/ 13 | sdist/ 14 | var/ 15 | *.egg-info/ 16 | .installed.cfg 17 | *.egg 18 | -------------------------------------------------------------------------------- /kits19cnn/experiments/__init__.py: -------------------------------------------------------------------------------- 1 | from .utils import get_training_augmentation, get_validation_augmentation, \ 2 | get_preprocessing, seed_everything 3 | from .train_3d import TrainSegExperiment, TrainClfSegExperiment3D 4 | from .train_2d import TrainSegExperiment2D, TrainClfSegExperiment2D 5 | from .infer import SegmentationInferenceExperiment 6 | from .infer_2d import SegmentationInferenceExperiment2D 7 | -------------------------------------------------------------------------------- /kits19cnn/io/__init__.py: -------------------------------------------------------------------------------- 1 | from .dataset import ClfSegVoxelDataset, VoxelDataset, TestVoxelDataset 2 | from .dataset_2d import SliceDataset, PseudoSliceDataset 3 | from .preprocess import Preprocessor 4 | from .resample import resample_patient 5 | from .custom_transforms import ROICropTransform, RepeatChannelsTransform, \ 6 | MultiClassToBinaryTransform, \ 7 | RandomResizedCropTransform 8 | -------------------------------------------------------------------------------- /kits19cnn/inference/ensemble.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | class Ensembler(object): 4 | """ 5 | Iterates through multiple directories of predicted activation maps 6 | and averages them. The ensembled results are saved in a separate 7 | directory, `out_dir`. 8 | * Assumes the predicted activation maps are called `pred_act.npy` 9 | in their respective case directories. 10 | """ 11 | def __init__(self): 12 | pass 13 | -------------------------------------------------------------------------------- /kits19cnn/metrics.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | def evaluate_official(y_true, y_pred): 4 | """ 5 | Official evaluation metric. (numpy) 6 | """ 7 | try: 8 | # Compute tumor+kidney Dice 9 | tk_pd = np.greater(y_pred, 0) 10 | tk_gt = np.greater(y_true, 0) 11 | intersection = np.logical_and(tk_pd, tk_gt).sum() 12 | tk_dice = 2*intersection/(tk_pd.sum() + tk_gt.sum()) 13 | except ZeroDivisionError: 14 | return 0.0, 0.0 15 | 16 | try: 17 | # Compute tumor Dice 18 | tu_pd = np.greater(y_pred, 1) 19 | tu_gt = np.greater(y_true, 1) 20 | intersection = np.logical_and(tu_pd, tu_gt).sum() 21 | tu_dice = 2*intersection/(tu_pd.sum() + tu_gt.sum()) 22 | except ZeroDivisionError: 23 | return tk_dice, 0.0 24 | 25 | return tk_dice, tu_dice 26 | -------------------------------------------------------------------------------- /kits19cnn/models/smp_models.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from kits19cnn.models.nnunet.neural_network import SegmentationNetwork 3 | from kits19cnn.utils import softmax_helper 4 | 5 | def wrap_smp_model(smp_model_type, smp_model_kwargs={}, num_classes=3, 6 | activation="softmax"): 7 | """ 8 | Wraps a 2D smp model with SegmentationNetwork's methods. Mainly for 9 | inference, so that the smp_model can use the `predict_3D` method. 10 | """ 11 | class WrappedModel(smp_model_type, SegmentationNetwork): 12 | def __init__(self, model_kwargs={}): 13 | super().__init__(**model_kwargs) 14 | self.conv_op = torch.nn.Conv2d 15 | self.num_classes = num_classes 16 | if activation == "softmax": 17 | self.inference_apply_nonlin = softmax_helper 18 | elif activation == "sigmoid": 19 | self.inference_apply_nonlin = torch.sigmoid 20 | wrapped_model = WrappedModel(smp_model_kwargs) 21 | return wrapped_model 22 | -------------------------------------------------------------------------------- /scripts/evaluate.py: -------------------------------------------------------------------------------- 1 | from kits19cnn.inference import Evaluator 2 | 3 | def main(config): 4 | """ 5 | Main code for running the evaluation of 3D volumes. 6 | Args: 7 | config (dict): dictionary read from a yaml file 8 | i.e. script_configs/eval.yml 9 | Returns: 10 | None 11 | """ 12 | evaluator = Evaluator(config["orig_img_dir"], config["pred_dir"], 13 | label_file_ending=config["label_file_ending"], 14 | binary_tumor=config["binary_tumor"]) 15 | evaluator.evaluate_all(print_metrics=config["print_metrics"]) 16 | 17 | if __name__ == "__main__": 18 | import yaml 19 | import argparse 20 | 21 | parser = argparse.ArgumentParser(description="For training.") 22 | parser.add_argument("--yml_path", type=str, required=True, 23 | help="Path to the .yml config.") 24 | args = parser.parse_args() 25 | 26 | with open(args.yml_path, 'r') as stream: 27 | try: 28 | config = yaml.safe_load(stream) 29 | except yaml.YAMLError as exc: 30 | print(exc) 31 | 32 | main(config) 33 | -------------------------------------------------------------------------------- /script_configs/pred.yml: -------------------------------------------------------------------------------- 1 | in_dir: /content/kits_preprocessed 2 | out_dir: /content/kits19_predictions 3 | with_masks: True 4 | mode: segmentation 5 | checkpoint_path: nnunet3d_exp2_94epochs_last_full.pth 6 | pseudo_3D: False 7 | 8 | io_params: 9 | test_size: 0.2 10 | split_seed: 200 11 | batch_size: 1 12 | num_workers: 2 13 | # aug_key: aug2 14 | file_ending: .npy # nii.gz 15 | 16 | model_params: 17 | architecture: nnunet 18 | instance_norm: False 19 | nnunet: 20 | input_channels: 1 21 | base_num_features: 30 22 | num_classes: 3 23 | num_pool: 5 24 | num_conv_per_stage: 2 25 | feat_map_mul_on_downscale: 2 26 | deep_supervision: False 27 | convolutional_pooling: True 28 | convolutional_upsampling: True 29 | max_num_features: 320 30 | # # 2D ONLY 31 | # encoder: resnet34 32 | # activation: softmax 33 | # unet_smp: 34 | # attention_type: ~ # scse 35 | # decoder_use_batchnorm: True # inplace for InplaceABN 36 | # fpn_smp: 37 | # dropout: 0.2 38 | 39 | predict_3D_params: 40 | do_mirroring: True 41 | num_repeats: 1 42 | use_train_mode: False 43 | batch_size: 1 44 | mirror_axes: 45 | - 0 46 | - 1 47 | - 2 48 | tiled: True 49 | tile_in_z: True 50 | step: 2 51 | patch_size: 52 | - 96 53 | - 160 54 | - 160 55 | regions_class_order: ~ #argmax 56 | use_gaussian: False 57 | pad_border_mode: edge 58 | pad_kwargs: {} 59 | all_in_gpu: False 60 | -------------------------------------------------------------------------------- /script_configs/infer_tu_only/pred.yml: -------------------------------------------------------------------------------- 1 | in_dir: /content/kits_preprocessed 2 | out_dir: /content/kits19_predictions 3 | with_masks: True 4 | mode: segmentation 5 | checkpoint_path: resnet34unet_seg_tuonly2d2_381epochs_seed15_best.pth 6 | pseudo_3D: False 7 | 8 | io_params: 9 | test_size: 0.2 10 | split_seed: 15 11 | batch_size: 1 12 | num_workers: 2 13 | file_ending: .npy # nii.gz 14 | 15 | model_params: 16 | architecture: unet_smp 17 | # instance_norm: False 18 | # nnunet: 19 | # input_channels: 1 20 | # base_num_features: 30 21 | # num_classes: 3 22 | # num_pool: 5 23 | # num_conv_per_stage: 2 24 | # feat_map_mul_on_downscale: 2 25 | # deep_supervision: False 26 | # convolutional_pooling: True 27 | # convolutional_upsampling: True 28 | # max_num_features: 320 29 | # # 2D ONLY 30 | encoder: resnet34 31 | activation: sigmoid 32 | unet_smp: 33 | attention_type: ~ # scse 34 | classes: 1 35 | decoder_use_batchnorm: True # inplace for InplaceABN 36 | # fpn_smp: 37 | # classes: 1 38 | # dropout: 0.2 39 | 40 | predict_3D_params: 41 | do_mirroring: True 42 | num_repeats: 1 43 | use_train_mode: False 44 | batch_size: 1 45 | mirror_axes: 46 | - 0 47 | - 1 48 | # - 2 49 | tiled: True 50 | tile_in_z: True 51 | step: 2 52 | patch_size: 53 | - 256 54 | - 256 55 | # - 96 56 | # - 160 57 | # - 160 58 | regions_class_order: # ~ #argmax 59 | - 0 60 | - 1 61 | use_gaussian: False 62 | pad_border_mode: edge 63 | pad_kwargs: {} 64 | all_in_gpu: False 65 | -------------------------------------------------------------------------------- /kits19cnn/models/nnunet/initialization.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from torch import nn 16 | 17 | 18 | class InitWeights_He(object): 19 | def __init__(self, neg_slope=1e-2): 20 | self.neg_slope = neg_slope 21 | 22 | def __call__(self, module): 23 | if isinstance(module, nn.Conv3d) or isinstance(module, nn.Conv2d) or isinstance(module, nn.ConvTranspose2d) or isinstance(module, nn.ConvTranspose3d): 24 | module.weight = nn.init.kaiming_normal_(module.weight, a=1e-2) 25 | if module.bias is not None: 26 | module.bias = nn.init.constant_(module.bias, 0) 27 | 28 | 29 | class InitWeights_XavierUniform(object): 30 | def __init__(self, gain=1): 31 | self.gain = gain 32 | 33 | def __call__(self, module): 34 | if isinstance(module, nn.Conv3d) or isinstance(module, nn.Conv2d) or isinstance(module, nn.ConvTranspose2d) or isinstance(module, nn.ConvTranspose3d): 35 | module.weight = nn.init.xavier_uniform_(module.weight, self.gain) 36 | if module.bias is not None: 37 | module.bias = nn.init.constant_(module.bias, 0) 38 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | setup(name='kits19cnn', 4 | version='0.01.5', 5 | description='Submission for the KiTS 2019 Challenge', 6 | url='https://github.com/jchen42703/kits19-cnn', 7 | author='Joseph Chen, Benson Jin', 8 | author_email='jchen42703@gmail.com, jinb2@bxsci.edu', 9 | license='Apache License Version 2.0, January 2004', 10 | packages=find_packages(), 11 | install_requires=[ 12 | "numpy>=1.10.2", 13 | "scipy", 14 | "scikit-image", 15 | "future", 16 | "keras", 17 | "tensorflow", 18 | "nibabel", 19 | "pandas", 20 | "sklearn", 21 | "batchgenerators", 22 | "torch>=1.2.0", 23 | "torchvision>=0.4.0", 24 | "catalyst", 25 | "pytorch_toolbelt", 26 | "segmentation_models_pytorch", 27 | ], 28 | classifiers=[ 29 | 'Development Status :: 3 - Alpha', 30 | # Chose either "3 - Alpha", "4 - Beta" or "5 - Production/Stable" as the current state of your package 31 | 'Intended Audience :: Developers', # Define that your audience are developers 32 | 'Topic :: Software Development :: Build Tools', 33 | 'License :: OSI Approved :: MIT License', # Again, pick a license 34 | 'Programming Language :: Python :: 3', # Specify which python versions that you want to support 35 | 'Programming Language :: Python :: 3.4', 36 | 'Programming Language :: Python :: 3.5', 37 | 'Programming Language :: Python :: 3.6', 38 | 'Programming Language :: Python :: 2.7', 39 | ], 40 | keywords=['deep learning', 'image segmentation', 'image classification', 'medical image analysis', 41 | 'medical image segmentation', 'data augmentation'], 42 | ) 43 | -------------------------------------------------------------------------------- /scripts/predict.py: -------------------------------------------------------------------------------- 1 | from catalyst.dl.runner import SupervisedRunner 2 | 3 | from kits19cnn.inference import Predictor 4 | from kits19cnn.experiments import SegmentationInferenceExperiment, \ 5 | SegmentationInferenceExperiment2D, \ 6 | seed_everything 7 | 8 | def main(config): 9 | """ 10 | Main code for training a classification model. 11 | 12 | Args: 13 | config (dict): dictionary read from a yaml file 14 | i.e. experiments/finetune_classification.yml 15 | Returns: 16 | None 17 | """ 18 | # setting up the train/val split with filenames 19 | seed = config["io_params"]["split_seed"] 20 | seed_everything(seed) 21 | dim = len(config["predict_3D_params"]["patch_size"]) 22 | mode = config["mode"].lower() 23 | assert mode in ["classification", "segmentation"], \ 24 | "The `mode` must be one of ['classification', 'segmentation']." 25 | if mode == "classification": 26 | raise NotImplementedError 27 | elif mode == "segmentation": 28 | if dim == 2: 29 | exp = SegmentationInferenceExperiment2D(config) 30 | elif dim == 3: 31 | exp = SegmentationInferenceExperiment(config) 32 | 33 | print(f"Seed: {seed}\nMode: {mode}") 34 | pred = Predictor(out_dir=config["out_dir"], 35 | checkpoint_path=config["checkpoint_path"], 36 | model=exp.model, test_loader=exp.loaders["test"], 37 | pred_3D_params=config["predict_3D_params"], 38 | pseudo_3D=config.get("pseudo_3D")) 39 | pred.run_3D_predictions() 40 | 41 | if __name__ == "__main__": 42 | import yaml 43 | import argparse 44 | 45 | parser = argparse.ArgumentParser(description="For training.") 46 | parser.add_argument("--yml_path", type=str, required=True, 47 | help="Path to the .yml config.") 48 | args = parser.parse_args() 49 | 50 | with open(args.yml_path, 'r') as stream: 51 | try: 52 | config = yaml.safe_load(stream) 53 | except yaml.YAMLError as exc: 54 | print(exc) 55 | 56 | main(config) 57 | -------------------------------------------------------------------------------- /kits19cnn/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | import torch 15 | import numpy as np 16 | 17 | def flip(x, dim): 18 | """ 19 | flips the tensor at dimension dim (mirroring!) 20 | :param x: 21 | :param dim: 22 | :return: 23 | """ 24 | indices = [slice(None)] * x.dim() 25 | indices[dim] = torch.arange(x.size(dim) - 1, -1, -1, 26 | dtype=torch.long, device=x.device) 27 | return x[tuple(indices)] 28 | 29 | def sum_tensor(inp, axes, keepdim=False): 30 | axes = np.unique(axes).astype(int) 31 | if keepdim: 32 | for ax in axes: 33 | inp = inp.sum(int(ax), keepdim=True) 34 | else: 35 | for ax in sorted(axes, reverse=True): 36 | inp = inp.sum(int(ax)) 37 | return inp 38 | 39 | def maybe_to_torch(d): 40 | if isinstance(d, list): 41 | d = [maybe_to_torch(i) if not isinstance(i, torch.Tensor) else i for i in d] 42 | elif not isinstance(d, torch.Tensor): 43 | d = torch.from_numpy(d).float() 44 | return d 45 | 46 | 47 | def to_cuda(data, non_blocking=True, gpu_id=0): 48 | if isinstance(data, list): 49 | data = [i.cuda(gpu_id, non_blocking=non_blocking) for i in data] 50 | else: 51 | data = data.cuda(gpu_id, non_blocking=True) 52 | return data 53 | 54 | def softmax_helper(x): 55 | rpt = [1 for _ in range(len(x.size()))] 56 | rpt[1] = x.size(1) 57 | x_max = x.max(1, keepdim=True)[0].repeat(*rpt) 58 | e_x = torch.exp(x - x_max) 59 | return e_x / e_x.sum(1, keepdim=True).repeat(*rpt) 60 | -------------------------------------------------------------------------------- /scripts/train_yaml.py: -------------------------------------------------------------------------------- 1 | from catalyst.dl.runner import SupervisedRunner 2 | 3 | from kits19cnn.experiments import TrainSegExperiment, TrainClfSegExperiment3D, \ 4 | TrainSegExperiment2D, TrainClfSegExperiment2D, \ 5 | seed_everything 6 | from kits19cnn.visualize import plot_metrics, save_figs 7 | 8 | def main(config): 9 | """ 10 | Main code for training a classification model. 11 | 12 | Args: 13 | config (dict): dictionary read from a yaml file 14 | i.e. experiments/finetune_classification.yml 15 | Returns: 16 | None 17 | """ 18 | # setting up the train/val split with filenames 19 | seed = config["io_params"]["split_seed"] 20 | seed_everything(seed) 21 | mode = config["mode"].lower() 22 | assert mode in ["classification", "segmentation", "both"], \ 23 | "The `mode` must be one of ['classification', 'segmentation', 'both']." 24 | if mode == "classification": 25 | raise NotImplementedError 26 | elif mode == "segmentation": 27 | if config["dim"] == 2: 28 | exp = TrainSegExperiment2D(config) 29 | elif config["dim"] == 3: 30 | exp = TrainSegExperiment(config) 31 | output_key = "logits" 32 | elif mode == "both": 33 | if config["dim"] == 2: 34 | exp = TrainClfSegExperiment2D(config) 35 | elif config["dim"] == 3: 36 | exp = TrainClfSegExperiment3D(config) 37 | output_key = ["seg_logits", "clf_logits"] 38 | 39 | print(f"Seed: {seed}\nMode: {mode}") 40 | 41 | runner = SupervisedRunner(output_key=output_key) 42 | 43 | runner.train(model=exp.model, criterion=exp.criterion, optimizer=exp.opt, 44 | scheduler=exp.lr_scheduler, loaders=exp.loaders, 45 | callbacks=exp.cb_list, **config["runner_params"]) 46 | # Not saving plots if plot_params not specified in config 47 | if not config.get("plot_params"): 48 | figs = plot_metrics(logdir=config["runner_params"]["logdir"], 49 | metrics=config["plot_params"]["metrics"]) 50 | save_figs(figs, save_dir=config["plot_params"]["save_dir"]) 51 | 52 | if __name__ == "__main__": 53 | import yaml 54 | import argparse 55 | 56 | parser = argparse.ArgumentParser(description="For training.") 57 | parser.add_argument("--yml_path", type=str, required=True, 58 | help="Path to the .yml config.") 59 | args = parser.parse_args() 60 | 61 | with open(args.yml_path, 'r') as stream: 62 | try: 63 | config = yaml.safe_load(stream) 64 | except yaml.YAMLError as exc: 65 | print(exc) 66 | 67 | main(config) 68 | -------------------------------------------------------------------------------- /kits19cnn/inference/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | from os.path import join, isdir 3 | 4 | import numpy as np 5 | import nibabel as nib 6 | import torch 7 | 8 | def load_weights_infer(checkpoint_path, model): 9 | """ 10 | Loads pytorch model from a checkpoint and into inference mode. 11 | 12 | Args: 13 | checkpoint_path (str): path to a .pt or .pth checkpoint 14 | model (torch.nn.Module): <- 15 | Returns: 16 | Model with loaded weights and in evaluation mode 17 | """ 18 | try: 19 | # catalyst weights 20 | state_dict = torch.load(checkpoint_path, map_location="cpu")["model_state_dict"] 21 | except: 22 | # anything else 23 | state_dict = torch.load(checkpoint_path, map_location="cpu") 24 | try: 25 | model.load_state_dict(state_dict, strict=True) 26 | except: 27 | # for clf + seg for seg only prediction 28 | print(f"Non-strict loading of weights from {checkpoint_path}") 29 | model.load_state_dict(state_dict, strict=False) 30 | model.eval() 31 | return model 32 | 33 | def create_submission(pred_dir, out_dir, orig_dir, cases=None): 34 | """ 35 | Creates a submission directory from the predictions. To complete the 36 | submission, run this from command prompt: 37 | zip predictions.zip out_dir/predictions/prediction_*.nii.gz; 38 | change out_dir to whatever out_dir you specified for this function. 39 | Args: 40 | pred_dir: file path to the predictions directory, generated by 41 | infer.Predictor 42 | out_dir: file path to the directory where the predictions.zip is to 43 | be saved. 44 | orig_dir: file path to the original image directory (kits19/data). 45 | cases (iterable): iterable of case folder names 46 | """ 47 | # setting up the output directory 48 | out_pred_dir = join(out_dir, "predictions") 49 | 50 | if not isdir(out_pred_dir): 51 | os.mkdir(out_pred_dir) 52 | print("Created directory: {0}".format(out_pred_dir)) 53 | 54 | if cases is None: 55 | cases = ["case_{:05d}".format(i) for i in range(210, 300)] 56 | 57 | # iteratively converting each prediction 58 | for (i, case) in enumerate(cases): 59 | print("Processing {0}/{1}: {2}".format(i+1, len(cases), case)) 60 | x_nib = nib.load(join(orig_dir, case, "imaging.nii.gz")) 61 | pred = np.load(join(pred_dir, case, "pred_{0}.npy".format(case))) 62 | # converting the prediction to a .nii.gz file 63 | p_nib = nib.Nifti1Image(pred, x_nib.affine) 64 | id_ = case[-5:] # number id attached to each case 65 | out_fpath = join(out_pred_dir, "prediction_{0}.nii.gz".format(id_)) 66 | nib.save(p_nib, out_fpath) 67 | 68 | for case in tqdm(self.cases): 69 | x_nib = nib.load(join(orig_dir, case, "imaging.nii.gz")) 70 | pred = np.load(join(pred_dir, case, f"pred_{case}.npy")) 71 | # converting the prediction to a .nii.gz file 72 | p_nib = nib.Nifti1Image(pred, x_nib.affine) 73 | id_ = case[-5:] # number id attached to each case 74 | out_fpath = join(out_pred_dir, f"prediction_{id_}.nii.gz") 75 | nib.save(p_nib, out_fpath) 76 | -------------------------------------------------------------------------------- /script_configs/train.yml: -------------------------------------------------------------------------------- 1 | mode: segmentation # classification # both 2 | dim: 3 # 2 3 | data_folder: /content/kits_preprocessed/ #/content/kits19/data 4 | 5 | runner_params: 6 | logdir: /content/logs/segmentation/ 7 | num_epochs: 85 8 | fp16: False 9 | verbose: True 10 | 11 | io_params: 12 | test_size: 0.2 13 | split_seed: 200 14 | batch_size: 2 15 | num_workers: 2 16 | aug_key: aug2 17 | file_ending: .npy 18 | # slice_indices_path: /content/kits_preprocessed/slice_indices.json # 2D 19 | # p_pos_per_sample: 0.33 # 2D 20 | # pseudo_3D: False 21 | # num_pseudo_slices: 7 22 | 23 | criterion_params: 24 | loss: ce_dice_loss 25 | ce_dice_loss: 26 | soft_dice_kwargs: 27 | batch_dice: True 28 | smooth: 0.00001 #1e-5 29 | do_bg: False 30 | square: False 31 | ce_kwargs: {} 32 | # for clf_seg 33 | # seg_loss: ce_dice_loss 34 | # ce_dice_loss: 35 | # soft_dice_kwargs: 36 | # batch_dice: True 37 | # smooth: 0.00001 #1e-5 38 | # do_bg: False 39 | # square: False 40 | # ce_kwargs: {} 41 | # clf_loss: bce_dice_loss 42 | # bce_dice_loss: 43 | # eps: 0.0000001 # 1e-7 44 | # activation: sigmoid 45 | 46 | model_params: 47 | architecture: nnunet 48 | nnunet: 49 | input_channels: 1 50 | base_num_features: 30 51 | num_classes: 3 52 | num_pool: 5 53 | num_conv_per_stage: 2 54 | feat_map_mul_on_downscale: 2 55 | deep_supervision: False 56 | convolutional_pooling: True 57 | convolutional_upsampling: True 58 | max_num_features: 320 59 | classification: False #True 60 | dropout_op_kwargs: 61 | p: 0 62 | inplace: True 63 | ## 2D ONLY 64 | # encoder: resnet34 65 | # unet_smp: 66 | # attention_type: ~ # scse 67 | # classes: 3 68 | # decoder_use_batchnorm: True # inplace for InplaceABN 69 | # fpn_smp: 70 | # classes: 3 71 | # dropout: 0.2 72 | 73 | opt_params: 74 | opt: SGD 75 | SGD: 76 | lr: 0.0001 77 | momentum: 0.9 78 | weight_decay: 0.0001 79 | scheduler_params: 80 | scheduler: ReduceLROnPlateau 81 | ReduceLROnPlateau: 82 | factor: 0.15 83 | patience: 30 #2 84 | mode: min 85 | verbose: True 86 | threshold: 0.001 87 | threshold_mode: abs 88 | 89 | callback_params: 90 | EarlyStoppingCallback: 91 | patience: 60 92 | min_delta: 0.001 93 | # AccuracyCallback: 94 | # threshold: 0.5 95 | # activation: Softmax 96 | # PrecisionRecallF1ScoreCallback: 97 | # num_classes: 3 98 | # threshold: 0.5 99 | # activation: Softmax 100 | checkpoint_params: 101 | checkpoint_path: ~ #/content/logs/segmentation/checkpoints/last.pth #/content/logs/segmentation/checkpoints/last_full.pth 102 | mode: model_only 103 | 104 | # specify if you want to save plotly plots as .pngs 105 | ## Requires separate installation of xvfb on Colab. 106 | # plot_params: 107 | # metrics: 108 | # - loss/epoch 109 | # # - ppv/class_0/epoch 110 | # # - f1/class_0/epoch 111 | # # - tpr/class_0/epoch 112 | # save_dir: /content/logs/segmentation/ 113 | -------------------------------------------------------------------------------- /kits19cnn/experiments/infer_2d.py: -------------------------------------------------------------------------------- 1 | import segmentation_models_pytorch as smp 2 | import torch 3 | 4 | from kits19cnn.utils import softmax_helper 5 | from kits19cnn.models import Generic_UNet, wrap_smp_model 6 | from .utils import get_preprocessing 7 | from kits19cnn.io import TestVoxelDataset 8 | from .infer import BaseInferenceExperiment 9 | 10 | class SegmentationInferenceExperiment2D(BaseInferenceExperiment): 11 | """ 12 | Inference Experiment to support prediction experiments 13 | """ 14 | def __init__(self, config: dict): 15 | """ 16 | Args: 17 | config (dict): 18 | """ 19 | self.model_params = config["model_params"] 20 | super().__init__(config=config) 21 | 22 | def get_datasets(self, test_ids): 23 | """ 24 | Creates and returns the test dataset. 25 | """ 26 | use_rgb = "smp" in self.model_params["architecture"] 27 | preprocess_t = get_preprocessing(use_rgb) 28 | # creating the datasets 29 | test_dataset = TestVoxelDataset(im_ids=test_ids, 30 | transforms=None, 31 | preprocessing=preprocess_t, 32 | file_ending=self.io_params["file_ending"]) 33 | return test_dataset 34 | 35 | def get_model(self): 36 | """ 37 | Fetches the 2D model: the nnU-Net, smp U-Net or smp FPN for prediction. 38 | """ 39 | architecture = self.model_params["architecture"] 40 | # creating model 41 | if architecture == "nnunet": 42 | unet_kwargs = self.model_params[architecture] 43 | unet_kwargs = self.setup_2D_UNet_params(unet_kwargs) 44 | model = Generic_UNet(**unet_kwargs) 45 | model.inference_apply_nonlin = softmax_helper 46 | # smp models 47 | elif "smp" in architecture: 48 | model_type = smp.FPN if architecture == "fpn_smp" else smp.Unet 49 | print(f"Model type: {model_type}") 50 | model_kwargs = {"encoder_name": self.model_params["encoder"], 51 | "encoder_weights": None, "classes": 3, 52 | "activation": None} 53 | model_kwargs.update(self.model_params[architecture]) 54 | # adds the `predict_3D` method for the smp model 55 | model = wrap_smp_model(model_type, model_kwargs, 56 | num_classes=model_kwargs["classes"], 57 | activation=self.model_params["activation"]) 58 | # calculating # of parameters 59 | total = sum(p.numel() for p in model.parameters()) 60 | trainable = sum(p.numel() for p in model.parameters() if p.requires_grad) 61 | print(f"Total # of Params: {total}\nTrainable params: {trainable}") 62 | 63 | return model.cuda() 64 | 65 | def setup_2D_UNet_params(self, unet_kwargs): 66 | """ 67 | ^^^^^^^^^^^^^ 68 | """ 69 | unet_kwargs["conv_op"] = torch.nn.Conv2d 70 | if self.model_params.get("instance_norm"): 71 | unet_kwargs["norm_op"] = torch.nn.InstanceNorm2d 72 | unet_kwargs["dropout_op"] = torch.nn.Dropout2d 73 | unet_kwargs["nonlin"] = torch.nn.ReLU 74 | unet_kwargs["nonlin_kwargs"] = {"inplace": True} 75 | unet_kwargs["final_nonlin"] = lambda x: x 76 | return unet_kwargs 77 | -------------------------------------------------------------------------------- /kits19cnn/inference/inference_class.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | from os.path import join, isdir 3 | from tqdm import tqdm 4 | import os 5 | import numpy as np 6 | import inspect 7 | import nibabel as nib 8 | import torch 9 | 10 | from kits19cnn.inference.utils import load_weights_infer 11 | 12 | class Predictor(object): 13 | """ 14 | Inference for a single model for every file generated by `test_loader`. 15 | Predictions are saved in `out_dir`. 16 | """ 17 | def __init__(self, out_dir, checkpoint_path, model, 18 | test_loader, pred_3D_params={"do_mirroring": True}, 19 | pseudo_3D: bool = False): 20 | """ 21 | Attributes 22 | out_dir (str): path to the output directory to store predictions 23 | checkpoint_path (str): path to a model checkpoint for `model` 24 | model (torch.nn.Module): class with the `predict_3D` method for 25 | predicting a single patient volume. 26 | test_loader: Iterable instance for generating data 27 | (pref. torch DataLoader) 28 | must have the __len__ arg. 29 | pred_3D_params (dict): kwargs for `model.predict_3D` 30 | pseudo_3D (bool): whether or not to have pseudo 3D inputs 31 | """ 32 | self.out_dir = out_dir 33 | if not isdir(self.out_dir): 34 | os.mkdir(self.out_dir) 35 | print(f"Created {self.out_dir}!") 36 | assert inspect.ismethod(model.predict_3D), \ 37 | "model must have the method `predict_3D`" 38 | self.model = load_weights_infer(checkpoint_path, model) 39 | self.test_loader = test_loader 40 | self.pred_3D_params = pred_3D_params 41 | self.pseudo_3D = pseudo_3D 42 | 43 | def run_3D_predictions(self): 44 | """ 45 | Runs predictions on the dataset (specified in test_loader) 46 | """ 47 | cases = self.test_loader.dataset.im_ids 48 | assert len(cases) == len(self.test_loader) 49 | for (test_batch, case) in tqdm(zip(self.test_loader, cases), total=len(cases)): 50 | test_x = torch.squeeze(test_batch[0], dim=0) 51 | if self.pseudo_3D: 52 | pred, _, act, _ = self.model.predict_3D_pseudo3D_2Dconv(test_x, 53 | **self.pred_3D_params) 54 | else: 55 | pred, _, act, _ = self.model.predict_3D(test_x, 56 | **self.pred_3D_params) 57 | assert len(pred.shape) == 3 58 | assert len(act.shape) == 4 59 | ### possible place to threshold ROI size ### 60 | self.save_pred(pred, act, case) 61 | 62 | def save_pred(self, pred, act, case): 63 | """ 64 | Saves both prediction and activation maps in `out_dir` in the 65 | KiTS19 format. 66 | Args: 67 | pred (np.ndarray): shape (x, y, z) 68 | act (np.ndarray): shape (n_classes, x, y, z) 69 | case: path to a case folder (an element of self.cases) 70 | Returns: 71 | None 72 | """ 73 | # extracting the raw case folder name 74 | case = Path(case).name 75 | out_case_dir = join(self.out_dir, case) 76 | # checking to make sure that the output directories exist 77 | if not isdir(out_case_dir): 78 | os.mkdir(out_case_dir) 79 | 80 | np.save(join(out_case_dir, "pred.npy"), pred) 81 | np.save(join(out_case_dir, "pred_act.npy"), act) 82 | 83 | def resample_predictions(self, orig_spacing, target_spacing, 84 | resampled_preds_dir): 85 | """ 86 | Iterates through `out_dir` and creates resampled .npy arrays to 87 | the specified spacing and saves them in `resampled_preds_dir` 88 | """ 89 | from kits19cnn.io.resample import resample_patient 90 | raise NotImplementedError 91 | 92 | def prepare_submission(self): 93 | """ 94 | Resamples predictions and converts them to a .zip with .nii.gz files. 95 | """ 96 | raise NotImplementedError 97 | -------------------------------------------------------------------------------- /kits19cnn/experiments/train_3d.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from kits19cnn.models import Generic_UNet 4 | from kits19cnn.io import VoxelDataset, ClfSegVoxelDataset 5 | 6 | from .train import TrainExperiment, TrainClfSegExperiment 7 | from .utils import get_preprocessing, get_training_augmentation, \ 8 | get_validation_augmentation, seed_everything 9 | 10 | class TrainSegExperiment(TrainExperiment): 11 | """ 12 | Stores the main parts of a segmentation experiment: 13 | - df split 14 | - datasets 15 | - loaders 16 | - model 17 | - optimizer 18 | - lr_scheduler 19 | - criterion 20 | - callbacks 21 | """ 22 | def __init__(self, config: dict): 23 | """ 24 | Args: 25 | config (dict): from `train_seg_yaml.py` 26 | """ 27 | self.model_params = config["model_params"] 28 | super().__init__(config=config) 29 | 30 | def get_datasets(self, train_ids, valid_ids): 31 | """ 32 | Creates and returns the train and validation datasets. 33 | """ 34 | # preparing transforms 35 | train_aug = get_training_augmentation(self.io_params["aug_key"]) 36 | val_aug = get_validation_augmentation(self.io_params["aug_key"]) 37 | # creating the datasets 38 | train_dataset = VoxelDataset(im_ids=train_ids, 39 | transforms=train_aug, 40 | preprocessing=get_preprocessing()) 41 | valid_dataset = VoxelDataset(im_ids=valid_ids, 42 | transforms=val_aug, 43 | preprocessing=get_preprocessing()) 44 | return (train_dataset, valid_dataset) 45 | 46 | def get_model(self): 47 | architecture = self.model_params["architecture"] 48 | if architecture == "nnunet": 49 | architecture_kwargs = self.model_params[architecture] 50 | architecture_kwargs["conv_op"] = torch.nn.Conv3d 51 | architecture_kwargs["norm_op"] = torch.nn.InstanceNorm3d 52 | architecture_kwargs["dropout_op"] = torch.nn.Dropout3d 53 | architecture_kwargs["nonlin"] = torch.nn.ReLU 54 | architecture_kwargs["nonlin_kwargs"] = {"inplace": True} 55 | architecture_kwargs["final_nonlin"] = lambda x: x 56 | model = Generic_UNet(**architecture_kwargs) 57 | else: 58 | raise NotImplementedError 59 | # calculating # of parameters 60 | total = sum(p.numel() for p in model.parameters()) 61 | trainable = sum(p.numel() for p in model.parameters() if p.requires_grad) 62 | print(f"Total # of Params: {total}\nTrainable params: {trainable}") 63 | 64 | return model 65 | 66 | class TrainClfSegExperiment3D(TrainClfSegExperiment, TrainSegExperiment): 67 | """ 68 | Stores the main parts of a classification+segmentation experiment: 69 | - df split 70 | - datasets 71 | - loaders 72 | - model 73 | - optimizer 74 | - lr_scheduler 75 | - criterion 76 | - callbacks 77 | """ 78 | def __init__(self, config: dict): 79 | """ 80 | Args: 81 | config (dict): from `train_seg_yaml.py` 82 | """ 83 | self.model_params = config["model_params"] 84 | super().__init__(config=config) 85 | 86 | def get_datasets(self, train_ids, valid_ids): 87 | """ 88 | Creates and returns the train and validation datasets. 89 | """ 90 | # preparing transforms 91 | train_aug = get_training_augmentation(self.io_params["aug_key"]) 92 | val_aug = get_validation_augmentation(self.io_params["aug_key"]) 93 | # creating the datasets 94 | preprocess = get_preprocessing() 95 | train_dataset = ClfSegVoxelDataset(im_ids=train_ids, 96 | transforms=train_aug, 97 | preprocessing=preprocess, 98 | mode="both", num_classes=3) 99 | valid_dataset = ClfSegVoxelDataset(im_ids=valid_ids, 100 | transforms=val_aug, 101 | preprocessing=preprocess, 102 | mode="both", num_classes=3) 103 | return (train_dataset, valid_dataset) 104 | -------------------------------------------------------------------------------- /kits19cnn/loss_functions.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | import torch 15 | from torch import nn 16 | from segmentation_models_pytorch.utils.losses import DiceLoss 17 | 18 | from kits19cnn.utils import softmax_helper, sum_tensor 19 | 20 | class BCEDiceLoss(DiceLoss): 21 | __name__ = 'bce_dice_loss' 22 | 23 | def __init__(self, eps=1e-7, activation='sigmoid'): 24 | super().__init__(eps, activation) 25 | self.bce = nn.BCEWithLogitsLoss(reduction='mean') 26 | 27 | def forward(self, y_pr, y_gt): 28 | y_pr = y_pr.float() 29 | y_gt = y_gt.float() 30 | dice = super().forward(y_pr, y_gt) 31 | bce = self.bce(y_pr, y_gt) 32 | return dice + bce 33 | 34 | class CrossentropyND(nn.CrossEntropyLoss): 35 | """ 36 | Network has to have NO NONLINEARITY! 37 | """ 38 | def forward(self, inp, target): 39 | target = target.long() 40 | num_classes = inp.size()[1] 41 | 42 | i0 = 1 43 | i1 = 2 44 | 45 | while i1 < len(inp.shape): # this is ugly but torch only allows to transpose two axes at once 46 | inp = inp.transpose(i0, i1) 47 | i0 += 1 48 | i1 += 1 49 | 50 | inp = inp.contiguous() 51 | inp = inp.view(-1, num_classes) 52 | 53 | target = target.view(-1,) 54 | 55 | return super(CrossentropyND, self).forward(inp, target) 56 | 57 | def get_tp_fp_fn(net_output, gt, axes=None, mask=None, square=False): 58 | """ 59 | net_output must be (b, c, x, y(, z))) 60 | gt must be a label map (shape (b, 1, x, y(, z)) OR shape (b, x, y(, z))) or one hot encoding (b, c, x, y(, z)) 61 | if mask is provided it must have shape (b, 1, x, y(, z))) 62 | :param net_output: 63 | :param gt: 64 | :param axes: 65 | :param mask: mask must be 1 for valid pixels and 0 for invalid pixels 66 | :param square: if True then fp, tp and fn will be squared before summation 67 | :return: 68 | """ 69 | if axes is None: 70 | axes = tuple(range(2, len(net_output.size()))) 71 | 72 | shp_x = net_output.shape 73 | shp_y = gt.shape 74 | 75 | with torch.no_grad(): 76 | if len(shp_x) != len(shp_y): 77 | gt = gt.view((shp_y[0], 1, *shp_y[1:])) 78 | 79 | if all([i == j for i, j in zip(net_output.shape, gt.shape)]): 80 | # if this is the case then gt is probably already a one hot encoding 81 | y_onehot = gt 82 | else: 83 | gt = gt.long() 84 | y_onehot = torch.zeros(shp_x) 85 | if net_output.device.type == "cuda": 86 | y_onehot = y_onehot.cuda(net_output.device.index) 87 | y_onehot.scatter_(1, gt, 1) 88 | 89 | tp = net_output * y_onehot 90 | fp = net_output * (1 - y_onehot) 91 | fn = (1 - net_output) * y_onehot 92 | 93 | if mask is not None: 94 | tp = torch.stack(tuple(x_i * mask[:, 0] for x_i in torch.unbind(tp, dim=1)), dim=1) 95 | fp = torch.stack(tuple(x_i * mask[:, 0] for x_i in torch.unbind(fp, dim=1)), dim=1) 96 | fn = torch.stack(tuple(x_i * mask[:, 0] for x_i in torch.unbind(fn, dim=1)), dim=1) 97 | 98 | if square: 99 | tp = tp ** 2 100 | fp = fp ** 2 101 | fn = fn ** 2 102 | 103 | tp = sum_tensor(tp, axes, keepdim=False) 104 | fp = sum_tensor(fp, axes, keepdim=False) 105 | fn = sum_tensor(fn, axes, keepdim=False) 106 | 107 | return tp, fp, fn 108 | 109 | 110 | class SoftDiceLoss(nn.Module): 111 | def __init__(self, apply_nonlin=None, batch_dice=False, do_bg=True, 112 | smooth=1., square=False): 113 | """ 114 | """ 115 | super(SoftDiceLoss, self).__init__() 116 | 117 | self.square = square 118 | self.do_bg = do_bg 119 | self.batch_dice = batch_dice 120 | self.apply_nonlin = apply_nonlin 121 | self.smooth = smooth 122 | 123 | def forward(self, x, y, loss_mask=None): 124 | shp_x = x.shape 125 | 126 | if self.batch_dice: 127 | axes = [0] + list(range(2, len(shp_x))) 128 | else: 129 | axes = list(range(2, len(shp_x))) 130 | 131 | if self.apply_nonlin is not None: 132 | x = self.apply_nonlin(x) 133 | 134 | tp, fp, fn = get_tp_fp_fn(x, y, axes, loss_mask, self.square) 135 | 136 | dc = (2 * tp + self.smooth) / (2 * tp + fp + fn + self.smooth) 137 | 138 | if not self.do_bg: 139 | if self.batch_dice: 140 | dc = dc[1:] 141 | else: 142 | dc = dc[:, 1:] 143 | dc = dc.mean() 144 | 145 | return -dc 146 | 147 | 148 | class DC_and_CE_loss(nn.Module): 149 | def __init__(self, soft_dice_kwargs, ce_kwargs, aggregate="sum"): 150 | super(DC_and_CE_loss, self).__init__() 151 | self.aggregate = aggregate 152 | self.ce = CrossentropyND(**ce_kwargs) 153 | self.dc = SoftDiceLoss(apply_nonlin=softmax_helper, **soft_dice_kwargs) 154 | 155 | def forward(self, net_output, target): 156 | dc_loss = self.dc(net_output, target) 157 | ce_loss = self.ce(net_output, target) 158 | if self.aggregate == "sum": 159 | result = ce_loss + dc_loss 160 | else: 161 | raise NotImplementedError("nah son") # reserved for other stuff (later) 162 | return result 163 | -------------------------------------------------------------------------------- /kits19cnn/experiments/infer.py: -------------------------------------------------------------------------------- 1 | from glob import glob 2 | from tqdm import tqdm 3 | from abc import abstractmethod 4 | import os 5 | 6 | import torch 7 | from torch.utils.data import DataLoader 8 | from sklearn.model_selection import train_test_split 9 | 10 | from kits19cnn.io import TestVoxelDataset 11 | from kits19cnn.utils import softmax_helper 12 | from kits19cnn.models import Generic_UNet 13 | from .utils import get_preprocessing 14 | 15 | class BaseInferenceExperiment(object): 16 | def __init__(self, config: dict): 17 | """ 18 | Args: 19 | config (dict): 20 | 21 | Attributes: 22 | config-related: 23 | config (dict): 24 | io_params (dict): 25 | in_dir (key: str): path to the data folder 26 | test_size (key: float): split size for test 27 | split_seed (key: int): seed 28 | batch_size (key: int): <- 29 | num_workers (key: int): # of workers for data loaders 30 | split_dict (dict): test_ids 31 | test_dset (torch.data.Dataset): <- 32 | loaders (dict): train/validation loaders 33 | model (torch.nn.Module): <- 34 | """ 35 | # for reuse 36 | self.config = config 37 | self.io_params = config["io_params"] 38 | # initializing the experiment components 39 | self.case_list = self.setup_im_ids() 40 | test_ids = self.get_split()[-1] if config["with_masks"] else self.case_list 41 | print(f"Inferring on {len(test_ids)} test cases") 42 | self.test_dset = self.get_datasets(test_ids) 43 | self.loaders = self.get_loaders() 44 | self.model = self.get_model() 45 | 46 | @abstractmethod 47 | def get_datasets(self, test_ids): 48 | """ 49 | Initializes the data augmentation and preprocessing transforms. Creates 50 | and returns the train and validation datasets. 51 | """ 52 | return 53 | 54 | @abstractmethod 55 | def get_model(self): 56 | """ 57 | Creates and returns the model. 58 | """ 59 | return 60 | 61 | def setup_im_ids(self): 62 | """ 63 | Creates a list of all paths to case folders for the dataset split 64 | """ 65 | search_path = os.path.join(self.config["in_dir"], "*/") 66 | case_list = sorted(glob(search_path)) 67 | case_list = case_list[:210] if self.config["with_masks"] else case_list[210:] 68 | return case_list 69 | 70 | def get_split(self): 71 | """ 72 | Creates train/valid filename splits 73 | """ 74 | # setting up the train/val split with filenames 75 | split_seed: int = self.io_params["split_seed"] 76 | test_size: float = self.io_params["test_size"] 77 | # doing the splits: 1-test_size, test_size//2, test_size//2 78 | print("Splitting the dataset normally...") 79 | train_ids, total_test = train_test_split(self.case_list, 80 | random_state=split_seed, 81 | test_size=test_size) 82 | val_ids, test_ids = train_test_split(sorted(total_test), 83 | random_state=split_seed, 84 | test_size=0.5) 85 | return (train_ids, val_ids, test_ids) 86 | 87 | def get_loaders(self): 88 | """ 89 | Creates train/val loaders from datasets created in self.get_datasets. 90 | Returns the loaders. 91 | """ 92 | # setting up the loaders 93 | b_size, num_workers = self.io_params["batch_size"], self.io_params["num_workers"] 94 | test_loader = DataLoader(self.test_dset, batch_size=b_size, 95 | shuffle=False, num_workers=num_workers) 96 | return {"test": test_loader} 97 | 98 | class SegmentationInferenceExperiment(BaseInferenceExperiment): 99 | """ 100 | Inference Experiment to support prediction experiments 101 | """ 102 | def __init__(self, config: dict): 103 | """ 104 | Args: 105 | config (dict): 106 | """ 107 | self.model_params = config["model_params"] 108 | super().__init__(config=config) 109 | 110 | def get_datasets(self, test_ids): 111 | """ 112 | Creates and returns the test dataset. 113 | """ 114 | # creating the datasets 115 | test_dataset = TestVoxelDataset(im_ids=test_ids, 116 | transforms=None, 117 | preprocessing=get_preprocessing(), 118 | file_ending=self.io_params["file_ending"]) 119 | return test_dataset 120 | 121 | def get_model(self): 122 | architecture = self.model_params["architecture"] 123 | # creating model 124 | if architecture == "nnunet": 125 | unet_kwargs = self.model_params[architecture] 126 | unet_kwargs = self.setup_3D_UNet_params(unet_kwargs) 127 | model = Generic_UNet(**unet_kwargs) 128 | model.inference_apply_nonlin = softmax_helper 129 | else: 130 | raise NotImplementedError 131 | # calculating # of parameters 132 | total = sum(p.numel() for p in model.parameters()) 133 | trainable = sum(p.numel() for p in model.parameters() if p.requires_grad) 134 | print(f"Total # of Params: {total}\nTrainable params: {trainable}") 135 | 136 | return model.cuda() 137 | 138 | def setup_3D_UNet_params(self, unet_kwargs): 139 | """ 140 | ^^^^^^^^^^^^^ 141 | """ 142 | unet_kwargs["conv_op"] = torch.nn.Conv3d 143 | unet_kwargs["norm_op"] = torch.nn.InstanceNorm3d 144 | unet_kwargs["dropout_op"] = torch.nn.Dropout3d 145 | unet_kwargs["nonlin"] = torch.nn.ReLU 146 | unet_kwargs["nonlin_kwargs"] = {"inplace": True} 147 | unet_kwargs["final_nonlin"] = lambda x: x 148 | return unet_kwargs 149 | -------------------------------------------------------------------------------- /kits19cnn/visualize.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | import os 3 | 4 | from typing import Dict, List, Optional, Union # isort:skip 5 | from collections import defaultdict 6 | from pathlib import Path 7 | 8 | import plotly.graph_objs as go 9 | from plotly.offline import init_notebook_mode, iplot 10 | 11 | from catalyst.utils.tensorboard import SummaryItem, SummaryReader 12 | 13 | print("If you're using a notebook, " 14 | "make sure to run %matplotlib inline beforehand.") 15 | 16 | def plot_scan(scan, start_with, show_every, rows=3, cols=3): 17 | """ 18 | Plots multiple scans throughout your medical image. 19 | Args: 20 | scan: numpy array with shape (x,y,z) 21 | start_with: slice to start with 22 | show_every: size of the step between each slice iteration 23 | rows: rows of plot 24 | cols: cols of plot 25 | Returns: 26 | a plot of multiple scans from the same image 27 | """ 28 | fig,ax = plt.subplots(rows, cols, figsize=[3*cols,3*rows]) 29 | for i in range(rows*cols): 30 | ind = start_with + i*show_every 31 | ax[int(i/cols), int(i%cols)].set_title("slice %d" % ind) 32 | ax[int(i/cols), int(i%cols)].axis("off") 33 | 34 | ax[int(i/cols), int(i%cols)].imshow(scan[ind], cmap="gray") 35 | plt.show() 36 | 37 | def plot_scan_and_mask(scan, mask, start_with, show_every, rows=3, cols=3): 38 | """ 39 | Plots multiple scans with the mask overlay throughout your medical image. 40 | Args: 41 | scan: numpy array with shape (x,y,z) 42 | start_with: slice to start with 43 | show_every: size of the step between each slice iteration 44 | rows: rows of plot 45 | cols: cols of plot 46 | Returns: 47 | a plot of multiple scans from the same image 48 | """ 49 | fig,ax = plt.subplots(rows, cols, figsize=[4*cols, 4*rows]) 50 | for i in range(rows*cols): 51 | ind = start_with + i*show_every 52 | ax[int(i/cols), int(i%cols)].set_title("slice %d" % ind) 53 | ax[int(i/cols), int(i%cols)].axis("off") 54 | 55 | ax[int(i/cols), int(i%cols)].imshow(scan[ind], cmap="gray") 56 | ax[int(i/cols), int(i%cols)].imshow(mask[ind], cmap="jet", alpha=0.5) 57 | plt.show() 58 | 59 | # FROM: https://github.com/catalyst-team/catalyst/blob/master/catalyst/dl/utils/visualization.py 60 | def _get_tensorboard_scalars( 61 | logdir: Union[str, Path], metrics: Optional[List[str]], step: str 62 | ) -> Dict[str, List]: 63 | summary_reader = SummaryReader(logdir, types=["scalar"]) 64 | 65 | items = defaultdict(list) 66 | for item in summary_reader: 67 | if step in item.tag and ( 68 | metrics is None or any(m in item.tag for m in metrics) 69 | ): 70 | items[item.tag].append(item) 71 | return items 72 | 73 | def _get_scatter(scalars: List[SummaryItem], name: str) -> go.Scatter: 74 | xs = [s.step for s in scalars] 75 | ys = [s.value for s in scalars] 76 | return go.Scatter(x=xs, y=ys, name=name) 77 | 78 | def plot_tensorboard_log( 79 | logdir: Union[str, Path], 80 | step: Optional[str] = "batch", 81 | metrics: Optional[List[str]] = None, 82 | height: Optional[int] = None, 83 | width: Optional[int] = None 84 | ) -> None: 85 | init_notebook_mode() 86 | logdir = Path(logdir) 87 | 88 | logdirs = { 89 | x.name.replace("_log", ""): x 90 | for x in logdir.glob("**/*") if x.is_dir() and str(x).endswith("_log") 91 | } 92 | 93 | scalars_per_loader = { 94 | key: _get_tensorboard_scalars(inner_logdir, metrics, step) 95 | for key, inner_logdir in logdirs.items() 96 | } 97 | 98 | scalars_per_metric = defaultdict(dict) 99 | for key, value in scalars_per_loader.items(): 100 | for key2, value2 in value.items(): 101 | scalars_per_metric[key2][key] = value2 102 | 103 | figs = [] 104 | for metric_name, metric_logs in scalars_per_metric.items(): 105 | metric_data = [] 106 | for key, value in metric_logs.items(): 107 | try: 108 | data_ = _get_scatter(value, f"{key}/{metric_name}") 109 | metric_data.append(data_) 110 | except: # noqa: E722 111 | pass 112 | 113 | layout = go.Layout( 114 | title=metric_name, 115 | height=height, 116 | width=width, 117 | yaxis=dict(hoverformat=".5f") 118 | ) 119 | fig = go.Figure(data=metric_data, layout=layout) 120 | iplot(fig) 121 | figs.append(fig) 122 | return figs 123 | 124 | def plot_metrics( 125 | logdir: Union[str, Path], 126 | step: Optional[str] = "epoch", 127 | metrics: Optional[List[str]] = None, 128 | height: Optional[int] = None, 129 | width: Optional[int] = None 130 | ) -> None: 131 | """Plots your learning results. 132 | Args: 133 | logdir: the logdir that was specified during training. 134 | step: 'batch' or 'epoch' - what logs to show: for batches or 135 | for epochs 136 | metrics: list of metrics to plot. The loss should be specified as 137 | 'loss', learning rate = '_base/lr' and other metrics should be 138 | specified as names in metrics dict 139 | that was specified during training 140 | height: the height of the whole resulting plot 141 | width: the width of the whole resulting plot 142 | """ 143 | assert step in ["batch", "epoch"], \ 144 | f"Step should be either 'batch' or 'epoch', got '{step}'" 145 | metrics = metrics or ["loss"] 146 | return plot_tensorboard_log(logdir, step, metrics, height, width) 147 | 148 | def save_figs(figs_list, save_dir=None): 149 | """ 150 | Saves plotly figures. (from plot_metrics) 151 | """ 152 | if save_dir is None: 153 | save_dir = os.getcwd() 154 | 155 | for fig in figs_list: 156 | # takes a metric like train/f1/class_0/epoch to f1_class_0_epoch 157 | train_metric_name = fig["data"][0]["name"] 158 | split = train_metric_name.split("/") 159 | metric_name = "".join([f"{name}_" for name in split 160 | if not name in ["train", "valid"]])[:-1] 161 | save_name = os.path.join(save_dir, f"{metric_name}.png") 162 | fig.write_image(save_name) 163 | print(f"Saved {save_name}...") 164 | -------------------------------------------------------------------------------- /notebooks/Visualizing Volumes.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Visualizing the Data\n", 8 | "This notebooks uses [`ipyvolume`](https://github.com/maartenbreddels/ipyvolume)\n", 9 | "To install: (make sure you have [`jupyter_contrib_nbextensions`](https://github.com/ipython-contrib/jupyter_contrib_nbextensions) installed)\n", 10 | "i.e.\n", 11 | "```\n", 12 | "conda install -c conda-forge jupyter_contrib_nbextensions\n", 13 | "```\n", 14 | "\n", 15 | "```\n", 16 | "conda install -c conda-forge ipyvolume OR pip install ipyvolume\n", 17 | "jupyter nbextension enable --py --sys-prefix ipyvolume\n", 18 | "jupyter nbextension enable --py --sys-prefix widgetsnbextension\n", 19 | "```\n", 20 | "* Holes in volume\n", 21 | " * patients 0-3, 7, 9, 14..." 22 | ] 23 | }, 24 | { 25 | "cell_type": "markdown", 26 | "metadata": {}, 27 | "source": [ 28 | "# Raw Interpolated" 29 | ] 30 | }, 31 | { 32 | "cell_type": "code", 33 | "execution_count": 1, 34 | "metadata": {}, 35 | "outputs": [ 36 | { 37 | "name": "stdout", 38 | "output_type": "stream", 39 | "text": [ 40 | "Case: case_00014; Shape: (146, 556, 556)\n" 41 | ] 42 | } 43 | ], 44 | "source": [ 45 | "import os\n", 46 | "import nibabel as nib\n", 47 | "\n", 48 | "dset_dir = r\"C:\\Users\\jchen\\Desktop\\Datasets\\kits19\\data\"\n", 49 | "i = 14\n", 50 | "case = f\"case_000{i}\"\n", 51 | "image = nib.load(os.path.join(dset_dir, case, \"imaging.nii.gz\")).get_fdata().squeeze()\n", 52 | "mask = nib.load(os.path.join(dset_dir, case, \"segmentation.nii.gz\")).get_fdata().squeeze()\n", 53 | "\n", 54 | "print(f\"Case: {case}; Shape: {image.shape}\")\n", 55 | "assert mask.shape == image.shape" 56 | ] 57 | }, 58 | { 59 | "cell_type": "code", 60 | "execution_count": 2, 61 | "metadata": {}, 62 | "outputs": [ 63 | { 64 | "name": "stdout", 65 | "output_type": "stream", 66 | "text": [ 67 | "[[ 0. 0. -0.78162497 0. ]\n", 68 | " [ 0. -0.78162497 0. 0. ]\n", 69 | " [-3. 0. 0. 0. ]\n", 70 | " [ 0. 0. 0. 1. ]]\n" 71 | ] 72 | } 73 | ], 74 | "source": [ 75 | "print(nib.load(os.path.join(dset_dir, case, \"imaging.nii.gz\")).affine)" 76 | ] 77 | }, 78 | { 79 | "cell_type": "code", 80 | "execution_count": 3, 81 | "metadata": {}, 82 | "outputs": [ 83 | { 84 | "name": "stderr", 85 | "output_type": "stream", 86 | "text": [ 87 | "C:\\Users\\jchen\\Miniconda3\\envs\\py36\\lib\\site-packages\\ipyvolume\\widgets.py:179: FutureWarning: Using a non-tuple sequence for multidimensional indexing is deprecated; use `arr[tuple(seq)]` instead of `arr[seq]`. In the future this will be interpreted as an array index, `arr[np.array(seq)]`, which will result either in an error or a different result.\n", 88 | " data_view = self.data_original[view]\n", 89 | "C:\\Users\\jchen\\Miniconda3\\envs\\py36\\lib\\site-packages\\ipyvolume\\utils.py:204: FutureWarning: Using a non-tuple sequence for multidimensional indexing is deprecated; use `arr[tuple(seq)]` instead of `arr[seq]`. In the future this will be interpreted as an array index, `arr[np.array(seq)]`, which will result either in an error or a different result.\n", 90 | " data = (data[slices1] + data[slices2])/2\n", 91 | "C:\\Users\\jchen\\Miniconda3\\envs\\py36\\lib\\site-packages\\ipyvolume\\serialize.py:81: RuntimeWarning: invalid value encountered in true_divide\n", 92 | " gradient = gradient / np.sqrt(gradient[0]**2 + gradient[1]**2 + gradient[2]**2)\n" 93 | ] 94 | }, 95 | { 96 | "data": { 97 | "application/vnd.jupyter.widget-view+json": { 98 | "model_id": "49c7b4f7217a417a9edf3de52124d107", 99 | "version_major": 2, 100 | "version_minor": 0 101 | }, 102 | "text/plain": [ 103 | "VBox(children=(VBox(children=(HBox(children=(Label(value='levels:'), FloatSlider(value=0.1, max=1.0, step=0.00…" 104 | ] 105 | }, 106 | "metadata": {}, 107 | "output_type": "display_data" 108 | } 109 | ], 110 | "source": [ 111 | "import ipyvolume as ipv\n", 112 | "ipv.quickvolshow(mask)" 113 | ] 114 | }, 115 | { 116 | "cell_type": "markdown", 117 | "metadata": {}, 118 | "source": [ 119 | "## Preprocessed" 120 | ] 121 | }, 122 | { 123 | "cell_type": "code", 124 | "execution_count": 6, 125 | "metadata": {}, 126 | "outputs": [ 127 | { 128 | "name": "stdout", 129 | "output_type": "stream", 130 | "text": [ 131 | "Case: case_00014; Shape: (136, 268, 268)\n" 132 | ] 133 | } 134 | ], 135 | "source": [ 136 | "import os\n", 137 | "import numpy as np\n", 138 | "\n", 139 | "dset_dir = r\"C:\\Users\\jchen\\Desktop\\Datasets\\kits_preprocessed_isensee_spacing\"\n", 140 | "case = f\"case_000{i}\"\n", 141 | "image = np.load(os.path.join(dset_dir, case, \"imaging.npy\")).squeeze()\n", 142 | "mask = np.load(os.path.join(dset_dir, case, \"segmentation.npy\")).squeeze()\n", 143 | "\n", 144 | "print(f\"Case: {case}; Shape: {image.shape}\")\n", 145 | "assert mask.shape == image.shape" 146 | ] 147 | }, 148 | { 149 | "cell_type": "code", 150 | "execution_count": 7, 151 | "metadata": { 152 | "scrolled": false 153 | }, 154 | "outputs": [ 155 | { 156 | "data": { 157 | "application/vnd.jupyter.widget-view+json": { 158 | "model_id": "fb519330c64e4f959797013e325e8cc7", 159 | "version_major": 2, 160 | "version_minor": 0 161 | }, 162 | "text/plain": [ 163 | "VBox(children=(VBox(children=(HBox(children=(Label(value='levels:'), FloatSlider(value=0.1, max=1.0, step=0.00…" 164 | ] 165 | }, 166 | "metadata": {}, 167 | "output_type": "display_data" 168 | } 169 | ], 170 | "source": [ 171 | "import ipyvolume as ipv\n", 172 | "ipv.quickvolshow(mask)" 173 | ] 174 | }, 175 | { 176 | "cell_type": "code", 177 | "execution_count": null, 178 | "metadata": {}, 179 | "outputs": [], 180 | "source": [] 181 | }, 182 | { 183 | "cell_type": "markdown", 184 | "metadata": {}, 185 | "source": [ 186 | "# Predictions" 187 | ] 188 | }, 189 | { 190 | "cell_type": "code", 191 | "execution_count": 5, 192 | "metadata": {}, 193 | "outputs": [], 194 | "source": [ 195 | "# ..." 196 | ] 197 | }, 198 | { 199 | "cell_type": "code", 200 | "execution_count": null, 201 | "metadata": {}, 202 | "outputs": [], 203 | "source": [] 204 | } 205 | ], 206 | "metadata": { 207 | "kernelspec": { 208 | "display_name": "Python 3", 209 | "language": "python", 210 | "name": "python3" 211 | }, 212 | "language_info": { 213 | "codemirror_mode": { 214 | "name": "ipython", 215 | "version": 3 216 | }, 217 | "file_extension": ".py", 218 | "mimetype": "text/x-python", 219 | "name": "python", 220 | "nbconvert_exporter": "python", 221 | "pygments_lexer": "ipython3", 222 | "version": "3.6.8" 223 | } 224 | }, 225 | "nbformat": 4, 226 | "nbformat_minor": 2 227 | } 228 | -------------------------------------------------------------------------------- /kits19cnn/io/custom_transforms.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import math 3 | import random 4 | from batchgenerators.transforms import AbstractTransform 5 | 6 | from kits19cnn.io.custom_augmentations import foreground_crop, center_crop, \ 7 | random_resized_crop 8 | 9 | class RandomResizedCropTransform(AbstractTransform): 10 | """ 11 | Crop the given array to random size and aspect ratio. 12 | Doesn't resize across the depth dimenion (assumes it is dim=0) if 13 | the data is 3D. 14 | 15 | A crop of random size (default: of 0.08 to 1.0) of the original size and a 16 | random aspect ratio (default: of 3/4 to 4/3) of the original aspect ratio 17 | is made. This crop is finally resized to given size. 18 | This is popularly used to train the Inception networks. 19 | 20 | Assumes the data and segmentation masks are the same size. 21 | """ 22 | def __init__(self, target_size, scale=(0.08, 1.0), 23 | ratio=(3. / 4., 4. / 3.), 24 | data_key="data", label_key="seg", p_per_sample=0.33, 25 | crop_kwargs={}, resize_kwargs={}): 26 | """ 27 | Attributes: 28 | pass 29 | """ 30 | if len(target_size) > 2: 31 | print("Currently only adjusts the aspect ratio for the 2D dims.") 32 | if (scale[0] > scale[1]) or (ratio[0] > ratio[1]): 33 | warnings.warn("range should be of kind (min, max)") 34 | self.target_size = target_size 35 | self.scale = scale 36 | self.ratio = ratio 37 | self.data_key = data_key 38 | self.label_key = label_key 39 | self.p_per_sample = p_per_sample 40 | self.crop_kwargs = crop_kwargs 41 | self.resize_kwargs = resize_kwargs 42 | 43 | def _get_image_size(self, data): 44 | """ 45 | Assumes data has shape (b, c, h, w (, d)). Fetches the h, w, and d. 46 | depth if applicable. 47 | """ 48 | return data.shape[2:] 49 | 50 | def get_crop_size(self, data, scale, ratio): 51 | """ 52 | Get parameters for ``crop`` for a random sized crop. 53 | """ 54 | shape_dims = self._get_image_size(data) 55 | area = np.prod(shape_dims) 56 | 57 | while True: 58 | target_area = random.uniform(*scale) * area 59 | log_ratio = (math.log(ratio[0]), math.log(ratio[1])) 60 | aspect_ratio = math.exp(random.uniform(*log_ratio)) 61 | 62 | w = int(round(math.sqrt(target_area * aspect_ratio))) 63 | h = int(round(math.sqrt(target_area / aspect_ratio))) 64 | 65 | if len(shape_dims) == 3: 66 | depth = shape_dims[0] 67 | crop_size = np.array([depth, h, w]) 68 | else: 69 | crop_size = np.array([h, w]) 70 | 71 | if (crop_size <= shape_dims).all() and (crop_size > 0).all(): 72 | return crop_size 73 | 74 | def __call__(self, **data_dict): 75 | """ 76 | Actually doing the cropping. 77 | """ 78 | data = data_dict.get(self.data_key) 79 | seg = data_dict.get(self.label_key) 80 | if np.random.uniform() < self.p_per_sample: 81 | crop_size = self.get_crop_size(data, self.scale, self.ratio) 82 | data, seg = random_resized_crop(data, seg, 83 | target_size=self.target_size, 84 | crop_size=crop_size, 85 | crop_kwargs=self.crop_kwargs, 86 | resize_kwargs=self.resize_kwargs) 87 | else: 88 | data, seg = center_crop(data, self.target_size, seg, 89 | crop_kwargs=self.crop_kwargs) 90 | 91 | data_dict[self.data_key] = data 92 | if seg is not None: 93 | data_dict[self.label_key] = seg.astype(np.float32) 94 | 95 | return data_dict 96 | 97 | class ROICropTransform(AbstractTransform): 98 | """ 99 | Crops the foreground in images `p_per_sample` part of the time. The 100 | fallback cropping is center cropping. 101 | """ 102 | def __init__(self, crop_size=128, margins=(0, 0, 0), data_key="data", 103 | label_key="seg", coords_key="bbox_coords", 104 | p_per_sample=0.33, crop_kwargs={}): 105 | self.data_key = data_key 106 | self.label_key = label_key 107 | self.coords_key = coords_key 108 | self.margins = margins 109 | self.crop_size = crop_size 110 | self.p_per_sample = p_per_sample 111 | self.crop_kwargs = crop_kwargs 112 | 113 | def __call__(self, **data_dict): 114 | """ 115 | Actually doing the cropping. Make sure that data_dict has the 116 | a key for the coords (self.coords_key) if p>0. 117 | (If the output of data_dict.get(self.coords_key) is None, then foreground 118 | crops are done on-the-fly). 119 | """ 120 | data = data_dict.get(self.data_key) 121 | seg = data_dict.get(self.label_key) 122 | if np.random.uniform() < self.p_per_sample: 123 | coords = data_dict.get(self.coords_key) 124 | 125 | data, seg = foreground_crop(data, seg, patch_size=self.crop_size, 126 | margins=self.margins, 127 | bbox_coords=coords, 128 | crop_kwargs=self.crop_kwargs) 129 | else: 130 | data, seg = center_crop(data, self.crop_size, seg, 131 | crop_kwargs=self.crop_kwargs) 132 | 133 | data_dict[self.data_key] = data 134 | if seg is not None: 135 | data_dict[self.label_key] = seg 136 | 137 | return data_dict 138 | 139 | class MultiClassToBinaryTransform(AbstractTransform): 140 | """ 141 | For changing a multi-class case to a binary one. Specify the label to 142 | change to binary with `roi_label`. 143 | - Don't forget to adjust `remove_label` accordingly! 144 | - label will be turned to a binary label with only `roi_label` 145 | existing as 1s 146 | """ 147 | def __init__(self, roi_label="2", remove_label="1", label_key="seg"): 148 | self.roi_label = int(roi_label) 149 | self.remove_label = int(remove_label) 150 | self.label_key = label_key 151 | 152 | def __call__(self, **data_dict): 153 | """ 154 | Replaces the label values 155 | """ 156 | label = data_dict.get(self.label_key) 157 | # changing labels 158 | label[label == self.remove_label] = 0 159 | label[label == self.roi_label] = 1 160 | 161 | data_dict[self.label_key] = label 162 | 163 | return data_dict 164 | 165 | class RepeatChannelsTransform(AbstractTransform): 166 | """ 167 | Repeats across the channels dimension `num_tiles` number of times. 168 | """ 169 | def __init__(self, num_repeats=3, data_key="data"): 170 | self.num_repeats = num_repeats 171 | self.data_key = data_key 172 | 173 | def __call__(self, **data_dict): 174 | """ 175 | Repeats across the channels dimension (axis=1). 176 | """ 177 | data = data_dict.get(self.data_key) 178 | 179 | data_dict[self.data_key] = np.repeat(data, self.num_repeats, axis=1) 180 | 181 | return data_dict 182 | -------------------------------------------------------------------------------- /kits19cnn/experiments/train_2d.py: -------------------------------------------------------------------------------- 1 | import json 2 | from abc import abstractmethod 3 | 4 | import torch 5 | import segmentation_models_pytorch as smp 6 | 7 | from kits19cnn.io import SliceDataset, PseudoSliceDataset 8 | from kits19cnn.models import Generic_UNet 9 | from .utils import get_training_augmentation, get_validation_augmentation, \ 10 | get_preprocessing 11 | from .train import TrainExperiment, TrainClfSegExperiment 12 | 13 | class TrainExperiment2D(TrainExperiment): 14 | """ 15 | Stores the main parts of a experiment with 2D images: 16 | - df split 17 | - datasets 18 | - loaders 19 | - model 20 | - optimizer 21 | - lr_scheduler 22 | - criterion 23 | - callbacks 24 | """ 25 | def __init__(self, config: dict): 26 | """ 27 | Args: 28 | config (dict): from `train_seg_yaml.py` 29 | """ 30 | self.model_params = config["model_params"] 31 | super().__init__(config=config) 32 | 33 | @abstractmethod 34 | def get_model(self): 35 | """ 36 | Creates and returns the model. 37 | """ 38 | return 39 | 40 | def get_datasets(self, train_ids, valid_ids): 41 | """ 42 | Creates and returns the train and validation datasets. 43 | """ 44 | # preparing transforms 45 | train_aug = get_training_augmentation(self.io_params["aug_key"]) 46 | val_aug = get_validation_augmentation(self.io_params["aug_key"]) 47 | use_rgb = "smp" in self.model_params["architecture"] 48 | # creating the datasets 49 | with open(self.io_params["slice_indices_path"], "r") as fp: 50 | pos_slice_dict = json.load(fp) 51 | p_pos_per_sample = self.io_params["p_pos_per_sample"] 52 | if self.io_params.get("pseudo_3D"): 53 | assert not use_rgb, \ 54 | "Currently architectures that require RGB inputs cannot use pseudo slices." 55 | train_dataset = PseudoSliceDataset(im_ids=train_ids, 56 | pos_slice_dict=pos_slice_dict, 57 | transforms=train_aug, 58 | preprocessing=get_preprocessing(use_rgb), 59 | p_pos_per_sample=p_pos_per_sample, 60 | mode=self.config["mode"], 61 | num_pseudo_slices=self.io_params["num_pseudo_slices"]) 62 | valid_dataset = PseudoSliceDataset(im_ids=valid_ids, 63 | pos_slice_dict=pos_slice_dict, 64 | transforms=val_aug, 65 | preprocessing=get_preprocessing(use_rgb), 66 | p_pos_per_sample=p_pos_per_sample, 67 | mode=self.config["mode"], 68 | num_pseudo_slices=self.io_params["num_pseudo_slices"]) 69 | else: 70 | train_dataset = SliceDataset(im_ids=train_ids, 71 | pos_slice_dict=pos_slice_dict, 72 | transforms=train_aug, 73 | preprocessing=get_preprocessing(use_rgb), 74 | p_pos_per_sample=p_pos_per_sample, 75 | mode=self.config["mode"]) 76 | valid_dataset = SliceDataset(im_ids=valid_ids, 77 | pos_slice_dict=pos_slice_dict, 78 | transforms=val_aug, 79 | preprocessing=get_preprocessing(use_rgb), 80 | p_pos_per_sample=p_pos_per_sample, 81 | mode=self.config["mode"]) 82 | 83 | return (train_dataset, valid_dataset) 84 | 85 | class TrainSegExperiment2D(TrainExperiment2D): 86 | """ 87 | Stores the main parts of a segmentation experiment: 88 | - df split 89 | - datasets 90 | - loaders 91 | - model 92 | - optimizer 93 | - lr_scheduler 94 | - criterion 95 | - callbacks 96 | """ 97 | def __init__(self, config: dict): 98 | """ 99 | Args: 100 | config (dict): from `train_seg_yaml.py` 101 | """ 102 | self.model_params = config["model_params"] 103 | super().__init__(config=config) 104 | 105 | def get_model(self): 106 | architecture = self.model_params["architecture"] 107 | if architecture.lower() == "nnunet": 108 | architecture_kwargs = self.model_params[architecture] 109 | architecture_kwargs["norm_op"] = torch.nn.InstanceNorm2d 110 | architecture_kwargs["nonlin"] = torch.nn.ReLU 111 | architecture_kwargs["nonlin_kwargs"] = {"inplace": True} 112 | architecture_kwargs["final_nonlin"] = lambda x: x 113 | model = Generic_UNet(**architecture_kwargs) 114 | elif architecture.lower() == "unet_smp": 115 | model = smp.Unet(encoder_name=self.model_params["encoder"], 116 | encoder_weights="imagenet", activation=None, 117 | **self.model_params[architecture]) 118 | elif architecture.lower() == "fpn_smp": 119 | model = smp.FPN(encoder_name=self.model_params["encoder"], 120 | encoder_weights="imagenet", activation=None, 121 | **self.model_params[architecture]) 122 | # calculating # of parameters 123 | total = sum(p.numel() for p in model.parameters()) 124 | trainable = sum(p.numel() for p in model.parameters() if p.requires_grad) 125 | print(f"Total # of Params: {total}\nTrainable params: {trainable}") 126 | 127 | return model 128 | 129 | class TrainClfSegExperiment2D(TrainExperiment2D, TrainClfSegExperiment): 130 | """ 131 | Stores the main parts of a classification+segmentation experiment: 132 | - df split 133 | - datasets 134 | - loaders 135 | - model 136 | - optimizer 137 | - lr_scheduler 138 | - criterion 139 | - callbacks 140 | """ 141 | def __init__(self, config: dict): 142 | """ 143 | Args: 144 | config (dict): from `train_seg_yaml.py` 145 | """ 146 | self.model_params = config["model_params"] 147 | super().__init__(config=config) 148 | 149 | def get_model(self): 150 | architecture = self.model_params["architecture"] 151 | if architecture.lower() == "nnunet": 152 | architecture_kwargs = self.model_params[architecture] 153 | if self.io_params["batch_size"] < 10: 154 | architecture_kwargs["norm_op"] = torch.nn.InstanceNorm2d 155 | architecture_kwargs["nonlin"] = torch.nn.ReLU 156 | architecture_kwargs["nonlin_kwargs"] = {"inplace": True} 157 | architecture_kwargs["final_nonlin"] = lambda x: x 158 | model = Generic_UNet(**architecture_kwargs) 159 | else: 160 | raise NotImplementedError 161 | # calculating # of parameters 162 | total = sum(p.numel() for p in model.parameters()) 163 | trainable = sum(p.numel() for p in model.parameters() if p.requires_grad) 164 | print(f"Total # of Params: {total}\nTrainable params: {trainable}") 165 | 166 | return model 167 | -------------------------------------------------------------------------------- /kits19cnn/inference/evaluate.py: -------------------------------------------------------------------------------- 1 | from tqdm import tqdm 2 | import numpy as np 3 | import nibabel as nib 4 | import pandas as pd 5 | import os 6 | from os.path import isdir, join 7 | from pathlib import Path 8 | 9 | from kits19cnn.metrics import evaluate_official 10 | from sklearn.metrics import precision_recall_fscore_support 11 | 12 | class Evaluator(object): 13 | """ 14 | Evaluates all of the predictions in a user-specified directory and logs them in a csv. Assumes that 15 | the output is in the KiTS19 file structure. 16 | """ 17 | def __init__(self, orig_img_dir, pred_dir, cases=None, 18 | label_file_ending=".npy", binary_tumor=False): 19 | """ 20 | Attributes: 21 | orig_img_dir: path to the directory containing the 22 | labels to evaluate with 23 | i.e. original kits19/data directory or the preprocessed imgs 24 | directory 25 | assumes structure: 26 | orig_img_dir 27 | case_xxxxx 28 | imaging{file_ending} 29 | segmentation{file_ending} 30 | pred_dir: path to the predictions directory, created by Predictor 31 | assumes structure: 32 | pred_dir 33 | case_xxxxx 34 | pred.npy 35 | act.npy 36 | cases: list of filepaths to case folders or just case folder names. 37 | Defaults to None. 38 | label_file_ending (str): one of ['.npy', '.nii', '.nii.gz'] 39 | binary_tumor (bool): whether or not to treat predicted 1s as tumor 40 | """ 41 | self.orig_img_dir = orig_img_dir 42 | self.pred_dir = pred_dir 43 | self.file_ending = label_file_ending 44 | assert self.file_ending in [".npy", ".nii", ".nii.gz"], \ 45 | "label_file_ending must be one of [''.npy', '.nii', '.nii.gz']" 46 | # converting cases from filepaths to raw folder names 47 | if cases is None: 48 | self.cases_raw = [case \ 49 | for case in os.listdir(self.pred_dir) \ 50 | if case.startswith("case")] 51 | assert len(self.cases_raw) > 0, \ 52 | "Please make sure that pred_dir has the case folders" 53 | elif cases is not None: 54 | # extracting raw cases from filepath cases 55 | cases_raw = [Path(case).name for case in cases] 56 | # filtering them down to only cases in pred_dir 57 | self.cases_raw = [case for case in cases_raw \ 58 | if isdir(join(self.pred_dir, case))] 59 | self.binary_tumor = binary_tumor 60 | if self.binary_tumor: 61 | print("Evaluating predicted 1s as tumor (changed to 2).") 62 | 63 | def evaluate_all(self, print_metrics=False): 64 | """ 65 | Evaluates all cases and creates the results.csv, which stores all of 66 | the metrics and the averages. 67 | Args: 68 | print_metrics (bool): whether or not to print metrics. 69 | Defaults to False to be cleaner with tqdm. 70 | """ 71 | metrics_dict = {"cases": [], 72 | "tk_dice": [], "tu_dice": [], 73 | "precision": [], "recall": [], 74 | "fpr": [], "orig_shape": [], 75 | "support": [], "pred_support": []} 76 | 77 | for case in tqdm(self.cases_raw): 78 | # loading the necessary arrays 79 | label, pred = self.load_masks_and_pred(case) 80 | metrics_dict = self.eval_all_metrics_per_case(metrics_dict, label, 81 | pred, case, 82 | print_metrics) 83 | 84 | metrics_dict = self.round_all(self.average_all_cases_per_metric(metrics_dict)) 85 | df = pd.DataFrame(metrics_dict) 86 | metrics_path = join(self.pred_dir, "results.csv") 87 | print(f"Saving {metrics_path}...") 88 | df.to_csv(metrics_path) 89 | 90 | def load_masks_and_pred(self, case): 91 | """ 92 | Loads mask and prediction from `case` 93 | Args: 94 | case (str): case folder names to use 95 | Returns: 96 | label (np.ndarray): shape (x, y, z) 97 | pred (np.ndarray): shape (x, y, z) 98 | """ 99 | y_path = join(self.orig_img_dir, case, f"segmentation{self.file_ending}") 100 | if self.file_ending == ".npy": 101 | label = np.load(y_path) 102 | elif self.file_ending == ".nii.gz" or self.file_ending == ".nii": 103 | label = nib.load(y_path).get_fdata() 104 | pred = np.load(join(self.pred_dir, case, "pred.npy")).squeeze() 105 | if self.binary_tumor: 106 | # treating prediced 1s as tumor (2) 107 | pred[pred == 1] = 2 108 | return (label, pred) 109 | 110 | def eval_all_metrics_per_case(self, metrics_dict, y_true, y_pred, 111 | case, print_metrics=False): 112 | """ 113 | Calculates the official metrics, precision, recall, specificity (fpr), 114 | and stores some metadata such as the original shape and support 115 | (# of pixels for each class). They are then appended to the main 116 | metrics dictionary that is to become results.csv. 117 | """ 118 | # calculating metrics 119 | tk_dice, tu_dice = evaluate_official(y_true, y_pred) 120 | prec, recall, _, supp = precision_recall_fscore_support(y_true.ravel(), 121 | y_pred.ravel(), 122 | labels=[0, 1, 2]) 123 | pred_supp = np.unique(y_pred, return_counts=True)[-1] 124 | fpr = 1-recall 125 | orig_shape = y_true.shape 126 | 127 | if print_metrics: 128 | print(f"PPV: {prec}\nTPR: {recall}\nSupp: {supp}") 129 | print(f"Tumour and Kidney Dice: {tk_dice}; Tumour Dice: {tu_dice}") 130 | # order for appending (sorted keys) 131 | # ['cases', 'fpr', 'orig_shape', 'precision', 'pred_support', 'recall', 132 | # 'support', 'tk_dice', 'tu_dice'] 133 | append_list = [case, fpr, orig_shape, prec, pred_supp, recall, supp, 134 | tk_dice, tu_dice] 135 | sorted_keys = sorted(metrics_dict.keys()) 136 | assert len(append_list) == len(sorted_keys) 137 | # appending to each key's list 138 | for (key_, value_) in zip(sorted_keys, append_list): 139 | metrics_dict[key_].append(value_) 140 | return metrics_dict 141 | 142 | def average_all_cases_per_metric(self, metrics_dict): 143 | """ 144 | Averages the metrics (each key of metrics_dict). 145 | """ 146 | metrics_dict["cases"].append("average") 147 | for key in list(metrics_dict.keys()): 148 | if key == "cases": 149 | pass 150 | else: 151 | # axis=0 will make it so that each sub-axis of orig_shape and 152 | # support will be averaged 153 | try: 154 | metrics_dict[key].append(np.mean(metrics_dict[key], axis=0)) 155 | except: 156 | metrics_dict[key].append("N/A") 157 | return metrics_dict 158 | 159 | def round_all(self, metrics_dict): 160 | """ 161 | Rounding all relevant metrics to three decimal places for cleanliness. 162 | """ 163 | for key in list(metrics_dict.keys()): 164 | if key in ["cases", "pred_support"]: 165 | pass 166 | else: 167 | metrics_dict[key] = np.round(metrics_dict[key], 168 | decimals=3).tolist() 169 | return metrics_dict 170 | -------------------------------------------------------------------------------- /kits19cnn/io/dataset.py: -------------------------------------------------------------------------------- 1 | from os.path import isfile, join 2 | import numpy as np 3 | import nibabel as nib 4 | 5 | import torch 6 | from torch.utils.data import Dataset 7 | 8 | class VoxelDataset(Dataset): 9 | def __init__(self, im_ids: np.array, 10 | transforms=None, 11 | preprocessing=None, 12 | file_ending: str = ".npy"): 13 | """ 14 | Attributes 15 | im_ids (np.ndarray): of image names. 16 | transforms (albumentations.augmentation): transforms to apply 17 | before preprocessing. Defaults to HFlip and ToTensor 18 | preprocessing: ops to perform after transforms, such as 19 | z-score standardization. Defaults to None. 20 | file_ending (str): one of ['.npy', '.nii', '.nii.gz'] 21 | """ 22 | self.im_ids = im_ids 23 | self.transforms = transforms 24 | self.preprocessing = preprocessing 25 | self.file_ending = file_ending 26 | print(f"Using the {file_ending} files...") 27 | 28 | def __getitem__(self, idx): 29 | # loads data as a numpy arr and then adds the channel + batch size dimensions 30 | case_id = self.im_ids[idx] 31 | x, y = self.load_volume(case_id) 32 | if self.transforms: 33 | x = x[None] if len(x.shape) == 4 else x 34 | y = y[None] if len(y.shape) == 4 else y 35 | # batchgenerators requires shape: (b, c, ...) 36 | data_dict = self.transforms(**{"data": x, "seg": y}) 37 | x, y = data_dict["data"], data_dict["seg"] 38 | if self.preprocessing: 39 | preprocessed = self.preprocessing(**{"data": x, "seg": y}) 40 | x, y = preprocessed["data"], preprocessed["seg"] 41 | # squeeze to remove batch size dim 42 | x = torch.squeeze(x, dim=0).float() 43 | y = torch.squeeze(y, dim=0) 44 | return (x, y) 45 | 46 | def __len__(self): 47 | return len(self.im_ids) 48 | 49 | def load_volume(self, case_id): 50 | """ 51 | Loads volume from either .npy or nifti files. 52 | Args: 53 | case_id: path to the case folder 54 | i.e. /content/kits19/data/case_00001 55 | Returns: 56 | Tuple of: 57 | - x (np.ndarray): shape (1, d, h, w) 58 | - y (np.ndarray): same shape as x 59 | """ 60 | x_path = join(case_id, f"imaging{self.file_ending}") 61 | y_path = join(case_id, f"segmentation{self.file_ending}") 62 | if self.file_ending == ".npy": 63 | x, y = np.load(x_path), np.load(y_path) 64 | elif self.file_ending == ".nii.gz" or self.file_ending == ".nii": 65 | x, y = nib.load(x_path).get_fdata(), nib.load(y_path).get_fdata() 66 | return (x[None], y[None]) 67 | 68 | class ClfSegVoxelDataset(VoxelDataset): 69 | """ 70 | Can handle classification+segmentation, classification only, and seg only 71 | outputs. 72 | """ 73 | def __init__(self, im_ids: np.array, 74 | transforms=None, 75 | preprocessing=None, 76 | file_ending: str = ".npy", 77 | mode: str = "both", 78 | num_classes: int = 3): 79 | """ 80 | Attributes 81 | im_ids (np.ndarray): of image names. 82 | transforms (albumentations.augmentation): transforms to apply 83 | before preprocessing. Defaults to HFlip and ToTensor 84 | preprocessing: ops to perform after transforms, such as 85 | z-score standardization. Defaults to None. 86 | file_ending (str): one of ['.npy', '.nii', '.nii.gz'] 87 | mode (str): decides how the nature of the outputs 88 | must be one of ['both', 'clf_only', 'seg_only'] 89 | num_classes (int): number of classes. Defaults to 3. 90 | """ 91 | super().__init__(im_ids=im_ids, transforms=transforms, 92 | preprocessing=preprocessing, file_ending=file_ending) 93 | assert mode.lower() in ["both", "clf_only", "seg_only"], \ 94 | "`mode` must be one of ['both', 'clf_only', 'seg_only']" 95 | self.mode = mode 96 | self.num_classes = num_classes 97 | 98 | def __getitem__(self, idx): 99 | # loads data as a numpy arr and then adds the channel + batch size dimensions 100 | case_id = self.im_ids[idx] 101 | x, y = self.load_volume(case_id) 102 | if self.transforms: 103 | x = x[None] if len(x.shape) == 4 else x 104 | y = y[None] if len(y.shape) == 4 else y 105 | # batchgenerators requires shape: (b, c, ...) 106 | data_dict = self.transforms(**{"data": x, "seg": y}) 107 | x, y = data_dict["data"], data_dict["seg"] 108 | 109 | if self.mode in ["both", "clf_only"]: 110 | y_clf = torch.from_numpy(self.get_clf_label_from_cropped_mask(y)) 111 | 112 | if self.preprocessing: 113 | preprocessed = self.preprocessing(**{"data": x, "seg": y}) 114 | x, y = preprocessed["data"], preprocessed["seg"] 115 | # squeeze to remove batch size dim 116 | if torch.is_tensor(x): 117 | x = torch.squeeze(x, dim=0).float() 118 | if torch.is_tensor(y): 119 | y = torch.squeeze(y, dim=0) 120 | 121 | if self.mode == "both": 122 | return {"features": x, "seg_targets": y, "clf_targets": y_clf} 123 | elif self.mode == "clf_only": 124 | return (x, y_clf) 125 | elif self.mode == "seg_only": 126 | return (x, y) 127 | 128 | def get_clf_label_from_cropped_mask(self, cropped_mask: np.array): 129 | """ 130 | Multi-label one-hot encoding of mask to get the classification 131 | label. 132 | Args: 133 | cropped_mask (np.ndarray): contains int in [0, num_classes-1] 134 | Returns: 135 | one_hot (np.ndarray): multi-label one hot encoded array 136 | i.e. [0, 1, 0] or [1, 0, 1], etc. 137 | """ 138 | unique = np.unique(cropped_mask).astype(np.int32) 139 | one_hot = np.zeros(self.num_classes) 140 | one_hot[unique] = 1 141 | return one_hot 142 | 143 | class TestVoxelDataset(VoxelDataset): 144 | """ 145 | Same as VoxelDataset, but can handle when there are no masks (just returns 146 | blank masks). This is a separate class to prevent lowkey errors with 147 | blank masks--VoxelDataset explicitly fails when there are no masks. 148 | """ 149 | def __init__(self, im_ids: np.array, 150 | transforms=None, 151 | preprocessing=None, 152 | file_ending=".npy"): 153 | """ 154 | Attributes 155 | im_ids (np.ndarray): of image names. 156 | transforms (albumentations.augmentation): transforms to apply 157 | before preprocessing. Defaults to HFlip and ToTensor 158 | preprocessing: ops to perform after transforms, such as 159 | z-score standardization. Defaults to None. 160 | file_ending (str): one of ['.npy', '.nii', '.nii.gz'] 161 | """ 162 | super().__init__(im_ids=im_ids, transforms=transforms, 163 | preprocessing=preprocessing, file_ending=file_ending) 164 | 165 | def load_volume(self, case_id): 166 | """ 167 | Loads volume from either .npy or nifti files. 168 | Args: 169 | case_id: path to the case folder 170 | i.e. /content/kits19/data/case_00001 171 | Returns: 172 | Tuple of: 173 | - x (np.ndarray): shape (1, d, h, w) 174 | - y (np.ndarray): same shape as x 175 | if this does not exist, it's returned as a blank mask 176 | """ 177 | x_path = join(case_id, f"imaging{self.file_ending}") 178 | y_path = join(case_id, f"segmentation{self.file_ending}") 179 | if self.file_ending == ".npy": 180 | x = np.load(x_path) 181 | y = np.load(y_path) if isfile(y_path) else np.zeros(x.shape) 182 | elif self.file_ending == ".nii.gz" or self.file_ending == ".nii": 183 | x = nib.load(x_path).get_fdata() 184 | y = nib.load(y_path).get_fdata() if isfile(y_path) else np.zeros(x.shape) 185 | return (x[None], y[None]) 186 | -------------------------------------------------------------------------------- /kits19cnn/io/resample.py: -------------------------------------------------------------------------------- 1 | 2 | # Copyright 2019 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | from collections import OrderedDict 17 | from skimage.transform import resize 18 | from scipy.ndimage.interpolation import map_coordinates 19 | import numpy as np 20 | from batchgenerators.augmentations.utils import resize_segmentation 21 | 22 | def get_do_separate_z(spacing): 23 | do_separate_z = (np.max(spacing) / np.min(spacing)) > RESAMPLING_SEPARATE_Z_ANISOTROPY_THRESHOLD 24 | return do_separate_z 25 | 26 | 27 | def get_lowres_axis(new_spacing): 28 | axis = np.where(max(new_spacing) / np.array(new_spacing) == 1)[0] # find which axis is anisotropic 29 | return axis 30 | 31 | def resample_patient(data, seg, original_spacing, target_spacing, order_data=3, 32 | order_seg=0, force_separate_z=False, 33 | cval_data=0, cval_seg=0, order_z_data=0, order_z_seg=0): 34 | """ 35 | :param cval_seg: 36 | :param cval_data: 37 | :param data: 38 | :param seg: 39 | :param original_spacing: 40 | :param target_spacing: 41 | :param order_data: 42 | :param order_seg: 43 | :param force_separate_z: if None then we dynamically decide how to resample along z, if True/False then always 44 | /never resample along z separately 45 | :param order_z_seg: only applies if do_separate_z is True 46 | :param order_z_data: only applies if do_separate_z is True 47 | :return: 48 | """ 49 | assert not ((data is None) and (seg is None)) 50 | if data is not None: 51 | assert len(data.shape) == 4, "data must be c x y z" 52 | if seg is not None: 53 | assert len(seg.shape) == 4, "seg must be c x y z" 54 | 55 | if data is not None: 56 | shape = np.array(data[0].shape) 57 | else: 58 | shape = np.array(seg[0].shape) 59 | new_shape = np.round(((np.array(original_spacing) / np.array(target_spacing)).astype(float) * shape)).astype(int) 60 | 61 | if force_separate_z is not None: 62 | do_separate_z = force_separate_z 63 | if force_separate_z: 64 | axis = get_lowres_axis(original_spacing) 65 | else: 66 | axis = None 67 | else: 68 | if get_do_separate_z(original_spacing): 69 | do_separate_z = True 70 | axis = get_lowres_axis(original_spacing) 71 | elif get_do_separate_z(target_spacing): 72 | do_separate_z = True 73 | axis = get_lowres_axis(target_spacing) 74 | else: 75 | do_separate_z = False 76 | axis = None 77 | 78 | if data is not None: 79 | data_reshaped = resample_data_or_seg(data, new_shape, False, axis, 80 | order_data, do_separate_z, 81 | cval=cval_data, 82 | order_z=order_z_data) 83 | else: 84 | data_reshaped = None 85 | if seg is not None: 86 | seg_reshaped = resample_data_or_seg(seg, new_shape, True, axis, 87 | order_seg, do_separate_z, 88 | cval=cval_seg, 89 | order_z=order_z_seg) 90 | else: 91 | seg_reshaped = None 92 | return data_reshaped, seg_reshaped 93 | 94 | 95 | def resample_data_or_seg(data, new_shape, is_seg, axis=None, order=3, 96 | do_separate_z=False, cval=0, order_z=0): 97 | """ 98 | separate_z=True will resample with order 0 along z 99 | :param data: 100 | :param new_shape: 101 | :param is_seg: 102 | :param axis: 103 | :param order: 104 | :param do_separate_z: 105 | :param cval: 106 | :param order_z: only applies if do_separate_z is True 107 | :return: 108 | """ 109 | assert len(data.shape) == 4, "data must be (c, x, y, z)" 110 | if is_seg: 111 | resize_fn = resize_segmentation 112 | kwargs = OrderedDict() 113 | else: 114 | resize_fn = resize 115 | kwargs = {'mode': 'edge', 'anti_aliasing': False} 116 | dtype_data = data.dtype 117 | data = data.astype(float) 118 | shape = np.array(data[0].shape) 119 | new_shape = np.array(new_shape) 120 | if np.any(shape != new_shape): 121 | if do_separate_z: 122 | print("separate z") 123 | assert len(axis) == 1, "only one anisotropic axis supported" 124 | axis = axis[0] 125 | if axis == 0: 126 | new_shape_2d = new_shape[1:] 127 | elif axis == 1: 128 | new_shape_2d = new_shape[[0, 2]] 129 | else: 130 | new_shape_2d = new_shape[:-1] 131 | 132 | reshaped_final_data = [] 133 | for c in range(data.shape[0]): 134 | reshaped_data = [] 135 | for slice_id in range(shape[axis]): 136 | if axis == 0: 137 | reshaped_data.append(resize_fn(data[c, slice_id], 138 | new_shape_2d, order, 139 | cval=cval, **kwargs)) 140 | elif axis == 1: 141 | reshaped_data.append(resize_fn(data[c, :, slice_id], 142 | new_shape_2d, order, 143 | cval=cval, **kwargs)) 144 | else: 145 | reshaped_data.append(resize_fn(data[c, :, :, slice_id], 146 | new_shape_2d, order, 147 | cval=cval, **kwargs)) 148 | reshaped_data = np.stack(reshaped_data, axis) 149 | if shape[axis] != new_shape[axis]: 150 | 151 | # The following few lines are blatantly copied and modified from sklearn's resize() 152 | rows, cols, dim = new_shape[0], new_shape[1], new_shape[2] 153 | orig_rows, orig_cols, orig_dim = reshaped_data.shape 154 | 155 | row_scale = float(orig_rows) / rows 156 | col_scale = float(orig_cols) / cols 157 | dim_scale = float(orig_dim) / dim 158 | 159 | map_rows, map_cols, map_dims = np.mgrid[:rows, :cols, :dim] 160 | map_rows = row_scale * (map_rows + 0.5) - 0.5 161 | map_cols = col_scale * (map_cols + 0.5) - 0.5 162 | map_dims = dim_scale * (map_dims + 0.5) - 0.5 163 | 164 | coord_map = np.array([map_rows, map_cols, map_dims]) 165 | if not is_seg or order_z == 0: 166 | reshaped_final_data.append(map_coordinates(reshaped_data, coord_map, 167 | order=order_z, cval=cval, 168 | mode='nearest')[None]) 169 | else: 170 | unique_labels = np.unique(reshaped_data) 171 | reshaped = np.zeros(new_shape, dtype=dtype_data) 172 | 173 | for i, cl in enumerate(unique_labels): 174 | reshaped_multihot = np.round( 175 | map_coordinates((reshaped_data == cl).astype(float), 176 | coord_map, order=order_z, 177 | cval=cval, mode='nearest')) 178 | reshaped[reshaped_multihot >= 0.5] = cl 179 | reshaped_final_data.append(reshaped[None]) 180 | else: 181 | reshaped_final_data.append(reshaped_data[None]) 182 | reshaped_final_data = np.vstack(reshaped_final_data) 183 | else: 184 | reshaped = [] 185 | for c in range(data.shape[0]): 186 | reshaped.append(resize_fn(data[c], new_shape, order, 187 | cval=cval, **kwargs)[None]) 188 | reshaped_final_data = np.vstack(reshaped) 189 | return reshaped_final_data.astype(dtype_data) 190 | else: 191 | print("no resampling necessary") 192 | return data 193 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # kits19-cnn 2 | Using 2D & 3D convolutional neural networks for the [2019 Kidney and Kidney Tumor Segmentation Challenge](https://kits19.grand-challenge.org/). This repository is associated with this [conference paper (older version)](https://www.researchgate.net/publication/336247303_A_2D_U-Net_for_Automated_Kidney_and_Renal_Tumor_Segmentation). 3 | 4 |  5 |  6 | 7 | ## Disclaimer 8 | I'm not sure why the tumor scores are so low for all of the architectures, so I'm open to any suggestions and PRs! Am actively working on improving them. 9 | 10 | ## Credits 11 | Major credits to: 12 | * Isensee's [nnU-Net](https://github.com/MIC-DKFZ/nnUNet) and [batchgenerators](https://github.com/MIC-DKFZ/batchgenerators) 13 | * qubvel's [segmentation_models.pytorch](https://github.com/qubvel/segmentation_models.pytorch) 14 | * Nick Heller for hosting KiTS19! 15 | 16 | ## Torch/Catalyst Pipeline (Overview) 17 | ### Preprocessing 18 | * Resampling to 3.22 × 1.62 × 1.62 mm 19 | * Isensee's nnU-Net methodology 20 | * Clipping to the [0.5, 99.5] percentiles and applying z-score standardization 21 | 22 | ### Training 23 | * Foreground class sampling 24 | * __2D:__ Done by sampling per slice (loading only 2D arrays) 25 | * __SO MUCH FASTER THAN LOADING 3D ARRAYS__ 26 | * Difference: 3 seconds v. 5 minutes per epoch 27 | * __3D:__ Done through `ROICropTransform` 28 | * Data Augmentation 29 | * Located in `kits19cnn/experiments/utils.py` 30 | * Pay attention to the `augmentation_key`s in `get_training_augmentation` and `get_validation_augmentation` 31 | * Done through `batchgenerators` + my own custom transforms 32 | * SGD (lr=1e-4) and LRPlateau (factor=0.15 and patience=5); BCEDiceLoss 33 | * 2D: batch size = 18 (regular training) 34 | * 3D: batch size = 4 (fp16 training) 35 | 36 | ### Architectures 37 | * 2D (patch size: (256, 256)) 38 | * Vanilla 2D nnU-Net 39 | * 6 pools with convolutional downsampling and upsampling 40 | * max number of filters set to 320 and the starting number is 30 41 | * 2D U-Net with pretrained ImageNet classifiers 42 | * 2D FPN with pretrained ImageNet classifiers 43 | * 3D (patch size: (96, 160, 160)) 44 | * 3D nnU-Net 45 | * 5 pools with convolutional downsampling and upsampling 46 | * max number of filters set to 320 and the starting number is 30 47 | * 3D nnU-Net (Classification + Segmentation) 48 | 49 | ## Results 50 |
| Neural Network | 55 |Parameters | 56 |Local Test (Tumor-Kidney) Dice | 57 |Local Test (Tumor Only) Dice | 58 |Weights | 59 |
|---|---|---|---|---|
| 2D nnU-Net | 63 |12M | 64 |0.90 | 65 |0.26 | 66 |... | 67 |
| 3D nnU-Net | 70 |29.6M | 71 |0.86 | 72 |0.22 | 73 |... | 74 |
| ResNet34 + U-Net Decoder | 77 |24M | 78 |0.90 | 79 |0.29 | 80 |... | 81 |
| ResNet34 + FPN Decoder | 84 |22M | 85 |0.83 | 86 |0.29 | 87 |... | 88 |