├── kits19cnn ├── __init__.py ├── models │ ├── __init__.py │ ├── nnunet │ │ ├── __init__.py │ │ ├── initialization.py │ │ └── generic_UNet.py │ └── smp_models.py ├── inference │ ├── __init__.py │ ├── ensemble.py │ ├── utils.py │ ├── inference_class.py │ └── evaluate.py ├── experiments │ ├── __init__.py │ ├── infer_2d.py │ ├── train_3d.py │ ├── infer.py │ ├── train_2d.py │ ├── utils.py │ └── train.py ├── io │ ├── __init__.py │ ├── custom_transforms.py │ ├── dataset.py │ ├── resample.py │ ├── preprocess.py │ ├── custom_augmentations.py │ └── dataset_2d.py ├── metrics.py ├── utils.py ├── loss_functions.py └── visualize.py ├── images ├── label_case_00113.png └── pred2d_case_00113.png ├── script_configs ├── eval.yml ├── infer_tu_only │ ├── eval.yml │ └── pred.yml ├── pred.yml └── train.yml ├── .gitignore ├── scripts ├── evaluate.py ├── predict.py └── train_yaml.py ├── setup.py ├── notebooks └── Visualizing Volumes.ipynb ├── README.md └── LICENSE /kits19cnn/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /images/label_case_00113.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jchen42703/kits19-cnn/HEAD/images/label_case_00113.png -------------------------------------------------------------------------------- /images/pred2d_case_00113.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jchen42703/kits19-cnn/HEAD/images/pred2d_case_00113.png -------------------------------------------------------------------------------- /kits19cnn/models/__init__.py: -------------------------------------------------------------------------------- 1 | from .nnunet import SegmentationNetwork, Generic_UNet 2 | from .smp_models import wrap_smp_model 3 | -------------------------------------------------------------------------------- /kits19cnn/models/nnunet/__init__.py: -------------------------------------------------------------------------------- 1 | from .neural_network import SegmentationNetwork 2 | from .generic_UNet import Generic_UNet 3 | -------------------------------------------------------------------------------- /kits19cnn/inference/__init__.py: -------------------------------------------------------------------------------- 1 | from .inference_class import Predictor 2 | from .evaluate import Evaluator 3 | from .utils import create_submission, load_weights_infer 4 | from .ensemble import Ensembler 5 | -------------------------------------------------------------------------------- /script_configs/eval.yml: -------------------------------------------------------------------------------- 1 | orig_img_dir: /content/kits_preprocessed 2 | pred_dir: /content/kits19_predictions 3 | label_file_ending: .npy 4 | print_metrics: False 5 | binary_tumor: False 6 | -------------------------------------------------------------------------------- /script_configs/infer_tu_only/eval.yml: -------------------------------------------------------------------------------- 1 | orig_img_dir: /content/kits_preprocessed 2 | pred_dir: /content/kits19_predictions 3 | label_file_ending: .npy 4 | print_metrics: False 5 | binary_tumor: True 6 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | *.pyc 2 | *.ipynb_checkpoints 3 | __pycache__/ 4 | 5 | bin/ 6 | build/ 7 | develop-eggs/ 8 | dist/ 9 | eggs/ 10 | lib/ 11 | lib64/ 12 | parts/ 13 | sdist/ 14 | var/ 15 | *.egg-info/ 16 | .installed.cfg 17 | *.egg 18 | -------------------------------------------------------------------------------- /kits19cnn/experiments/__init__.py: -------------------------------------------------------------------------------- 1 | from .utils import get_training_augmentation, get_validation_augmentation, \ 2 | get_preprocessing, seed_everything 3 | from .train_3d import TrainSegExperiment, TrainClfSegExperiment3D 4 | from .train_2d import TrainSegExperiment2D, TrainClfSegExperiment2D 5 | from .infer import SegmentationInferenceExperiment 6 | from .infer_2d import SegmentationInferenceExperiment2D 7 | -------------------------------------------------------------------------------- /kits19cnn/io/__init__.py: -------------------------------------------------------------------------------- 1 | from .dataset import ClfSegVoxelDataset, VoxelDataset, TestVoxelDataset 2 | from .dataset_2d import SliceDataset, PseudoSliceDataset 3 | from .preprocess import Preprocessor 4 | from .resample import resample_patient 5 | from .custom_transforms import ROICropTransform, RepeatChannelsTransform, \ 6 | MultiClassToBinaryTransform, \ 7 | RandomResizedCropTransform 8 | -------------------------------------------------------------------------------- /kits19cnn/inference/ensemble.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | class Ensembler(object): 4 | """ 5 | Iterates through multiple directories of predicted activation maps 6 | and averages them. The ensembled results are saved in a separate 7 | directory, `out_dir`. 8 | * Assumes the predicted activation maps are called `pred_act.npy` 9 | in their respective case directories. 10 | """ 11 | def __init__(self): 12 | pass 13 | -------------------------------------------------------------------------------- /kits19cnn/metrics.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | def evaluate_official(y_true, y_pred): 4 | """ 5 | Official evaluation metric. (numpy) 6 | """ 7 | try: 8 | # Compute tumor+kidney Dice 9 | tk_pd = np.greater(y_pred, 0) 10 | tk_gt = np.greater(y_true, 0) 11 | intersection = np.logical_and(tk_pd, tk_gt).sum() 12 | tk_dice = 2*intersection/(tk_pd.sum() + tk_gt.sum()) 13 | except ZeroDivisionError: 14 | return 0.0, 0.0 15 | 16 | try: 17 | # Compute tumor Dice 18 | tu_pd = np.greater(y_pred, 1) 19 | tu_gt = np.greater(y_true, 1) 20 | intersection = np.logical_and(tu_pd, tu_gt).sum() 21 | tu_dice = 2*intersection/(tu_pd.sum() + tu_gt.sum()) 22 | except ZeroDivisionError: 23 | return tk_dice, 0.0 24 | 25 | return tk_dice, tu_dice 26 | -------------------------------------------------------------------------------- /kits19cnn/models/smp_models.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from kits19cnn.models.nnunet.neural_network import SegmentationNetwork 3 | from kits19cnn.utils import softmax_helper 4 | 5 | def wrap_smp_model(smp_model_type, smp_model_kwargs={}, num_classes=3, 6 | activation="softmax"): 7 | """ 8 | Wraps a 2D smp model with SegmentationNetwork's methods. Mainly for 9 | inference, so that the smp_model can use the `predict_3D` method. 10 | """ 11 | class WrappedModel(smp_model_type, SegmentationNetwork): 12 | def __init__(self, model_kwargs={}): 13 | super().__init__(**model_kwargs) 14 | self.conv_op = torch.nn.Conv2d 15 | self.num_classes = num_classes 16 | if activation == "softmax": 17 | self.inference_apply_nonlin = softmax_helper 18 | elif activation == "sigmoid": 19 | self.inference_apply_nonlin = torch.sigmoid 20 | wrapped_model = WrappedModel(smp_model_kwargs) 21 | return wrapped_model 22 | -------------------------------------------------------------------------------- /scripts/evaluate.py: -------------------------------------------------------------------------------- 1 | from kits19cnn.inference import Evaluator 2 | 3 | def main(config): 4 | """ 5 | Main code for running the evaluation of 3D volumes. 6 | Args: 7 | config (dict): dictionary read from a yaml file 8 | i.e. script_configs/eval.yml 9 | Returns: 10 | None 11 | """ 12 | evaluator = Evaluator(config["orig_img_dir"], config["pred_dir"], 13 | label_file_ending=config["label_file_ending"], 14 | binary_tumor=config["binary_tumor"]) 15 | evaluator.evaluate_all(print_metrics=config["print_metrics"]) 16 | 17 | if __name__ == "__main__": 18 | import yaml 19 | import argparse 20 | 21 | parser = argparse.ArgumentParser(description="For training.") 22 | parser.add_argument("--yml_path", type=str, required=True, 23 | help="Path to the .yml config.") 24 | args = parser.parse_args() 25 | 26 | with open(args.yml_path, 'r') as stream: 27 | try: 28 | config = yaml.safe_load(stream) 29 | except yaml.YAMLError as exc: 30 | print(exc) 31 | 32 | main(config) 33 | -------------------------------------------------------------------------------- /script_configs/pred.yml: -------------------------------------------------------------------------------- 1 | in_dir: /content/kits_preprocessed 2 | out_dir: /content/kits19_predictions 3 | with_masks: True 4 | mode: segmentation 5 | checkpoint_path: nnunet3d_exp2_94epochs_last_full.pth 6 | pseudo_3D: False 7 | 8 | io_params: 9 | test_size: 0.2 10 | split_seed: 200 11 | batch_size: 1 12 | num_workers: 2 13 | # aug_key: aug2 14 | file_ending: .npy # nii.gz 15 | 16 | model_params: 17 | architecture: nnunet 18 | instance_norm: False 19 | nnunet: 20 | input_channels: 1 21 | base_num_features: 30 22 | num_classes: 3 23 | num_pool: 5 24 | num_conv_per_stage: 2 25 | feat_map_mul_on_downscale: 2 26 | deep_supervision: False 27 | convolutional_pooling: True 28 | convolutional_upsampling: True 29 | max_num_features: 320 30 | # # 2D ONLY 31 | # encoder: resnet34 32 | # activation: softmax 33 | # unet_smp: 34 | # attention_type: ~ # scse 35 | # decoder_use_batchnorm: True # inplace for InplaceABN 36 | # fpn_smp: 37 | # dropout: 0.2 38 | 39 | predict_3D_params: 40 | do_mirroring: True 41 | num_repeats: 1 42 | use_train_mode: False 43 | batch_size: 1 44 | mirror_axes: 45 | - 0 46 | - 1 47 | - 2 48 | tiled: True 49 | tile_in_z: True 50 | step: 2 51 | patch_size: 52 | - 96 53 | - 160 54 | - 160 55 | regions_class_order: ~ #argmax 56 | use_gaussian: False 57 | pad_border_mode: edge 58 | pad_kwargs: {} 59 | all_in_gpu: False 60 | -------------------------------------------------------------------------------- /script_configs/infer_tu_only/pred.yml: -------------------------------------------------------------------------------- 1 | in_dir: /content/kits_preprocessed 2 | out_dir: /content/kits19_predictions 3 | with_masks: True 4 | mode: segmentation 5 | checkpoint_path: resnet34unet_seg_tuonly2d2_381epochs_seed15_best.pth 6 | pseudo_3D: False 7 | 8 | io_params: 9 | test_size: 0.2 10 | split_seed: 15 11 | batch_size: 1 12 | num_workers: 2 13 | file_ending: .npy # nii.gz 14 | 15 | model_params: 16 | architecture: unet_smp 17 | # instance_norm: False 18 | # nnunet: 19 | # input_channels: 1 20 | # base_num_features: 30 21 | # num_classes: 3 22 | # num_pool: 5 23 | # num_conv_per_stage: 2 24 | # feat_map_mul_on_downscale: 2 25 | # deep_supervision: False 26 | # convolutional_pooling: True 27 | # convolutional_upsampling: True 28 | # max_num_features: 320 29 | # # 2D ONLY 30 | encoder: resnet34 31 | activation: sigmoid 32 | unet_smp: 33 | attention_type: ~ # scse 34 | classes: 1 35 | decoder_use_batchnorm: True # inplace for InplaceABN 36 | # fpn_smp: 37 | # classes: 1 38 | # dropout: 0.2 39 | 40 | predict_3D_params: 41 | do_mirroring: True 42 | num_repeats: 1 43 | use_train_mode: False 44 | batch_size: 1 45 | mirror_axes: 46 | - 0 47 | - 1 48 | # - 2 49 | tiled: True 50 | tile_in_z: True 51 | step: 2 52 | patch_size: 53 | - 256 54 | - 256 55 | # - 96 56 | # - 160 57 | # - 160 58 | regions_class_order: # ~ #argmax 59 | - 0 60 | - 1 61 | use_gaussian: False 62 | pad_border_mode: edge 63 | pad_kwargs: {} 64 | all_in_gpu: False 65 | -------------------------------------------------------------------------------- /kits19cnn/models/nnunet/initialization.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from torch import nn 16 | 17 | 18 | class InitWeights_He(object): 19 | def __init__(self, neg_slope=1e-2): 20 | self.neg_slope = neg_slope 21 | 22 | def __call__(self, module): 23 | if isinstance(module, nn.Conv3d) or isinstance(module, nn.Conv2d) or isinstance(module, nn.ConvTranspose2d) or isinstance(module, nn.ConvTranspose3d): 24 | module.weight = nn.init.kaiming_normal_(module.weight, a=1e-2) 25 | if module.bias is not None: 26 | module.bias = nn.init.constant_(module.bias, 0) 27 | 28 | 29 | class InitWeights_XavierUniform(object): 30 | def __init__(self, gain=1): 31 | self.gain = gain 32 | 33 | def __call__(self, module): 34 | if isinstance(module, nn.Conv3d) or isinstance(module, nn.Conv2d) or isinstance(module, nn.ConvTranspose2d) or isinstance(module, nn.ConvTranspose3d): 35 | module.weight = nn.init.xavier_uniform_(module.weight, self.gain) 36 | if module.bias is not None: 37 | module.bias = nn.init.constant_(module.bias, 0) 38 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | setup(name='kits19cnn', 4 | version='0.01.5', 5 | description='Submission for the KiTS 2019 Challenge', 6 | url='https://github.com/jchen42703/kits19-cnn', 7 | author='Joseph Chen, Benson Jin', 8 | author_email='jchen42703@gmail.com, jinb2@bxsci.edu', 9 | license='Apache License Version 2.0, January 2004', 10 | packages=find_packages(), 11 | install_requires=[ 12 | "numpy>=1.10.2", 13 | "scipy", 14 | "scikit-image", 15 | "future", 16 | "keras", 17 | "tensorflow", 18 | "nibabel", 19 | "pandas", 20 | "sklearn", 21 | "batchgenerators", 22 | "torch>=1.2.0", 23 | "torchvision>=0.4.0", 24 | "catalyst", 25 | "pytorch_toolbelt", 26 | "segmentation_models_pytorch", 27 | ], 28 | classifiers=[ 29 | 'Development Status :: 3 - Alpha', 30 | # Chose either "3 - Alpha", "4 - Beta" or "5 - Production/Stable" as the current state of your package 31 | 'Intended Audience :: Developers', # Define that your audience are developers 32 | 'Topic :: Software Development :: Build Tools', 33 | 'License :: OSI Approved :: MIT License', # Again, pick a license 34 | 'Programming Language :: Python :: 3', # Specify which python versions that you want to support 35 | 'Programming Language :: Python :: 3.4', 36 | 'Programming Language :: Python :: 3.5', 37 | 'Programming Language :: Python :: 3.6', 38 | 'Programming Language :: Python :: 2.7', 39 | ], 40 | keywords=['deep learning', 'image segmentation', 'image classification', 'medical image analysis', 41 | 'medical image segmentation', 'data augmentation'], 42 | ) 43 | -------------------------------------------------------------------------------- /scripts/predict.py: -------------------------------------------------------------------------------- 1 | from catalyst.dl.runner import SupervisedRunner 2 | 3 | from kits19cnn.inference import Predictor 4 | from kits19cnn.experiments import SegmentationInferenceExperiment, \ 5 | SegmentationInferenceExperiment2D, \ 6 | seed_everything 7 | 8 | def main(config): 9 | """ 10 | Main code for training a classification model. 11 | 12 | Args: 13 | config (dict): dictionary read from a yaml file 14 | i.e. experiments/finetune_classification.yml 15 | Returns: 16 | None 17 | """ 18 | # setting up the train/val split with filenames 19 | seed = config["io_params"]["split_seed"] 20 | seed_everything(seed) 21 | dim = len(config["predict_3D_params"]["patch_size"]) 22 | mode = config["mode"].lower() 23 | assert mode in ["classification", "segmentation"], \ 24 | "The `mode` must be one of ['classification', 'segmentation']." 25 | if mode == "classification": 26 | raise NotImplementedError 27 | elif mode == "segmentation": 28 | if dim == 2: 29 | exp = SegmentationInferenceExperiment2D(config) 30 | elif dim == 3: 31 | exp = SegmentationInferenceExperiment(config) 32 | 33 | print(f"Seed: {seed}\nMode: {mode}") 34 | pred = Predictor(out_dir=config["out_dir"], 35 | checkpoint_path=config["checkpoint_path"], 36 | model=exp.model, test_loader=exp.loaders["test"], 37 | pred_3D_params=config["predict_3D_params"], 38 | pseudo_3D=config.get("pseudo_3D")) 39 | pred.run_3D_predictions() 40 | 41 | if __name__ == "__main__": 42 | import yaml 43 | import argparse 44 | 45 | parser = argparse.ArgumentParser(description="For training.") 46 | parser.add_argument("--yml_path", type=str, required=True, 47 | help="Path to the .yml config.") 48 | args = parser.parse_args() 49 | 50 | with open(args.yml_path, 'r') as stream: 51 | try: 52 | config = yaml.safe_load(stream) 53 | except yaml.YAMLError as exc: 54 | print(exc) 55 | 56 | main(config) 57 | -------------------------------------------------------------------------------- /kits19cnn/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | import torch 15 | import numpy as np 16 | 17 | def flip(x, dim): 18 | """ 19 | flips the tensor at dimension dim (mirroring!) 20 | :param x: 21 | :param dim: 22 | :return: 23 | """ 24 | indices = [slice(None)] * x.dim() 25 | indices[dim] = torch.arange(x.size(dim) - 1, -1, -1, 26 | dtype=torch.long, device=x.device) 27 | return x[tuple(indices)] 28 | 29 | def sum_tensor(inp, axes, keepdim=False): 30 | axes = np.unique(axes).astype(int) 31 | if keepdim: 32 | for ax in axes: 33 | inp = inp.sum(int(ax), keepdim=True) 34 | else: 35 | for ax in sorted(axes, reverse=True): 36 | inp = inp.sum(int(ax)) 37 | return inp 38 | 39 | def maybe_to_torch(d): 40 | if isinstance(d, list): 41 | d = [maybe_to_torch(i) if not isinstance(i, torch.Tensor) else i for i in d] 42 | elif not isinstance(d, torch.Tensor): 43 | d = torch.from_numpy(d).float() 44 | return d 45 | 46 | 47 | def to_cuda(data, non_blocking=True, gpu_id=0): 48 | if isinstance(data, list): 49 | data = [i.cuda(gpu_id, non_blocking=non_blocking) for i in data] 50 | else: 51 | data = data.cuda(gpu_id, non_blocking=True) 52 | return data 53 | 54 | def softmax_helper(x): 55 | rpt = [1 for _ in range(len(x.size()))] 56 | rpt[1] = x.size(1) 57 | x_max = x.max(1, keepdim=True)[0].repeat(*rpt) 58 | e_x = torch.exp(x - x_max) 59 | return e_x / e_x.sum(1, keepdim=True).repeat(*rpt) 60 | -------------------------------------------------------------------------------- /scripts/train_yaml.py: -------------------------------------------------------------------------------- 1 | from catalyst.dl.runner import SupervisedRunner 2 | 3 | from kits19cnn.experiments import TrainSegExperiment, TrainClfSegExperiment3D, \ 4 | TrainSegExperiment2D, TrainClfSegExperiment2D, \ 5 | seed_everything 6 | from kits19cnn.visualize import plot_metrics, save_figs 7 | 8 | def main(config): 9 | """ 10 | Main code for training a classification model. 11 | 12 | Args: 13 | config (dict): dictionary read from a yaml file 14 | i.e. experiments/finetune_classification.yml 15 | Returns: 16 | None 17 | """ 18 | # setting up the train/val split with filenames 19 | seed = config["io_params"]["split_seed"] 20 | seed_everything(seed) 21 | mode = config["mode"].lower() 22 | assert mode in ["classification", "segmentation", "both"], \ 23 | "The `mode` must be one of ['classification', 'segmentation', 'both']." 24 | if mode == "classification": 25 | raise NotImplementedError 26 | elif mode == "segmentation": 27 | if config["dim"] == 2: 28 | exp = TrainSegExperiment2D(config) 29 | elif config["dim"] == 3: 30 | exp = TrainSegExperiment(config) 31 | output_key = "logits" 32 | elif mode == "both": 33 | if config["dim"] == 2: 34 | exp = TrainClfSegExperiment2D(config) 35 | elif config["dim"] == 3: 36 | exp = TrainClfSegExperiment3D(config) 37 | output_key = ["seg_logits", "clf_logits"] 38 | 39 | print(f"Seed: {seed}\nMode: {mode}") 40 | 41 | runner = SupervisedRunner(output_key=output_key) 42 | 43 | runner.train(model=exp.model, criterion=exp.criterion, optimizer=exp.opt, 44 | scheduler=exp.lr_scheduler, loaders=exp.loaders, 45 | callbacks=exp.cb_list, **config["runner_params"]) 46 | # Not saving plots if plot_params not specified in config 47 | if not config.get("plot_params"): 48 | figs = plot_metrics(logdir=config["runner_params"]["logdir"], 49 | metrics=config["plot_params"]["metrics"]) 50 | save_figs(figs, save_dir=config["plot_params"]["save_dir"]) 51 | 52 | if __name__ == "__main__": 53 | import yaml 54 | import argparse 55 | 56 | parser = argparse.ArgumentParser(description="For training.") 57 | parser.add_argument("--yml_path", type=str, required=True, 58 | help="Path to the .yml config.") 59 | args = parser.parse_args() 60 | 61 | with open(args.yml_path, 'r') as stream: 62 | try: 63 | config = yaml.safe_load(stream) 64 | except yaml.YAMLError as exc: 65 | print(exc) 66 | 67 | main(config) 68 | -------------------------------------------------------------------------------- /kits19cnn/inference/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | from os.path import join, isdir 3 | 4 | import numpy as np 5 | import nibabel as nib 6 | import torch 7 | 8 | def load_weights_infer(checkpoint_path, model): 9 | """ 10 | Loads pytorch model from a checkpoint and into inference mode. 11 | 12 | Args: 13 | checkpoint_path (str): path to a .pt or .pth checkpoint 14 | model (torch.nn.Module): <- 15 | Returns: 16 | Model with loaded weights and in evaluation mode 17 | """ 18 | try: 19 | # catalyst weights 20 | state_dict = torch.load(checkpoint_path, map_location="cpu")["model_state_dict"] 21 | except: 22 | # anything else 23 | state_dict = torch.load(checkpoint_path, map_location="cpu") 24 | try: 25 | model.load_state_dict(state_dict, strict=True) 26 | except: 27 | # for clf + seg for seg only prediction 28 | print(f"Non-strict loading of weights from {checkpoint_path}") 29 | model.load_state_dict(state_dict, strict=False) 30 | model.eval() 31 | return model 32 | 33 | def create_submission(pred_dir, out_dir, orig_dir, cases=None): 34 | """ 35 | Creates a submission directory from the predictions. To complete the 36 | submission, run this from command prompt: 37 | zip predictions.zip out_dir/predictions/prediction_*.nii.gz; 38 | change out_dir to whatever out_dir you specified for this function. 39 | Args: 40 | pred_dir: file path to the predictions directory, generated by 41 | infer.Predictor 42 | out_dir: file path to the directory where the predictions.zip is to 43 | be saved. 44 | orig_dir: file path to the original image directory (kits19/data). 45 | cases (iterable): iterable of case folder names 46 | """ 47 | # setting up the output directory 48 | out_pred_dir = join(out_dir, "predictions") 49 | 50 | if not isdir(out_pred_dir): 51 | os.mkdir(out_pred_dir) 52 | print("Created directory: {0}".format(out_pred_dir)) 53 | 54 | if cases is None: 55 | cases = ["case_{:05d}".format(i) for i in range(210, 300)] 56 | 57 | # iteratively converting each prediction 58 | for (i, case) in enumerate(cases): 59 | print("Processing {0}/{1}: {2}".format(i+1, len(cases), case)) 60 | x_nib = nib.load(join(orig_dir, case, "imaging.nii.gz")) 61 | pred = np.load(join(pred_dir, case, "pred_{0}.npy".format(case))) 62 | # converting the prediction to a .nii.gz file 63 | p_nib = nib.Nifti1Image(pred, x_nib.affine) 64 | id_ = case[-5:] # number id attached to each case 65 | out_fpath = join(out_pred_dir, "prediction_{0}.nii.gz".format(id_)) 66 | nib.save(p_nib, out_fpath) 67 | 68 | for case in tqdm(self.cases): 69 | x_nib = nib.load(join(orig_dir, case, "imaging.nii.gz")) 70 | pred = np.load(join(pred_dir, case, f"pred_{case}.npy")) 71 | # converting the prediction to a .nii.gz file 72 | p_nib = nib.Nifti1Image(pred, x_nib.affine) 73 | id_ = case[-5:] # number id attached to each case 74 | out_fpath = join(out_pred_dir, f"prediction_{id_}.nii.gz") 75 | nib.save(p_nib, out_fpath) 76 | -------------------------------------------------------------------------------- /script_configs/train.yml: -------------------------------------------------------------------------------- 1 | mode: segmentation # classification # both 2 | dim: 3 # 2 3 | data_folder: /content/kits_preprocessed/ #/content/kits19/data 4 | 5 | runner_params: 6 | logdir: /content/logs/segmentation/ 7 | num_epochs: 85 8 | fp16: False 9 | verbose: True 10 | 11 | io_params: 12 | test_size: 0.2 13 | split_seed: 200 14 | batch_size: 2 15 | num_workers: 2 16 | aug_key: aug2 17 | file_ending: .npy 18 | # slice_indices_path: /content/kits_preprocessed/slice_indices.json # 2D 19 | # p_pos_per_sample: 0.33 # 2D 20 | # pseudo_3D: False 21 | # num_pseudo_slices: 7 22 | 23 | criterion_params: 24 | loss: ce_dice_loss 25 | ce_dice_loss: 26 | soft_dice_kwargs: 27 | batch_dice: True 28 | smooth: 0.00001 #1e-5 29 | do_bg: False 30 | square: False 31 | ce_kwargs: {} 32 | # for clf_seg 33 | # seg_loss: ce_dice_loss 34 | # ce_dice_loss: 35 | # soft_dice_kwargs: 36 | # batch_dice: True 37 | # smooth: 0.00001 #1e-5 38 | # do_bg: False 39 | # square: False 40 | # ce_kwargs: {} 41 | # clf_loss: bce_dice_loss 42 | # bce_dice_loss: 43 | # eps: 0.0000001 # 1e-7 44 | # activation: sigmoid 45 | 46 | model_params: 47 | architecture: nnunet 48 | nnunet: 49 | input_channels: 1 50 | base_num_features: 30 51 | num_classes: 3 52 | num_pool: 5 53 | num_conv_per_stage: 2 54 | feat_map_mul_on_downscale: 2 55 | deep_supervision: False 56 | convolutional_pooling: True 57 | convolutional_upsampling: True 58 | max_num_features: 320 59 | classification: False #True 60 | dropout_op_kwargs: 61 | p: 0 62 | inplace: True 63 | ## 2D ONLY 64 | # encoder: resnet34 65 | # unet_smp: 66 | # attention_type: ~ # scse 67 | # classes: 3 68 | # decoder_use_batchnorm: True # inplace for InplaceABN 69 | # fpn_smp: 70 | # classes: 3 71 | # dropout: 0.2 72 | 73 | opt_params: 74 | opt: SGD 75 | SGD: 76 | lr: 0.0001 77 | momentum: 0.9 78 | weight_decay: 0.0001 79 | scheduler_params: 80 | scheduler: ReduceLROnPlateau 81 | ReduceLROnPlateau: 82 | factor: 0.15 83 | patience: 30 #2 84 | mode: min 85 | verbose: True 86 | threshold: 0.001 87 | threshold_mode: abs 88 | 89 | callback_params: 90 | EarlyStoppingCallback: 91 | patience: 60 92 | min_delta: 0.001 93 | # AccuracyCallback: 94 | # threshold: 0.5 95 | # activation: Softmax 96 | # PrecisionRecallF1ScoreCallback: 97 | # num_classes: 3 98 | # threshold: 0.5 99 | # activation: Softmax 100 | checkpoint_params: 101 | checkpoint_path: ~ #/content/logs/segmentation/checkpoints/last.pth #/content/logs/segmentation/checkpoints/last_full.pth 102 | mode: model_only 103 | 104 | # specify if you want to save plotly plots as .pngs 105 | ## Requires separate installation of xvfb on Colab. 106 | # plot_params: 107 | # metrics: 108 | # - loss/epoch 109 | # # - ppv/class_0/epoch 110 | # # - f1/class_0/epoch 111 | # # - tpr/class_0/epoch 112 | # save_dir: /content/logs/segmentation/ 113 | -------------------------------------------------------------------------------- /kits19cnn/experiments/infer_2d.py: -------------------------------------------------------------------------------- 1 | import segmentation_models_pytorch as smp 2 | import torch 3 | 4 | from kits19cnn.utils import softmax_helper 5 | from kits19cnn.models import Generic_UNet, wrap_smp_model 6 | from .utils import get_preprocessing 7 | from kits19cnn.io import TestVoxelDataset 8 | from .infer import BaseInferenceExperiment 9 | 10 | class SegmentationInferenceExperiment2D(BaseInferenceExperiment): 11 | """ 12 | Inference Experiment to support prediction experiments 13 | """ 14 | def __init__(self, config: dict): 15 | """ 16 | Args: 17 | config (dict): 18 | """ 19 | self.model_params = config["model_params"] 20 | super().__init__(config=config) 21 | 22 | def get_datasets(self, test_ids): 23 | """ 24 | Creates and returns the test dataset. 25 | """ 26 | use_rgb = "smp" in self.model_params["architecture"] 27 | preprocess_t = get_preprocessing(use_rgb) 28 | # creating the datasets 29 | test_dataset = TestVoxelDataset(im_ids=test_ids, 30 | transforms=None, 31 | preprocessing=preprocess_t, 32 | file_ending=self.io_params["file_ending"]) 33 | return test_dataset 34 | 35 | def get_model(self): 36 | """ 37 | Fetches the 2D model: the nnU-Net, smp U-Net or smp FPN for prediction. 38 | """ 39 | architecture = self.model_params["architecture"] 40 | # creating model 41 | if architecture == "nnunet": 42 | unet_kwargs = self.model_params[architecture] 43 | unet_kwargs = self.setup_2D_UNet_params(unet_kwargs) 44 | model = Generic_UNet(**unet_kwargs) 45 | model.inference_apply_nonlin = softmax_helper 46 | # smp models 47 | elif "smp" in architecture: 48 | model_type = smp.FPN if architecture == "fpn_smp" else smp.Unet 49 | print(f"Model type: {model_type}") 50 | model_kwargs = {"encoder_name": self.model_params["encoder"], 51 | "encoder_weights": None, "classes": 3, 52 | "activation": None} 53 | model_kwargs.update(self.model_params[architecture]) 54 | # adds the `predict_3D` method for the smp model 55 | model = wrap_smp_model(model_type, model_kwargs, 56 | num_classes=model_kwargs["classes"], 57 | activation=self.model_params["activation"]) 58 | # calculating # of parameters 59 | total = sum(p.numel() for p in model.parameters()) 60 | trainable = sum(p.numel() for p in model.parameters() if p.requires_grad) 61 | print(f"Total # of Params: {total}\nTrainable params: {trainable}") 62 | 63 | return model.cuda() 64 | 65 | def setup_2D_UNet_params(self, unet_kwargs): 66 | """ 67 | ^^^^^^^^^^^^^ 68 | """ 69 | unet_kwargs["conv_op"] = torch.nn.Conv2d 70 | if self.model_params.get("instance_norm"): 71 | unet_kwargs["norm_op"] = torch.nn.InstanceNorm2d 72 | unet_kwargs["dropout_op"] = torch.nn.Dropout2d 73 | unet_kwargs["nonlin"] = torch.nn.ReLU 74 | unet_kwargs["nonlin_kwargs"] = {"inplace": True} 75 | unet_kwargs["final_nonlin"] = lambda x: x 76 | return unet_kwargs 77 | -------------------------------------------------------------------------------- /kits19cnn/inference/inference_class.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | from os.path import join, isdir 3 | from tqdm import tqdm 4 | import os 5 | import numpy as np 6 | import inspect 7 | import nibabel as nib 8 | import torch 9 | 10 | from kits19cnn.inference.utils import load_weights_infer 11 | 12 | class Predictor(object): 13 | """ 14 | Inference for a single model for every file generated by `test_loader`. 15 | Predictions are saved in `out_dir`. 16 | """ 17 | def __init__(self, out_dir, checkpoint_path, model, 18 | test_loader, pred_3D_params={"do_mirroring": True}, 19 | pseudo_3D: bool = False): 20 | """ 21 | Attributes 22 | out_dir (str): path to the output directory to store predictions 23 | checkpoint_path (str): path to a model checkpoint for `model` 24 | model (torch.nn.Module): class with the `predict_3D` method for 25 | predicting a single patient volume. 26 | test_loader: Iterable instance for generating data 27 | (pref. torch DataLoader) 28 | must have the __len__ arg. 29 | pred_3D_params (dict): kwargs for `model.predict_3D` 30 | pseudo_3D (bool): whether or not to have pseudo 3D inputs 31 | """ 32 | self.out_dir = out_dir 33 | if not isdir(self.out_dir): 34 | os.mkdir(self.out_dir) 35 | print(f"Created {self.out_dir}!") 36 | assert inspect.ismethod(model.predict_3D), \ 37 | "model must have the method `predict_3D`" 38 | self.model = load_weights_infer(checkpoint_path, model) 39 | self.test_loader = test_loader 40 | self.pred_3D_params = pred_3D_params 41 | self.pseudo_3D = pseudo_3D 42 | 43 | def run_3D_predictions(self): 44 | """ 45 | Runs predictions on the dataset (specified in test_loader) 46 | """ 47 | cases = self.test_loader.dataset.im_ids 48 | assert len(cases) == len(self.test_loader) 49 | for (test_batch, case) in tqdm(zip(self.test_loader, cases), total=len(cases)): 50 | test_x = torch.squeeze(test_batch[0], dim=0) 51 | if self.pseudo_3D: 52 | pred, _, act, _ = self.model.predict_3D_pseudo3D_2Dconv(test_x, 53 | **self.pred_3D_params) 54 | else: 55 | pred, _, act, _ = self.model.predict_3D(test_x, 56 | **self.pred_3D_params) 57 | assert len(pred.shape) == 3 58 | assert len(act.shape) == 4 59 | ### possible place to threshold ROI size ### 60 | self.save_pred(pred, act, case) 61 | 62 | def save_pred(self, pred, act, case): 63 | """ 64 | Saves both prediction and activation maps in `out_dir` in the 65 | KiTS19 format. 66 | Args: 67 | pred (np.ndarray): shape (x, y, z) 68 | act (np.ndarray): shape (n_classes, x, y, z) 69 | case: path to a case folder (an element of self.cases) 70 | Returns: 71 | None 72 | """ 73 | # extracting the raw case folder name 74 | case = Path(case).name 75 | out_case_dir = join(self.out_dir, case) 76 | # checking to make sure that the output directories exist 77 | if not isdir(out_case_dir): 78 | os.mkdir(out_case_dir) 79 | 80 | np.save(join(out_case_dir, "pred.npy"), pred) 81 | np.save(join(out_case_dir, "pred_act.npy"), act) 82 | 83 | def resample_predictions(self, orig_spacing, target_spacing, 84 | resampled_preds_dir): 85 | """ 86 | Iterates through `out_dir` and creates resampled .npy arrays to 87 | the specified spacing and saves them in `resampled_preds_dir` 88 | """ 89 | from kits19cnn.io.resample import resample_patient 90 | raise NotImplementedError 91 | 92 | def prepare_submission(self): 93 | """ 94 | Resamples predictions and converts them to a .zip with .nii.gz files. 95 | """ 96 | raise NotImplementedError 97 | -------------------------------------------------------------------------------- /kits19cnn/experiments/train_3d.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from kits19cnn.models import Generic_UNet 4 | from kits19cnn.io import VoxelDataset, ClfSegVoxelDataset 5 | 6 | from .train import TrainExperiment, TrainClfSegExperiment 7 | from .utils import get_preprocessing, get_training_augmentation, \ 8 | get_validation_augmentation, seed_everything 9 | 10 | class TrainSegExperiment(TrainExperiment): 11 | """ 12 | Stores the main parts of a segmentation experiment: 13 | - df split 14 | - datasets 15 | - loaders 16 | - model 17 | - optimizer 18 | - lr_scheduler 19 | - criterion 20 | - callbacks 21 | """ 22 | def __init__(self, config: dict): 23 | """ 24 | Args: 25 | config (dict): from `train_seg_yaml.py` 26 | """ 27 | self.model_params = config["model_params"] 28 | super().__init__(config=config) 29 | 30 | def get_datasets(self, train_ids, valid_ids): 31 | """ 32 | Creates and returns the train and validation datasets. 33 | """ 34 | # preparing transforms 35 | train_aug = get_training_augmentation(self.io_params["aug_key"]) 36 | val_aug = get_validation_augmentation(self.io_params["aug_key"]) 37 | # creating the datasets 38 | train_dataset = VoxelDataset(im_ids=train_ids, 39 | transforms=train_aug, 40 | preprocessing=get_preprocessing()) 41 | valid_dataset = VoxelDataset(im_ids=valid_ids, 42 | transforms=val_aug, 43 | preprocessing=get_preprocessing()) 44 | return (train_dataset, valid_dataset) 45 | 46 | def get_model(self): 47 | architecture = self.model_params["architecture"] 48 | if architecture == "nnunet": 49 | architecture_kwargs = self.model_params[architecture] 50 | architecture_kwargs["conv_op"] = torch.nn.Conv3d 51 | architecture_kwargs["norm_op"] = torch.nn.InstanceNorm3d 52 | architecture_kwargs["dropout_op"] = torch.nn.Dropout3d 53 | architecture_kwargs["nonlin"] = torch.nn.ReLU 54 | architecture_kwargs["nonlin_kwargs"] = {"inplace": True} 55 | architecture_kwargs["final_nonlin"] = lambda x: x 56 | model = Generic_UNet(**architecture_kwargs) 57 | else: 58 | raise NotImplementedError 59 | # calculating # of parameters 60 | total = sum(p.numel() for p in model.parameters()) 61 | trainable = sum(p.numel() for p in model.parameters() if p.requires_grad) 62 | print(f"Total # of Params: {total}\nTrainable params: {trainable}") 63 | 64 | return model 65 | 66 | class TrainClfSegExperiment3D(TrainClfSegExperiment, TrainSegExperiment): 67 | """ 68 | Stores the main parts of a classification+segmentation experiment: 69 | - df split 70 | - datasets 71 | - loaders 72 | - model 73 | - optimizer 74 | - lr_scheduler 75 | - criterion 76 | - callbacks 77 | """ 78 | def __init__(self, config: dict): 79 | """ 80 | Args: 81 | config (dict): from `train_seg_yaml.py` 82 | """ 83 | self.model_params = config["model_params"] 84 | super().__init__(config=config) 85 | 86 | def get_datasets(self, train_ids, valid_ids): 87 | """ 88 | Creates and returns the train and validation datasets. 89 | """ 90 | # preparing transforms 91 | train_aug = get_training_augmentation(self.io_params["aug_key"]) 92 | val_aug = get_validation_augmentation(self.io_params["aug_key"]) 93 | # creating the datasets 94 | preprocess = get_preprocessing() 95 | train_dataset = ClfSegVoxelDataset(im_ids=train_ids, 96 | transforms=train_aug, 97 | preprocessing=preprocess, 98 | mode="both", num_classes=3) 99 | valid_dataset = ClfSegVoxelDataset(im_ids=valid_ids, 100 | transforms=val_aug, 101 | preprocessing=preprocess, 102 | mode="both", num_classes=3) 103 | return (train_dataset, valid_dataset) 104 | -------------------------------------------------------------------------------- /kits19cnn/loss_functions.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | import torch 15 | from torch import nn 16 | from segmentation_models_pytorch.utils.losses import DiceLoss 17 | 18 | from kits19cnn.utils import softmax_helper, sum_tensor 19 | 20 | class BCEDiceLoss(DiceLoss): 21 | __name__ = 'bce_dice_loss' 22 | 23 | def __init__(self, eps=1e-7, activation='sigmoid'): 24 | super().__init__(eps, activation) 25 | self.bce = nn.BCEWithLogitsLoss(reduction='mean') 26 | 27 | def forward(self, y_pr, y_gt): 28 | y_pr = y_pr.float() 29 | y_gt = y_gt.float() 30 | dice = super().forward(y_pr, y_gt) 31 | bce = self.bce(y_pr, y_gt) 32 | return dice + bce 33 | 34 | class CrossentropyND(nn.CrossEntropyLoss): 35 | """ 36 | Network has to have NO NONLINEARITY! 37 | """ 38 | def forward(self, inp, target): 39 | target = target.long() 40 | num_classes = inp.size()[1] 41 | 42 | i0 = 1 43 | i1 = 2 44 | 45 | while i1 < len(inp.shape): # this is ugly but torch only allows to transpose two axes at once 46 | inp = inp.transpose(i0, i1) 47 | i0 += 1 48 | i1 += 1 49 | 50 | inp = inp.contiguous() 51 | inp = inp.view(-1, num_classes) 52 | 53 | target = target.view(-1,) 54 | 55 | return super(CrossentropyND, self).forward(inp, target) 56 | 57 | def get_tp_fp_fn(net_output, gt, axes=None, mask=None, square=False): 58 | """ 59 | net_output must be (b, c, x, y(, z))) 60 | gt must be a label map (shape (b, 1, x, y(, z)) OR shape (b, x, y(, z))) or one hot encoding (b, c, x, y(, z)) 61 | if mask is provided it must have shape (b, 1, x, y(, z))) 62 | :param net_output: 63 | :param gt: 64 | :param axes: 65 | :param mask: mask must be 1 for valid pixels and 0 for invalid pixels 66 | :param square: if True then fp, tp and fn will be squared before summation 67 | :return: 68 | """ 69 | if axes is None: 70 | axes = tuple(range(2, len(net_output.size()))) 71 | 72 | shp_x = net_output.shape 73 | shp_y = gt.shape 74 | 75 | with torch.no_grad(): 76 | if len(shp_x) != len(shp_y): 77 | gt = gt.view((shp_y[0], 1, *shp_y[1:])) 78 | 79 | if all([i == j for i, j in zip(net_output.shape, gt.shape)]): 80 | # if this is the case then gt is probably already a one hot encoding 81 | y_onehot = gt 82 | else: 83 | gt = gt.long() 84 | y_onehot = torch.zeros(shp_x) 85 | if net_output.device.type == "cuda": 86 | y_onehot = y_onehot.cuda(net_output.device.index) 87 | y_onehot.scatter_(1, gt, 1) 88 | 89 | tp = net_output * y_onehot 90 | fp = net_output * (1 - y_onehot) 91 | fn = (1 - net_output) * y_onehot 92 | 93 | if mask is not None: 94 | tp = torch.stack(tuple(x_i * mask[:, 0] for x_i in torch.unbind(tp, dim=1)), dim=1) 95 | fp = torch.stack(tuple(x_i * mask[:, 0] for x_i in torch.unbind(fp, dim=1)), dim=1) 96 | fn = torch.stack(tuple(x_i * mask[:, 0] for x_i in torch.unbind(fn, dim=1)), dim=1) 97 | 98 | if square: 99 | tp = tp ** 2 100 | fp = fp ** 2 101 | fn = fn ** 2 102 | 103 | tp = sum_tensor(tp, axes, keepdim=False) 104 | fp = sum_tensor(fp, axes, keepdim=False) 105 | fn = sum_tensor(fn, axes, keepdim=False) 106 | 107 | return tp, fp, fn 108 | 109 | 110 | class SoftDiceLoss(nn.Module): 111 | def __init__(self, apply_nonlin=None, batch_dice=False, do_bg=True, 112 | smooth=1., square=False): 113 | """ 114 | """ 115 | super(SoftDiceLoss, self).__init__() 116 | 117 | self.square = square 118 | self.do_bg = do_bg 119 | self.batch_dice = batch_dice 120 | self.apply_nonlin = apply_nonlin 121 | self.smooth = smooth 122 | 123 | def forward(self, x, y, loss_mask=None): 124 | shp_x = x.shape 125 | 126 | if self.batch_dice: 127 | axes = [0] + list(range(2, len(shp_x))) 128 | else: 129 | axes = list(range(2, len(shp_x))) 130 | 131 | if self.apply_nonlin is not None: 132 | x = self.apply_nonlin(x) 133 | 134 | tp, fp, fn = get_tp_fp_fn(x, y, axes, loss_mask, self.square) 135 | 136 | dc = (2 * tp + self.smooth) / (2 * tp + fp + fn + self.smooth) 137 | 138 | if not self.do_bg: 139 | if self.batch_dice: 140 | dc = dc[1:] 141 | else: 142 | dc = dc[:, 1:] 143 | dc = dc.mean() 144 | 145 | return -dc 146 | 147 | 148 | class DC_and_CE_loss(nn.Module): 149 | def __init__(self, soft_dice_kwargs, ce_kwargs, aggregate="sum"): 150 | super(DC_and_CE_loss, self).__init__() 151 | self.aggregate = aggregate 152 | self.ce = CrossentropyND(**ce_kwargs) 153 | self.dc = SoftDiceLoss(apply_nonlin=softmax_helper, **soft_dice_kwargs) 154 | 155 | def forward(self, net_output, target): 156 | dc_loss = self.dc(net_output, target) 157 | ce_loss = self.ce(net_output, target) 158 | if self.aggregate == "sum": 159 | result = ce_loss + dc_loss 160 | else: 161 | raise NotImplementedError("nah son") # reserved for other stuff (later) 162 | return result 163 | -------------------------------------------------------------------------------- /kits19cnn/experiments/infer.py: -------------------------------------------------------------------------------- 1 | from glob import glob 2 | from tqdm import tqdm 3 | from abc import abstractmethod 4 | import os 5 | 6 | import torch 7 | from torch.utils.data import DataLoader 8 | from sklearn.model_selection import train_test_split 9 | 10 | from kits19cnn.io import TestVoxelDataset 11 | from kits19cnn.utils import softmax_helper 12 | from kits19cnn.models import Generic_UNet 13 | from .utils import get_preprocessing 14 | 15 | class BaseInferenceExperiment(object): 16 | def __init__(self, config: dict): 17 | """ 18 | Args: 19 | config (dict): 20 | 21 | Attributes: 22 | config-related: 23 | config (dict): 24 | io_params (dict): 25 | in_dir (key: str): path to the data folder 26 | test_size (key: float): split size for test 27 | split_seed (key: int): seed 28 | batch_size (key: int): <- 29 | num_workers (key: int): # of workers for data loaders 30 | split_dict (dict): test_ids 31 | test_dset (torch.data.Dataset): <- 32 | loaders (dict): train/validation loaders 33 | model (torch.nn.Module): <- 34 | """ 35 | # for reuse 36 | self.config = config 37 | self.io_params = config["io_params"] 38 | # initializing the experiment components 39 | self.case_list = self.setup_im_ids() 40 | test_ids = self.get_split()[-1] if config["with_masks"] else self.case_list 41 | print(f"Inferring on {len(test_ids)} test cases") 42 | self.test_dset = self.get_datasets(test_ids) 43 | self.loaders = self.get_loaders() 44 | self.model = self.get_model() 45 | 46 | @abstractmethod 47 | def get_datasets(self, test_ids): 48 | """ 49 | Initializes the data augmentation and preprocessing transforms. Creates 50 | and returns the train and validation datasets. 51 | """ 52 | return 53 | 54 | @abstractmethod 55 | def get_model(self): 56 | """ 57 | Creates and returns the model. 58 | """ 59 | return 60 | 61 | def setup_im_ids(self): 62 | """ 63 | Creates a list of all paths to case folders for the dataset split 64 | """ 65 | search_path = os.path.join(self.config["in_dir"], "*/") 66 | case_list = sorted(glob(search_path)) 67 | case_list = case_list[:210] if self.config["with_masks"] else case_list[210:] 68 | return case_list 69 | 70 | def get_split(self): 71 | """ 72 | Creates train/valid filename splits 73 | """ 74 | # setting up the train/val split with filenames 75 | split_seed: int = self.io_params["split_seed"] 76 | test_size: float = self.io_params["test_size"] 77 | # doing the splits: 1-test_size, test_size//2, test_size//2 78 | print("Splitting the dataset normally...") 79 | train_ids, total_test = train_test_split(self.case_list, 80 | random_state=split_seed, 81 | test_size=test_size) 82 | val_ids, test_ids = train_test_split(sorted(total_test), 83 | random_state=split_seed, 84 | test_size=0.5) 85 | return (train_ids, val_ids, test_ids) 86 | 87 | def get_loaders(self): 88 | """ 89 | Creates train/val loaders from datasets created in self.get_datasets. 90 | Returns the loaders. 91 | """ 92 | # setting up the loaders 93 | b_size, num_workers = self.io_params["batch_size"], self.io_params["num_workers"] 94 | test_loader = DataLoader(self.test_dset, batch_size=b_size, 95 | shuffle=False, num_workers=num_workers) 96 | return {"test": test_loader} 97 | 98 | class SegmentationInferenceExperiment(BaseInferenceExperiment): 99 | """ 100 | Inference Experiment to support prediction experiments 101 | """ 102 | def __init__(self, config: dict): 103 | """ 104 | Args: 105 | config (dict): 106 | """ 107 | self.model_params = config["model_params"] 108 | super().__init__(config=config) 109 | 110 | def get_datasets(self, test_ids): 111 | """ 112 | Creates and returns the test dataset. 113 | """ 114 | # creating the datasets 115 | test_dataset = TestVoxelDataset(im_ids=test_ids, 116 | transforms=None, 117 | preprocessing=get_preprocessing(), 118 | file_ending=self.io_params["file_ending"]) 119 | return test_dataset 120 | 121 | def get_model(self): 122 | architecture = self.model_params["architecture"] 123 | # creating model 124 | if architecture == "nnunet": 125 | unet_kwargs = self.model_params[architecture] 126 | unet_kwargs = self.setup_3D_UNet_params(unet_kwargs) 127 | model = Generic_UNet(**unet_kwargs) 128 | model.inference_apply_nonlin = softmax_helper 129 | else: 130 | raise NotImplementedError 131 | # calculating # of parameters 132 | total = sum(p.numel() for p in model.parameters()) 133 | trainable = sum(p.numel() for p in model.parameters() if p.requires_grad) 134 | print(f"Total # of Params: {total}\nTrainable params: {trainable}") 135 | 136 | return model.cuda() 137 | 138 | def setup_3D_UNet_params(self, unet_kwargs): 139 | """ 140 | ^^^^^^^^^^^^^ 141 | """ 142 | unet_kwargs["conv_op"] = torch.nn.Conv3d 143 | unet_kwargs["norm_op"] = torch.nn.InstanceNorm3d 144 | unet_kwargs["dropout_op"] = torch.nn.Dropout3d 145 | unet_kwargs["nonlin"] = torch.nn.ReLU 146 | unet_kwargs["nonlin_kwargs"] = {"inplace": True} 147 | unet_kwargs["final_nonlin"] = lambda x: x 148 | return unet_kwargs 149 | -------------------------------------------------------------------------------- /kits19cnn/visualize.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | import os 3 | 4 | from typing import Dict, List, Optional, Union # isort:skip 5 | from collections import defaultdict 6 | from pathlib import Path 7 | 8 | import plotly.graph_objs as go 9 | from plotly.offline import init_notebook_mode, iplot 10 | 11 | from catalyst.utils.tensorboard import SummaryItem, SummaryReader 12 | 13 | print("If you're using a notebook, " 14 | "make sure to run %matplotlib inline beforehand.") 15 | 16 | def plot_scan(scan, start_with, show_every, rows=3, cols=3): 17 | """ 18 | Plots multiple scans throughout your medical image. 19 | Args: 20 | scan: numpy array with shape (x,y,z) 21 | start_with: slice to start with 22 | show_every: size of the step between each slice iteration 23 | rows: rows of plot 24 | cols: cols of plot 25 | Returns: 26 | a plot of multiple scans from the same image 27 | """ 28 | fig,ax = plt.subplots(rows, cols, figsize=[3*cols,3*rows]) 29 | for i in range(rows*cols): 30 | ind = start_with + i*show_every 31 | ax[int(i/cols), int(i%cols)].set_title("slice %d" % ind) 32 | ax[int(i/cols), int(i%cols)].axis("off") 33 | 34 | ax[int(i/cols), int(i%cols)].imshow(scan[ind], cmap="gray") 35 | plt.show() 36 | 37 | def plot_scan_and_mask(scan, mask, start_with, show_every, rows=3, cols=3): 38 | """ 39 | Plots multiple scans with the mask overlay throughout your medical image. 40 | Args: 41 | scan: numpy array with shape (x,y,z) 42 | start_with: slice to start with 43 | show_every: size of the step between each slice iteration 44 | rows: rows of plot 45 | cols: cols of plot 46 | Returns: 47 | a plot of multiple scans from the same image 48 | """ 49 | fig,ax = plt.subplots(rows, cols, figsize=[4*cols, 4*rows]) 50 | for i in range(rows*cols): 51 | ind = start_with + i*show_every 52 | ax[int(i/cols), int(i%cols)].set_title("slice %d" % ind) 53 | ax[int(i/cols), int(i%cols)].axis("off") 54 | 55 | ax[int(i/cols), int(i%cols)].imshow(scan[ind], cmap="gray") 56 | ax[int(i/cols), int(i%cols)].imshow(mask[ind], cmap="jet", alpha=0.5) 57 | plt.show() 58 | 59 | # FROM: https://github.com/catalyst-team/catalyst/blob/master/catalyst/dl/utils/visualization.py 60 | def _get_tensorboard_scalars( 61 | logdir: Union[str, Path], metrics: Optional[List[str]], step: str 62 | ) -> Dict[str, List]: 63 | summary_reader = SummaryReader(logdir, types=["scalar"]) 64 | 65 | items = defaultdict(list) 66 | for item in summary_reader: 67 | if step in item.tag and ( 68 | metrics is None or any(m in item.tag for m in metrics) 69 | ): 70 | items[item.tag].append(item) 71 | return items 72 | 73 | def _get_scatter(scalars: List[SummaryItem], name: str) -> go.Scatter: 74 | xs = [s.step for s in scalars] 75 | ys = [s.value for s in scalars] 76 | return go.Scatter(x=xs, y=ys, name=name) 77 | 78 | def plot_tensorboard_log( 79 | logdir: Union[str, Path], 80 | step: Optional[str] = "batch", 81 | metrics: Optional[List[str]] = None, 82 | height: Optional[int] = None, 83 | width: Optional[int] = None 84 | ) -> None: 85 | init_notebook_mode() 86 | logdir = Path(logdir) 87 | 88 | logdirs = { 89 | x.name.replace("_log", ""): x 90 | for x in logdir.glob("**/*") if x.is_dir() and str(x).endswith("_log") 91 | } 92 | 93 | scalars_per_loader = { 94 | key: _get_tensorboard_scalars(inner_logdir, metrics, step) 95 | for key, inner_logdir in logdirs.items() 96 | } 97 | 98 | scalars_per_metric = defaultdict(dict) 99 | for key, value in scalars_per_loader.items(): 100 | for key2, value2 in value.items(): 101 | scalars_per_metric[key2][key] = value2 102 | 103 | figs = [] 104 | for metric_name, metric_logs in scalars_per_metric.items(): 105 | metric_data = [] 106 | for key, value in metric_logs.items(): 107 | try: 108 | data_ = _get_scatter(value, f"{key}/{metric_name}") 109 | metric_data.append(data_) 110 | except: # noqa: E722 111 | pass 112 | 113 | layout = go.Layout( 114 | title=metric_name, 115 | height=height, 116 | width=width, 117 | yaxis=dict(hoverformat=".5f") 118 | ) 119 | fig = go.Figure(data=metric_data, layout=layout) 120 | iplot(fig) 121 | figs.append(fig) 122 | return figs 123 | 124 | def plot_metrics( 125 | logdir: Union[str, Path], 126 | step: Optional[str] = "epoch", 127 | metrics: Optional[List[str]] = None, 128 | height: Optional[int] = None, 129 | width: Optional[int] = None 130 | ) -> None: 131 | """Plots your learning results. 132 | Args: 133 | logdir: the logdir that was specified during training. 134 | step: 'batch' or 'epoch' - what logs to show: for batches or 135 | for epochs 136 | metrics: list of metrics to plot. The loss should be specified as 137 | 'loss', learning rate = '_base/lr' and other metrics should be 138 | specified as names in metrics dict 139 | that was specified during training 140 | height: the height of the whole resulting plot 141 | width: the width of the whole resulting plot 142 | """ 143 | assert step in ["batch", "epoch"], \ 144 | f"Step should be either 'batch' or 'epoch', got '{step}'" 145 | metrics = metrics or ["loss"] 146 | return plot_tensorboard_log(logdir, step, metrics, height, width) 147 | 148 | def save_figs(figs_list, save_dir=None): 149 | """ 150 | Saves plotly figures. (from plot_metrics) 151 | """ 152 | if save_dir is None: 153 | save_dir = os.getcwd() 154 | 155 | for fig in figs_list: 156 | # takes a metric like train/f1/class_0/epoch to f1_class_0_epoch 157 | train_metric_name = fig["data"][0]["name"] 158 | split = train_metric_name.split("/") 159 | metric_name = "".join([f"{name}_" for name in split 160 | if not name in ["train", "valid"]])[:-1] 161 | save_name = os.path.join(save_dir, f"{metric_name}.png") 162 | fig.write_image(save_name) 163 | print(f"Saved {save_name}...") 164 | -------------------------------------------------------------------------------- /notebooks/Visualizing Volumes.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Visualizing the Data\n", 8 | "This notebooks uses [`ipyvolume`](https://github.com/maartenbreddels/ipyvolume)\n", 9 | "To install: (make sure you have [`jupyter_contrib_nbextensions`](https://github.com/ipython-contrib/jupyter_contrib_nbextensions) installed)\n", 10 | "i.e.\n", 11 | "```\n", 12 | "conda install -c conda-forge jupyter_contrib_nbextensions\n", 13 | "```\n", 14 | "\n", 15 | "```\n", 16 | "conda install -c conda-forge ipyvolume OR pip install ipyvolume\n", 17 | "jupyter nbextension enable --py --sys-prefix ipyvolume\n", 18 | "jupyter nbextension enable --py --sys-prefix widgetsnbextension\n", 19 | "```\n", 20 | "* Holes in volume\n", 21 | " * patients 0-3, 7, 9, 14..." 22 | ] 23 | }, 24 | { 25 | "cell_type": "markdown", 26 | "metadata": {}, 27 | "source": [ 28 | "# Raw Interpolated" 29 | ] 30 | }, 31 | { 32 | "cell_type": "code", 33 | "execution_count": 1, 34 | "metadata": {}, 35 | "outputs": [ 36 | { 37 | "name": "stdout", 38 | "output_type": "stream", 39 | "text": [ 40 | "Case: case_00014; Shape: (146, 556, 556)\n" 41 | ] 42 | } 43 | ], 44 | "source": [ 45 | "import os\n", 46 | "import nibabel as nib\n", 47 | "\n", 48 | "dset_dir = r\"C:\\Users\\jchen\\Desktop\\Datasets\\kits19\\data\"\n", 49 | "i = 14\n", 50 | "case = f\"case_000{i}\"\n", 51 | "image = nib.load(os.path.join(dset_dir, case, \"imaging.nii.gz\")).get_fdata().squeeze()\n", 52 | "mask = nib.load(os.path.join(dset_dir, case, \"segmentation.nii.gz\")).get_fdata().squeeze()\n", 53 | "\n", 54 | "print(f\"Case: {case}; Shape: {image.shape}\")\n", 55 | "assert mask.shape == image.shape" 56 | ] 57 | }, 58 | { 59 | "cell_type": "code", 60 | "execution_count": 2, 61 | "metadata": {}, 62 | "outputs": [ 63 | { 64 | "name": "stdout", 65 | "output_type": "stream", 66 | "text": [ 67 | "[[ 0. 0. -0.78162497 0. ]\n", 68 | " [ 0. -0.78162497 0. 0. ]\n", 69 | " [-3. 0. 0. 0. ]\n", 70 | " [ 0. 0. 0. 1. ]]\n" 71 | ] 72 | } 73 | ], 74 | "source": [ 75 | "print(nib.load(os.path.join(dset_dir, case, \"imaging.nii.gz\")).affine)" 76 | ] 77 | }, 78 | { 79 | "cell_type": "code", 80 | "execution_count": 3, 81 | "metadata": {}, 82 | "outputs": [ 83 | { 84 | "name": "stderr", 85 | "output_type": "stream", 86 | "text": [ 87 | "C:\\Users\\jchen\\Miniconda3\\envs\\py36\\lib\\site-packages\\ipyvolume\\widgets.py:179: FutureWarning: Using a non-tuple sequence for multidimensional indexing is deprecated; use `arr[tuple(seq)]` instead of `arr[seq]`. In the future this will be interpreted as an array index, `arr[np.array(seq)]`, which will result either in an error or a different result.\n", 88 | " data_view = self.data_original[view]\n", 89 | "C:\\Users\\jchen\\Miniconda3\\envs\\py36\\lib\\site-packages\\ipyvolume\\utils.py:204: FutureWarning: Using a non-tuple sequence for multidimensional indexing is deprecated; use `arr[tuple(seq)]` instead of `arr[seq]`. In the future this will be interpreted as an array index, `arr[np.array(seq)]`, which will result either in an error or a different result.\n", 90 | " data = (data[slices1] + data[slices2])/2\n", 91 | "C:\\Users\\jchen\\Miniconda3\\envs\\py36\\lib\\site-packages\\ipyvolume\\serialize.py:81: RuntimeWarning: invalid value encountered in true_divide\n", 92 | " gradient = gradient / np.sqrt(gradient[0]**2 + gradient[1]**2 + gradient[2]**2)\n" 93 | ] 94 | }, 95 | { 96 | "data": { 97 | "application/vnd.jupyter.widget-view+json": { 98 | "model_id": "49c7b4f7217a417a9edf3de52124d107", 99 | "version_major": 2, 100 | "version_minor": 0 101 | }, 102 | "text/plain": [ 103 | "VBox(children=(VBox(children=(HBox(children=(Label(value='levels:'), FloatSlider(value=0.1, max=1.0, step=0.00…" 104 | ] 105 | }, 106 | "metadata": {}, 107 | "output_type": "display_data" 108 | } 109 | ], 110 | "source": [ 111 | "import ipyvolume as ipv\n", 112 | "ipv.quickvolshow(mask)" 113 | ] 114 | }, 115 | { 116 | "cell_type": "markdown", 117 | "metadata": {}, 118 | "source": [ 119 | "## Preprocessed" 120 | ] 121 | }, 122 | { 123 | "cell_type": "code", 124 | "execution_count": 6, 125 | "metadata": {}, 126 | "outputs": [ 127 | { 128 | "name": "stdout", 129 | "output_type": "stream", 130 | "text": [ 131 | "Case: case_00014; Shape: (136, 268, 268)\n" 132 | ] 133 | } 134 | ], 135 | "source": [ 136 | "import os\n", 137 | "import numpy as np\n", 138 | "\n", 139 | "dset_dir = r\"C:\\Users\\jchen\\Desktop\\Datasets\\kits_preprocessed_isensee_spacing\"\n", 140 | "case = f\"case_000{i}\"\n", 141 | "image = np.load(os.path.join(dset_dir, case, \"imaging.npy\")).squeeze()\n", 142 | "mask = np.load(os.path.join(dset_dir, case, \"segmentation.npy\")).squeeze()\n", 143 | "\n", 144 | "print(f\"Case: {case}; Shape: {image.shape}\")\n", 145 | "assert mask.shape == image.shape" 146 | ] 147 | }, 148 | { 149 | "cell_type": "code", 150 | "execution_count": 7, 151 | "metadata": { 152 | "scrolled": false 153 | }, 154 | "outputs": [ 155 | { 156 | "data": { 157 | "application/vnd.jupyter.widget-view+json": { 158 | "model_id": "fb519330c64e4f959797013e325e8cc7", 159 | "version_major": 2, 160 | "version_minor": 0 161 | }, 162 | "text/plain": [ 163 | "VBox(children=(VBox(children=(HBox(children=(Label(value='levels:'), FloatSlider(value=0.1, max=1.0, step=0.00…" 164 | ] 165 | }, 166 | "metadata": {}, 167 | "output_type": "display_data" 168 | } 169 | ], 170 | "source": [ 171 | "import ipyvolume as ipv\n", 172 | "ipv.quickvolshow(mask)" 173 | ] 174 | }, 175 | { 176 | "cell_type": "code", 177 | "execution_count": null, 178 | "metadata": {}, 179 | "outputs": [], 180 | "source": [] 181 | }, 182 | { 183 | "cell_type": "markdown", 184 | "metadata": {}, 185 | "source": [ 186 | "# Predictions" 187 | ] 188 | }, 189 | { 190 | "cell_type": "code", 191 | "execution_count": 5, 192 | "metadata": {}, 193 | "outputs": [], 194 | "source": [ 195 | "# ..." 196 | ] 197 | }, 198 | { 199 | "cell_type": "code", 200 | "execution_count": null, 201 | "metadata": {}, 202 | "outputs": [], 203 | "source": [] 204 | } 205 | ], 206 | "metadata": { 207 | "kernelspec": { 208 | "display_name": "Python 3", 209 | "language": "python", 210 | "name": "python3" 211 | }, 212 | "language_info": { 213 | "codemirror_mode": { 214 | "name": "ipython", 215 | "version": 3 216 | }, 217 | "file_extension": ".py", 218 | "mimetype": "text/x-python", 219 | "name": "python", 220 | "nbconvert_exporter": "python", 221 | "pygments_lexer": "ipython3", 222 | "version": "3.6.8" 223 | } 224 | }, 225 | "nbformat": 4, 226 | "nbformat_minor": 2 227 | } 228 | -------------------------------------------------------------------------------- /kits19cnn/io/custom_transforms.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import math 3 | import random 4 | from batchgenerators.transforms import AbstractTransform 5 | 6 | from kits19cnn.io.custom_augmentations import foreground_crop, center_crop, \ 7 | random_resized_crop 8 | 9 | class RandomResizedCropTransform(AbstractTransform): 10 | """ 11 | Crop the given array to random size and aspect ratio. 12 | Doesn't resize across the depth dimenion (assumes it is dim=0) if 13 | the data is 3D. 14 | 15 | A crop of random size (default: of 0.08 to 1.0) of the original size and a 16 | random aspect ratio (default: of 3/4 to 4/3) of the original aspect ratio 17 | is made. This crop is finally resized to given size. 18 | This is popularly used to train the Inception networks. 19 | 20 | Assumes the data and segmentation masks are the same size. 21 | """ 22 | def __init__(self, target_size, scale=(0.08, 1.0), 23 | ratio=(3. / 4., 4. / 3.), 24 | data_key="data", label_key="seg", p_per_sample=0.33, 25 | crop_kwargs={}, resize_kwargs={}): 26 | """ 27 | Attributes: 28 | pass 29 | """ 30 | if len(target_size) > 2: 31 | print("Currently only adjusts the aspect ratio for the 2D dims.") 32 | if (scale[0] > scale[1]) or (ratio[0] > ratio[1]): 33 | warnings.warn("range should be of kind (min, max)") 34 | self.target_size = target_size 35 | self.scale = scale 36 | self.ratio = ratio 37 | self.data_key = data_key 38 | self.label_key = label_key 39 | self.p_per_sample = p_per_sample 40 | self.crop_kwargs = crop_kwargs 41 | self.resize_kwargs = resize_kwargs 42 | 43 | def _get_image_size(self, data): 44 | """ 45 | Assumes data has shape (b, c, h, w (, d)). Fetches the h, w, and d. 46 | depth if applicable. 47 | """ 48 | return data.shape[2:] 49 | 50 | def get_crop_size(self, data, scale, ratio): 51 | """ 52 | Get parameters for ``crop`` for a random sized crop. 53 | """ 54 | shape_dims = self._get_image_size(data) 55 | area = np.prod(shape_dims) 56 | 57 | while True: 58 | target_area = random.uniform(*scale) * area 59 | log_ratio = (math.log(ratio[0]), math.log(ratio[1])) 60 | aspect_ratio = math.exp(random.uniform(*log_ratio)) 61 | 62 | w = int(round(math.sqrt(target_area * aspect_ratio))) 63 | h = int(round(math.sqrt(target_area / aspect_ratio))) 64 | 65 | if len(shape_dims) == 3: 66 | depth = shape_dims[0] 67 | crop_size = np.array([depth, h, w]) 68 | else: 69 | crop_size = np.array([h, w]) 70 | 71 | if (crop_size <= shape_dims).all() and (crop_size > 0).all(): 72 | return crop_size 73 | 74 | def __call__(self, **data_dict): 75 | """ 76 | Actually doing the cropping. 77 | """ 78 | data = data_dict.get(self.data_key) 79 | seg = data_dict.get(self.label_key) 80 | if np.random.uniform() < self.p_per_sample: 81 | crop_size = self.get_crop_size(data, self.scale, self.ratio) 82 | data, seg = random_resized_crop(data, seg, 83 | target_size=self.target_size, 84 | crop_size=crop_size, 85 | crop_kwargs=self.crop_kwargs, 86 | resize_kwargs=self.resize_kwargs) 87 | else: 88 | data, seg = center_crop(data, self.target_size, seg, 89 | crop_kwargs=self.crop_kwargs) 90 | 91 | data_dict[self.data_key] = data 92 | if seg is not None: 93 | data_dict[self.label_key] = seg.astype(np.float32) 94 | 95 | return data_dict 96 | 97 | class ROICropTransform(AbstractTransform): 98 | """ 99 | Crops the foreground in images `p_per_sample` part of the time. The 100 | fallback cropping is center cropping. 101 | """ 102 | def __init__(self, crop_size=128, margins=(0, 0, 0), data_key="data", 103 | label_key="seg", coords_key="bbox_coords", 104 | p_per_sample=0.33, crop_kwargs={}): 105 | self.data_key = data_key 106 | self.label_key = label_key 107 | self.coords_key = coords_key 108 | self.margins = margins 109 | self.crop_size = crop_size 110 | self.p_per_sample = p_per_sample 111 | self.crop_kwargs = crop_kwargs 112 | 113 | def __call__(self, **data_dict): 114 | """ 115 | Actually doing the cropping. Make sure that data_dict has the 116 | a key for the coords (self.coords_key) if p>0. 117 | (If the output of data_dict.get(self.coords_key) is None, then foreground 118 | crops are done on-the-fly). 119 | """ 120 | data = data_dict.get(self.data_key) 121 | seg = data_dict.get(self.label_key) 122 | if np.random.uniform() < self.p_per_sample: 123 | coords = data_dict.get(self.coords_key) 124 | 125 | data, seg = foreground_crop(data, seg, patch_size=self.crop_size, 126 | margins=self.margins, 127 | bbox_coords=coords, 128 | crop_kwargs=self.crop_kwargs) 129 | else: 130 | data, seg = center_crop(data, self.crop_size, seg, 131 | crop_kwargs=self.crop_kwargs) 132 | 133 | data_dict[self.data_key] = data 134 | if seg is not None: 135 | data_dict[self.label_key] = seg 136 | 137 | return data_dict 138 | 139 | class MultiClassToBinaryTransform(AbstractTransform): 140 | """ 141 | For changing a multi-class case to a binary one. Specify the label to 142 | change to binary with `roi_label`. 143 | - Don't forget to adjust `remove_label` accordingly! 144 | - label will be turned to a binary label with only `roi_label` 145 | existing as 1s 146 | """ 147 | def __init__(self, roi_label="2", remove_label="1", label_key="seg"): 148 | self.roi_label = int(roi_label) 149 | self.remove_label = int(remove_label) 150 | self.label_key = label_key 151 | 152 | def __call__(self, **data_dict): 153 | """ 154 | Replaces the label values 155 | """ 156 | label = data_dict.get(self.label_key) 157 | # changing labels 158 | label[label == self.remove_label] = 0 159 | label[label == self.roi_label] = 1 160 | 161 | data_dict[self.label_key] = label 162 | 163 | return data_dict 164 | 165 | class RepeatChannelsTransform(AbstractTransform): 166 | """ 167 | Repeats across the channels dimension `num_tiles` number of times. 168 | """ 169 | def __init__(self, num_repeats=3, data_key="data"): 170 | self.num_repeats = num_repeats 171 | self.data_key = data_key 172 | 173 | def __call__(self, **data_dict): 174 | """ 175 | Repeats across the channels dimension (axis=1). 176 | """ 177 | data = data_dict.get(self.data_key) 178 | 179 | data_dict[self.data_key] = np.repeat(data, self.num_repeats, axis=1) 180 | 181 | return data_dict 182 | -------------------------------------------------------------------------------- /kits19cnn/experiments/train_2d.py: -------------------------------------------------------------------------------- 1 | import json 2 | from abc import abstractmethod 3 | 4 | import torch 5 | import segmentation_models_pytorch as smp 6 | 7 | from kits19cnn.io import SliceDataset, PseudoSliceDataset 8 | from kits19cnn.models import Generic_UNet 9 | from .utils import get_training_augmentation, get_validation_augmentation, \ 10 | get_preprocessing 11 | from .train import TrainExperiment, TrainClfSegExperiment 12 | 13 | class TrainExperiment2D(TrainExperiment): 14 | """ 15 | Stores the main parts of a experiment with 2D images: 16 | - df split 17 | - datasets 18 | - loaders 19 | - model 20 | - optimizer 21 | - lr_scheduler 22 | - criterion 23 | - callbacks 24 | """ 25 | def __init__(self, config: dict): 26 | """ 27 | Args: 28 | config (dict): from `train_seg_yaml.py` 29 | """ 30 | self.model_params = config["model_params"] 31 | super().__init__(config=config) 32 | 33 | @abstractmethod 34 | def get_model(self): 35 | """ 36 | Creates and returns the model. 37 | """ 38 | return 39 | 40 | def get_datasets(self, train_ids, valid_ids): 41 | """ 42 | Creates and returns the train and validation datasets. 43 | """ 44 | # preparing transforms 45 | train_aug = get_training_augmentation(self.io_params["aug_key"]) 46 | val_aug = get_validation_augmentation(self.io_params["aug_key"]) 47 | use_rgb = "smp" in self.model_params["architecture"] 48 | # creating the datasets 49 | with open(self.io_params["slice_indices_path"], "r") as fp: 50 | pos_slice_dict = json.load(fp) 51 | p_pos_per_sample = self.io_params["p_pos_per_sample"] 52 | if self.io_params.get("pseudo_3D"): 53 | assert not use_rgb, \ 54 | "Currently architectures that require RGB inputs cannot use pseudo slices." 55 | train_dataset = PseudoSliceDataset(im_ids=train_ids, 56 | pos_slice_dict=pos_slice_dict, 57 | transforms=train_aug, 58 | preprocessing=get_preprocessing(use_rgb), 59 | p_pos_per_sample=p_pos_per_sample, 60 | mode=self.config["mode"], 61 | num_pseudo_slices=self.io_params["num_pseudo_slices"]) 62 | valid_dataset = PseudoSliceDataset(im_ids=valid_ids, 63 | pos_slice_dict=pos_slice_dict, 64 | transforms=val_aug, 65 | preprocessing=get_preprocessing(use_rgb), 66 | p_pos_per_sample=p_pos_per_sample, 67 | mode=self.config["mode"], 68 | num_pseudo_slices=self.io_params["num_pseudo_slices"]) 69 | else: 70 | train_dataset = SliceDataset(im_ids=train_ids, 71 | pos_slice_dict=pos_slice_dict, 72 | transforms=train_aug, 73 | preprocessing=get_preprocessing(use_rgb), 74 | p_pos_per_sample=p_pos_per_sample, 75 | mode=self.config["mode"]) 76 | valid_dataset = SliceDataset(im_ids=valid_ids, 77 | pos_slice_dict=pos_slice_dict, 78 | transforms=val_aug, 79 | preprocessing=get_preprocessing(use_rgb), 80 | p_pos_per_sample=p_pos_per_sample, 81 | mode=self.config["mode"]) 82 | 83 | return (train_dataset, valid_dataset) 84 | 85 | class TrainSegExperiment2D(TrainExperiment2D): 86 | """ 87 | Stores the main parts of a segmentation experiment: 88 | - df split 89 | - datasets 90 | - loaders 91 | - model 92 | - optimizer 93 | - lr_scheduler 94 | - criterion 95 | - callbacks 96 | """ 97 | def __init__(self, config: dict): 98 | """ 99 | Args: 100 | config (dict): from `train_seg_yaml.py` 101 | """ 102 | self.model_params = config["model_params"] 103 | super().__init__(config=config) 104 | 105 | def get_model(self): 106 | architecture = self.model_params["architecture"] 107 | if architecture.lower() == "nnunet": 108 | architecture_kwargs = self.model_params[architecture] 109 | architecture_kwargs["norm_op"] = torch.nn.InstanceNorm2d 110 | architecture_kwargs["nonlin"] = torch.nn.ReLU 111 | architecture_kwargs["nonlin_kwargs"] = {"inplace": True} 112 | architecture_kwargs["final_nonlin"] = lambda x: x 113 | model = Generic_UNet(**architecture_kwargs) 114 | elif architecture.lower() == "unet_smp": 115 | model = smp.Unet(encoder_name=self.model_params["encoder"], 116 | encoder_weights="imagenet", activation=None, 117 | **self.model_params[architecture]) 118 | elif architecture.lower() == "fpn_smp": 119 | model = smp.FPN(encoder_name=self.model_params["encoder"], 120 | encoder_weights="imagenet", activation=None, 121 | **self.model_params[architecture]) 122 | # calculating # of parameters 123 | total = sum(p.numel() for p in model.parameters()) 124 | trainable = sum(p.numel() for p in model.parameters() if p.requires_grad) 125 | print(f"Total # of Params: {total}\nTrainable params: {trainable}") 126 | 127 | return model 128 | 129 | class TrainClfSegExperiment2D(TrainExperiment2D, TrainClfSegExperiment): 130 | """ 131 | Stores the main parts of a classification+segmentation experiment: 132 | - df split 133 | - datasets 134 | - loaders 135 | - model 136 | - optimizer 137 | - lr_scheduler 138 | - criterion 139 | - callbacks 140 | """ 141 | def __init__(self, config: dict): 142 | """ 143 | Args: 144 | config (dict): from `train_seg_yaml.py` 145 | """ 146 | self.model_params = config["model_params"] 147 | super().__init__(config=config) 148 | 149 | def get_model(self): 150 | architecture = self.model_params["architecture"] 151 | if architecture.lower() == "nnunet": 152 | architecture_kwargs = self.model_params[architecture] 153 | if self.io_params["batch_size"] < 10: 154 | architecture_kwargs["norm_op"] = torch.nn.InstanceNorm2d 155 | architecture_kwargs["nonlin"] = torch.nn.ReLU 156 | architecture_kwargs["nonlin_kwargs"] = {"inplace": True} 157 | architecture_kwargs["final_nonlin"] = lambda x: x 158 | model = Generic_UNet(**architecture_kwargs) 159 | else: 160 | raise NotImplementedError 161 | # calculating # of parameters 162 | total = sum(p.numel() for p in model.parameters()) 163 | trainable = sum(p.numel() for p in model.parameters() if p.requires_grad) 164 | print(f"Total # of Params: {total}\nTrainable params: {trainable}") 165 | 166 | return model 167 | -------------------------------------------------------------------------------- /kits19cnn/inference/evaluate.py: -------------------------------------------------------------------------------- 1 | from tqdm import tqdm 2 | import numpy as np 3 | import nibabel as nib 4 | import pandas as pd 5 | import os 6 | from os.path import isdir, join 7 | from pathlib import Path 8 | 9 | from kits19cnn.metrics import evaluate_official 10 | from sklearn.metrics import precision_recall_fscore_support 11 | 12 | class Evaluator(object): 13 | """ 14 | Evaluates all of the predictions in a user-specified directory and logs them in a csv. Assumes that 15 | the output is in the KiTS19 file structure. 16 | """ 17 | def __init__(self, orig_img_dir, pred_dir, cases=None, 18 | label_file_ending=".npy", binary_tumor=False): 19 | """ 20 | Attributes: 21 | orig_img_dir: path to the directory containing the 22 | labels to evaluate with 23 | i.e. original kits19/data directory or the preprocessed imgs 24 | directory 25 | assumes structure: 26 | orig_img_dir 27 | case_xxxxx 28 | imaging{file_ending} 29 | segmentation{file_ending} 30 | pred_dir: path to the predictions directory, created by Predictor 31 | assumes structure: 32 | pred_dir 33 | case_xxxxx 34 | pred.npy 35 | act.npy 36 | cases: list of filepaths to case folders or just case folder names. 37 | Defaults to None. 38 | label_file_ending (str): one of ['.npy', '.nii', '.nii.gz'] 39 | binary_tumor (bool): whether or not to treat predicted 1s as tumor 40 | """ 41 | self.orig_img_dir = orig_img_dir 42 | self.pred_dir = pred_dir 43 | self.file_ending = label_file_ending 44 | assert self.file_ending in [".npy", ".nii", ".nii.gz"], \ 45 | "label_file_ending must be one of [''.npy', '.nii', '.nii.gz']" 46 | # converting cases from filepaths to raw folder names 47 | if cases is None: 48 | self.cases_raw = [case \ 49 | for case in os.listdir(self.pred_dir) \ 50 | if case.startswith("case")] 51 | assert len(self.cases_raw) > 0, \ 52 | "Please make sure that pred_dir has the case folders" 53 | elif cases is not None: 54 | # extracting raw cases from filepath cases 55 | cases_raw = [Path(case).name for case in cases] 56 | # filtering them down to only cases in pred_dir 57 | self.cases_raw = [case for case in cases_raw \ 58 | if isdir(join(self.pred_dir, case))] 59 | self.binary_tumor = binary_tumor 60 | if self.binary_tumor: 61 | print("Evaluating predicted 1s as tumor (changed to 2).") 62 | 63 | def evaluate_all(self, print_metrics=False): 64 | """ 65 | Evaluates all cases and creates the results.csv, which stores all of 66 | the metrics and the averages. 67 | Args: 68 | print_metrics (bool): whether or not to print metrics. 69 | Defaults to False to be cleaner with tqdm. 70 | """ 71 | metrics_dict = {"cases": [], 72 | "tk_dice": [], "tu_dice": [], 73 | "precision": [], "recall": [], 74 | "fpr": [], "orig_shape": [], 75 | "support": [], "pred_support": []} 76 | 77 | for case in tqdm(self.cases_raw): 78 | # loading the necessary arrays 79 | label, pred = self.load_masks_and_pred(case) 80 | metrics_dict = self.eval_all_metrics_per_case(metrics_dict, label, 81 | pred, case, 82 | print_metrics) 83 | 84 | metrics_dict = self.round_all(self.average_all_cases_per_metric(metrics_dict)) 85 | df = pd.DataFrame(metrics_dict) 86 | metrics_path = join(self.pred_dir, "results.csv") 87 | print(f"Saving {metrics_path}...") 88 | df.to_csv(metrics_path) 89 | 90 | def load_masks_and_pred(self, case): 91 | """ 92 | Loads mask and prediction from `case` 93 | Args: 94 | case (str): case folder names to use 95 | Returns: 96 | label (np.ndarray): shape (x, y, z) 97 | pred (np.ndarray): shape (x, y, z) 98 | """ 99 | y_path = join(self.orig_img_dir, case, f"segmentation{self.file_ending}") 100 | if self.file_ending == ".npy": 101 | label = np.load(y_path) 102 | elif self.file_ending == ".nii.gz" or self.file_ending == ".nii": 103 | label = nib.load(y_path).get_fdata() 104 | pred = np.load(join(self.pred_dir, case, "pred.npy")).squeeze() 105 | if self.binary_tumor: 106 | # treating prediced 1s as tumor (2) 107 | pred[pred == 1] = 2 108 | return (label, pred) 109 | 110 | def eval_all_metrics_per_case(self, metrics_dict, y_true, y_pred, 111 | case, print_metrics=False): 112 | """ 113 | Calculates the official metrics, precision, recall, specificity (fpr), 114 | and stores some metadata such as the original shape and support 115 | (# of pixels for each class). They are then appended to the main 116 | metrics dictionary that is to become results.csv. 117 | """ 118 | # calculating metrics 119 | tk_dice, tu_dice = evaluate_official(y_true, y_pred) 120 | prec, recall, _, supp = precision_recall_fscore_support(y_true.ravel(), 121 | y_pred.ravel(), 122 | labels=[0, 1, 2]) 123 | pred_supp = np.unique(y_pred, return_counts=True)[-1] 124 | fpr = 1-recall 125 | orig_shape = y_true.shape 126 | 127 | if print_metrics: 128 | print(f"PPV: {prec}\nTPR: {recall}\nSupp: {supp}") 129 | print(f"Tumour and Kidney Dice: {tk_dice}; Tumour Dice: {tu_dice}") 130 | # order for appending (sorted keys) 131 | # ['cases', 'fpr', 'orig_shape', 'precision', 'pred_support', 'recall', 132 | # 'support', 'tk_dice', 'tu_dice'] 133 | append_list = [case, fpr, orig_shape, prec, pred_supp, recall, supp, 134 | tk_dice, tu_dice] 135 | sorted_keys = sorted(metrics_dict.keys()) 136 | assert len(append_list) == len(sorted_keys) 137 | # appending to each key's list 138 | for (key_, value_) in zip(sorted_keys, append_list): 139 | metrics_dict[key_].append(value_) 140 | return metrics_dict 141 | 142 | def average_all_cases_per_metric(self, metrics_dict): 143 | """ 144 | Averages the metrics (each key of metrics_dict). 145 | """ 146 | metrics_dict["cases"].append("average") 147 | for key in list(metrics_dict.keys()): 148 | if key == "cases": 149 | pass 150 | else: 151 | # axis=0 will make it so that each sub-axis of orig_shape and 152 | # support will be averaged 153 | try: 154 | metrics_dict[key].append(np.mean(metrics_dict[key], axis=0)) 155 | except: 156 | metrics_dict[key].append("N/A") 157 | return metrics_dict 158 | 159 | def round_all(self, metrics_dict): 160 | """ 161 | Rounding all relevant metrics to three decimal places for cleanliness. 162 | """ 163 | for key in list(metrics_dict.keys()): 164 | if key in ["cases", "pred_support"]: 165 | pass 166 | else: 167 | metrics_dict[key] = np.round(metrics_dict[key], 168 | decimals=3).tolist() 169 | return metrics_dict 170 | -------------------------------------------------------------------------------- /kits19cnn/io/dataset.py: -------------------------------------------------------------------------------- 1 | from os.path import isfile, join 2 | import numpy as np 3 | import nibabel as nib 4 | 5 | import torch 6 | from torch.utils.data import Dataset 7 | 8 | class VoxelDataset(Dataset): 9 | def __init__(self, im_ids: np.array, 10 | transforms=None, 11 | preprocessing=None, 12 | file_ending: str = ".npy"): 13 | """ 14 | Attributes 15 | im_ids (np.ndarray): of image names. 16 | transforms (albumentations.augmentation): transforms to apply 17 | before preprocessing. Defaults to HFlip and ToTensor 18 | preprocessing: ops to perform after transforms, such as 19 | z-score standardization. Defaults to None. 20 | file_ending (str): one of ['.npy', '.nii', '.nii.gz'] 21 | """ 22 | self.im_ids = im_ids 23 | self.transforms = transforms 24 | self.preprocessing = preprocessing 25 | self.file_ending = file_ending 26 | print(f"Using the {file_ending} files...") 27 | 28 | def __getitem__(self, idx): 29 | # loads data as a numpy arr and then adds the channel + batch size dimensions 30 | case_id = self.im_ids[idx] 31 | x, y = self.load_volume(case_id) 32 | if self.transforms: 33 | x = x[None] if len(x.shape) == 4 else x 34 | y = y[None] if len(y.shape) == 4 else y 35 | # batchgenerators requires shape: (b, c, ...) 36 | data_dict = self.transforms(**{"data": x, "seg": y}) 37 | x, y = data_dict["data"], data_dict["seg"] 38 | if self.preprocessing: 39 | preprocessed = self.preprocessing(**{"data": x, "seg": y}) 40 | x, y = preprocessed["data"], preprocessed["seg"] 41 | # squeeze to remove batch size dim 42 | x = torch.squeeze(x, dim=0).float() 43 | y = torch.squeeze(y, dim=0) 44 | return (x, y) 45 | 46 | def __len__(self): 47 | return len(self.im_ids) 48 | 49 | def load_volume(self, case_id): 50 | """ 51 | Loads volume from either .npy or nifti files. 52 | Args: 53 | case_id: path to the case folder 54 | i.e. /content/kits19/data/case_00001 55 | Returns: 56 | Tuple of: 57 | - x (np.ndarray): shape (1, d, h, w) 58 | - y (np.ndarray): same shape as x 59 | """ 60 | x_path = join(case_id, f"imaging{self.file_ending}") 61 | y_path = join(case_id, f"segmentation{self.file_ending}") 62 | if self.file_ending == ".npy": 63 | x, y = np.load(x_path), np.load(y_path) 64 | elif self.file_ending == ".nii.gz" or self.file_ending == ".nii": 65 | x, y = nib.load(x_path).get_fdata(), nib.load(y_path).get_fdata() 66 | return (x[None], y[None]) 67 | 68 | class ClfSegVoxelDataset(VoxelDataset): 69 | """ 70 | Can handle classification+segmentation, classification only, and seg only 71 | outputs. 72 | """ 73 | def __init__(self, im_ids: np.array, 74 | transforms=None, 75 | preprocessing=None, 76 | file_ending: str = ".npy", 77 | mode: str = "both", 78 | num_classes: int = 3): 79 | """ 80 | Attributes 81 | im_ids (np.ndarray): of image names. 82 | transforms (albumentations.augmentation): transforms to apply 83 | before preprocessing. Defaults to HFlip and ToTensor 84 | preprocessing: ops to perform after transforms, such as 85 | z-score standardization. Defaults to None. 86 | file_ending (str): one of ['.npy', '.nii', '.nii.gz'] 87 | mode (str): decides how the nature of the outputs 88 | must be one of ['both', 'clf_only', 'seg_only'] 89 | num_classes (int): number of classes. Defaults to 3. 90 | """ 91 | super().__init__(im_ids=im_ids, transforms=transforms, 92 | preprocessing=preprocessing, file_ending=file_ending) 93 | assert mode.lower() in ["both", "clf_only", "seg_only"], \ 94 | "`mode` must be one of ['both', 'clf_only', 'seg_only']" 95 | self.mode = mode 96 | self.num_classes = num_classes 97 | 98 | def __getitem__(self, idx): 99 | # loads data as a numpy arr and then adds the channel + batch size dimensions 100 | case_id = self.im_ids[idx] 101 | x, y = self.load_volume(case_id) 102 | if self.transforms: 103 | x = x[None] if len(x.shape) == 4 else x 104 | y = y[None] if len(y.shape) == 4 else y 105 | # batchgenerators requires shape: (b, c, ...) 106 | data_dict = self.transforms(**{"data": x, "seg": y}) 107 | x, y = data_dict["data"], data_dict["seg"] 108 | 109 | if self.mode in ["both", "clf_only"]: 110 | y_clf = torch.from_numpy(self.get_clf_label_from_cropped_mask(y)) 111 | 112 | if self.preprocessing: 113 | preprocessed = self.preprocessing(**{"data": x, "seg": y}) 114 | x, y = preprocessed["data"], preprocessed["seg"] 115 | # squeeze to remove batch size dim 116 | if torch.is_tensor(x): 117 | x = torch.squeeze(x, dim=0).float() 118 | if torch.is_tensor(y): 119 | y = torch.squeeze(y, dim=0) 120 | 121 | if self.mode == "both": 122 | return {"features": x, "seg_targets": y, "clf_targets": y_clf} 123 | elif self.mode == "clf_only": 124 | return (x, y_clf) 125 | elif self.mode == "seg_only": 126 | return (x, y) 127 | 128 | def get_clf_label_from_cropped_mask(self, cropped_mask: np.array): 129 | """ 130 | Multi-label one-hot encoding of mask to get the classification 131 | label. 132 | Args: 133 | cropped_mask (np.ndarray): contains int in [0, num_classes-1] 134 | Returns: 135 | one_hot (np.ndarray): multi-label one hot encoded array 136 | i.e. [0, 1, 0] or [1, 0, 1], etc. 137 | """ 138 | unique = np.unique(cropped_mask).astype(np.int32) 139 | one_hot = np.zeros(self.num_classes) 140 | one_hot[unique] = 1 141 | return one_hot 142 | 143 | class TestVoxelDataset(VoxelDataset): 144 | """ 145 | Same as VoxelDataset, but can handle when there are no masks (just returns 146 | blank masks). This is a separate class to prevent lowkey errors with 147 | blank masks--VoxelDataset explicitly fails when there are no masks. 148 | """ 149 | def __init__(self, im_ids: np.array, 150 | transforms=None, 151 | preprocessing=None, 152 | file_ending=".npy"): 153 | """ 154 | Attributes 155 | im_ids (np.ndarray): of image names. 156 | transforms (albumentations.augmentation): transforms to apply 157 | before preprocessing. Defaults to HFlip and ToTensor 158 | preprocessing: ops to perform after transforms, such as 159 | z-score standardization. Defaults to None. 160 | file_ending (str): one of ['.npy', '.nii', '.nii.gz'] 161 | """ 162 | super().__init__(im_ids=im_ids, transforms=transforms, 163 | preprocessing=preprocessing, file_ending=file_ending) 164 | 165 | def load_volume(self, case_id): 166 | """ 167 | Loads volume from either .npy or nifti files. 168 | Args: 169 | case_id: path to the case folder 170 | i.e. /content/kits19/data/case_00001 171 | Returns: 172 | Tuple of: 173 | - x (np.ndarray): shape (1, d, h, w) 174 | - y (np.ndarray): same shape as x 175 | if this does not exist, it's returned as a blank mask 176 | """ 177 | x_path = join(case_id, f"imaging{self.file_ending}") 178 | y_path = join(case_id, f"segmentation{self.file_ending}") 179 | if self.file_ending == ".npy": 180 | x = np.load(x_path) 181 | y = np.load(y_path) if isfile(y_path) else np.zeros(x.shape) 182 | elif self.file_ending == ".nii.gz" or self.file_ending == ".nii": 183 | x = nib.load(x_path).get_fdata() 184 | y = nib.load(y_path).get_fdata() if isfile(y_path) else np.zeros(x.shape) 185 | return (x[None], y[None]) 186 | -------------------------------------------------------------------------------- /kits19cnn/io/resample.py: -------------------------------------------------------------------------------- 1 | 2 | # Copyright 2019 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | from collections import OrderedDict 17 | from skimage.transform import resize 18 | from scipy.ndimage.interpolation import map_coordinates 19 | import numpy as np 20 | from batchgenerators.augmentations.utils import resize_segmentation 21 | 22 | def get_do_separate_z(spacing): 23 | do_separate_z = (np.max(spacing) / np.min(spacing)) > RESAMPLING_SEPARATE_Z_ANISOTROPY_THRESHOLD 24 | return do_separate_z 25 | 26 | 27 | def get_lowres_axis(new_spacing): 28 | axis = np.where(max(new_spacing) / np.array(new_spacing) == 1)[0] # find which axis is anisotropic 29 | return axis 30 | 31 | def resample_patient(data, seg, original_spacing, target_spacing, order_data=3, 32 | order_seg=0, force_separate_z=False, 33 | cval_data=0, cval_seg=0, order_z_data=0, order_z_seg=0): 34 | """ 35 | :param cval_seg: 36 | :param cval_data: 37 | :param data: 38 | :param seg: 39 | :param original_spacing: 40 | :param target_spacing: 41 | :param order_data: 42 | :param order_seg: 43 | :param force_separate_z: if None then we dynamically decide how to resample along z, if True/False then always 44 | /never resample along z separately 45 | :param order_z_seg: only applies if do_separate_z is True 46 | :param order_z_data: only applies if do_separate_z is True 47 | :return: 48 | """ 49 | assert not ((data is None) and (seg is None)) 50 | if data is not None: 51 | assert len(data.shape) == 4, "data must be c x y z" 52 | if seg is not None: 53 | assert len(seg.shape) == 4, "seg must be c x y z" 54 | 55 | if data is not None: 56 | shape = np.array(data[0].shape) 57 | else: 58 | shape = np.array(seg[0].shape) 59 | new_shape = np.round(((np.array(original_spacing) / np.array(target_spacing)).astype(float) * shape)).astype(int) 60 | 61 | if force_separate_z is not None: 62 | do_separate_z = force_separate_z 63 | if force_separate_z: 64 | axis = get_lowres_axis(original_spacing) 65 | else: 66 | axis = None 67 | else: 68 | if get_do_separate_z(original_spacing): 69 | do_separate_z = True 70 | axis = get_lowres_axis(original_spacing) 71 | elif get_do_separate_z(target_spacing): 72 | do_separate_z = True 73 | axis = get_lowres_axis(target_spacing) 74 | else: 75 | do_separate_z = False 76 | axis = None 77 | 78 | if data is not None: 79 | data_reshaped = resample_data_or_seg(data, new_shape, False, axis, 80 | order_data, do_separate_z, 81 | cval=cval_data, 82 | order_z=order_z_data) 83 | else: 84 | data_reshaped = None 85 | if seg is not None: 86 | seg_reshaped = resample_data_or_seg(seg, new_shape, True, axis, 87 | order_seg, do_separate_z, 88 | cval=cval_seg, 89 | order_z=order_z_seg) 90 | else: 91 | seg_reshaped = None 92 | return data_reshaped, seg_reshaped 93 | 94 | 95 | def resample_data_or_seg(data, new_shape, is_seg, axis=None, order=3, 96 | do_separate_z=False, cval=0, order_z=0): 97 | """ 98 | separate_z=True will resample with order 0 along z 99 | :param data: 100 | :param new_shape: 101 | :param is_seg: 102 | :param axis: 103 | :param order: 104 | :param do_separate_z: 105 | :param cval: 106 | :param order_z: only applies if do_separate_z is True 107 | :return: 108 | """ 109 | assert len(data.shape) == 4, "data must be (c, x, y, z)" 110 | if is_seg: 111 | resize_fn = resize_segmentation 112 | kwargs = OrderedDict() 113 | else: 114 | resize_fn = resize 115 | kwargs = {'mode': 'edge', 'anti_aliasing': False} 116 | dtype_data = data.dtype 117 | data = data.astype(float) 118 | shape = np.array(data[0].shape) 119 | new_shape = np.array(new_shape) 120 | if np.any(shape != new_shape): 121 | if do_separate_z: 122 | print("separate z") 123 | assert len(axis) == 1, "only one anisotropic axis supported" 124 | axis = axis[0] 125 | if axis == 0: 126 | new_shape_2d = new_shape[1:] 127 | elif axis == 1: 128 | new_shape_2d = new_shape[[0, 2]] 129 | else: 130 | new_shape_2d = new_shape[:-1] 131 | 132 | reshaped_final_data = [] 133 | for c in range(data.shape[0]): 134 | reshaped_data = [] 135 | for slice_id in range(shape[axis]): 136 | if axis == 0: 137 | reshaped_data.append(resize_fn(data[c, slice_id], 138 | new_shape_2d, order, 139 | cval=cval, **kwargs)) 140 | elif axis == 1: 141 | reshaped_data.append(resize_fn(data[c, :, slice_id], 142 | new_shape_2d, order, 143 | cval=cval, **kwargs)) 144 | else: 145 | reshaped_data.append(resize_fn(data[c, :, :, slice_id], 146 | new_shape_2d, order, 147 | cval=cval, **kwargs)) 148 | reshaped_data = np.stack(reshaped_data, axis) 149 | if shape[axis] != new_shape[axis]: 150 | 151 | # The following few lines are blatantly copied and modified from sklearn's resize() 152 | rows, cols, dim = new_shape[0], new_shape[1], new_shape[2] 153 | orig_rows, orig_cols, orig_dim = reshaped_data.shape 154 | 155 | row_scale = float(orig_rows) / rows 156 | col_scale = float(orig_cols) / cols 157 | dim_scale = float(orig_dim) / dim 158 | 159 | map_rows, map_cols, map_dims = np.mgrid[:rows, :cols, :dim] 160 | map_rows = row_scale * (map_rows + 0.5) - 0.5 161 | map_cols = col_scale * (map_cols + 0.5) - 0.5 162 | map_dims = dim_scale * (map_dims + 0.5) - 0.5 163 | 164 | coord_map = np.array([map_rows, map_cols, map_dims]) 165 | if not is_seg or order_z == 0: 166 | reshaped_final_data.append(map_coordinates(reshaped_data, coord_map, 167 | order=order_z, cval=cval, 168 | mode='nearest')[None]) 169 | else: 170 | unique_labels = np.unique(reshaped_data) 171 | reshaped = np.zeros(new_shape, dtype=dtype_data) 172 | 173 | for i, cl in enumerate(unique_labels): 174 | reshaped_multihot = np.round( 175 | map_coordinates((reshaped_data == cl).astype(float), 176 | coord_map, order=order_z, 177 | cval=cval, mode='nearest')) 178 | reshaped[reshaped_multihot >= 0.5] = cl 179 | reshaped_final_data.append(reshaped[None]) 180 | else: 181 | reshaped_final_data.append(reshaped_data[None]) 182 | reshaped_final_data = np.vstack(reshaped_final_data) 183 | else: 184 | reshaped = [] 185 | for c in range(data.shape[0]): 186 | reshaped.append(resize_fn(data[c], new_shape, order, 187 | cval=cval, **kwargs)[None]) 188 | reshaped_final_data = np.vstack(reshaped) 189 | return reshaped_final_data.astype(dtype_data) 190 | else: 191 | print("no resampling necessary") 192 | return data 193 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # kits19-cnn 2 | Using 2D & 3D convolutional neural networks for the [2019 Kidney and Kidney Tumor Segmentation Challenge](https://kits19.grand-challenge.org/). This repository is associated with this [conference paper (older version)](https://www.researchgate.net/publication/336247303_A_2D_U-Net_for_Automated_Kidney_and_Renal_Tumor_Segmentation). 3 | 4 | ![Label](images/label_case_00113.png) 5 | ![Prediction from 2D U-Net (ResNet34)](images/pred2d_case_00113.png) 6 | 7 | ## Disclaimer 8 | I'm not sure why the tumor scores are so low for all of the architectures, so I'm open to any suggestions and PRs! Am actively working on improving them. 9 | 10 | ## Credits 11 | Major credits to: 12 | * Isensee's [nnU-Net](https://github.com/MIC-DKFZ/nnUNet) and [batchgenerators](https://github.com/MIC-DKFZ/batchgenerators) 13 | * qubvel's [segmentation_models.pytorch](https://github.com/qubvel/segmentation_models.pytorch) 14 | * Nick Heller for hosting KiTS19! 15 | 16 | ## Torch/Catalyst Pipeline (Overview) 17 | ### Preprocessing 18 | * Resampling to 3.22 × 1.62 × 1.62 mm 19 | * Isensee's nnU-Net methodology 20 | * Clipping to the [0.5, 99.5] percentiles and applying z-score standardization 21 | 22 | ### Training 23 | * Foreground class sampling 24 | * __2D:__ Done by sampling per slice (loading only 2D arrays) 25 | * __SO MUCH FASTER THAN LOADING 3D ARRAYS__ 26 | * Difference: 3 seconds v. 5 minutes per epoch 27 | * __3D:__ Done through `ROICropTransform` 28 | * Data Augmentation 29 | * Located in `kits19cnn/experiments/utils.py` 30 | * Pay attention to the `augmentation_key`s in `get_training_augmentation` and `get_validation_augmentation` 31 | * Done through `batchgenerators` + my own custom transforms 32 | * SGD (lr=1e-4) and LRPlateau (factor=0.15 and patience=5); BCEDiceLoss 33 | * 2D: batch size = 18 (regular training) 34 | * 3D: batch size = 4 (fp16 training) 35 | 36 | ### Architectures 37 | * 2D (patch size: (256, 256)) 38 | * Vanilla 2D nnU-Net 39 | * 6 pools with convolutional downsampling and upsampling 40 | * max number of filters set to 320 and the starting number is 30 41 | * 2D U-Net with pretrained ImageNet classifiers 42 | * 2D FPN with pretrained ImageNet classifiers 43 | * 3D (patch size: (96, 160, 160)) 44 | * 3D nnU-Net 45 | * 5 pools with convolutional downsampling and upsampling 46 | * max number of filters set to 320 and the starting number is 30 47 | * 3D nnU-Net (Classification + Segmentation) 48 | 49 | ## Results 50 | 51 | 52 | 53 | 54 | 55 | 56 | 57 | 58 | 59 | 60 | 61 | 62 | 63 | 64 | 65 | 66 | 67 | 68 | 69 | 70 | 71 | 72 | 73 | 74 | 75 | 76 | 77 | 78 | 79 | 80 | 81 | 82 | 83 | 84 | 85 | 86 | 87 | 88 | 89 | 90 |
Neural NetworkParametersLocal Test (Tumor-Kidney) DiceLocal Test (Tumor Only) DiceWeights
2D nnU-Net12M0.900.26...
3D nnU-Net29.6M0.860.22...
ResNet34 + U-Net Decoder24M0.900.29...
ResNet34 + FPN Decoder22M0.830.29...
91 | 92 | 93 | ## How to Use 94 | 95 | ### Downloading the Dataset 96 | The recommended way is to just follow the instructions on the [original kits19 Github challenge page](https://github.com/neheller/kits19), which utilizes `git lfs`. 97 | Here is a brief run-down for Google Colaboratory: 98 | ``` 99 | ! sudo add-apt-repository ppa:git-core/ppa 100 | ! curl -s https://packagecloud.io/install/repositories/github/git-lfs/script.deb.sh | sudo bash 101 | ! sudo apt-get install git-lfs 102 | ! git lfs install 103 | % cd "/content/" 104 | ! rm -r kits19 105 | ! git clone https://github.com/neheller/kits19.git 106 | # takes roughly 11 minutes to download 107 | ``` 108 | 109 | ### Preprocessing 110 | To do general preprocessing (resampling): 111 | ``` 112 | # preprocessing 113 | from kits19cnn.io.preprocess import Preprocessor 114 | base_dir = "/content/kits19/data" 115 | out_dir = "/content/kits_preprocessed" 116 | 117 | preprocess = Preprocessor(base_dir, out_dir) 118 | preprocess.cases = sorted(preprocess.cases)[:210] 119 | %time preprocess.gen_data() 120 | ``` 121 | Note that the standardization and clipping is done on-the-fly. 122 | 123 | If you want to do __2D segmentation__: 124 | ``` 125 | # preprocessing 126 | from kits19cnn.io.preprocess import Preprocessor 127 | out_dir = "/content/kits_preprocessed" 128 | 129 | preprocess = Preprocessor(out_dir, out_dir, with_mask=True) 130 | preprocess.cases = sorted(preprocess.cases)[:210] 131 | preprocess.save_dir_as_2d() 132 | ``` 133 | If you want to do __binary 2D segmentation__ (kidney only or renal tumor only). 134 | ``` 135 | import os 136 | from kits19cnn.experiments.utils import parse_fg_slice_dict_single_class 137 | preprocessed_dir = "/content/kits_preprocessed" 138 | 139 | json_path = os.path.join(preprocessed_dir, "slice_indices.json") 140 | out_path = os.path.join(preprocessed_dir, "slice_indices_tu_only.json") 141 | 142 | _ = parse_fg_slice_dict_single_class(json_path, out_path, removed_fg_idx="1") 143 | out_path = os.path.join(preprocessed_dir, "slice_indices_kidney_only.json") 144 | _ = parse_fg_slice_dict_single_class(json_path, out_path, removed_fg_idx="2") 145 | ``` 146 | 147 | ### Training 148 | Please see the example yaml file at `script_configs/train.yml`. Works for 2D, 2.5D, 149 | and 3D. Also, supports binary 2D segmentation if you change the `slice_indices_path`. 150 | Also, supports classification + segmentation for nnU-Net (doesn't work that well). 151 | ``` 152 | python /content/kits19-cnn/scripts/train_yaml.py --yml_path="/content/kits19-cnn/script_configs/train.yml" 153 | ``` 154 | __TensorBoard__: Catalyst automatically supports tensorboard logging, so just run this in Colaboratory: 155 | ``` 156 | # Load the TensorBoard notebook extension 157 | %load_ext tensorboard 158 | # Run this before training 159 | %tensorboard --logdir logs 160 | ``` 161 | __For Plotting Support (plotly/orca) [OPTIONAL]:__ 162 | The regular training script (`script_configs/train.yml`) doesn't plot the graphs 163 | directly, but saves them as .png files. If you don't want to do all of this installing, just exclude `plot_params` in `scripts/train_yaml.py` 164 | ``` 165 | # on colab 166 | 167 | # installing anaconda and plotly with orca + dependencies 168 | !wget https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh 169 | !chmod +x Miniconda3-latest-Linux-x86_64.sh 170 | !bash ./Miniconda3-latest-Linux-x86_64.sh -b -f -p /usr/local 171 | # !conda install -c plotly plotly-orca 172 | !conda install -c plotly plotly-orca psutil requests ipykernel 173 | !export PYTHONPATH="${PYTHONPATH}:/usr/local/lib/python3.7/site-packages/" 174 | !pip install nbformat 175 | 176 | # orca with xvfb support (so orca can save the graphs) 177 | # Plotly depedencies 178 | !apt-get install -y --no-install-recommends \ 179 | wget \ 180 | xvfb \ 181 | libgtk2.0-0 \ 182 | libxtst6 \ 183 | libxss1 \ 184 | libgconf-2-4 \ 185 | libnss3 \ 186 | libasound2 && \ 187 | mkdir -p /home/orca && \ 188 | cd /home/orca && \ 189 | wget https://github.com/plotly/orca/releases/download/v1.2.1/orca-1.2.1-x86_64.AppImage && \ 190 | chmod +x orca-1.2.1-x86_64.AppImage && \ 191 | ./orca-1.2.1-x86_64.AppImage --appimage-extract && \ 192 | printf '#!/bin/bash \nxvfb-run --auto-servernum --server-args "-screen 0 640x480x24" /home/orca/squashfs-root/app/orca "$@"' > /usr/bin/orca && \ 193 | chmod +x /usr/bin/orca 194 | 195 | # enabling xvfb 196 | import plotly.io as pio 197 | pio.orca.config.use_xvfb = True 198 | pio.orca.config.save() 199 | ``` 200 | 201 | ### Inference 202 | Please see the example yaml file at `script_configs/pred.yml`. There's a tumor-only 203 | example in `script_configs/infer_tu_only/pred.yml`. 204 | ``` 205 | # kidney-tumor 206 | python /content/kits19-cnn/scripts/predict.py --yml_path="/content/kits19-cnn/script_configs/pred.yml" 207 | # tumor only 208 | python /content/kits19-cnn/scripts/predict.py --yml_path="/content/kits19-cnn/script_configs/infer_tu_only/pred.yml" 209 | ``` 210 | 211 | ### Evaluation 212 | Please see the example yaml file at `script_configs/eval.yml`. There's a tumor-only 213 | example in `script_configs/infer_tu_only/eval.yml`. 214 | ``` 215 | # kidney-tumor 216 | python /content/kits19-cnn/scripts/evaluate.py --yml_path="/content/kits19-cnn/script_configs/eval.yml" 217 | # tumor only 218 | python /content/kits19-cnn/scripts/evaluate.py --yml_path="/content/kits19-cnn/script_configs/infer_tu_only/eval.yml" 219 | ``` 220 | 221 | ### Submission 222 | Currently, only on the `preprocess-test-set` branch. 223 | -------------------------------------------------------------------------------- /kits19cnn/experiments/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | import json 4 | import numpy as np 5 | import batchgenerators.transforms as bg 6 | import torch 7 | from copy import deepcopy 8 | 9 | from kits19cnn.io import ROICropTransform, RepeatChannelsTransform, \ 10 | MultiClassToBinaryTransform, RandomResizedCropTransform 11 | 12 | bgut = bg.utility_transforms 13 | bgct = bg.color_transforms 14 | bgsnt = bg.sample_normalization_transforms 15 | 16 | def get_training_augmentation(augmentation_key="aug1"): 17 | default_angle = (-15. / 360 * 2. * np.pi, 15. / 360 * 2. * np.pi) 18 | aug1_spatial_kwargs = {"patch_size": (80, 160, 160), 19 | "patch_center_dist_from_border": (30, 30, 30), 20 | "do_elastic_deform": True, 21 | "alpha": (0., 900.), 22 | "sigma": (9., 13.), 23 | "do_rotation": True, 24 | "angle_x": default_angle, 25 | "angle_y": default_angle, 26 | "angle_z": default_angle, 27 | "do_scale": True, 28 | "scale": (0.85, 1.25), 29 | "border_mode_data": "constant", 30 | "order_data": 3, 31 | "random_crop": True, 32 | "p_el_per_sample": 0.2, 33 | "p_scale_per_sample": 0.2, 34 | "p_rot_per_sample": 0.2 35 | } 36 | transform_dict = { 37 | "aug1": [ 38 | bg.SpatialTransform(**aug1_spatial_kwargs), 39 | bg.MirrorTransform(axes=(0, 1, 2)), 40 | bg.GammaTransform(gamma_range=(0.7, 1.5), 41 | invert_image=False, 42 | per_channel=True, 43 | retain_stats=True, 44 | p_per_sample=0.3), 45 | ], 46 | } 47 | # aug2 48 | aug2_spatial_kwargs = deepcopy(aug1_spatial_kwargs) 49 | aug2_spatial_kwargs["patch_size"] = (96, 160, 160) 50 | transform_dict["aug2"] = [bg.SpatialTransform(**aug2_spatial_kwargs),] \ 51 | + transform_dict["aug1"][1:] 52 | 53 | # aug3 54 | aug3_spatial_kwargs = deepcopy(aug2_spatial_kwargs) 55 | aug3_spatial_kwargs["random_crop"] = False 56 | new_transforms = [bg.SpatialTransform(**aug3_spatial_kwargs), 57 | bg.BrightnessTransform(mu=101, sigma=76.9, 58 | p_per_sample=0.3),] 59 | # spatial, mirror, gamma, brightness 60 | transform_dict["aug3"] = [new_transforms[0]] + transform_dict["aug1"][1:] \ 61 | + [new_transforms[1]] 62 | 63 | # aug4 64 | # roicrop, spatial, mirror, gamma, brightness 65 | transform_dict["aug4"] = [ROICropTransform(crop_size=(96, 160, 160)),] + \ 66 | transform_dict["aug3"] 67 | 68 | # aug5 69 | # roicrop, spatial, mirror, gamma, brightness 70 | aug5_spatial_kwargs = deepcopy(aug3_spatial_kwargs) 71 | aug5_spatial_kwargs["patch_center_dist_from_border"] = None 72 | aug5_spatial_kwargs["border_cval_data"] = 0 73 | new_transforms = [ROICropTransform(crop_size=(96, 160, 160)), 74 | bg.SpatialTransform(**aug5_spatial_kwargs),] 75 | # RemoveLabelTransform added to preprocessing 76 | transform_dict["aug5"] = new_transforms + transform_dict["aug4"][2:] 77 | # 2D Transforms 78 | # spatial, mirror, gamma, brightness 79 | aug6_spatial_kwargs = deepcopy(aug3_spatial_kwargs) 80 | aug6_spatial_kwargs["patch_size"] = (192, 192) 81 | transforms_2d = [bg.SpatialTransform(**aug6_spatial_kwargs), 82 | bg.MirrorTransform(axes=(0, 1)),] 83 | transform_dict["aug6"] = transforms_2d + transform_dict["aug3"][2:] 84 | 85 | # spatial, mirror, gamma, brightness 86 | aug7_spatial_kwargs = deepcopy(aug3_spatial_kwargs) 87 | aug7_spatial_kwargs["patch_size"] = (256, 256) 88 | transforms_2d = [bg.SpatialTransform(**aug7_spatial_kwargs), 89 | bg.MirrorTransform(axes=(0, 1)),] 90 | transform_dict["aug7"] = transforms_2d + transform_dict["aug3"][2:] 91 | 92 | # spatial, mirror, gamma, brightness, removelabel 93 | tu_only = [MultiClassToBinaryTransform(roi_label="2", remove_label="1")] 94 | transform_dict["tu_only2d"] = transform_dict["aug7"] + tu_only 95 | 96 | tu_only2d2_spatial_kwargs = deepcopy(aug7_spatial_kwargs) 97 | tu_only2d2_spatial_kwargs["p_scale_per_sample"] = 0.5 98 | tu_only2d2_spatial_kwargs["p_rot_per_sample"] = 0.5 99 | 100 | new_t = [bg.SpatialTransform(**tu_only2d2_spatial_kwargs), 101 | bg.BrightnessTransform(mu=101, sigma=76.9, p_per_sample=0.5)] 102 | transform_dict["tu_only2d2"] = [new_t[0]] + transform_dict["tu_only2d"][1:-2] \ 103 | + [new_t[1]] + [transform_dict["tu_only2d"][-1]] 104 | # adding RandomResizedCropTransform to the end 105 | # spatial, mirror, gamma, brightness, random_resized_crop, removelabel 106 | rand_resized_t = RandomResizedCropTransform(target_size=(256, 256), 107 | p_per_sample=0.33) 108 | transform_dict["tu_only2d3"] = deepcopy(transform_dict["tu_only2d2"]) 109 | transform_dict["tu_only2d3"].insert(-2, rand_resized_t) 110 | 111 | train_transform = transform_dict[augmentation_key] 112 | print(f"Train Transforms: {train_transform}") 113 | return bg.Compose(train_transform) 114 | 115 | def get_validation_augmentation(augmentation_key): 116 | """ 117 | Validation data augmentations. Usually, just cropping. 118 | """ 119 | transform_dict = { 120 | "aug1": [ 121 | bg.RandomCropTransform(crop_size=(80, 160, 160)) 122 | ], 123 | "aug2": [ 124 | bg.RandomCropTransform(crop_size=(96, 160, 160)) 125 | ], 126 | "aug3": [ 127 | bg.RandomCropTransform(crop_size=(96, 160, 160)) 128 | ], 129 | "aug4": [ 130 | bg.RandomCropTransform(crop_size=(96, 160, 160)) 131 | ], 132 | "aug5": [ 133 | ROICropTransform(crop_size=(96, 160, 160)) 134 | ], 135 | "aug6": [ 136 | bg.RandomCropTransform(crop_size=(192, 192)) 137 | ], 138 | "aug7": [ 139 | bg.CenterCropTransform(crop_size=(256, 256)) 140 | ], 141 | "tu_only2d": [ 142 | bg.CenterCropTransform(crop_size=(256, 256)), 143 | bgut.RemoveLabelTransform(1, 0) 144 | ], 145 | "tu_only2d2": [ 146 | bg.CenterCropTransform(crop_size=(256, 256)), 147 | MultiClassToBinaryTransform(roi_label="2", remove_label="1"), 148 | ], 149 | "tu_only2d3": [ 150 | bg.CenterCropTransform(crop_size=(256, 256)), 151 | MultiClassToBinaryTransform(roi_label="2", remove_label="1"), 152 | ], 153 | } 154 | test_transform = transform_dict[augmentation_key] 155 | print(f"\nTest/Validation Transforms: {test_transform}") 156 | return bg.Compose(test_transform) 157 | 158 | def get_preprocessing(rgb: bool = False): 159 | """ 160 | Construct preprocessing transform 161 | 162 | Args: 163 | rgb (bool): Whether or not to return the input with three channels 164 | or just single (grayscale) 165 | Return: 166 | transform: albumentations.Compose 167 | """ 168 | _transform = [ 169 | bgct.ClipValueRange(min=-79, max=304), 170 | bgsnt.MeanStdNormalizationTransform(mean=101, std=76.9, 171 | per_channel=False), 172 | bgut.RemoveLabelTransform(-1, 0), 173 | bg.NumpyToTensor(), 174 | ] 175 | if rgb: 176 | # insert right before converting to a torch tensor 177 | _transform.insert(-2, RepeatChannelsTransform(num_repeats=3)) 178 | print(f"\nPreprocessing Transforms: {_transform}") 179 | return bg.Compose(_transform) 180 | 181 | def parse_fg_slice_dict_single_class(json_path, out_path, removed_fg_idx="1"): 182 | """ 183 | Reads the foreground (positive) class slice dictionary and creates a new 184 | dictionary with `removed_fg_idx` removed. 185 | Args: 186 | json_path (str): json should be 'slice_indices.json' generated by 187 | `io.Preprocessor` (sub dicts contain keys for list of slice indices 188 | for each foreground class) 189 | out_path (str): path to the json to save the new dictionary with 190 | the `removed_fg_idx` key removed 191 | removed_fg_idx (str): string key to remove ('1' or '2' in this case) 192 | Returns: 193 | the changed slice_dict 194 | """ 195 | # reading json 196 | with open(json_path, "r") as fp: 197 | slice_dict = json.load(fp) 198 | cases = list(slice_dict.keys()) 199 | # vv assumes same for all cases (which is true) 200 | sub_dict_keys = list(slice_dict[cases[0]]) 201 | print(f"{len(cases)} Cases; Case sub-dict keys: {sub_dict_keys}") 202 | # removing all sub dicts with the `removed_fg_idx` key 203 | print(f"Removing idx: {removed_fg_idx}") 204 | for case in cases: 205 | case_dict = slice_dict[case] 206 | case_dict.pop(removed_fg_idx) 207 | sub_dict_keys = list(slice_dict[cases[0]]) 208 | print(f"New case sub-dict keys: {sub_dict_keys}") 209 | # saving dict 210 | with open(out_path, "w") as fp: 211 | json.dump(slice_dict, fp) 212 | print(f"Saved at {out_path}.") 213 | return slice_dict 214 | 215 | def seed_everything(seed=42): 216 | random.seed(seed) 217 | os.environ["PYTHONHASHSEED"] = str(seed) 218 | np.random.seed(seed) 219 | torch.manual_seed(seed) 220 | torch.cuda.manual_seed_all(seed) 221 | torch.backends.cudnn.benchmark = False ##uses the inbuilt cudnn auto-tuner to find the fastest convolution algorithms. - 222 | torch.backends.cudnn.enabled = True 223 | torch.backends.cudnn.deterministic = True 224 | -------------------------------------------------------------------------------- /kits19cnn/io/preprocess.py: -------------------------------------------------------------------------------- 1 | import os 2 | from os.path import join, isdir 3 | from pathlib import Path 4 | from collections import defaultdict 5 | from tqdm import tqdm 6 | import nibabel as nib 7 | import numpy as np 8 | import json 9 | 10 | from kits19cnn.io.resample import resample_patient 11 | 12 | class Preprocessor(object): 13 | """ 14 | Preprocesses the original dataset (interpolated). 15 | Procedures: 16 | * clipping (ROI) 17 | * save as .npy array 18 | * imaging.npy 19 | * segmentation.npy (if with_masks) 20 | * resampling from `orig_spacing` to `target_spacing` 21 | currently uses spacing reported in the #1 solution 22 | """ 23 | def __init__(self, in_dir, out_dir, cases=None, kits_json_path=None, 24 | target_spacing=(3.22, 1.62, 1.62), 25 | clip_values=None, with_mask=False, fg_classes=[1, 2]): 26 | """ 27 | Attributes: 28 | in_dir (str): directory with the input data. Should be the 29 | kits19/data directory. 30 | out_dir (str): output directory where you want to save each case 31 | cases: list of case folders to preprocess 32 | kits_json_path (str): path to the kits.json file in the kits19/data 33 | directory. This only should be specfied if you're resampling. 34 | Defaults to None. 35 | target_spacing (list/tuple): spacing to resample to 36 | clip_values (list, tuple): values you want to clip CT scans to. 37 | Defaults to None for no clipping. 38 | with_mask (bool): whether or not to preprocess with masks or no 39 | masks. Applicable to preprocessing test set (no labels 40 | available). 41 | fg_classes (list): of foreground class indices 42 | """ 43 | self.in_dir = in_dir 44 | self.out_dir = out_dir 45 | 46 | self._load_kits_json(kits_json_path) 47 | self.clip_values = clip_values 48 | self.target_spacing = np.array(target_spacing) 49 | self.with_mask = with_mask 50 | self.fg_classes = fg_classes 51 | self.cases = cases 52 | # automatically collecting all of the case folder names 53 | if self.cases is None: 54 | self.cases = [os.path.join(self.in_dir, case) \ 55 | for case in os.listdir(self.in_dir) \ 56 | if case.startswith("case")] 57 | self.cases = sorted(self.cases) 58 | assert len(self.cases) > 0, \ 59 | "Please make sure that in_dir refers to the proper directory." 60 | # making directory if out_dir doesn't exist 61 | if not isdir(out_dir): 62 | os.mkdir(out_dir) 63 | print("Created directory: {0}".format(out_dir)) 64 | 65 | def gen_data(self): 66 | """ 67 | Generates and saves preprocessed data 68 | Args: 69 | task_path: file path to the task directory (must have the corresponding "dataset.json" in it) 70 | Returns: 71 | preprocessed input image and mask 72 | """ 73 | # Generating data and saving them recursively 74 | for case in tqdm(self.cases): 75 | x_path, y_path = join(case, "imaging.nii.gz"), join(case, "segmentation.nii.gz") 76 | image = nib.load(x_path).get_fdata()[None] 77 | label = nib.load(y_path).get_fdata()[None] if self.with_mask \ 78 | else None 79 | preprocessed_img, preprocessed_label = self.preprocess(image, 80 | label, 81 | case) 82 | 83 | self.save_imgs(preprocessed_img, preprocessed_label, case) 84 | 85 | def preprocess(self, image, mask, case=None): 86 | """ 87 | Clipping, cropping, and resampling. 88 | Args: 89 | image: numpy array 90 | mask: numpy array or None 91 | case (str): path to a case folder 92 | Returns: 93 | tuple of: 94 | - preprocessed image 95 | - preprocessed mask or None 96 | """ 97 | raw_case = Path(case).name # raw case name, i.e. case_00000 98 | if self.target_spacing is not None: 99 | for info_dict in self.kits_json: 100 | # guaranteeing that the info is corresponding to the right 101 | # case 102 | if info_dict["case_id"] == raw_case: 103 | case_info_dict = info_dict 104 | break 105 | orig_spacing = (case_info_dict["captured_slice_thickness"], 106 | case_info_dict["captured_pixel_width"], 107 | case_info_dict["captured_pixel_width"]) 108 | image, mask = resample_patient(image, mask, np.array(orig_spacing), 109 | target_spacing=self.target_spacing) 110 | if self.clip_values is not None: 111 | image = np.clip(image, self.clip_values[0], self.clip_values[1]) 112 | 113 | mask = mask[None] if mask is not None else mask 114 | return (image[None], mask) 115 | 116 | def save_imgs(self, image, mask, case): 117 | """ 118 | Saves an image and mask pair as .npy arrays in the KiTS19 file structure 119 | Args: 120 | image: numpy array 121 | mask: numpy array 122 | case: path to a case folder (each element of self.cases) 123 | """ 124 | # saving the generated dataset 125 | # output dir in KiTS19 format 126 | # extracting the raw case folder name 127 | case = Path(case).name 128 | out_case_dir = join(self.out_dir, case) 129 | # checking to make sure that the output directories exist 130 | if not isdir(out_case_dir): 131 | os.mkdir(out_case_dir) 132 | 133 | np.save(os.path.join(out_case_dir, "imaging.npy"), image) 134 | if mask is not None: 135 | np.save(os.path.join(out_case_dir, "segmentation.npy"), mask) 136 | 137 | def save_dir_as_2d(self): 138 | """ 139 | Takes preprocessed 3D numpy arrays and saves them as slices 140 | in the same directory. 141 | """ 142 | self.pos_slice_dict = {} 143 | # Generating data and saving them recursively 144 | for case in tqdm(self.cases): 145 | # assumes the .npy files have shape: (n_channels, d, h, w) 146 | image = np.load(join(case, "imaging.npy")) 147 | label = np.load(join(case, "segmentation.npy")) 148 | image = image.squeeze(axis=0) if len(image.shape)==5 else image 149 | label = label.squeeze(axis=0) if len(label.shape)==5 else label 150 | 151 | self.save_3d_as_2d(image, label, case) 152 | self._save_pos_slice_dict() 153 | 154 | def save_3d_as_2d(self, image, mask, case): 155 | """ 156 | Saves an image and mask pair as .npy arrays in the 157 | KiTS19 file structure 158 | Args: 159 | image: numpy array 160 | mask: numpy array 161 | case: path to a case folder (each element of self.cases) 162 | """ 163 | # saving the generated dataset 164 | # output dir in KiTS19 format 165 | # extracting the raw case folder name 166 | case = Path(case).name 167 | out_case_dir = join(self.out_dir, case) 168 | # checking to make sure that the output directories exist 169 | if not isdir(out_case_dir): 170 | os.mkdir(out_case_dir) 171 | 172 | # iterates through all slices and saves them individually as 2D arrays 173 | fg_indices = defaultdict(list) 174 | if mask.shape[1] <= 1: 175 | print("WARNING: Please double check your mask shape;", 176 | f"Masks have shape {mask.shape} when it should be", 177 | "shape (n_channels, d, h, w)") 178 | raise Exception("Please fix shapes.") 179 | for slice_idx in range(mask.shape[1]): 180 | label_slice = mask[:, slice_idx] 181 | # appending fg slice indices 182 | for idx in self.fg_classes: 183 | if (label_slice == idx).any(): 184 | fg_indices[idx].append(slice_idx) 185 | # naming convention: {type of slice}_{case}_{slice_idx} 186 | slice_idx_str = str(slice_idx) 187 | # adding 0s to slice_idx until it reaches 3 digits, 188 | # so sorting files is easier when stacking 189 | while len(slice_idx_str) < 3: 190 | slice_idx_str = "0"+slice_idx_str 191 | np.save(join(out_case_dir, f"imaging_{slice_idx_str}.npy"), 192 | image[:, slice_idx]) 193 | np.save(join(out_case_dir, f"segmentation_{slice_idx_str}.npy"), 194 | label_slice) 195 | # {case1: [idx1, idx2,...], case2: ...} 196 | self.pos_slice_dict[case] = fg_indices 197 | 198 | def _save_pos_slice_dict(self): 199 | """ 200 | Saves the foreground (positive) class dictionaries: 201 | - slice_indices.json 202 | saves the slice indices per class 203 | { 204 | case: {fg_class1: [slice indices...], 205 | fg_class2: [slice indices...], 206 | ...} 207 | } 208 | - slice_indices_general.json 209 | saves the slice indices for all foreground classes into a 210 | single list 211 | {case: [slice indices...],} 212 | """ 213 | # converting pos_slice_dict to general_slice_dict 214 | general_slice_dict = defaultdict(list) 215 | for case, slice_idx_dict in self.pos_slice_dict.items(): 216 | for slice_idx_list in list(slice_idx_dict.values()): 217 | for slice_idx in slice_idx_list: 218 | general_slice_dict[case].append(slice_idx) 219 | 220 | save_path = join(self.out_dir, "slice_indices.json") 221 | save_path_general = join(self.out_dir, "slice_indices_general.json") 222 | # saving the dictionaries 223 | print(f"Logged the slice indices for each class in {self.fg_classes} at" 224 | f"{save_path}.") 225 | with open(save_path, "w") as fp: 226 | json.dump(self.pos_slice_dict, fp) 227 | print("Logged slice indices for all fg classes instead of for each", 228 | f"class separately at {save_path_general}.") 229 | with open(save_path_general, "w") as fp: 230 | json.dump(general_slice_dict, fp) 231 | 232 | def _load_kits_json(self, json_path): 233 | """ 234 | Loads the kits.json file into `self.kits_json` 235 | """ 236 | if json_path is None: 237 | print("`kits_json_path is empty, so not resampling.`") 238 | elif json_path is not None: 239 | with open(json_path, "r") as fp: 240 | self.kits_json = json.load(fp) 241 | -------------------------------------------------------------------------------- /kits19cnn/io/custom_augmentations.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from batchgenerators.augmentations.utils import resize_segmentation, \ 3 | resize_multichannel_image 4 | from batchgenerators.augmentations.crop_and_pad_augmentations import get_lbs_for_center_crop, \ 5 | get_lbs_for_random_crop 6 | 7 | def get_bbox_coords_fg(mask, fg_classes=[1, 2]): 8 | """ 9 | Creates bounding box coordinates for foreground 10 | Arg: 11 | mask (np.ndarray): shape (x, y, z) 12 | fg_classes (list-like/arr-like): foreground classes to sample from 13 | (sampling is done uniformly across all classes) 14 | Returns: 15 | coords (list): [[x_min, x_max], [y_min, y_max], [z_min, z_max]] 16 | """ 17 | # squeeze to remove the channels dim if necessary 18 | if len(mask.shape) > 3: 19 | mask = mask.squeeze(axis=0) 20 | if fg_classes is None: 21 | classes = np.unique(mask) 22 | sampled_fg_class = np.random.choice(classes[np.where(classes > 0)]) 23 | else: 24 | sampled_fg_class = np.random.choice(fg_classes) 25 | all_coords = np.where(mask == sampled_fg_class) 26 | min_, max_ = np.min(all_coords, axis=1), np.max(all_coords, axis=1)+1 27 | coords = list(zip(min_, max_)) 28 | return coords 29 | 30 | def get_lbs_from_bbox(coords): 31 | """ 32 | Args: 33 | coords (list/tuple of lists): bbox coords 34 | i.e. 2D: [[10, 100], [50, 76]] 35 | 3D: [[10, 100], [50, 76], [50, 76]] 36 | Returns: 37 | lb: coordinates for cropping 38 | """ 39 | lb = [] 40 | for dim_range in coords: 41 | lb.append(np.random.randint(dim_range[0], dim_range[1])) 42 | return lb 43 | 44 | def crop(data, seg=None, crop_size=128, margins=(0, 0, 0), crop_type="center", 45 | pad_mode="constant", pad_kwargs={"constant_values": 0}, 46 | pad_mode_seg="constant", pad_kwargs_seg={"constant_values": 0}, 47 | bbox_coords=None): 48 | """ 49 | crops data and seg (seg may be None) to crop_size. Whether this will be 50 | achieved via center or random crop is determined by crop_type. 51 | Margin will be respected only for random_crop and will prevent the crops 52 | form being closer than margin to the respective image border. crop_size 53 | can be larger than data_shape - margin -> data/seg will be padded with 54 | zeros in that case. margins can be negative -> results in padding of 55 | data/seg followed by cropping with margin=0 for the appropriate axes 56 | :param data: b, c, x, y(, z) 57 | :param seg: 58 | :param crop_size: 59 | :param margins: distance from each border, can be int or list/tuple of ints (one element for each dimension). 60 | Can be negative (data/seg will be padded if needed) 61 | :param crop_type: random or center 62 | :param bbox_coords: from get_bbox_coords_fg. Defaults to None. 63 | (Gets the bounding box coordinates on-the-fly if None) 64 | :return: 65 | """ 66 | if not isinstance(data, (list, tuple, np.ndarray)): 67 | raise TypeError("data has to be either a numpy array or a list") 68 | 69 | data_shape = tuple([len(data)] + list(data[0].shape)) 70 | data_dtype = data[0].dtype 71 | dim = len(data_shape) - 2 72 | 73 | if seg is not None: 74 | seg_shape = tuple([len(seg)] + list(seg[0].shape)) 75 | seg_dtype = seg[0].dtype 76 | 77 | if not isinstance(seg, (list, tuple, np.ndarray)): 78 | raise TypeError("data has to be either a numpy array or a list") 79 | 80 | assert all([i == j for i, j in zip(seg_shape[2:], data_shape[2:])]), "data and seg must have the same spatial " \ 81 | "dimensions. Data: %s, seg: %s" % \ 82 | (str(data_shape), str(seg_shape)) 83 | 84 | if type(crop_size) not in (tuple, list, np.ndarray): 85 | crop_size = [crop_size] * dim 86 | else: 87 | assert len(crop_size) == len( 88 | data_shape) - 2, "If you provide a list/tuple as center crop make sure it has the same dimension as your " \ 89 | "data (2d/3d)" 90 | 91 | if not isinstance(margins, (np.ndarray, tuple, list)): 92 | margins = [margins] * dim 93 | 94 | data_return = np.zeros([data_shape[0], data_shape[1]] + list(crop_size), dtype=data_dtype) 95 | if seg is not None: 96 | seg_return = np.zeros([seg_shape[0], seg_shape[1]] + list(crop_size), dtype=seg_dtype) 97 | else: 98 | seg_return = None 99 | 100 | for b in range(data_shape[0]): 101 | data_shape_here = [data_shape[0]] + list(data[b].shape) 102 | if seg is not None: 103 | seg_shape_here = [seg_shape[0]] + list(seg[b].shape) 104 | 105 | if crop_type == "center": 106 | lbs = get_lbs_for_center_crop(crop_size, data_shape_here) 107 | elif crop_type == "random": 108 | lbs = get_lbs_for_random_crop(crop_size, data_shape_here, margins) 109 | elif crop_type == "roi": 110 | if bbox_coords is None: 111 | bbox_coords = get_bbox_coords_fg(seg[b]) 112 | lbs = get_lbs_from_bbox(bbox_coords) 113 | else: 114 | raise NotImplementedError("crop_type must be either center, roi, or random") 115 | 116 | need_to_pad = [[0, 0]] + [[abs(min(0, lbs[d])), 117 | abs(min(0, data_shape_here[d + 2] - (lbs[d] + crop_size[d])))] 118 | for d in range(dim)] 119 | 120 | # we should crop first, then pad -> reduces i/o for memmaps, reduces RAM usage and improves speed 121 | ubs = [min(lbs[d] + crop_size[d], data_shape_here[d+2]) for d in range(dim)] 122 | lbs = [max(0, lbs[d]) for d in range(dim)] 123 | 124 | slicer_data = [slice(0, data_shape_here[1])] + [slice(lbs[d], ubs[d]) for d in range(dim)] 125 | data_cropped = data[b][tuple(slicer_data)] 126 | 127 | if seg_return is not None: 128 | slicer_seg = [slice(0, seg_shape_here[1])] + [slice(lbs[d], ubs[d]) for d in range(dim)] 129 | seg_cropped = seg[b][tuple(slicer_seg)] 130 | 131 | if any([i > 0 for j in need_to_pad for i in j]): 132 | data_return[b] = np.pad(data_cropped, need_to_pad, pad_mode, **pad_kwargs) 133 | if seg_return is not None: 134 | seg_return[b] = np.pad(seg_cropped, need_to_pad, pad_mode_seg, **pad_kwargs_seg) 135 | else: 136 | data_return[b] = data_cropped 137 | if seg_return is not None: 138 | seg_return[b] = seg_cropped 139 | 140 | return data_return, seg_return 141 | 142 | def foreground_crop(data, seg=None, patch_size=128, margins=0, 143 | bbox_coords=None, crop_kwargs={}): 144 | """ 145 | Crops a region around the foreground 146 | Args: 147 | data (np.ndarray): shape (b, c, x, y(, z)) 148 | seg (np.ndarray or None): same shape as data 149 | patch_size (int or list-like): (x, y(,z)) 150 | margins (int or list-like): distance from border for each dimension 151 | i.e. =0 -> (0, 0, 0) 152 | bbox_coords (list-like): min/max for each dimension in (x, y, z) 153 | """ 154 | data_shape = tuple([len(data)] + list(data[0].shape)) 155 | dim = len(data_shape) - 2 156 | assert dim == 3, \ 157 | "Currently, only works for 3D training." 158 | if isinstance(patch_size, int): 159 | patch_size = dim * [patch_size] 160 | if isinstance(margins, int): 161 | margins = dim * [margins] 162 | # centering the crop 163 | margins = [margins[d] - patch_size[d] // 2 for d in range(dim)] 164 | reject = True 165 | while reject: 166 | cropped = crop(data, seg, patch_size, margins=margins, 167 | crop_type="roi", bbox_coords=bbox_coords, 168 | **crop_kwargs) 169 | if np.sum(cropped[1]) > 0: 170 | reject = False 171 | return cropped 172 | 173 | def center_crop(data, crop_size, seg=None, crop_kwargs={}): 174 | """ 175 | same as: 176 | batchgenerators.augmentations.crop_and_pad_augmentations.center_crop, but 177 | now can specify crop kwargs bc I need the constant value to be (-1). 178 | """ 179 | return crop(data, seg, crop_size, margins=0, crop_type="center", 180 | **crop_kwargs) 181 | 182 | def resize_data_and_seg(data, size, seg=None, order_data=3, 183 | order_seg=1, cval_seg=0): 184 | """ 185 | Args: 186 | data (np.ndarray): shape (b, c, h, w (, d)) 187 | seg (np.ndarray): shape (b, c, h, w (, d)). Defaults to None. 188 | size (list/tuple of int): size to resize to 189 | does not include the batch size or number of channels 190 | order_data (int): interpolation order for data 191 | (see skimage.transform.resize) 192 | order_seg (int): interpolation order for seg 193 | (see skimage.transform.resize) 194 | """ 195 | target_data = np.ones(list(data.shape[:2]) + size) 196 | if seg is not None: 197 | target_seg = np.ones(list(seg.shape[:2]) + size) 198 | else: 199 | target_seg = None 200 | 201 | for b in range(len(data)): 202 | target_data[b] = resize_multichannel_image(data[b], size, order_data) 203 | if seg is not None: 204 | for c in range(seg.shape[1]): 205 | target_seg[b, c] = resize_segmentation(seg[b, c], size, 206 | order_seg, cval_seg) 207 | return target_data, target_seg 208 | 209 | def random_resized_crop(data, seg=None, target_size=128, crop_size=64, 210 | crop_kwargs={}, resize_kwargs={}): 211 | """ 212 | Crops to `crop_size` and then resizes the result to `target_size` 213 | Args: 214 | data (np.ndarray): shape (b, c, h, w (, d)) 215 | seg (np.ndarray): shape (b, c, h, w (, d)) 216 | target_size (int/(list/tuple of int)): size to resize to 217 | does not include the batch size or number of channels 218 | crop_size (int/(list/tuple of int)): initial crop size 219 | does not include the batch size or number of channels 220 | crop_kwargs (dict): for custom_augmentations.crop 221 | resize_kwargs (dict): for custom_augmentations.resize_data_and_seg 222 | order_data (int): interpolation order for data 223 | (see skimage.transform.resize) 224 | Defaults to 3. 225 | order_seg (int): interpolation order for seg 226 | (see skimage.transform.resize) 227 | Defaults to 1. 228 | cval_seg (int): fill value for segmentation 229 | Defaults to 0. 230 | """ 231 | dimensionality = len(data.shape) - 2 232 | if not isinstance(target_size, (list, tuple)): 233 | target_size_here = [target_size] * dimensionality 234 | else: 235 | assert len(target_size) == dimensionality, \ 236 | "If you give a tuple/list as target size, make sure it has " \ 237 | "the same dimensionality as data!" 238 | target_size_here = list(target_size) 239 | data, seg = crop(data, seg, crop_size=crop_size, crop_type="random", 240 | **crop_kwargs) 241 | data, seg = resize_data_and_seg(data, size=target_size_here, seg=seg, 242 | **resize_kwargs) 243 | return data, seg 244 | -------------------------------------------------------------------------------- /kits19cnn/io/dataset_2d.py: -------------------------------------------------------------------------------- 1 | import os 2 | from os.path import join 3 | from pathlib import Path 4 | from glob import glob 5 | import numpy as np 6 | 7 | import torch 8 | from torch.utils.data import Dataset 9 | 10 | class SliceDataset(Dataset): 11 | def __init__(self, im_ids: np.array, pos_slice_dict: dict, transforms=None, 12 | preprocessing=None, p_pos_per_sample: float = 0.33, 13 | mode: str = "segmentation", num_classes: int = 3): 14 | """ 15 | Reads from a directory of 2D slice numpy arrays and samples positive 16 | slices. Assumes the data directory contains 2D slices processed by 17 | `io.Preprocessor.save_dir_as_2d()`. 18 | Attributes 19 | im_ids (np.ndarray): of image names. 20 | pos_slice_dict (dict): dictionary generated by 21 | `io.Preprocessor.save_dir_as_2d()` 22 | transforms (albumentations.augmentation): transforms to apply 23 | before preprocessing. Defaults to HFlip and ToTensor 24 | preprocessing: ops to perform after transforms, such as 25 | z-score standardization. Defaults to None. 26 | p_pos_per_sample (float): probability at which to sample slices 27 | that contain foreground classes. 28 | mode (str): one of ["seg", "clf", "both"] 29 | num_classes (int): number of classes. 3 for this challenge. 30 | """ 31 | self.im_ids = im_ids 32 | self.pos_slice_dict = pos_slice_dict 33 | self.transforms = transforms 34 | self.preprocessing = preprocessing 35 | self.p_pos_per_sample = p_pos_per_sample 36 | print(f"Assuming inputs are .npy files...") 37 | self.check_fg_idx_per_class() 38 | assert mode in ["segmentation", "classification", "both"], \ 39 | "mode must be one of ['segmentation', 'classification', 'both']" 40 | self.mode = mode 41 | self.num_classes = num_classes 42 | 43 | def __getitem__(self, idx): 44 | # loads data as a numpy arr and then adds the channel + batch size dimensions 45 | case_id = self.im_ids[idx] 46 | x, y = self.load_slices(case_id) 47 | 48 | if self.transforms: 49 | x = x[None] if len(x.shape) == 3 else x 50 | y = y[None] if len(y.shape) == 3 else y 51 | # batchgenerators requires shape: (b, c, ...) 52 | data_dict = self.transforms(**{"data": x, "seg": y}) 53 | x, y = data_dict["data"], data_dict["seg"] 54 | 55 | if self.mode in ["both", "clf_only"]: 56 | y_clf = torch.from_numpy(self.get_clf_label_from_mask(y)) 57 | 58 | if self.preprocessing: 59 | preprocessed = self.preprocessing(**{"data": x, "seg": y}) 60 | x, y = preprocessed["data"], preprocessed["seg"] 61 | # squeeze to remove batch size dim 62 | x = torch.squeeze(x, dim=0).float() 63 | y = torch.squeeze(y, dim=0) 64 | 65 | if self.mode == "both": 66 | return {"features": x, "seg_targets": y, "clf_targets": y_clf} 67 | elif self.mode == "classification": 68 | return (x, y_clf) 69 | elif self.mode == "segmentation": 70 | return (x, y) 71 | 72 | def __len__(self): 73 | return len(self.im_ids) 74 | 75 | def load_slices(self, case_fpath): 76 | """ 77 | Gets the slice idx using self.get_slice_idx_str() and actually loads 78 | the appropriate slice array. 79 | """ 80 | slice_idx_str = self.get_slice_idx_str(case_fpath) 81 | x_path = join(case_fpath, f"imaging_{slice_idx_str}.npy") 82 | y_path = join(case_fpath, f"segmentation_{slice_idx_str}.npy") 83 | return (np.load(x_path)[None], np.load(y_path)[None]) 84 | 85 | def get_slice_idx_str(self, case_fpath): 86 | """ 87 | Gets the slice idx and processes it so that it fits how the arrays 88 | were saved by `io.Preprocessor.save_dir_as_2d`. 89 | """ 90 | # extracting slice: 91 | temp_p = np.random.uniform(0, 1) 92 | if temp_p < self.p_pos_per_sample: 93 | slice_idx = self.get_rand_pos_slice_idx(case_fpath) 94 | else: 95 | slice_idx = self.get_rand_slice_idx(case_fpath) 96 | slice_idx_str = self._parse_slice_idx_to_str(slice_idx) 97 | return slice_idx_str 98 | 99 | def get_rand_pos_slice_idx(self, case_fpath): 100 | """ 101 | Gets a random positive slice index from self.pos_slice_dict (that was 102 | generated by io.preprocess.Preprocessor when save_as_slices=True). 103 | Args: 104 | case_fpath: each element of self.im_ids (path to a case folder) 105 | Returns: 106 | an integer representing a random non-background class slice index 107 | """ 108 | case_raw = Path(case_fpath).name 109 | # finding random positive class index 110 | if self.fg_idx_per_class: 111 | sampled_class = np.random.choice(self.fg_classes) 112 | slice_indices = self.pos_slice_dict[case_raw][sampled_class] 113 | random_pos_coord = np.random.choice(slice_indices) 114 | else: 115 | random_pos_coord = np.random.choice(self.pos_slice_dict[case_raw]) 116 | return random_pos_coord 117 | 118 | def get_rand_slice_idx(self, case_fpath): 119 | """ 120 | Args: 121 | case_fpath: each element of self.im_ids (path to a case folder) 122 | Returns: 123 | A randomly selected slice index 124 | """ 125 | # assumes that there are no other files in said directory with "imaging_" 126 | _slice_files = [file for file in os.listdir(case_fpath) 127 | if file.startswith("imaging_")] 128 | return np.random.randint(0, len(_slice_files)) 129 | 130 | def check_fg_idx_per_class(self): 131 | """ 132 | checks the first key: value pair of self.pos_slice_dict 133 | If dict -> fg_idx_per_class, if list: not fg_idx_per_class 134 | fg_idx_per_class -> uniformly sample per class v. sample all fg idx 135 | """ 136 | dummy_key = list(self.pos_slice_dict.keys())[0] 137 | dummy_value = self.pos_slice_dict[dummy_key] 138 | self.fg_idx_per_class = True if isinstance(dummy_value, dict) else False 139 | if self.fg_idx_per_class: 140 | self.fg_classes = list(dummy_value.keys()) 141 | 142 | def get_clf_label_from_mask(self, mask: np.array): 143 | """ 144 | Multi-label one-hot encoding of mask to get the classification 145 | label. 146 | Args: 147 | mask (np.ndarray): contains int in [0, num_classes-1] 148 | Returns: 149 | one_hot (np.ndarray): multi-label one hot encoded array 150 | i.e. [0, 1, 0] or [1, 0, 1], etc. 151 | """ 152 | unique = np.unique(mask).astype(np.int32) 153 | one_hot = np.zeros(self.num_classes) 154 | one_hot[unique] = 1 155 | return one_hot 156 | 157 | def _parse_slice_idx_to_str(self, slice_idx): 158 | """ 159 | Parse the slice index to a three digit string for reading the 2D .npy 160 | files generated by io.preprocess.Preprocessor. 161 | """ 162 | slice_idx_str = str(slice_idx) 163 | while len(slice_idx_str) < 3: 164 | slice_idx_str = "0"+slice_idx_str 165 | return slice_idx_str 166 | 167 | class PseudoSliceDataset(SliceDataset): 168 | def __init__(self, im_ids: np.array, pos_slice_dict: dict, transforms=None, 169 | preprocessing=None, p_pos_per_sample: float = 0.33, 170 | mode: str = "segmentation", num_classes: int = 3, 171 | num_pseudo_slices=1): 172 | """ 173 | Reads from a directory of 2D slice numpy arrays and samples positive 174 | slices. Assumes the data directory contains 2D slices processed by 175 | `io.Preprocessor.save_dir_as_2d()`. 176 | Attributes 177 | im_ids (np.ndarray): of image names. 178 | pos_slice_dict (dict): dictionary generated by 179 | `io.Preprocessor.save_dir_as_2d()` 180 | transforms (albumentations.augmentation): transforms to apply 181 | before preprocessing. Defaults to HFlip and ToTensor 182 | preprocessing: ops to perform after transforms, such as 183 | z-score standardization. Defaults to None. 184 | p_pos_per_sample (float): probability at which to sample slices 185 | that contain foreground classes. 186 | mode (str): one of ["seg", "clf", "both"] 187 | num_classes (int): number of classes. 3 for this challenge. 188 | num_pseudo_slices (int): number of pseudo 3D slices. Defaults to 1. 189 | 1 meaning no pseudo slices. If it's greater than 1, it must 190 | be odd (even numbers above and below) 191 | """ 192 | super().__init__(im_ids=im_ids, pos_slice_dict=pos_slice_dict, 193 | transforms=transforms, preprocessing=preprocessing, 194 | p_pos_per_sample=p_pos_per_sample, mode=mode, 195 | num_classes=num_classes) 196 | self.num_pseudo_slices = num_pseudo_slices 197 | assert num_pseudo_slices % 2 == 1, \ 198 | "`num_pseudo_slices` must be odd. i.e. 7 -> 3 above and 3 below" 199 | 200 | def load_slices(self, case_fpath): 201 | """ 202 | Gets the slice idx using self.get_slice_idx_str() and actually loads 203 | the appropriate slice array. Returned arrays have shape: 204 | (batch_size, n_channels, h, w) 205 | for batchgenerators transforms. 206 | """ 207 | center_slice_idx_str, center_slice_idx = self.get_slice_idx_str(case_fpath) 208 | total_num_slices = len(glob(join(case_fpath, "imaging_*.npy"))) 209 | min = center_slice_idx - (self.num_pseudo_slices - 1) // 2 210 | max = center_slice_idx + (self.num_pseudo_slices - 1) // 2 + 1 211 | 212 | x_path = join(case_fpath, f"imaging_{center_slice_idx_str}.npy") 213 | y_path = join(case_fpath, f"segmentation_{center_slice_idx_str}.npy") 214 | center_x, center_y = np.load(x_path)[None], np.load(y_path)[None] 215 | 216 | if self.num_pseudo_slices == 1: 217 | return (center_x, center_y) 218 | elif self.num_pseudo_slices > 1: 219 | # total shape: (1, num_pseudo_slices, h, w) 220 | x_arr = np.zeros((1, self.num_pseudo_slices) + center_x.shape[2:]) 221 | for idx, slice_idx in enumerate(range(min, max)): 222 | slice_idx_str = self._parse_slice_idx_to_str(slice_idx) 223 | x_path = join(case_fpath, f"imaging_{slice_idx_str}.npy") 224 | # loading slices if they exist 225 | if os.path.isfile(x_path): 226 | x_arr[:, idx] = np.load(x_path) 227 | return (x_arr, center_y) 228 | 229 | def get_slice_idx_str(self, case_fpath): 230 | """ 231 | Gets the slice idx and processes it so that it fits how the arrays 232 | were saved by `io.Preprocessor.save_dir_as_2d`. 233 | """ 234 | # extracting slice: 235 | temp_p = np.random.uniform(0, 1) 236 | if temp_p < self.p_pos_per_sample: 237 | slice_idx = self.get_rand_pos_slice_idx(case_fpath) 238 | else: 239 | slice_idx = self.get_rand_slice_idx(case_fpath) 240 | slice_idx_str = self._parse_slice_idx_to_str(slice_idx) 241 | return (slice_idx_str, slice_idx) 242 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /kits19cnn/experiments/train.py: -------------------------------------------------------------------------------- 1 | import os 2 | from glob import glob 3 | from abc import abstractmethod 4 | from pathlib import Path 5 | import catalyst.dl.callbacks as callbacks 6 | from sklearn.model_selection import train_test_split 7 | from torch.utils.data import DataLoader 8 | import torch 9 | 10 | from kits19cnn.loss_functions import DC_and_CE_loss, BCEDiceLoss 11 | from .utils import get_preprocessing, get_training_augmentation, \ 12 | get_validation_augmentation, seed_everything 13 | 14 | class TrainExperiment(object): 15 | def __init__(self, config: dict): 16 | """ 17 | Args: 18 | config (dict): from `train_classification_yaml.py` 19 | 20 | Attributes: 21 | config-related: 22 | config (dict): from `train_classification_yaml.py` 23 | io_params (dict): contains io-related parameters 24 | image_folder (key: str): path to the image folder 25 | df_setup_type (key: str): regular or pos_only 26 | test_size (key: float): split size for test 27 | split_seed (key: int): seed 28 | batch_size (key: int): <- 29 | num_workers (key: int): # of workers for data loaders 30 | aug_key (key: str): One of the augmentation keys for 31 | `get_training_augmentation` and `get_validation_augmentation` 32 | in `scripts/utils.py` 33 | opt_params (dict): optimizer related parameters 34 | lr (key: str): learning rate 35 | opt (key: str): optimizer name 36 | Currently, only supports sgd and adam. 37 | scheduler_params (key: str): dict of: 38 | scheduler (key: str): scheduler name 39 | {scheduler} (key: dict): args for the above scheduler 40 | cb_params (dict): 41 | earlystop (key: str): 42 | dict -> kwargs for EarlyStoppingCallback 43 | accuracy (key: str): 44 | dict -> kwargs for AccuracyCallback 45 | checkpoint_params (key: dict): 46 | checkpoint_path (key: str): path to the checkpoint 47 | checkpoint_mode (key: str): model_only or 48 | full (for stateful loading) 49 | split_dict (dict): train_ids and valid_ids 50 | train_dset, val_dset: <- 51 | loaders (dict): train/validation loaders 52 | model (torch.nn.Module): <- 53 | opt (torch.optim.Optimizer): <- 54 | lr_scheduler (torch.optim.lr_scheduler): <- 55 | criterion (torch.nn.Module): <- 56 | cb_list (list): list of catalyst callbacks 57 | """ 58 | # for reuse 59 | self.config = config 60 | self.io_params = config["io_params"] 61 | self.opt_params = config["opt_params"] 62 | self.cb_params = config["callback_params"] 63 | self.criterion_params = config["criterion_params"] 64 | # initializing the experiment components 65 | self.case_list = self.setup_im_ids() 66 | train_ids, val_ids, _ = self.get_split() 67 | self.train_dset, self.val_dset = self.get_datasets(train_ids, val_ids) 68 | self.loaders = self.get_loaders() 69 | self.model = self.get_model() 70 | self.opt = self.get_opt() 71 | self.lr_scheduler = self.get_lr_scheduler() 72 | self.criterion = self.get_criterion() 73 | self.cb_list = self.get_callbacks() 74 | 75 | @abstractmethod 76 | def get_datasets(self, train_ids, valid_ids): 77 | """ 78 | Initializes the data augmentation and preprocessing transforms. Creates 79 | and returns the train and validation datasets. 80 | """ 81 | return 82 | 83 | @abstractmethod 84 | def get_model(self): 85 | """ 86 | Creates and returns the model. 87 | """ 88 | return 89 | 90 | def setup_im_ids(self): 91 | """ 92 | Creates a list of all paths to case folders for the dataset split 93 | """ 94 | search_path = os.path.join(self.config["data_folder"], "*/") 95 | case_list = sorted(glob(search_path)) 96 | case_list = case_list[:210] if len(case_list) >= 210 else case_list 97 | return case_list 98 | 99 | def get_split(self): 100 | """ 101 | Creates train/valid filename splits 102 | """ 103 | # setting up the train/val split with filenames 104 | split_seed: int = self.io_params["split_seed"] 105 | test_size: float = self.io_params["test_size"] 106 | # doing the splits: 1-test_size, test_size//2, test_size//2 107 | print("Splitting the dataset normally...") 108 | train_ids, total_test = train_test_split(self.case_list, 109 | random_state=split_seed, 110 | test_size=test_size) 111 | val_ids, test_ids = train_test_split(sorted(total_test), 112 | random_state=split_seed, 113 | test_size=0.5) 114 | return (train_ids, val_ids, test_ids) 115 | 116 | def get_loaders(self): 117 | """ 118 | Creates train/val loaders from datasets created in self.get_datasets. 119 | Returns the loaders. 120 | """ 121 | # setting up the loaders 122 | b_size, num_workers = self.io_params["batch_size"], self.io_params["num_workers"] 123 | train_loader = DataLoader(self.train_dset, batch_size=b_size, 124 | shuffle=True, num_workers=num_workers) 125 | valid_loader = DataLoader(self.val_dset, batch_size=b_size, 126 | shuffle=False, num_workers=num_workers) 127 | 128 | self.train_steps = len(self.train_dset) # for schedulers 129 | return {"train": train_loader, "valid": valid_loader} 130 | 131 | def get_opt(self): 132 | """ 133 | Creates the optimizer 134 | """ 135 | assert isinstance(self.model, torch.nn.Module), \ 136 | "`model` must be an instance of torch.nn.Module`" 137 | # fetching optimizers 138 | opt_name = self.opt_params["opt"] 139 | opt_kwargs = self.opt_params[opt_name] 140 | opt_cls = torch.optim.__dict__[opt_name] 141 | opt = opt_cls(filter(lambda p: p.requires_grad, 142 | self.model.parameters()), 143 | **opt_kwargs) 144 | print(f"Optimizer: {opt}") 145 | return opt 146 | 147 | def get_lr_scheduler(self): 148 | """ 149 | Creates the LR scheduler from the optimizer created in `self.get_opt` 150 | """ 151 | assert isinstance(self.opt, torch.optim.Optimizer), \ 152 | "`optimizer` must be an instance of torch.optim.Optimizer" 153 | sched_params = self.opt_params["scheduler_params"] 154 | scheduler_name = sched_params["scheduler"] 155 | scheduler_args = sched_params[scheduler_name] 156 | scheduler_cls = torch.optim.lr_scheduler.__dict__[scheduler_name] 157 | scheduler = scheduler_cls(optimizer=self.opt, **scheduler_args) 158 | print(f"LR Scheduler: {scheduler.__class__.__name__}") 159 | return scheduler 160 | 161 | def get_criterion(self): 162 | """ 163 | Fetches the criterion. (Only one loss.) 164 | """ 165 | loss_name = self.criterion_params["loss"].lower() 166 | loss_dict = { 167 | "bce_dice_loss": BCEDiceLoss(eps=1.), 168 | "bce": torch.nn.BCEWithLogitsLoss(), 169 | "ce_dice_loss": DC_and_CE_loss(soft_dice_kwargs={}, ce_kwargs={}), 170 | } 171 | # re-initializing criterion with kwargs 172 | loss_kwargs = self.criterion_params.get(loss_name) 173 | loss_kwargs = {} if loss_kwargs is None else loss_kwargs 174 | 175 | loss = loss_dict[loss_name] 176 | loss.__init__(**loss_kwargs) 177 | print(f"Criterion: {loss}") 178 | return loss 179 | 180 | def get_callbacks(self): 181 | """ 182 | Creates a list of callbacks. 183 | """ 184 | cb_name_list = list(self.cb_params.keys()) 185 | cb_name_list.remove("checkpoint_params") 186 | callbacks_list = [callbacks.__dict__[cb_name](**self.cb_params[cb_name]) 187 | for cb_name in cb_name_list] 188 | callbacks_list = self.load_weights(callbacks_list) 189 | print(f"Callbacks: {[cb.__class__.__name__ for cb in callbacks_list]}") 190 | return callbacks_list 191 | 192 | def load_weights(self, callbacks_list): 193 | """ 194 | Loads model weights and appends the CheckpointCallback if doing 195 | stateful model loading. This doesn't add the CheckpointCallback if 196 | it's 'model_only' loading bc SupervisedRunner adds it by default. 197 | """ 198 | ckpoint_params = self.cb_params["checkpoint_params"] 199 | # Having checkpoint_params=None is a hacky way to say no checkpoint 200 | # callback but eh what the heck 201 | if ckpoint_params["checkpoint_path"] != None: 202 | mode = ckpoint_params["mode"].lower() 203 | if mode == "full": 204 | print("Stateful loading...") 205 | ckpoint_p = Path(ckpoint_params["checkpoint_path"]) 206 | fname = ckpoint_p.name 207 | # everything in the path besides the base file name 208 | resume_dir = str(ckpoint_p.parents[0]) 209 | print(f"Loading {fname} from {resume_dir}. \ 210 | \nCheckpoints will also be saved in {resume_dir}.") 211 | # adding the checkpoint callback 212 | ckpoint = [callbacks.CheckpointCallback(resume=fname, 213 | resume_dir=resume_dir)] 214 | callbacks_list = callbacks_list + ckpoint 215 | elif mode == "model_only": 216 | print("Loading weights into model...") 217 | self.model = load_weights_train(ckpoint_params["checkpoint_path"], 218 | self.model) 219 | return callbacks_list 220 | 221 | class TrainClfSegExperiment(TrainExperiment): 222 | """ 223 | Stores the main parts of a classification+segmentation experiment: 224 | - df split 225 | - datasets 226 | - loaders 227 | - model 228 | - optimizer 229 | - lr_scheduler 230 | - criterion 231 | - callbacks 232 | """ 233 | def __init__(self, config: dict): 234 | """ 235 | Args: 236 | config (dict): from a .yml file. 237 | """ 238 | self.model_params = config["model_params"] 239 | super().__init__(config=config) 240 | 241 | @abstractmethod 242 | def get_datasets(self, train_ids, valid_ids): 243 | """ 244 | Initializes the data augmentation and preprocessing transforms. Creates 245 | and returns the train and validation datasets. 246 | """ 247 | return 248 | 249 | @abstractmethod 250 | def get_model(self): 251 | """ 252 | Creates and returns the model. 253 | """ 254 | return 255 | 256 | def get_criterion(self): 257 | """ 258 | Returns a dictionary of the desired criterion (for seg and clf) 259 | """ 260 | loss_dict = { 261 | "bce_dice_loss": BCEDiceLoss(eps=1.), 262 | "bce": torch.nn.BCEWithLogitsLoss(), 263 | "ce_dice_loss": DC_and_CE_loss(soft_dice_kwargs={}, ce_kwargs={}), 264 | } 265 | 266 | seg_loss_name = self.criterion_params["seg_loss"].lower() 267 | clf_loss_name = self.criterion_params["clf_loss"].lower() 268 | 269 | # re-initializing criterion with kwargs 270 | seg_kwargs = self.criterion_params.get(seg_loss_name) 271 | clf_kwargs = self.criterion_params.get(clf_loss_name) 272 | seg_kwargs = {} if seg_kwargs is None else seg_kwargs 273 | clf_kwargs = {} if clf_kwargs is None else clf_kwargs 274 | 275 | seg_loss = loss_dict[seg_loss_name] 276 | clf_loss = loss_dict[clf_loss_name] 277 | seg_loss.__init__(**seg_kwargs), clf_loss.__init__(**clf_kwargs) 278 | criterion_dict = {seg_loss_name: seg_loss, 279 | clf_loss_name: clf_loss} 280 | print(f"Criterion: {criterion_dict}") 281 | return criterion_dict 282 | 283 | def get_callbacks(self): 284 | """ 285 | Gets the callbacks list; since this is multi-task, we need multiple 286 | metrics! Therefore, callbacks_list will now contain the 287 | CriterionAggregatorCallback and CriterionCallback. They calculate and 288 | record the `seg_loss` and `clf_loss`. 289 | """ 290 | from catalyst.dl.callbacks import CriterionAggregatorCallback, \ 291 | CriterionCallback 292 | seg_loss_name = self.criterion_params["seg_loss"].lower() 293 | clf_loss_name = self.criterion_params["clf_loss"].lower() 294 | criterion_cb_list = [ 295 | CriterionCallback(prefix="seg_loss", 296 | input_key="seg_targets", 297 | output_key="seg_logits", 298 | criterion_key=seg_loss_name), 299 | CriterionCallback(prefix="clf_loss", 300 | input_key="clf_targets", 301 | output_key="clf_logits", 302 | criterion_key=clf_loss_name), 303 | CriterionAggregatorCallback(prefix="loss", 304 | loss_keys=\ 305 | ["seg_loss", "clf_loss"]), 306 | ] 307 | # regular callbacks 308 | cb_name_list = list(self.cb_params.keys()) 309 | cb_name_list.remove("checkpoint_params") 310 | callbacks_list = [callbacks.__dict__[cb_name](**self.cb_params[cb_name]) 311 | for cb_name in cb_name_list] 312 | callbacks_list = self.load_weights(callbacks_list) + criterion_cb_list 313 | print(f"Callbacks: {[cb.__class__.__name__ for cb in callbacks_list]}") 314 | return callbacks_list 315 | 316 | def load_weights_train(checkpoint_path, model): 317 | """ 318 | Loads weights from a checkpoint and into training. 319 | 320 | Args: 321 | checkpoint_path (str): path to a .pt or .pth checkpoint 322 | model (torch.nn.Module): <- 323 | Returns: 324 | Model with loaded weights and in train() mode 325 | """ 326 | try: 327 | # catalyst weights 328 | state_dict = torch.load(checkpoint_path, map_location="cpu")["model_state_dict"] 329 | except: 330 | # anything else 331 | state_dict = torch.load(checkpoint_path, map_location="cpu") 332 | model.load_state_dict(state_dict, strict=True) 333 | model.train() 334 | return model 335 | -------------------------------------------------------------------------------- /kits19cnn/models/nnunet/generic_UNet.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from copy import deepcopy 16 | from torch import nn 17 | import torch 18 | import numpy as np 19 | import torch.nn.functional 20 | 21 | from kits19cnn.utils import softmax_helper 22 | from .initialization import InitWeights_He 23 | from .neural_network import SegmentationNetwork 24 | 25 | class ConvDropoutNormNonlin(nn.Module): 26 | def __init__(self, input_channels, output_channels, 27 | conv_op=nn.Conv2d, conv_kwargs=None, 28 | norm_op=nn.BatchNorm2d, norm_op_kwargs=None, 29 | dropout_op=nn.Dropout2d, dropout_op_kwargs=None, 30 | nonlin=nn.LeakyReLU, nonlin_kwargs=None): 31 | super(ConvDropoutNormNonlin, self).__init__() 32 | if nonlin_kwargs is None: 33 | nonlin_kwargs = {'negative_slope': 1e-2, 'inplace': True} 34 | if dropout_op_kwargs is None: 35 | dropout_op_kwargs = {'p': 0.5, 'inplace': True} 36 | if norm_op_kwargs is None: 37 | norm_op_kwargs = {'eps': 1e-5, 'affine': True, 'momentum': 0.1} 38 | if conv_kwargs is None: 39 | conv_kwargs = {'kernel_size': 3, 'stride': 1, 'padding': 1, 'dilation': 1, 'bias': True} 40 | 41 | self.nonlin_kwargs = nonlin_kwargs 42 | self.nonlin = nonlin 43 | self.dropout_op = dropout_op 44 | self.dropout_op_kwargs = dropout_op_kwargs 45 | self.norm_op_kwargs = norm_op_kwargs 46 | self.conv_kwargs = conv_kwargs 47 | self.conv_op = conv_op 48 | self.norm_op = norm_op 49 | 50 | self.conv = self.conv_op(input_channels, output_channels, **self.conv_kwargs) 51 | if self.dropout_op is not None and self.dropout_op_kwargs['p'] is not None and self.dropout_op_kwargs[ 52 | 'p'] > 0: 53 | self.dropout = self.dropout_op(**self.dropout_op_kwargs) 54 | else: 55 | self.dropout = None 56 | self.instnorm = self.norm_op(output_channels, **self.norm_op_kwargs) 57 | self.lrelu = self.nonlin(**self.nonlin_kwargs) 58 | 59 | def forward(self, x): 60 | x = self.conv(x) 61 | if self.dropout is not None: 62 | x = self.dropout(x) 63 | return self.lrelu(self.instnorm(x)) 64 | 65 | 66 | class StackedConvLayers(nn.Module): 67 | def __init__(self, input_feature_channels, output_feature_channels, num_convs, 68 | conv_op=nn.Conv2d, conv_kwargs=None, 69 | norm_op=nn.BatchNorm2d, norm_op_kwargs=None, 70 | dropout_op=nn.Dropout2d, dropout_op_kwargs=None, 71 | nonlin=nn.LeakyReLU, nonlin_kwargs=None, first_stride=None): 72 | ''' 73 | stacks ConvDropoutNormLReLU layers. initial_stride will only be applied to first layer in the stack. The other parameters affect all layers 74 | :param input_feature_channels: 75 | :param output_feature_channels: 76 | :param num_convs: 77 | :param dilation: 78 | :param kernel_size: 79 | :param padding: 80 | :param dropout: 81 | :param initial_stride: 82 | :param conv_op: 83 | :param norm_op: 84 | :param dropout_op: 85 | :param inplace: 86 | :param neg_slope: 87 | :param norm_affine: 88 | :param conv_bias: 89 | ''' 90 | self.input_channels = input_feature_channels 91 | self.output_channels = output_feature_channels 92 | 93 | if nonlin_kwargs is None: 94 | nonlin_kwargs = {'negative_slope': 1e-2, 'inplace': True} 95 | if dropout_op_kwargs is None: 96 | dropout_op_kwargs = {'p': 0.5, 'inplace': True} 97 | if norm_op_kwargs is None: 98 | norm_op_kwargs = {'eps': 1e-5, 'affine': True, 'momentum': 0.1} 99 | if conv_kwargs is None: 100 | conv_kwargs = {'kernel_size': 3, 'stride': 1, 'padding': 1, 'dilation': 1, 'bias': True} 101 | 102 | self.nonlin_kwargs = nonlin_kwargs 103 | self.nonlin = nonlin 104 | self.dropout_op = dropout_op 105 | self.dropout_op_kwargs = dropout_op_kwargs 106 | self.norm_op_kwargs = norm_op_kwargs 107 | self.conv_kwargs = conv_kwargs 108 | self.conv_op = conv_op 109 | self.norm_op = norm_op 110 | 111 | if first_stride is not None: 112 | self.conv_kwargs_first_conv = deepcopy(conv_kwargs) 113 | self.conv_kwargs_first_conv['stride'] = first_stride 114 | else: 115 | self.conv_kwargs_first_conv = conv_kwargs 116 | 117 | super(StackedConvLayers, self).__init__() 118 | self.blocks = nn.Sequential( 119 | *([ConvDropoutNormNonlin(input_feature_channels, output_feature_channels, self.conv_op, 120 | self.conv_kwargs_first_conv, 121 | self.norm_op, self.norm_op_kwargs, self.dropout_op, self.dropout_op_kwargs, 122 | self.nonlin, self.nonlin_kwargs)] + 123 | [ConvDropoutNormNonlin(output_feature_channels, output_feature_channels, self.conv_op, 124 | self.conv_kwargs, 125 | self.norm_op, self.norm_op_kwargs, self.dropout_op, self.dropout_op_kwargs, 126 | self.nonlin, self.nonlin_kwargs) for _ in range(num_convs - 1)])) 127 | 128 | def forward(self, x): 129 | return self.blocks(x) 130 | 131 | 132 | def print_module_training_status(module): 133 | if isinstance(module, nn.Conv2d) or isinstance(module, nn.Conv3d) or isinstance(module, nn.Dropout3d) or \ 134 | isinstance(module, nn.Dropout2d) or isinstance(module, nn.Dropout) or isinstance(module, nn.InstanceNorm3d) \ 135 | or isinstance(module, nn.InstanceNorm2d) or isinstance(module, nn.InstanceNorm1d) \ 136 | or isinstance(module, nn.BatchNorm2d) or isinstance(module, nn.BatchNorm3d) or isinstance(module, nn.BatchNorm1d): 137 | print(str(module), module.training) 138 | 139 | 140 | class Upsample(nn.Module): 141 | def __init__(self, size=None, scale_factor=None, mode='nearest', align_corners=False): 142 | super(Upsample, self).__init__() 143 | self.align_corners = align_corners 144 | self.mode = mode 145 | self.scale_factor = scale_factor 146 | self.size = size 147 | 148 | def forward(self, x): 149 | return nn.functional.interpolate(x, size=self.size, 150 | scale_factor=self.scale_factor, 151 | mode=self.mode, 152 | align_corners=self.align_corners) 153 | 154 | class ClassificationHead(nn.Module): 155 | def __init__(self, num_classes=3, input_features=320, 156 | final_nonlin=lambda x: x, conv_op=nn.Conv3d): 157 | super().__init__() 158 | if conv_op == nn.Conv2d: 159 | self.final_pool = nn.AdaptiveAvgPool2d((1, 1)) 160 | elif conv_op == nn.Conv3d: 161 | self.final_pool = nn.AdaptiveAvgPool3d((1, 1, 1)) 162 | self.out_dense = nn.Linear(input_features, num_classes) 163 | self.final_nonlin = final_nonlin 164 | 165 | def forward(self, x): 166 | pooled = self.final_pool(x) 167 | # (b, max_num_features, 1, 1 (,1)) 168 | logits = self.out_dense(torch.flatten(pooled, start_dim=1)) 169 | return self.final_nonlin(logits) 170 | 171 | class Generic_UNet(SegmentationNetwork): 172 | DEFAULT_BATCH_SIZE_3D = 2 173 | DEFAULT_PATCH_SIZE_3D = (64, 192, 160) 174 | SPACING_FACTOR_BETWEEN_STAGES = 2 175 | BASE_NUM_FEATURES_3D = 30 176 | MAX_NUMPOOL_3D = 999 177 | MAX_NUM_FILTERS_3D = 320 178 | 179 | DEFAULT_PATCH_SIZE_2D = (256, 256) 180 | BASE_NUM_FEATURES_2D = 30 181 | DEFAULT_BATCH_SIZE_2D = 50 182 | MAX_NUMPOOL_2D = 999 183 | MAX_FILTERS_2D = 480 184 | 185 | use_this_for_batch_size_computation_2D = 19739648 186 | use_this_for_batch_size_computation_3D = 520000000 # 505789440 187 | 188 | def __init__(self, input_channels, base_num_features, num_classes, num_pool, 189 | num_conv_per_stage=2, feat_map_mul_on_downscale=2, 190 | conv_op=nn.Conv2d, norm_op=nn.BatchNorm2d, norm_op_kwargs=None, 191 | dropout_op=nn.Dropout2d, dropout_op_kwargs=None, 192 | nonlin=nn.LeakyReLU, nonlin_kwargs=None, deep_supervision=True, 193 | dropout_in_localization=False, final_nonlin=softmax_helper, 194 | weightInitializer=InitWeights_He(1e-2), 195 | pool_op_kernel_sizes=None, 196 | conv_kernel_sizes=None, 197 | upscale_logits=False, convolutional_pooling=False, 198 | convolutional_upsampling=False, 199 | max_num_features=None, 200 | classification=False): 201 | """ 202 | basically more flexible than v1, architecture is the same 203 | Does this look complicated? Nah bro. Functionality > usability 204 | This does everything you need, including world peace. 205 | Questions? -> f.isensee@dkfz.de 206 | """ 207 | super(Generic_UNet, self).__init__() 208 | self.convolutional_upsampling = convolutional_upsampling 209 | self.convolutional_pooling = convolutional_pooling 210 | self.upscale_logits = upscale_logits 211 | if nonlin_kwargs is None: 212 | nonlin_kwargs = {'negative_slope':1e-2, 'inplace':True} 213 | if dropout_op_kwargs is None: 214 | dropout_op_kwargs = {'p':0.5, 'inplace':True} 215 | if norm_op_kwargs is None: 216 | norm_op_kwargs = {'eps':1e-5, 'affine':True, 'momentum':0.1} 217 | 218 | self.conv_kwargs = {'stride':1, 'dilation':1, 'bias':True} 219 | 220 | self.nonlin = nonlin 221 | self.nonlin_kwargs = nonlin_kwargs 222 | self.dropout_op_kwargs = dropout_op_kwargs 223 | self.norm_op_kwargs = norm_op_kwargs 224 | self.weightInitializer = weightInitializer 225 | self.conv_op = conv_op 226 | self.norm_op = norm_op 227 | self.dropout_op = dropout_op 228 | self.num_classes = num_classes 229 | self.final_nonlin = final_nonlin 230 | self.do_ds = deep_supervision 231 | 232 | if conv_op == nn.Conv2d: 233 | upsample_mode = 'bilinear' 234 | pool_op = nn.MaxPool2d 235 | transpconv = nn.ConvTranspose2d 236 | if pool_op_kernel_sizes is None: 237 | pool_op_kernel_sizes = [(2, 2)] * num_pool 238 | if conv_kernel_sizes is None: 239 | conv_kernel_sizes = [(3, 3)] * (num_pool + 1) 240 | elif conv_op == nn.Conv3d: 241 | upsample_mode = 'trilinear' 242 | pool_op = nn.MaxPool3d 243 | transpconv = nn.ConvTranspose3d 244 | if pool_op_kernel_sizes is None: 245 | pool_op_kernel_sizes = [(2, 2, 2)] * num_pool 246 | if conv_kernel_sizes is None: 247 | conv_kernel_sizes = [(3, 3, 3)] * (num_pool + 1) 248 | else: 249 | raise ValueError("unknown convolution dimensionality, conv op: %s" % str(conv_op)) 250 | 251 | self.input_shape_must_be_divisible_by = np.prod(pool_op_kernel_sizes, 0, dtype=np.int64) 252 | self.pool_op_kernel_sizes = pool_op_kernel_sizes 253 | self.conv_kernel_sizes = conv_kernel_sizes 254 | 255 | self.conv_pad_sizes = [] 256 | for krnl in self.conv_kernel_sizes: 257 | self.conv_pad_sizes.append([1 if i == 3 else 0 for i in krnl]) 258 | 259 | if max_num_features is None: 260 | if self.conv_op == nn.Conv3d: 261 | self.max_num_features = self.MAX_NUM_FILTERS_3D 262 | else: 263 | self.max_num_features = self.MAX_FILTERS_2D 264 | else: 265 | self.max_num_features = max_num_features 266 | 267 | self.conv_blocks_context = [] 268 | self.conv_blocks_localization = [] 269 | self.td = [] 270 | self.tu = [] 271 | self.seg_outputs = [] 272 | 273 | output_features = base_num_features 274 | input_features = input_channels 275 | 276 | for d in range(num_pool): 277 | # determine the first stride 278 | if d != 0 and self.convolutional_pooling: 279 | first_stride = pool_op_kernel_sizes[d-1] 280 | else: 281 | first_stride = None 282 | 283 | self.conv_kwargs['kernel_size'] = self.conv_kernel_sizes[d] 284 | self.conv_kwargs['padding'] = self.conv_pad_sizes[d] 285 | # add convolutions 286 | self.conv_blocks_context.append(StackedConvLayers(input_features, output_features, num_conv_per_stage, 287 | self.conv_op, self.conv_kwargs, self.norm_op, 288 | self.norm_op_kwargs, self.dropout_op, 289 | self.dropout_op_kwargs, self.nonlin, self.nonlin_kwargs, 290 | first_stride)) 291 | if not self.convolutional_pooling: 292 | self.td.append(pool_op(pool_op_kernel_sizes[d])) 293 | input_features = output_features 294 | output_features = int(np.round(output_features * feat_map_mul_on_downscale)) 295 | 296 | output_features = min(output_features, self.max_num_features) 297 | 298 | 299 | # now the bottleneck. 300 | # determine the first stride 301 | if self.convolutional_pooling: 302 | first_stride = pool_op_kernel_sizes[-1] 303 | else: 304 | first_stride = None 305 | 306 | # the output of the last conv must match the number of features from the skip connection if we are not using 307 | # convolutional upsampling. If we use convolutional upsampling then the reduction in feature maps will be 308 | # done by the transposed conv 309 | if self.convolutional_upsampling: 310 | final_num_features = output_features 311 | else: 312 | final_num_features = self.conv_blocks_context[-1].output_channels 313 | 314 | self.conv_kwargs['kernel_size'] = self.conv_kernel_sizes[num_pool] 315 | self.conv_kwargs['padding'] = self.conv_pad_sizes[num_pool] 316 | self.conv_blocks_context.append(nn.Sequential( 317 | StackedConvLayers(input_features, output_features, num_conv_per_stage - 1, self.conv_op, self.conv_kwargs, 318 | self.norm_op, self.norm_op_kwargs, self.dropout_op, self.dropout_op_kwargs, self.nonlin, 319 | self.nonlin_kwargs, first_stride), 320 | StackedConvLayers(output_features, final_num_features, 1, self.conv_op, self.conv_kwargs, 321 | self.norm_op, self.norm_op_kwargs, self.dropout_op, self.dropout_op_kwargs, self.nonlin, 322 | self.nonlin_kwargs))) 323 | 324 | # if we don't want to do dropout in the localization pathway then we set the dropout prob to zero here 325 | if not dropout_in_localization: 326 | old_dropout_p = self.dropout_op_kwargs['p'] 327 | self.dropout_op_kwargs['p'] = 0.0 328 | 329 | # now lets build the localization pathway 330 | for u in range(num_pool): 331 | nfeatures_from_down = final_num_features 332 | nfeatures_from_skip = self.conv_blocks_context[-(2 + u)].output_channels # self.conv_blocks_context[-1] is bottleneck, so start with -2 333 | n_features_after_tu_and_concat = nfeatures_from_skip * 2 334 | 335 | # the first conv reduces the number of features to match those of skip 336 | # the following convs work on that number of features 337 | # if not convolutional upsampling then the final conv reduces the num of features again 338 | if u != num_pool - 1 and not self.convolutional_upsampling: 339 | final_num_features = self.conv_blocks_context[-(3 + u)].output_channels 340 | else: 341 | final_num_features = nfeatures_from_skip 342 | 343 | if not self.convolutional_upsampling: 344 | self.tu.append(Upsample(scale_factor=pool_op_kernel_sizes[-(u+1)], mode=upsample_mode)) 345 | else: 346 | self.tu.append(transpconv(nfeatures_from_down, nfeatures_from_skip, pool_op_kernel_sizes[-(u+1)], 347 | pool_op_kernel_sizes[-(u+1)], bias=False)) 348 | 349 | self.conv_kwargs['kernel_size'] = self.conv_kernel_sizes[- (u+1)] 350 | self.conv_kwargs['padding'] = self.conv_pad_sizes[- (u+1)] 351 | self.conv_blocks_localization.append(nn.Sequential( 352 | StackedConvLayers(n_features_after_tu_and_concat, nfeatures_from_skip, num_conv_per_stage - 1, 353 | self.conv_op, self.conv_kwargs, self.norm_op, self.norm_op_kwargs, self.dropout_op, 354 | self.dropout_op_kwargs, self.nonlin, self.nonlin_kwargs), 355 | StackedConvLayers(nfeatures_from_skip, final_num_features, 1, self.conv_op, self.conv_kwargs, 356 | self.norm_op, self.norm_op_kwargs, self.dropout_op, self.dropout_op_kwargs, 357 | self.nonlin, self.nonlin_kwargs) 358 | )) 359 | 360 | for ds in range(len(self.conv_blocks_localization)): 361 | self.seg_outputs.append(conv_op(self.conv_blocks_localization[ds][-1].output_channels, num_classes, 362 | 1, 1, 0, 1, 1, False)) 363 | 364 | self.upscale_logits_ops = [] 365 | cum_upsample = np.cumprod(np.vstack(pool_op_kernel_sizes), axis=0)[::-1] 366 | for usl in range(num_pool - 1): 367 | if self.upscale_logits: 368 | self.upscale_logits_ops.append(Upsample(scale_factor=tuple([int(i) for i in cum_upsample[usl+1]]), 369 | mode=upsample_mode)) 370 | else: 371 | self.upscale_logits_ops.append(lambda x: x) 372 | 373 | if not dropout_in_localization: 374 | self.dropout_op_kwargs['p'] = old_dropout_p 375 | 376 | # register all modules properly 377 | self.conv_blocks_localization = nn.ModuleList(self.conv_blocks_localization) 378 | self.conv_blocks_context = nn.ModuleList(self.conv_blocks_context) 379 | self.td = nn.ModuleList(self.td) 380 | self.tu = nn.ModuleList(self.tu) 381 | self.seg_outputs = nn.ModuleList(self.seg_outputs) 382 | if self.upscale_logits: 383 | self.upscale_logits_ops = nn.ModuleList(self.upscale_logits_ops) # lambda x:x is not a Module so we need to distinguish here 384 | 385 | if self.weightInitializer is not None: 386 | self.apply(self.weightInitializer) 387 | #self.apply(print_module_training_status) 388 | self.classification = classification 389 | if self.classification: 390 | self.clf_head = ClassificationHead(num_classes=num_classes, 391 | input_features=\ 392 | self.max_num_features, 393 | final_nonlin=final_nonlin, 394 | conv_op=conv_op) 395 | print("The final classification layer assumes that the bottom", 396 | "context layer will have the max number features", 397 | f"({self.max_num_features}) as the number of input features.", 398 | "\nIf this is not true, please adjust `max_num_features`.") 399 | 400 | def forward(self, x): 401 | skips = [] 402 | seg_outputs = [] 403 | for d in range(len(self.conv_blocks_context) - 1): 404 | x = self.conv_blocks_context[d](x) 405 | skips.append(x) 406 | if not self.convolutional_pooling: 407 | x = self.td[d](x) 408 | 409 | x = self.conv_blocks_context[-1](x) 410 | if self.classification: 411 | clf_out = self.clf_head(x) 412 | 413 | for u in range(len(self.tu)): 414 | x = self.tu[u](x) 415 | x = torch.cat((x, skips[-(u + 1)]), dim=1) 416 | x = self.conv_blocks_localization[u](x) 417 | seg_outputs.append(self.final_nonlin(self.seg_outputs[u](x))) 418 | 419 | if self.classification: 420 | # returns classification pred last 421 | if self.do_ds: 422 | ds_out = [seg_outputs[-1]] + [i(j) for i, j in 423 | zip(list(self.upscale_logits_ops)[::-1], 424 | seg_outputs[:-1][::-1])] 425 | return tuple(ds_out + [clf_out,]) 426 | else: 427 | return tuple([seg_outputs[-1], clf_out]) 428 | else: 429 | if self.do_ds: 430 | return tuple([seg_outputs[-1],] + [i(j) for i, j in 431 | zip(list(self.upscale_logits_ops)[::-1], seg_outputs[:-1][::-1])]) 432 | else: 433 | return seg_outputs[-1] 434 | 435 | @staticmethod 436 | def compute_approx_vram_consumption(patch_size, num_pool_per_axis, base_num_features, max_num_features, 437 | num_modalities, num_classes, pool_op_kernel_sizes): 438 | """ 439 | This only applies for num_conv_per_stage and convolutional_upsampling=True 440 | not real vram consumption. just a constant term to which the vram consumption will be approx proportional 441 | (+ offset for parameter storage) 442 | :param patch_size: 443 | :param num_pool_per_axis: 444 | :param base_num_features: 445 | :param max_num_features: 446 | :return: 447 | """ 448 | 449 | if not isinstance(num_pool_per_axis, np.ndarray): 450 | num_pool_per_axis = np.array(num_pool_per_axis) 451 | 452 | npool = len(pool_op_kernel_sizes) 453 | 454 | map_size = np.array(patch_size) 455 | tmp = np.int64(5 * np.prod(map_size, dtype=np.int64) * base_num_features + num_modalities * np.prod(map_size, dtype=np.int64) + \ 456 | num_classes * np.prod(map_size, dtype=np.int64)) 457 | 458 | num_feat = base_num_features 459 | 460 | for p in range(npool): 461 | for pi in range(len(num_pool_per_axis)): 462 | map_size[pi] /= pool_op_kernel_sizes[p][pi] 463 | num_feat = min(num_feat * 2, max_num_features) 464 | num_blocks = 5 if p < (npool -1) else 2 # 2 + 2 for the convs of encode/decode and 1 for transposed conv 465 | tmp += num_blocks * np.prod(map_size, dtype=np.int64) * num_feat 466 | # print(p, map_size, num_feat, tmp) 467 | return tmp 468 | --------------------------------------------------------------------------------