├── __init__.py ├── unetr_pp ├── network_architecture │ ├── acdc │ │ ├── __init__.py │ │ ├── transformerblock.py │ │ └── unetr_pp_acdc.py │ ├── lung │ │ ├── __init__.py │ │ ├── unetr_pp_lung.py │ │ └── transformerblock.py │ ├── tumor │ │ ├── __init__.py │ │ ├── transformerblock.py │ │ └── unetr_pp_tumor.py │ ├── synapse │ │ ├── __init__.py │ │ ├── transformerblock.py │ │ └── unetr_pp_synapse.py │ ├── __init__.py │ ├── README.md │ ├── initialization.py │ └── layers.py ├── __init__.py ├── run │ ├── __init__.py │ └── default_configuration.py ├── inference │ ├── __init__.py │ └── inferTs │ │ └── swin_nomask_2 │ │ └── plans.pkl ├── utilities │ ├── __init__.py │ ├── random_stuff.py │ ├── nd_softmax.py │ ├── sitk_stuff.py │ ├── one_hot_encoding.py │ ├── file_endings.py │ ├── to_torch.py │ ├── recursive_delete_npz.py │ ├── tensor_utilities.py │ ├── recursive_rename_taskXX_to_taskXXX.py │ ├── folder_names.py │ ├── distributed.py │ ├── task_name_id_conversion.py │ └── file_conversions.py ├── evaluation │ ├── __init__.py │ ├── model_selection │ │ ├── __init__.py │ │ ├── collect_all_fold0_results_and_summarize_in_one_csv.py │ │ └── summarize_results_with_plans.py │ ├── unetr_pp_acdc_checkpoint │ │ └── unetr_pp │ │ │ └── 3d_fullres │ │ │ └── Task001_ACDC │ │ │ └── unetr_pp_trainer_acdc__unetr_pp_Plansv2.1 │ │ │ └── fold_0 │ │ │ └── .gitignore │ ├── unetr_pp_lung_checkpoint │ │ └── unetr_pp │ │ │ └── 3d_fullres │ │ │ └── Task006_Lung │ │ │ └── unetr_pp_trainer_lung__unetr_pp_Plansv2.1 │ │ │ └── fold_0 │ │ │ └── .gitignore │ ├── unetr_pp_tumor_checkpoint │ │ └── unetr_pp │ │ │ └── 3d_fullres │ │ │ └── Task003_tumor │ │ │ └── unetr_pp_trainer_tumor__unetr_pp_Plansv2.1 │ │ │ └── fold_0 │ │ │ └── .gitignore │ ├── unetr_pp_synapse_checkpoint │ │ └── unetr_pp │ │ │ └── 3d_fullres │ │ │ └── Task002_Synapse │ │ │ └── unetr_pp_trainer_synapse__unetr_pp_Plansv2.1 │ │ │ └── fold_0 │ │ │ └── .gitignore │ ├── collect_results_files.py │ ├── add_mean_dice_to_json.py │ ├── surface_dice.py │ ├── add_dummy_task_with_mean_over_all_tasks.py │ └── region_based_evaluation.py ├── experiment_planning │ ├── __init__.py │ ├── change_batch_size.py │ ├── alternative_experiment_planning │ │ ├── experiment_planner_baseline_3DUNet_v23.py │ │ ├── normalization │ │ │ ├── experiment_planner_2DUNet_v21_RGB_scaleto_0_1.py │ │ │ ├── experiment_planner_3DUNet_nonCT.py │ │ │ └── experiment_planner_3DUNet_CT2.py │ │ ├── target_spacing │ │ │ ├── experiment_planner_baseline_3DUNet_v21_customTargetSpacing_2x2x2.py │ │ │ └── experiment_planner_baseline_3DUNet_targetSpacingForAnisoAxis.py │ │ ├── experiment_planner_baseline_3DUNet_v21_3convperstage.py │ │ └── experiment_planner_baseline_3DUNet_v22.py │ ├── nnFormer_convert_decathlon_task.py │ ├── summarize_plans.py │ └── experiment_planner_baseline_2DUNet_v21.py ├── configuration.py ├── postprocessing │ ├── consolidate_postprocessing_simple.py │ ├── consolidate_all_for_paper.py │ └── consolidate_postprocessing.py ├── paths.py ├── preprocessing │ └── custom_preprocessors │ │ └── preprocessor_scale_RGB_to_0_1.py ├── inference_acdc.py └── inference_tumor.py ├── metadata.json ├── patch.png ├── location.png ├── labels ├── 20230702185753_inklabels.png ├── 20230929220926_inklabels.png ├── 20231005123336_inklabels.png ├── 20231007101619_inklabels.png ├── 20231012184423_inklabels.png ├── 20231016151002_inklabels.png ├── 20231022170901_inklabels.png ├── 20231031143852_inklabels.png ├── 20231106155351_inklabels.png ├── 20231210121321_inklabels.png └── 20231221180251_inklabels.png ├── sample_num_layers.py ├── segment_utils.py ├── LICENSE ├── distribute_labels.py ├── simplify_tiles.py ├── padding_fixer.py ├── data_setup.py ├── custom_augmentations.py ├── coord_plotter.py ├── rotate_from_hough_lines.py ├── size_sorter.py ├── rotator_viewer.py ├── .gitignore ├── README.md ├── window_visualizer.py └── data_downloader.py /__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /unetr_pp/network_architecture/acdc/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /unetr_pp/network_architecture/lung/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /unetr_pp/network_architecture/tumor/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /unetr_pp/network_architecture/synapse/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /unetr_pp/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from . import * 3 | -------------------------------------------------------------------------------- /unetr_pp/run/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from . import * -------------------------------------------------------------------------------- /metadata.json: -------------------------------------------------------------------------------- 1 | { 2 | "version": "13", 3 | "type": "greybox_unetr_pp_new_aug" 4 | } -------------------------------------------------------------------------------- /unetr_pp/inference/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from . import * -------------------------------------------------------------------------------- /unetr_pp/utilities/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from . import * -------------------------------------------------------------------------------- /patch.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SQMah/Vesuvius-Grand-Prize-Submission/HEAD/patch.png -------------------------------------------------------------------------------- /unetr_pp/evaluation/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from . import * -------------------------------------------------------------------------------- /location.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SQMah/Vesuvius-Grand-Prize-Submission/HEAD/location.png -------------------------------------------------------------------------------- /unetr_pp/experiment_planning/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from . import * -------------------------------------------------------------------------------- /unetr_pp/network_architecture/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from . import * -------------------------------------------------------------------------------- /unetr_pp/evaluation/model_selection/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from . import * -------------------------------------------------------------------------------- /labels/20230702185753_inklabels.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SQMah/Vesuvius-Grand-Prize-Submission/HEAD/labels/20230702185753_inklabels.png -------------------------------------------------------------------------------- /labels/20230929220926_inklabels.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SQMah/Vesuvius-Grand-Prize-Submission/HEAD/labels/20230929220926_inklabels.png -------------------------------------------------------------------------------- /labels/20231005123336_inklabels.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SQMah/Vesuvius-Grand-Prize-Submission/HEAD/labels/20231005123336_inklabels.png -------------------------------------------------------------------------------- /labels/20231007101619_inklabels.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SQMah/Vesuvius-Grand-Prize-Submission/HEAD/labels/20231007101619_inklabels.png -------------------------------------------------------------------------------- /labels/20231012184423_inklabels.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SQMah/Vesuvius-Grand-Prize-Submission/HEAD/labels/20231012184423_inklabels.png -------------------------------------------------------------------------------- /labels/20231016151002_inklabels.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SQMah/Vesuvius-Grand-Prize-Submission/HEAD/labels/20231016151002_inklabels.png -------------------------------------------------------------------------------- /labels/20231022170901_inklabels.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SQMah/Vesuvius-Grand-Prize-Submission/HEAD/labels/20231022170901_inklabels.png -------------------------------------------------------------------------------- /labels/20231031143852_inklabels.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SQMah/Vesuvius-Grand-Prize-Submission/HEAD/labels/20231031143852_inklabels.png -------------------------------------------------------------------------------- /labels/20231106155351_inklabels.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SQMah/Vesuvius-Grand-Prize-Submission/HEAD/labels/20231106155351_inklabels.png -------------------------------------------------------------------------------- /labels/20231210121321_inklabels.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SQMah/Vesuvius-Grand-Prize-Submission/HEAD/labels/20231210121321_inklabels.png -------------------------------------------------------------------------------- /labels/20231221180251_inklabels.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SQMah/Vesuvius-Grand-Prize-Submission/HEAD/labels/20231221180251_inklabels.png -------------------------------------------------------------------------------- /unetr_pp/inference/inferTs/swin_nomask_2/plans.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SQMah/Vesuvius-Grand-Prize-Submission/HEAD/unetr_pp/inference/inferTs/swin_nomask_2/plans.pkl -------------------------------------------------------------------------------- /unetr_pp/evaluation/unetr_pp_acdc_checkpoint/unetr_pp/3d_fullres/Task001_ACDC/unetr_pp_trainer_acdc__unetr_pp_Plansv2.1/fold_0/.gitignore: -------------------------------------------------------------------------------- 1 | # Ignore everything in this directory 2 | * 3 | # Except this file 4 | !.gitignore 5 | -------------------------------------------------------------------------------- /unetr_pp/evaluation/unetr_pp_lung_checkpoint/unetr_pp/3d_fullres/Task006_Lung/unetr_pp_trainer_lung__unetr_pp_Plansv2.1/fold_0/.gitignore: -------------------------------------------------------------------------------- 1 | # Ignore everything in this directory 2 | * 3 | # Except this file 4 | !.gitignore 5 | -------------------------------------------------------------------------------- /unetr_pp/evaluation/unetr_pp_tumor_checkpoint/unetr_pp/3d_fullres/Task003_tumor/unetr_pp_trainer_tumor__unetr_pp_Plansv2.1/fold_0/.gitignore: -------------------------------------------------------------------------------- 1 | # Ignore everything in this directory 2 | * 3 | # Except this file 4 | !.gitignore 5 | -------------------------------------------------------------------------------- /unetr_pp/evaluation/unetr_pp_synapse_checkpoint/unetr_pp/3d_fullres/Task002_Synapse/unetr_pp_trainer_synapse__unetr_pp_Plansv2.1/fold_0/.gitignore: -------------------------------------------------------------------------------- 1 | # Ignore everything in this directory 2 | * 3 | # Except this file 4 | !.gitignore 5 | -------------------------------------------------------------------------------- /unetr_pp/network_architecture/README.md: -------------------------------------------------------------------------------- 1 | You can change batch size, input data size 2 | ``` 3 | https://github.com/282857341/nnFormer/blob/6e36d76f9b7d0bea522e1cd05adf502ba85480e6/nnformer/run/default_configuration.py#L49-L68 4 | ``` 5 | -------------------------------------------------------------------------------- /unetr_pp/configuration.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | default_num_threads = 8 if 'nnFormer_def_n_proc' not in os.environ else int(os.environ['nnFormer_def_n_proc']) 4 | RESAMPLING_SEPARATE_Z_ANISO_THRESHOLD = 3 # determines what threshold to use for resampling the low resolution axis 5 | # separately (with NN) -------------------------------------------------------------------------------- /unetr_pp/experiment_planning/change_batch_size.py: -------------------------------------------------------------------------------- 1 | from batchgenerators.utilities.file_and_folder_operations import * 2 | import numpy as np 3 | 4 | if __name__ == '__main__': 5 | input_file = '/home/xychen/new_transformer/nnFormerFrame/DATASET/nnFormer_preprocessed/Task008_Verse1/nnFormerPlansv2.1_plans_3D.pkl' 6 | output_file = '/home/xychen/new_transformer/nnFormerFrame/DATASET/nnFormer_preprocessed/Task008_Verse1/nnFormerPlansv2.1_plans_3D.pkl' 7 | a = load_pickle(input_file) 8 | #a['plans_per_stage'][0]['batch_size'] = int(np.floor(6 / 9 * a['plans_per_stage'][0]['batch_size'])) 9 | a['plans_per_stage'][0]['batch_size'] = 4 10 | save_pickle(a, output_file) -------------------------------------------------------------------------------- /unetr_pp/utilities/random_stuff.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | class no_op(object): 17 | def __enter__(self): 18 | pass 19 | 20 | def __exit__(self, *args): 21 | pass 22 | -------------------------------------------------------------------------------- /unetr_pp/utilities/nd_softmax.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import torch 16 | from torch import nn 17 | import torch.nn.functional as F 18 | 19 | 20 | softmax_helper = lambda x: F.softmax(x, 1) 21 | 22 | -------------------------------------------------------------------------------- /unetr_pp/utilities/sitk_stuff.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | import SimpleITK as sitk 17 | 18 | 19 | def copy_geometry(image: sitk.Image, ref: sitk.Image): 20 | image.SetOrigin(ref.GetOrigin()) 21 | image.SetDirection(ref.GetDirection()) 22 | image.SetSpacing(ref.GetSpacing()) 23 | return image 24 | -------------------------------------------------------------------------------- /sample_num_layers.py: -------------------------------------------------------------------------------- 1 | from utils import get_all_segment_paths 2 | import os 3 | from collections import defaultdict 4 | import matplotlib.pyplot as plt 5 | 6 | 7 | def get_num_layers_across_all_data(data_dir): 8 | layers = defaultdict(int) 9 | segments_paths = get_all_segment_paths(data_dir) 10 | for segment_path in segments_paths: 11 | segment_layer_path = os.path.join(segment_path, "layers") 12 | if os.path.isdir(segment_layer_path): 13 | print(f"Processing segment layer path {segment_layer_path}") 14 | for filename in os.listdir(segment_layer_path): 15 | if filename.endswith(".tif"): 16 | filename_without_ext = os.path.splitext(filename)[0] 17 | layers[int(filename_without_ext)] += 1 18 | 19 | # Plot layers as a bar chart 20 | plt.bar(layers.keys(), layers.values()) 21 | plt.show() 22 | 23 | # Write plot 24 | plt.savefig(os.path.join("./", "num_layers.png")) 25 | 26 | 27 | if __name__ == "__main__": 28 | get_num_layers_across_all_data("./data") 29 | -------------------------------------------------------------------------------- /segment_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | from functools import lru_cache 3 | from typing import Dict 4 | 5 | 6 | @lru_cache(maxsize=10) 7 | def get_segment_id_paths_dict(data_dir) -> Dict[str, str]: 8 | segment_id_path_keyed_by_segment_id = dict() 9 | for data_source in os.listdir(data_dir): 10 | data_source_path = os.path.join(data_dir, data_source) 11 | if os.path.isdir(data_source_path): 12 | for segment_id in os.listdir(data_source_path): 13 | segment_id_path = os.path.join(data_source_path, segment_id) 14 | if os.path.isdir(segment_id_path): 15 | if segment_id in segment_id_path_keyed_by_segment_id: 16 | raise Exception(f"{segment_id} was already previously found with path " 17 | f"{segment_id_path_keyed_by_segment_id[segment_id]}, " 18 | f"now trying to insert new path {segment_id_path}") 19 | segment_id_path_keyed_by_segment_id[segment_id] = segment_id_path 20 | return segment_id_path_keyed_by_segment_id 21 | -------------------------------------------------------------------------------- /unetr_pp/utilities/one_hot_encoding.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import numpy as np 16 | 17 | 18 | def to_one_hot(seg, all_seg_labels=None): 19 | if all_seg_labels is None: 20 | all_seg_labels = np.unique(seg) 21 | result = np.zeros((len(all_seg_labels), *seg.shape), dtype=seg.dtype) 22 | for i, l in enumerate(all_seg_labels): 23 | result[i][seg == l] = 1 24 | return result 25 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 Shao-Qian (SQ) Mah 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /unetr_pp/utilities/file_endings.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | from batchgenerators.utilities.file_and_folder_operations import * 17 | 18 | 19 | def remove_trailing_slash(filename: str): 20 | while filename.endswith('/'): 21 | filename = filename[:-1] 22 | return filename 23 | 24 | 25 | def maybe_add_0000_to_all_niigz(folder): 26 | nii_gz = subfiles(folder, suffix='.nii.gz') 27 | for n in nii_gz: 28 | n = remove_trailing_slash(n) 29 | if not n.endswith('_0000.nii.gz'): 30 | os.rename(n, n[:-7] + '_0000.nii.gz') 31 | -------------------------------------------------------------------------------- /unetr_pp/utilities/to_torch.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import torch 16 | 17 | 18 | def maybe_to_torch(d): 19 | if isinstance(d, list): 20 | d = [maybe_to_torch(i) if not isinstance(i, torch.Tensor) else i for i in d] 21 | elif not isinstance(d, torch.Tensor): 22 | d = torch.from_numpy(d).float() 23 | return d 24 | 25 | 26 | def to_cuda(data, non_blocking=True, gpu_id=0): 27 | if isinstance(data, list): 28 | data = [i.cuda(gpu_id, non_blocking=non_blocking) for i in data] 29 | else: 30 | data = data.cuda(gpu_id, non_blocking=non_blocking) 31 | return data 32 | -------------------------------------------------------------------------------- /distribute_labels.py: -------------------------------------------------------------------------------- 1 | import os 2 | from cfg import CFG 3 | from utils import make_symlink 4 | from segment_utils import get_segment_id_paths_dict 5 | 6 | label_path = CFG.processed_labels_dir 7 | train_val_segment_ids_set = set(CFG.train_val_segment_ids) 8 | data_dir = CFG.base_data_dir 9 | 10 | 11 | def distribute_labels(label_dir, train_val_segments_ids_set, segment_id_paths_dict): 12 | labels_segment_ids_added = set() 13 | for label in os.listdir(label_dir): 14 | if label.endswith(".png"): 15 | segment_label_path = os.path.join(label_dir, label) 16 | # Assume the filename scheme is {segment_id}_inklabels. 17 | label_segment_id = label.split("_")[0] 18 | labels_segment_ids_added.add(label_segment_id) 19 | new_segment_label_path = os.path.join(os.path.join(segment_id_paths_dict[label_segment_id], label)) 20 | make_symlink(segment_label_path, new_segment_label_path) 21 | diff = train_val_segments_ids_set - labels_segment_ids_added 22 | if diff: 23 | raise Exception(f"Expected all train/val segment ids to have corresponding ink labels, but segments " 24 | f"{diff} do not.") 25 | 26 | 27 | if __name__ == "__main__": 28 | distribute_labels(label_path, train_val_segment_ids_set, get_segment_id_paths_dict(data_dir)) 29 | -------------------------------------------------------------------------------- /unetr_pp/experiment_planning/alternative_experiment_planning/experiment_planner_baseline_3DUNet_v23.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from unetr_pp.experiment_planning.experiment_planner_baseline_3DUNet_v21 import \ 16 | ExperimentPlanner3D_v21 17 | from unetr_pp.paths import * 18 | 19 | 20 | class ExperimentPlanner3D_v23(ExperimentPlanner3D_v21): 21 | """ 22 | """ 23 | def __init__(self, folder_with_cropped_data, preprocessed_output_folder): 24 | super(ExperimentPlanner3D_v23, self).__init__(folder_with_cropped_data, preprocessed_output_folder) 25 | self.data_identifier = "nnFormerData_plans_v2.3" 26 | self.plans_fname = join(self.preprocessed_output_folder, 27 | "nnFormerPlansv2.3_plans_3D.pkl") 28 | self.preprocessor_name = "Preprocessor3DDifferentResampling" 29 | -------------------------------------------------------------------------------- /simplify_tiles.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | import matplotlib.pyplot as plt 4 | 5 | 6 | def process_image(image_path, threshold_percent): 7 | # Load the image in grayscale (0 - black, 255 - white) 8 | img = cv2.imread(image_path, cv2.IMREAD_UNCHANGED) 9 | print(img.dtype) 10 | out = np.zeros(img.shape, dtype=np.uint8) 11 | 12 | # Window size and stride 13 | window_size = 64 14 | stride = 64 15 | 16 | # Iterate over the image 17 | for y in range(0, img.shape[0], stride): 18 | for x in range(0, img.shape[1], stride): 19 | window = img[y:y + window_size, x:x + window_size] 20 | 21 | # Check if the window is complete 22 | if window.shape[0] == window_size and window.shape[1] == window_size: 23 | # Calculate the percentage of white pixels 24 | white_pixels = np.sum(window >= 100) 25 | total_pixels = window_size * window_size 26 | white_percent = (white_pixels / total_pixels) * 100 27 | 28 | # If the percentage of white exceeds the threshold, set the window to white 29 | if white_percent >= threshold_percent: 30 | out[y:y + window_size, x:x + window_size] = 255 31 | 32 | return out 33 | 34 | 35 | if __name__ == '__main__': 36 | # Replace 'path_to_your_image.jpg' with your image path 37 | processed_img = process_image('data/scroll1_hari/20230827161847/layers/30.tif', 38 | threshold_percent=50) # 50% threshold 39 | 40 | cv2.imshow("processed image", processed_img) 41 | cv2.waitKey(0) 42 | cv2.destroyAllWindows() 43 | -------------------------------------------------------------------------------- /unetr_pp/utilities/recursive_delete_npz.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | from batchgenerators.utilities.file_and_folder_operations import * 17 | import argparse 18 | import os 19 | 20 | 21 | def recursive_delete_npz(current_directory: str): 22 | npz_files = subfiles(current_directory, join=True, suffix=".npz") 23 | npz_files = [i for i in npz_files if not i.endswith("segFromPrevStage.npz")] # to be extra safe 24 | _ = [os.remove(i) for i in npz_files] 25 | for d in subdirs(current_directory, join=False): 26 | if d != "pred_next_stage": 27 | recursive_delete_npz(join(current_directory, d)) 28 | 29 | 30 | if __name__ == "__main__": 31 | parser = argparse.ArgumentParser(usage="USE THIS RESPONSIBLY! DANGEROUS! I (Fabian) use this to remove npz files " 32 | "after I ran figure_out_what_to_submit") 33 | parser.add_argument("-f", help="folder", required=True) 34 | 35 | args = parser.parse_args() 36 | 37 | recursive_delete_npz(args.f) 38 | -------------------------------------------------------------------------------- /padding_fixer.py: -------------------------------------------------------------------------------- 1 | import os 2 | import cv2 3 | from segment_utils import get_segment_id_paths_dict 4 | from cfg import CFG 5 | 6 | data_path = "data/scroll1_hari" 7 | label_path = "labels" 8 | resized_label_path = "../label_test" 9 | 10 | """ 11 | We have to resize padding before we rotate it, hence it's important that we use the base data directory as reference, 12 | not the rotate data directory. 13 | """ 14 | 15 | 16 | def add_or_remove_padding_ink_label_padding(segment_id_paths_dict, label_path, resized_label_path): 17 | for filename in os.listdir(label_path): 18 | if filename.endswith(".png"): 19 | print(f"Resizing {filename}") 20 | label_segment_path = os.path.join(label_path, filename) 21 | segment_id = os.path.splitext(filename)[0].split("_")[0] 22 | if segment_id not in segment_id_paths_dict: 23 | print(f"{segment_id} does not exist in data dir! Skipping") 24 | continue 25 | data_segment_tif = os.path.join(segment_id_paths_dict[segment_id], f"{segment_id}.tif") 26 | h, w = cv2.imread(data_segment_tif).shape[:2] 27 | label_data = cv2.imread(label_segment_path, 0) 28 | orig_h, orig_w = label_data.shape[:2] 29 | resized_label_data = label_data[0:h, 0:w] 30 | cv2.imwrite(os.path.join(resized_label_path, filename), resized_label_data) 31 | print(f"[Success] Resizing filename {filename}. Original h {orig_h}, w {orig_w}. Final h {h}, w {w}.") 32 | 33 | 34 | if __name__ == "__main__": 35 | add_or_remove_padding_ink_label_padding(segment_id_paths_dict=get_segment_id_paths_dict(CFG.base_data_dir), 36 | label_path=label_path, resized_label_path=resized_label_path) 37 | -------------------------------------------------------------------------------- /data_setup.py: -------------------------------------------------------------------------------- 1 | import os 2 | from cfg import CFG 3 | from utils import make_symlink 4 | from segment_utils import get_segment_id_paths_dict 5 | import shutil 6 | 7 | data_path = CFG.base_data_dir 8 | train_val_data_path = CFG.train_val_dir 9 | test_data_path = CFG.test_data_dir 10 | label_path = CFG.processed_labels_dir 11 | 12 | 13 | def create_links_from_list(base_dir, segment_ids, segment_id_paths): 14 | for segment_id in segment_ids: 15 | if segment_id not in segment_id_paths: 16 | print(f"{segment_id} does not exist in segment id paths in data path.") 17 | continue 18 | new_path = os.path.join(base_dir, segment_id) 19 | prev_path = os.path.join(os.getcwd(), segment_id_paths[segment_id]) 20 | ink_label_name = f"{segment_id}_inklabels.png" 21 | new_ink_path = os.path.join(prev_path, ink_label_name) 22 | if os.path.exists(new_path): 23 | os.remove(new_path) 24 | # if os.path.exists(new_ink_path): 25 | # os.remove(new_ink_path) 26 | make_symlink(prev_path, os.path.join(os.getcwd(), new_path)) 27 | try: 28 | make_symlink(os.path.join(os.getcwd(), os.path.join(label_path, ink_label_name)), new_ink_path) 29 | except FileExistsError: 30 | print(f"Skipping {segment_id} ink label because already created.") 31 | 32 | 33 | def setup_data(segment_id_paths_dict): 34 | create_links_from_list(train_val_data_path, CFG.train_segment_ids + CFG.val_segment_ids, segment_id_paths_dict) 35 | create_links_from_list(test_data_path, CFG.test_segment_ids, segment_id_paths_dict) 36 | 37 | 38 | if __name__ == "__main__": 39 | # shutil.rmtree(os.path.join(data_path, "scroll1", "20230702185753")) 40 | segment_id_paths_dict = get_segment_id_paths_dict(data_path) 41 | setup_data(segment_id_paths_dict) 42 | -------------------------------------------------------------------------------- /unetr_pp/network_architecture/initialization.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | from torch import nn 17 | 18 | 19 | class InitWeights_He(object): 20 | def __init__(self, neg_slope=1e-2): 21 | self.neg_slope = neg_slope 22 | 23 | def __call__(self, module): 24 | if isinstance(module, nn.Conv3d) or isinstance(module, nn.Conv2d) or isinstance(module, nn.ConvTranspose2d) or isinstance(module, nn.ConvTranspose3d): 25 | module.weight = nn.init.kaiming_normal_(module.weight, a=self.neg_slope) 26 | if module.bias is not None: 27 | module.bias = nn.init.constant_(module.bias, 0) 28 | 29 | 30 | class InitWeights_XavierUniform(object): 31 | def __init__(self, gain=1): 32 | self.gain = gain 33 | 34 | def __call__(self, module): 35 | if isinstance(module, nn.Conv3d) or isinstance(module, nn.Conv2d) or isinstance(module, nn.ConvTranspose2d) or isinstance(module, nn.ConvTranspose3d): 36 | module.weight = nn.init.xavier_uniform_(module.weight, self.gain) 37 | if module.bias is not None: 38 | module.bias = nn.init.constant_(module.bias, 0) 39 | -------------------------------------------------------------------------------- /unetr_pp/experiment_planning/alternative_experiment_planning/normalization/experiment_planner_2DUNet_v21_RGB_scaleto_0_1.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | from unetr_pp.experiment_planning.experiment_planner_baseline_2DUNet_v21 import ExperimentPlanner2D_v21 17 | from unetr_pp.paths import * 18 | 19 | 20 | class ExperimentPlanner2D_v21_RGB_scaleTo_0_1(ExperimentPlanner2D_v21): 21 | """ 22 | used by tutorial unetr_pp.tutorials.custom_preprocessing 23 | """ 24 | def __init__(self, folder_with_cropped_data, preprocessed_output_folder): 25 | super().__init__(folder_with_cropped_data, preprocessed_output_folder) 26 | self.data_identifier = "nnFormer_RGB_scaleTo_0_1" 27 | self.plans_fname = join(self.preprocessed_output_folder, "nnFormer_RGB_scaleTo_0_1" + "_plans_2D.pkl") 28 | 29 | # The custom preprocessor class we intend to use is GenericPreprocessor_scale_uint8_to_0_1. It must be located 30 | # in unetr_pp.preprocessing (any file and submodule) and will be found by its name. Make sure to always define 31 | # unique names! 32 | self.preprocessor_name = 'GenericPreprocessor_scale_uint8_to_0_1' 33 | -------------------------------------------------------------------------------- /unetr_pp/utilities/tensor_utilities.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import numpy as np 16 | import torch 17 | from torch import nn 18 | 19 | 20 | def sum_tensor(inp, axes, keepdim=False): 21 | axes = np.unique(axes).astype(int) 22 | if keepdim: 23 | for ax in axes: 24 | inp = inp.sum(int(ax), keepdim=True) 25 | else: 26 | for ax in sorted(axes, reverse=True): 27 | inp = inp.sum(int(ax)) 28 | return inp 29 | 30 | 31 | def mean_tensor(inp, axes, keepdim=False): 32 | axes = np.unique(axes).astype(int) 33 | if keepdim: 34 | for ax in axes: 35 | inp = inp.mean(int(ax), keepdim=True) 36 | else: 37 | for ax in sorted(axes, reverse=True): 38 | inp = inp.mean(int(ax)) 39 | return inp 40 | 41 | 42 | def flip(x, dim): 43 | """ 44 | flips the tensor at dimension dim (mirroring!) 45 | :param x: 46 | :param dim: 47 | :return: 48 | """ 49 | indices = [slice(None)] * x.dim() 50 | indices[dim] = torch.arange(x.size(dim) - 1, -1, -1, 51 | dtype=torch.long, device=x.device) 52 | return x[tuple(indices)] 53 | 54 | 55 | -------------------------------------------------------------------------------- /unetr_pp/utilities/recursive_rename_taskXX_to_taskXXX.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | from batchgenerators.utilities.file_and_folder_operations import * 17 | import os 18 | 19 | 20 | def recursive_rename(folder): 21 | s = subdirs(folder, join=False) 22 | for ss in s: 23 | if ss.startswith("Task") and ss.find("_") == 6: 24 | task_id = int(ss[4:6]) 25 | name = ss[7:] 26 | os.rename(join(folder, ss), join(folder, "Task%03.0d_" % task_id + name)) 27 | s = subdirs(folder, join=True) 28 | for ss in s: 29 | recursive_rename(ss) 30 | 31 | if __name__ == "__main__": 32 | recursive_rename("/media/fabian/Results/nnFormer") 33 | recursive_rename("/media/fabian/unetr_pp") 34 | recursive_rename("/media/fabian/My Book/MedicalDecathlon") 35 | recursive_rename("/home/fabian/drives/datasets/nnFormer_raw") 36 | recursive_rename("/home/fabian/drives/datasets/nnFormer_preprocessed") 37 | recursive_rename("/home/fabian/drives/datasets/nnFormer_testSets") 38 | recursive_rename("/home/fabian/drives/datasets/results/nnFormer") 39 | recursive_rename("/home/fabian/drives/e230-dgx2-1-data_fabian/Decathlon_raw") 40 | recursive_rename("/home/fabian/drives/e230-dgx2-1-data_fabian/nnFormer_preprocessed") 41 | 42 | -------------------------------------------------------------------------------- /unetr_pp/utilities/folder_names.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | from batchgenerators.utilities.file_and_folder_operations import * 17 | from unetr_pp.paths import network_training_output_dir 18 | 19 | 20 | def get_output_folder_name(model: str, task: str = None, trainer: str = None, plans: str = None, fold: int = None, 21 | overwrite_training_output_dir: str = None): 22 | """ 23 | Retrieves the correct output directory for the nnU-Net model described by the input parameters 24 | 25 | :param model: 26 | :param task: 27 | :param trainer: 28 | :param plans: 29 | :param fold: 30 | :param overwrite_training_output_dir: 31 | :return: 32 | """ 33 | assert model in ["2d", "3d_cascade_fullres", '3d_fullres', '3d_lowres'] 34 | 35 | if overwrite_training_output_dir is not None: 36 | tr_dir = overwrite_training_output_dir 37 | else: 38 | tr_dir = network_training_output_dir 39 | 40 | current = join(tr_dir, model) 41 | if task is not None: 42 | current = join(current, task) 43 | if trainer is not None and plans is not None: 44 | current = join(current, trainer + "__" + plans) 45 | if fold is not None: 46 | current = join(current, "fold_%d" % fold) 47 | return current 48 | -------------------------------------------------------------------------------- /unetr_pp/experiment_planning/alternative_experiment_planning/normalization/experiment_planner_3DUNet_nonCT.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | from collections import OrderedDict 17 | 18 | from unetr_pp.experiment_planning.experiment_planner_baseline_3DUNet import ExperimentPlanner 19 | from unetr_pp.paths import * 20 | 21 | 22 | class ExperimentPlannernonCT(ExperimentPlanner): 23 | """ 24 | Preprocesses all data in nonCT mode (this is what we use for MRI per default, but here it is applied to CT images 25 | as well) 26 | """ 27 | def __init__(self, folder_with_cropped_data, preprocessed_output_folder): 28 | super(ExperimentPlannernonCT, self).__init__(folder_with_cropped_data, preprocessed_output_folder) 29 | self.data_identifier = "nnFormer_nonCT" 30 | self.plans_fname = join(self.preprocessed_output_folder, "nnFormerPlans" + "nonCT_plans_3D.pkl") 31 | 32 | def determine_normalization_scheme(self): 33 | schemes = OrderedDict() 34 | modalities = self.dataset_properties['modalities'] 35 | num_modalities = len(list(modalities.keys())) 36 | 37 | for i in range(num_modalities): 38 | if modalities[i] == "CT": 39 | schemes[i] = "nonCT" 40 | else: 41 | schemes[i] = "nonCT" 42 | return schemes 43 | 44 | -------------------------------------------------------------------------------- /unetr_pp/experiment_planning/alternative_experiment_planning/target_spacing/experiment_planner_baseline_3DUNet_v21_customTargetSpacing_2x2x2.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import numpy as np 16 | from unetr_pp.experiment_planning.experiment_planner_baseline_3DUNet_v21 import ExperimentPlanner3D_v21 17 | from unetr_pp.paths import * 18 | 19 | 20 | class ExperimentPlanner3D_v21_customTargetSpacing_2x2x2(ExperimentPlanner3D_v21): 21 | def __init__(self, folder_with_cropped_data, preprocessed_output_folder): 22 | super(ExperimentPlanner3D_v21, self).__init__(folder_with_cropped_data, preprocessed_output_folder) 23 | # we change the data identifier and plans_fname. This will make this experiment planner save the preprocessed 24 | # data in a different folder so that they can co-exist with the default (ExperimentPlanner3D_v21). We also 25 | # create a custom plans file that will be linked to this data 26 | self.data_identifier = "nnFormerData_plans_v2.1_trgSp_2x2x2" 27 | self.plans_fname = join(self.preprocessed_output_folder, 28 | "nnFormerPlansv2.1_trgSp_2x2x2_plans_3D.pkl") 29 | 30 | def get_target_spacing(self): 31 | # simply return the desired spacing as np.array 32 | return np.array([2., 2., 2.]) # make sure this is float!!!! Not int! 33 | 34 | -------------------------------------------------------------------------------- /custom_augmentations.py: -------------------------------------------------------------------------------- 1 | from albumentations.core.transforms_interface import ImageOnlyTransform 2 | import numpy as np 3 | import random 4 | 5 | 6 | class ChannelInvert(ImageOnlyTransform): 7 | """Reverse channels of an input HWC image. 8 | 9 | Args: 10 | p (float): probability of applying the transform. Default: 0.5. 11 | 12 | Targets: 13 | image 14 | 15 | Image types: 16 | uint8, float32 17 | """ 18 | 19 | def apply(self, img, **params): 20 | # Assuming img is an HWC image. 21 | img = img[..., ::-1] 22 | return img 23 | 24 | def get_transform_init_args_names(self): 25 | # This function returns a tuple of parameter names that will be 26 | # used to initialize the transform object, if necessary. 27 | return () 28 | 29 | 30 | class FourthAugment(ImageOnlyTransform): 31 | """Custom transformation that shuffles channels in the input image.""" 32 | 33 | def __init__(self, always_apply=False, p=0.5): 34 | super(FourthAugment, self).__init__(always_apply, p) 35 | 36 | def apply(self, img, **params): 37 | # Assuming img is an HWC image. 38 | in_chans = img.shape[-1] 39 | image_tmp = np.zeros_like(img) 40 | cropping_num = random.randint(12, 16) 41 | 42 | start_idx = random.randint(0, in_chans - cropping_num) 43 | crop_indices = np.arange(start_idx, start_idx + cropping_num) 44 | 45 | start_paste_idx = random.randint(0, in_chans - cropping_num) 46 | 47 | tmp = np.arange(start_paste_idx, start_paste_idx + cropping_num) 48 | np.random.shuffle(tmp) 49 | 50 | cutout_idx = random.randint(0, 2) 51 | temporal_random_cutout_idx = tmp[:cutout_idx] 52 | 53 | image_tmp[..., start_paste_idx:start_paste_idx + cropping_num] = img[..., crop_indices] 54 | 55 | if random.random() > 0.4: 56 | image_tmp[..., temporal_random_cutout_idx] = 0 57 | return image_tmp 58 | 59 | def get_transform_init_args_names(self): 60 | return () 61 | -------------------------------------------------------------------------------- /unetr_pp/evaluation/collect_results_files.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import os 16 | import shutil 17 | from batchgenerators.utilities.file_and_folder_operations import subdirs, subfiles 18 | 19 | 20 | def crawl_and_copy(current_folder, out_folder, prefix="fabian_", suffix="ummary.json"): 21 | """ 22 | This script will run recursively through all subfolders of current_folder and copy all files that end with 23 | suffix with some automatically generated prefix into out_folder 24 | :param current_folder: 25 | :param out_folder: 26 | :param prefix: 27 | :return: 28 | """ 29 | s = subdirs(current_folder, join=False) 30 | f = subfiles(current_folder, join=False) 31 | f = [i for i in f if i.endswith(suffix)] 32 | if current_folder.find("fold0") != -1: 33 | for fl in f: 34 | shutil.copy(os.path.join(current_folder, fl), os.path.join(out_folder, prefix+fl)) 35 | for su in s: 36 | if prefix == "": 37 | add = su 38 | else: 39 | add = "__" + su 40 | crawl_and_copy(os.path.join(current_folder, su), out_folder, prefix=prefix+add) 41 | 42 | 43 | if __name__ == "__main__": 44 | from unetr_pp.paths import network_training_output_dir 45 | output_folder = "/home/fabian/PhD/results/nnFormerV2/leaderboard" 46 | crawl_and_copy(network_training_output_dir, output_folder) 47 | from unetr_pp.evaluation.add_mean_dice_to_json import run_in_folder 48 | run_in_folder(output_folder) 49 | -------------------------------------------------------------------------------- /unetr_pp/experiment_planning/alternative_experiment_planning/experiment_planner_baseline_3DUNet_v21_3convperstage.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from copy import deepcopy 16 | 17 | import numpy as np 18 | from unetr_pp.experiment_planning.common_utils import get_pool_and_conv_props 19 | from unetr_pp.experiment_planning.experiment_planner_baseline_3DUNet import ExperimentPlanner 20 | from unetr_pp.experiment_planning.experiment_planner_baseline_3DUNet_v21 import ExperimentPlanner3D_v21 21 | from unetr_pp.network_architecture.generic_UNet import Generic_UNet 22 | from unetr_pp.paths import * 23 | 24 | 25 | class ExperimentPlanner3D_v21_3cps(ExperimentPlanner3D_v21): 26 | """ 27 | have 3x conv-in-lrelu per resolution instead of 2 while remaining in the same memory budget 28 | 29 | This only works with 3d fullres because we use the same data as ExperimentPlanner3D_v21. Lowres would require to 30 | rerun preprocesing (different patch size = different 3d lowres target spacing) 31 | """ 32 | def __init__(self, folder_with_cropped_data, preprocessed_output_folder): 33 | super(ExperimentPlanner3D_v21_3cps, self).__init__(folder_with_cropped_data, preprocessed_output_folder) 34 | self.plans_fname = join(self.preprocessed_output_folder, 35 | "nnFormerPlansv2.1_3cps_plans_3D.pkl") 36 | self.unet_base_num_features = 32 37 | self.conv_per_stage = 3 38 | 39 | def run_preprocessing(self, num_threads): 40 | pass 41 | -------------------------------------------------------------------------------- /coord_plotter.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | 3 | if __name__ == "__main__": 4 | 5 | img_path = "./test.png" 6 | img = cv2.imread(img_path) 7 | 8 | annotations = { 9 | 'Alpha': { 10 | 'boxes': [((200, 60), (240, 80)), ((1080, 130), (1120, 150)), ((70, 290), (100, 310)), 11 | ((2390, 80), (2420, 100))] 12 | }, 13 | 'Beta': { 14 | 'boxes': [((390, 140), (420, 160)), ((1700, 190), (1730, 210))] 15 | }, 16 | 'Gamma': { 17 | 'boxes': [((620, 80), (660, 100)), ((120, 130), (160, 150)), ((2590, 90), (2620, 110))] 18 | }, 19 | 'Delta': { 20 | 'boxes': [((800, 250), (830, 270)), ((1450, 70), (1480, 90))] 21 | }, 22 | 'Epsilon': { 23 | 'boxes': [((980, 55), (1010, 75)), ((2100, 320), (2130, 340))] 24 | }, 25 | 'Zeta': { 26 | 'boxes': [((1600, 155), (1630, 175))] 27 | }, 28 | 'Eta': { 29 | 'boxes': [((1150, 95), (1180, 115)), ((2500, 200), (2530, 220))] 30 | }, 31 | 'Theta': { 32 | 'boxes': [((540, 45), (570, 65)), ((2000, 270), (2030, 290))] 33 | }, 34 | 'Iota': { 35 | 'boxes': [((920, 100), (950, 120)), ((2300, 140), (2330, 160))] 36 | }, 37 | 'Kappa': { 38 | 'boxes': [((300, 290), (330, 310)), ((2650, 300), (2680, 320))] 39 | }, 40 | 'Lambda': { 41 | 'boxes': [((850, 100), (880, 120)), ((1900, 200), (1930, 220))] 42 | }, 43 | # ... add more annotations as needed 44 | } 45 | 46 | # For each annotation, draw the bounding boxes and the English letter name 47 | for letter, data in annotations.items(): 48 | for box in data['boxes']: 49 | cv2.rectangle(img, box[0], box[1], (0, 255, 0), 2) # Draw rectangle 50 | # Display the English name slightly above the bounding box 51 | cv2.putText(img, letter, (box[0][0], box[0][1] - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 255), 2, 52 | cv2.LINE_AA) 53 | 54 | # Display the annotated image 55 | cv2.imshow('Annotated Image', img) 56 | cv2.waitKey(0) 57 | cv2.destroyAllWindows() 58 | -------------------------------------------------------------------------------- /rotate_from_hough_lines.py: -------------------------------------------------------------------------------- 1 | import random 2 | 3 | import cv2 4 | import numpy as np 5 | 6 | 7 | def find_rotation_theta(img_data, show=False): 8 | image_lines = np.copy(img_data) 9 | 10 | h, w = img_data.shape[:2] 11 | 12 | # Preprocess the image 13 | gray = cv2.cvtColor(img_data, cv2.COLOR_BGR2GRAY) 14 | blurred = cv2.GaussianBlur(gray, (5, 5), 0) 15 | 16 | # Edge detection using Canny 17 | outside_edge = cv2.Canny(blurred, 320, 350) 18 | internal_edges = cv2.Canny(blurred, 150, 250) - outside_edge 19 | 20 | # Apply Hough Line Transformation 21 | lines = cv2.HoughLines(internal_edges, 1, np.pi / 180, int(0.0163 * max(h, w))) 22 | rhos = [] 23 | thetas = [] 24 | 25 | # Draw the detected lines on the original image 26 | if lines is not None: 27 | for rho, theta in lines[:, 0]: 28 | rhos.append(rho) 29 | thetas.append(theta) 30 | a = np.cos(theta) 31 | b = np.sin(theta) 32 | x0 = a * rho 33 | y0 = b * rho 34 | x1 = int(x0 + 20000 * (-b)) 35 | y1 = int(y0 + 20000 * a) 36 | x2 = int(x0 - 20000 * (-b)) 37 | y2 = int(y0 - 20000 * a) 38 | cv2.line(image_lines, (x1, y1), (x2, y2), (0, 0, 255), 4) 39 | 40 | # Sort lines into two groups based on theta 41 | # Use the median of each group for further calculations 42 | median_theta_group_1 = random.choice(thetas) 43 | group1, group2 = [], [] 44 | for theta in thetas: 45 | if np.abs(theta - median_theta_group_1) < np.pi / 2: 46 | group1.append(theta) 47 | median_theta_group_1 = np.median(group1) 48 | else: 49 | group2.append(theta) 50 | 51 | # Calculate the median of each group 52 | median_group1_theta = np.median([group1]) if len(group1) > 0 else None 53 | median_group2_theta = np.median([group2]) if len(group2) > 0 else None 54 | 55 | if show: 56 | cv2.imshow('Edges', internal_edges) 57 | cv2.imshow('Hough Line Transform', image_lines) 58 | cv2.waitKey(0) 59 | cv2.destroyAllWindows() 60 | 61 | return (median_group1_theta if len(group1) > len(group2) else median_group2_theta) - np.pi / 2 62 | -------------------------------------------------------------------------------- /unetr_pp/evaluation/add_mean_dice_to_json.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import json 16 | import numpy as np 17 | from batchgenerators.utilities.file_and_folder_operations import subfiles 18 | from collections import OrderedDict 19 | 20 | 21 | def foreground_mean(filename): 22 | with open(filename, 'r') as f: 23 | res = json.load(f) 24 | class_ids = np.array([int(i) for i in res['results']['mean'].keys() if (i != 'mean')]) 25 | class_ids = class_ids[class_ids != 0] 26 | class_ids = class_ids[class_ids != -1] 27 | class_ids = class_ids[class_ids != 99] 28 | 29 | tmp = res['results']['mean'].get('99') 30 | if tmp is not None: 31 | _ = res['results']['mean'].pop('99') 32 | 33 | metrics = res['results']['mean']['1'].keys() 34 | res['results']['mean']["mean"] = OrderedDict() 35 | for m in metrics: 36 | foreground_values = [res['results']['mean'][str(i)][m] for i in class_ids] 37 | res['results']['mean']["mean"][m] = np.nanmean(foreground_values) 38 | with open(filename, 'w') as f: 39 | json.dump(res, f, indent=4, sort_keys=True) 40 | 41 | 42 | def run_in_folder(folder): 43 | json_files = subfiles(folder, True, None, ".json", True) 44 | json_files = [i for i in json_files if not i.split("/")[-1].startswith(".") and not i.endswith("_globalMean.json")] # stupid mac 45 | for j in json_files: 46 | foreground_mean(j) 47 | 48 | 49 | if __name__ == "__main__": 50 | folder = "/media/fabian/Results/nnFormerOutput_final/summary_jsons" 51 | run_in_folder(folder) 52 | -------------------------------------------------------------------------------- /unetr_pp/experiment_planning/alternative_experiment_planning/normalization/experiment_planner_3DUNet_CT2.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | from collections import OrderedDict 17 | 18 | from unetr_pp.experiment_planning.experiment_planner_baseline_3DUNet import ExperimentPlanner 19 | from unetr_pp.paths import * 20 | 21 | 22 | class ExperimentPlannerCT2(ExperimentPlanner): 23 | """ 24 | preprocesses CT data with the "CT2" normalization. 25 | 26 | (clip range comes from training set and is the 0.5 and 99.5 percentile of intensities in foreground) 27 | CT = clip to range, then normalize with global mn and sd (computed on foreground in training set) 28 | CT2 = clip to range, normalize each case separately with its own mn and std (computed within the area that was in clip_range) 29 | """ 30 | def __init__(self, folder_with_cropped_data, preprocessed_output_folder): 31 | super(ExperimentPlannerCT2, self).__init__(folder_with_cropped_data, preprocessed_output_folder) 32 | self.data_identifier = "nnFormer_CT2" 33 | self.plans_fname = join(self.preprocessed_output_folder, "nnFormerPlans" + "CT2_plans_3D.pkl") 34 | 35 | def determine_normalization_scheme(self): 36 | schemes = OrderedDict() 37 | modalities = self.dataset_properties['modalities'] 38 | num_modalities = len(list(modalities.keys())) 39 | 40 | for i in range(num_modalities): 41 | if modalities[i] == "CT": 42 | schemes[i] = "CT2" 43 | else: 44 | schemes[i] = "nonCT" 45 | return schemes 46 | -------------------------------------------------------------------------------- /size_sorter.py: -------------------------------------------------------------------------------- 1 | import os 2 | import cv2 3 | from segment_utils import get_segment_id_paths_dict 4 | from concurrent.futures import ThreadPoolExecutor 5 | import concurrent 6 | 7 | num_workers = 64 8 | 9 | 10 | def get_segment_size(segment_dir): 11 | print(f"Processing {segment_dir}") 12 | mask_name_list = list(filter(lambda x: "_mask" in x, os.listdir(segment_dir))) 13 | if not mask_name_list: 14 | print(f"Skipping {segment_dir} because no mask was found.") 15 | return -1 16 | mask_name = mask_name_list[0] 17 | mask_path = os.path.join(segment_dir, mask_name) 18 | img = cv2.imread(mask_path) 19 | size = 0 20 | gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) 21 | 22 | # Threshold the image to keep only white regions 23 | _, thresh = cv2.threshold(gray, 250, 255, cv2.THRESH_BINARY) # 250 is an arbitrary threshold, tweak it if necessary 24 | 25 | # Find contours 26 | contours, _ = cv2.findContours(thresh, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) 27 | 28 | # Find the size of the white contour(s) 29 | for contour in contours: 30 | size += cv2.contourArea(contour) 31 | 32 | return size 33 | 34 | 35 | def find_all_contour_sizes(data_dir): 36 | sizes = {} 37 | segment_paths_dict = get_segment_id_paths_dict(data_dir) 38 | with ThreadPoolExecutor(max_workers=num_workers) as executor: 39 | future_size_dict = {executor.submit(get_segment_size, segment_paths_dict[segment_id]): segment_id for segment_id 40 | in segment_paths_dict} 41 | for future in concurrent.futures.as_completed(future_size_dict): 42 | segment_id = future_size_dict[future] 43 | size = future.result() 44 | sizes[segment_id] = size 45 | sorted_sizes = dict(sorted(sizes.items(), key=lambda item: item[1], reverse=True)) 46 | return sorted_sizes 47 | 48 | 49 | def keep_last_in_sequence(d): 50 | sorted_keys = sorted(k for k in d.keys() if isinstance(k, int)) 51 | keys_to_remove = set() 52 | 53 | prev_key = None 54 | for key in sorted_keys: 55 | if prev_key is not None and key == prev_key + 1: 56 | keys_to_remove.add(prev_key) 57 | prev_key = key 58 | 59 | print(f"Keys to remove: {keys_to_remove}") 60 | for key in keys_to_remove: 61 | del d[key] 62 | 63 | return d 64 | 65 | 66 | if __name__ == "__main__": 67 | val = keep_last_in_sequence(find_all_contour_sizes('./data')) 68 | print(val) 69 | print(list(val.keys())) 70 | -------------------------------------------------------------------------------- /unetr_pp/network_architecture/layers.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from torch import nn 4 | import math 5 | 6 | 7 | class LayerNorm(nn.Module): 8 | def __init__(self, normalized_shape, eps=1e-6, data_format="channels_last"): 9 | super().__init__() 10 | self.weight = nn.Parameter(torch.ones(normalized_shape)) 11 | self.bias = nn.Parameter(torch.zeros(normalized_shape)) 12 | self.eps = eps 13 | self.data_format = data_format 14 | if self.data_format not in ["channels_last", "channels_first"]: 15 | raise NotImplementedError 16 | self.normalized_shape = (normalized_shape,) 17 | 18 | def forward(self, x): 19 | if self.data_format == "channels_last": 20 | return F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps) 21 | elif self.data_format == "channels_first": 22 | u = x.mean(1, keepdim=True) 23 | s = (x - u).pow(2).mean(1, keepdim=True) 24 | x = (x - u) / torch.sqrt(s + self.eps) 25 | x = self.weight[:, None, None] * x + self.bias[:, None, None] 26 | return x 27 | 28 | 29 | class PositionalEncodingFourier(nn.Module): 30 | def __init__(self, hidden_dim=32, dim=768, temperature=10000): 31 | super().__init__() 32 | self.token_projection = nn.Conv2d(hidden_dim * 2, dim, kernel_size=1) 33 | self.scale = 2 * math.pi 34 | self.temperature = temperature 35 | self.hidden_dim = hidden_dim 36 | self.dim = dim 37 | 38 | def forward(self, B, H, W): 39 | mask = torch.zeros(B, H, W).bool().to(self.token_projection.weight.device) 40 | not_mask = ~mask 41 | y_embed = not_mask.cumsum(1, dtype=torch.float32) 42 | x_embed = not_mask.cumsum(2, dtype=torch.float32) 43 | eps = 1e-6 44 | y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale 45 | x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale 46 | 47 | dim_t = torch.arange(self.hidden_dim, dtype=torch.float32, device=mask.device) 48 | dim_t = self.temperature ** (2 * (dim_t // 2) / self.hidden_dim) 49 | 50 | pos_x = x_embed[:, :, :, None] / dim_t 51 | pos_y = y_embed[:, :, :, None] / dim_t 52 | pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), 53 | pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3) 54 | pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), 55 | pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3) 56 | pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2) 57 | pos = self.token_projection(pos) 58 | 59 | return pos 60 | -------------------------------------------------------------------------------- /unetr_pp/evaluation/surface_dice.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | import numpy as np 17 | from medpy.metric.binary import __surface_distances 18 | 19 | 20 | def normalized_surface_dice(a: np.ndarray, b: np.ndarray, threshold: float, spacing: tuple = None, connectivity=1): 21 | """ 22 | This implementation differs from the official surface dice implementation! These two are not comparable!!!!! 23 | 24 | The normalized surface dice is symmetric, so it should not matter whether a or b is the reference image 25 | 26 | This implementation natively supports 2D and 3D images. Whether other dimensions are supported depends on the 27 | __surface_distances implementation in medpy 28 | 29 | :param a: image 1, must have the same shape as b 30 | :param b: image 2, must have the same shape as a 31 | :param threshold: distances below this threshold will be counted as true positives. Threshold is in mm, not voxels! 32 | (if spacing = (1, 1(, 1)) then one voxel=1mm so the threshold is effectively in voxels) 33 | must be a tuple of len dimension(a) 34 | :param spacing: how many mm is one voxel in reality? Can be left at None, we then assume an isotropic spacing of 1mm 35 | :param connectivity: see scipy.ndimage.generate_binary_structure for more information. I suggest you leave that 36 | one alone 37 | :return: 38 | """ 39 | assert all([i == j for i, j in zip(a.shape, b.shape)]), "a and b must have the same shape. a.shape= %s, " \ 40 | "b.shape= %s" % (str(a.shape), str(b.shape)) 41 | if spacing is None: 42 | spacing = tuple([1 for _ in range(len(a.shape))]) 43 | a_to_b = __surface_distances(a, b, spacing, connectivity) 44 | b_to_a = __surface_distances(b, a, spacing, connectivity) 45 | 46 | numel_a = len(a_to_b) 47 | numel_b = len(b_to_a) 48 | 49 | tp_a = np.sum(a_to_b <= threshold) / numel_a 50 | tp_b = np.sum(b_to_a <= threshold) / numel_b 51 | 52 | fp = np.sum(a_to_b > threshold) / numel_a 53 | fn = np.sum(b_to_a > threshold) / numel_b 54 | 55 | dc = (tp_a + tp_b) / (tp_a + tp_b + fp + fn + 1e-8) # 1e-8 just so that we don't get div by 0 56 | return dc 57 | 58 | -------------------------------------------------------------------------------- /unetr_pp/postprocessing/consolidate_postprocessing_simple.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | import argparse 17 | from unetr_pp.postprocessing.consolidate_postprocessing import consolidate_folds 18 | from unetr_pp.utilities.folder_names import get_output_folder_name 19 | from unetr_pp.utilities.task_name_id_conversion import convert_id_to_task_name 20 | from unetr_pp.paths import default_cascade_trainer, default_trainer, default_plans_identifier 21 | 22 | 23 | def main(): 24 | argparser = argparse.ArgumentParser(usage="Used to determine the postprocessing for a trained model. Useful for " 25 | "when the best configuration (2d, 3d_fullres etc) as selected manually.") 26 | argparser.add_argument("-m", type=str, required=True, help="U-Net model (2d, 3d_lowres, 3d_fullres or " 27 | "3d_cascade_fullres)") 28 | argparser.add_argument("-t", type=str, required=True, help="Task name or id") 29 | argparser.add_argument("-tr", type=str, required=False, default=None, 30 | help="nnFormerTrainer class. Default: %s, unless 3d_cascade_fullres " 31 | "(then it's %s)" % (default_trainer, default_cascade_trainer)) 32 | argparser.add_argument("-pl", type=str, required=False, default=default_plans_identifier, 33 | help="Plans name, Default=%s" % default_plans_identifier) 34 | argparser.add_argument("-val", type=str, required=False, default="validation_raw", 35 | help="Validation folder name. Default: validation_raw") 36 | 37 | args = argparser.parse_args() 38 | model = args.m 39 | task = args.t 40 | trainer = args.tr 41 | plans = args.pl 42 | val = args.val 43 | 44 | if not task.startswith("Task"): 45 | task_id = int(task) 46 | task = convert_id_to_task_name(task_id) 47 | 48 | if trainer is None: 49 | if model == "3d_cascade_fullres": 50 | trainer = "nnFormerTrainerV2CascadeFullRes" 51 | else: 52 | trainer = "nnFormerTrainerV2" 53 | 54 | folder = get_output_folder_name(model, task, trainer, plans, None) 55 | 56 | consolidate_folds(folder, val) 57 | 58 | 59 | if __name__ == "__main__": 60 | main() 61 | -------------------------------------------------------------------------------- /unetr_pp/paths.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import os 16 | from batchgenerators.utilities.file_and_folder_operations import maybe_mkdir_p, join 17 | 18 | # do not modify these unless you know what you are doing 19 | my_output_identifier = "unetr_pp" 20 | default_plans_identifier = "unetr_pp_Plansv2.1" 21 | default_data_identifier = 'unetr_pp_Data_plans_v2.1' 22 | default_trainer = "unetr_pp_trainer_synapse" 23 | 24 | """ 25 | PLEASE READ paths.md FOR INFORMATION TO HOW TO SET THIS UP 26 | """ 27 | 28 | base = os.environ['unetr_pp_raw_data_base'] if "unetr_pp_raw_data_base" in os.environ.keys() else None 29 | preprocessing_output_dir = os.environ['unetr_pp_preprocessed'] if "unetr_pp_preprocessed" in os.environ.keys() else None 30 | network_training_output_dir_base = os.path.join(os.environ['RESULTS_FOLDER']) if "RESULTS_FOLDER" in os.environ.keys() else None 31 | 32 | if base is not None: 33 | nnFormer_raw_data = join(base, "unetr_pp_raw_data") 34 | nnFormer_cropped_data = join(base, "unetr_pp_cropped_data") 35 | maybe_mkdir_p(nnFormer_raw_data) 36 | maybe_mkdir_p(nnFormer_cropped_data) 37 | else: 38 | print("unetr_pp_raw_data_base is not defined and model can only be used on data for which preprocessed files " 39 | "are already present on your system. model cannot be used for experiment planning and preprocessing like " 40 | "this. If this is not intended, please read run_training_synapse.sh/run_training_acdc.sh " 41 | "for information on how to set this up properly.") 42 | nnFormer_cropped_data = nnFormer_raw_data = None 43 | 44 | if preprocessing_output_dir is not None: 45 | maybe_mkdir_p(preprocessing_output_dir) 46 | else: 47 | print("unetr_pp_preprocessed is not defined and model can not be used for preprocessing " 48 | "or training. If this is not intended, please read documentation/setting_up_paths.md for " 49 | "information on how to set this up.") 50 | preprocessing_output_dir = None 51 | 52 | if network_training_output_dir_base is not None: 53 | network_training_output_dir = join(network_training_output_dir_base, my_output_identifier) 54 | maybe_mkdir_p(network_training_output_dir) 55 | else: 56 | print("RESULTS_FOLDER is not defined and nnU-Net cannot be used for training or " 57 | "inference. If this is not intended behavior, please read run_training_synapse.sh/run_training_acdc.sh " 58 | "for information on how to set this up.") 59 | network_training_output_dir = None 60 | -------------------------------------------------------------------------------- /unetr_pp/postprocessing/consolidate_all_for_paper.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | from unetr_pp.utilities.folder_names import get_output_folder_name 17 | 18 | 19 | def get_datasets(): 20 | configurations_all = { 21 | "Task01_BrainTumour": ("3d_fullres", "2d"), 22 | "Task02_Heart": ("3d_fullres", "2d",), 23 | "Task03_Liver": ("3d_cascade_fullres", "3d_fullres", "3d_lowres", "2d"), 24 | "Task04_Hippocampus": ("3d_fullres", "2d",), 25 | "Task05_Prostate": ("3d_fullres", "2d",), 26 | "Task06_Lung": ("3d_cascade_fullres", "3d_fullres", "3d_lowres", "2d"), 27 | "Task07_Pancreas": ("3d_cascade_fullres", "3d_fullres", "3d_lowres", "2d"), 28 | "Task08_HepaticVessel": ("3d_cascade_fullres", "3d_fullres", "3d_lowres", "2d"), 29 | "Task09_Spleen": ("3d_cascade_fullres", "3d_fullres", "3d_lowres", "2d"), 30 | "Task10_Colon": ("3d_cascade_fullres", "3d_fullres", "3d_lowres", "2d"), 31 | "Task48_KiTS_clean": ("3d_cascade_fullres", "3d_lowres", "3d_fullres", "2d"), 32 | "Task27_ACDC": ("3d_fullres", "2d",), 33 | "Task24_Promise": ("3d_fullres", "2d",), 34 | "Task35_ISBILesionSegmentation": ("3d_fullres", "2d",), 35 | "Task38_CHAOS_Task_3_5_Variant2": ("3d_fullres", "2d",), 36 | "Task29_LITS": ("3d_cascade_fullres", "3d_lowres", "2d", "3d_fullres",), 37 | "Task17_AbdominalOrganSegmentation": ("3d_cascade_fullres", "3d_lowres", "2d", "3d_fullres",), 38 | "Task55_SegTHOR": ("3d_cascade_fullres", "3d_lowres", "3d_fullres", "2d",), 39 | "Task56_VerSe": ("3d_cascade_fullres", "3d_lowres", "3d_fullres", "2d",), 40 | } 41 | return configurations_all 42 | 43 | 44 | def get_commands(configurations, regular_trainer="nnFormerTrainerV2", cascade_trainer="nnFormerTrainerV2CascadeFullRes", 45 | plans="nnFormerPlansv2.1"): 46 | 47 | node_pool = ["hdf18-gpu%02.0d" % i for i in range(1, 21)] + ["hdf19-gpu%02.0d" % i for i in range(1, 8)] + ["hdf19-gpu%02.0d" % i for i in range(11, 16)] 48 | ctr = 0 49 | for task in configurations: 50 | models = configurations[task] 51 | for m in models: 52 | if m == "3d_cascade_fullres": 53 | trainer = cascade_trainer 54 | else: 55 | trainer = regular_trainer 56 | 57 | folder = get_output_folder_name(m, task, trainer, plans, overwrite_training_output_dir="/datasets/datasets_fabian/results/nnFormer") 58 | node = node_pool[ctr % len(node_pool)] 59 | print("bsub -m %s -q gputest -L /bin/bash \"source ~/.bashrc && python postprocessing/" 60 | "consolidate_postprocessing.py -f" % node, folder, "\"") 61 | ctr += 1 62 | -------------------------------------------------------------------------------- /unetr_pp/experiment_planning/alternative_experiment_planning/experiment_planner_baseline_3DUNet_v22.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import numpy as np 16 | from unetr_pp.experiment_planning.experiment_planner_baseline_3DUNet_v21 import \ 17 | ExperimentPlanner3D_v21 18 | from unetr_pp.paths import * 19 | 20 | 21 | class ExperimentPlanner3D_v22(ExperimentPlanner3D_v21): 22 | """ 23 | """ 24 | def __init__(self, folder_with_cropped_data, preprocessed_output_folder): 25 | super().__init__(folder_with_cropped_data, preprocessed_output_folder) 26 | self.data_identifier = "nnFormerData_plans_v2.2" 27 | self.plans_fname = join(self.preprocessed_output_folder, 28 | "nnFormerPlansv2.2_plans_3D.pkl") 29 | 30 | def get_target_spacing(self): 31 | spacings = self.dataset_properties['all_spacings'] 32 | sizes = self.dataset_properties['all_sizes'] 33 | 34 | target = np.percentile(np.vstack(spacings), self.target_spacing_percentile, 0) 35 | target_size = np.percentile(np.vstack(sizes), self.target_spacing_percentile, 0) 36 | target_size_mm = np.array(target) * np.array(target_size) 37 | # we need to identify datasets for which a different target spacing could be beneficial. These datasets have 38 | # the following properties: 39 | # - one axis which much lower resolution than the others 40 | # - the lowres axis has much less voxels than the others 41 | # - (the size in mm of the lowres axis is also reduced) 42 | worst_spacing_axis = np.argmax(target) 43 | other_axes = [i for i in range(len(target)) if i != worst_spacing_axis] 44 | other_spacings = [target[i] for i in other_axes] 45 | other_sizes = [target_size[i] for i in other_axes] 46 | 47 | has_aniso_spacing = target[worst_spacing_axis] > (self.anisotropy_threshold * max(other_spacings)) 48 | has_aniso_voxels = target_size[worst_spacing_axis] * self.anisotropy_threshold < min(other_sizes) 49 | # we don't use the last one for now 50 | #median_size_in_mm = target[target_size_mm] * RESAMPLING_SEPARATE_Z_ANISOTROPY_THRESHOLD < max(target_size_mm) 51 | 52 | if has_aniso_spacing and has_aniso_voxels: 53 | spacings_of_that_axis = np.vstack(spacings)[:, worst_spacing_axis] 54 | target_spacing_of_that_axis = np.percentile(spacings_of_that_axis, 10) 55 | # don't let the spacing of that axis get higher than self.anisotropy_thresholdxthe_other_axes 56 | target_spacing_of_that_axis = max(max(other_spacings) * self.anisotropy_threshold, target_spacing_of_that_axis) 57 | target[worst_spacing_axis] = target_spacing_of_that_axis 58 | return target 59 | 60 | -------------------------------------------------------------------------------- /unetr_pp/utilities/distributed.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | import torch 17 | from torch import distributed 18 | from torch import autograd 19 | from torch.nn.parallel import DistributedDataParallel as DDP 20 | 21 | 22 | def print_if_rank0(*args): 23 | if distributed.get_rank() == 0: 24 | print(*args) 25 | 26 | 27 | class awesome_allgather_function(autograd.Function): 28 | @staticmethod 29 | def forward(ctx, input): 30 | world_size = distributed.get_world_size() 31 | # create a destination list for the allgather. I'm assuming you're gathering from 3 workers. 32 | allgather_list = [torch.empty_like(input) for _ in range(world_size)] 33 | #if distributed.get_rank() == 0: 34 | # import IPython;IPython.embed() 35 | distributed.all_gather(allgather_list, input) 36 | return torch.cat(allgather_list, dim=0) 37 | 38 | @staticmethod 39 | def backward(ctx, grad_output): 40 | #print_if_rank0("backward grad_output len", len(grad_output)) 41 | #print_if_rank0("backward grad_output shape", grad_output.shape) 42 | grads_per_rank = grad_output.shape[0] // distributed.get_world_size() 43 | rank = distributed.get_rank() 44 | # We'll receive gradients for the entire catted forward output, so to mimic DataParallel, 45 | # return only the slice that corresponds to this process's input: 46 | sl = slice(rank * grads_per_rank, (rank + 1) * grads_per_rank) 47 | #print("worker", rank, "backward slice", sl) 48 | return grad_output[sl] 49 | 50 | 51 | if __name__ == "__main__": 52 | import torch.distributed as dist 53 | import argparse 54 | from torch import nn 55 | from torch.optim import Adam 56 | 57 | argumentparser = argparse.ArgumentParser() 58 | argumentparser.add_argument("--local_rank", type=int) 59 | args = argumentparser.parse_args() 60 | 61 | torch.cuda.set_device(args.local_rank) 62 | dist.init_process_group(backend='nccl', init_method='env://') 63 | 64 | rnd = torch.rand((5, 2)).cuda() 65 | 66 | rnd_gathered = awesome_allgather_function.apply(rnd) 67 | print("gathering random tensors\nbefore\b", rnd, "\nafter\n", rnd_gathered) 68 | 69 | # so far this works as expected 70 | print("now running a DDP model") 71 | c = nn.Conv2d(2, 3, 3, 1, 1, 1, 1, True).cuda() 72 | c = DDP(c) 73 | opt = Adam(c.parameters()) 74 | 75 | bs = 5 76 | if dist.get_rank() == 0: 77 | bs = 4 78 | inp = torch.rand((bs, 2, 5, 5)).cuda() 79 | 80 | out = c(inp) 81 | print("output_shape", out.shape) 82 | 83 | out_gathered = awesome_allgather_function.apply(out) 84 | print("output_shape_after_gather", out_gathered.shape) 85 | # this also works 86 | 87 | loss = out_gathered.sum() 88 | loss.backward() 89 | opt.step() 90 | -------------------------------------------------------------------------------- /unetr_pp/evaluation/add_dummy_task_with_mean_over_all_tasks.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import json 16 | import numpy as np 17 | from batchgenerators.utilities.file_and_folder_operations import subfiles 18 | import os 19 | from collections import OrderedDict 20 | 21 | folder = "/home/fabian/drives/E132-Projekte/Projects/2018_MedicalDecathlon/Leaderboard" 22 | task_descriptors = ['2D final 2', 23 | '2D final, less pool, dc and topK, fold0', 24 | '2D final pseudo3d 7, fold0', 25 | '2D final, less pool, dc and ce, fold0', 26 | '3D stage0 final 2, fold0', 27 | '3D fullres final 2, fold0'] 28 | task_ids_with_no_stage0 = ["Task001_BrainTumour", "Task004_Hippocampus", "Task005_Prostate"] 29 | 30 | mean_scores = OrderedDict() 31 | for t in task_descriptors: 32 | mean_scores[t] = OrderedDict() 33 | 34 | json_files = subfiles(folder, True, None, ".json", True) 35 | json_files = [i for i in json_files if not i.split("/")[-1].startswith(".")] # stupid mac 36 | for j in json_files: 37 | with open(j, 'r') as f: 38 | res = json.load(f) 39 | task = res['task'] 40 | if task != "Task999_ALL": 41 | name = res['name'] 42 | if name in task_descriptors: 43 | if task not in list(mean_scores[name].keys()): 44 | mean_scores[name][task] = res['results']['mean']['mean'] 45 | else: 46 | raise RuntimeError("duplicate task %s for description %s" % (task, name)) 47 | 48 | for t in task_ids_with_no_stage0: 49 | mean_scores["3D stage0 final 2, fold0"][t] = mean_scores["3D fullres final 2, fold0"][t] 50 | 51 | a = set() 52 | for i in mean_scores.keys(): 53 | a = a.union(list(mean_scores[i].keys())) 54 | 55 | for i in mean_scores.keys(): 56 | try: 57 | for t in list(a): 58 | assert t in mean_scores[i].keys(), "did not find task %s for experiment %s" % (t, i) 59 | new_res = OrderedDict() 60 | new_res['name'] = i 61 | new_res['author'] = "Fabian" 62 | new_res['task'] = "Task999_ALL" 63 | new_res['results'] = OrderedDict() 64 | new_res['results']['mean'] = OrderedDict() 65 | new_res['results']['mean']['mean'] = OrderedDict() 66 | tasks = list(mean_scores[i].keys()) 67 | metrics = mean_scores[i][tasks[0]].keys() 68 | for m in metrics: 69 | foreground_values = [mean_scores[i][n][m] for n in tasks] 70 | new_res['results']['mean']["mean"][m] = np.nanmean(foreground_values) 71 | output_fname = i.replace(" ", "_") + "_globalMean.json" 72 | with open(os.path.join(folder, output_fname), 'w') as f: 73 | json.dump(new_res, f) 74 | except AssertionError: 75 | print("could not process experiment %s" % i) 76 | print("did not find task %s for experiment %s" % (t, i)) 77 | 78 | -------------------------------------------------------------------------------- /unetr_pp/preprocessing/custom_preprocessors/preprocessor_scale_RGB_to_0_1.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import numpy as np 16 | from unetr_pp.preprocessing.preprocessing import PreprocessorFor2D, resample_patient 17 | 18 | 19 | class GenericPreprocessor_scale_uint8_to_0_1(PreprocessorFor2D): 20 | """ 21 | For RGB images with a value range of [0, 255]. This preprocessor overwrites the default normalization scheme by 22 | normalizing intensity values through a simple division by 255 which rescales them to [0, 1] 23 | 24 | NOTE THAT THIS INHERITS FROM PreprocessorFor2D, SO ITS WRITTEN FOR 2D ONLY! WHEN CREATING A PREPROCESSOR FOR 3D 25 | DATA, USE GenericPreprocessor AS PARENT! 26 | """ 27 | def resample_and_normalize(self, data, target_spacing, properties, seg=None, force_separate_z=None): 28 | ############ THIS PART IS IDENTICAL TO PARENT CLASS ################ 29 | 30 | original_spacing_transposed = np.array(properties["original_spacing"])[self.transpose_forward] 31 | before = { 32 | 'spacing': properties["original_spacing"], 33 | 'spacing_transposed': original_spacing_transposed, 34 | 'data.shape (data is transposed)': data.shape 35 | } 36 | target_spacing[0] = original_spacing_transposed[0] 37 | data, seg = resample_patient(data, seg, np.array(original_spacing_transposed), target_spacing, 3, 1, 38 | force_separate_z=force_separate_z, order_z_data=0, order_z_seg=0, 39 | separate_z_anisotropy_threshold=self.resample_separate_z_anisotropy_threshold) 40 | after = { 41 | 'spacing': target_spacing, 42 | 'data.shape (data is resampled)': data.shape 43 | } 44 | print("before:", before, "\nafter: ", after, "\n") 45 | 46 | if seg is not None: # hippocampus 243 has one voxel with -2 as label. wtf? 47 | seg[seg < -1] = 0 48 | 49 | properties["size_after_resampling"] = data[0].shape 50 | properties["spacing_after_resampling"] = target_spacing 51 | use_nonzero_mask = self.use_nonzero_mask 52 | 53 | assert len(self.normalization_scheme_per_modality) == len(data), "self.normalization_scheme_per_modality " \ 54 | "must have as many entries as data has " \ 55 | "modalities" 56 | assert len(self.use_nonzero_mask) == len(data), "self.use_nonzero_mask must have as many entries as data" \ 57 | " has modalities" 58 | 59 | print("normalization...") 60 | 61 | ############ HERE IS WHERE WE START CHANGING THINGS!!!!!!!################ 62 | 63 | # this is where the normalization takes place. We ignore use_nonzero_mask and normalization_scheme_per_modality 64 | for c in range(len(data)): 65 | data[c] = data[c].astype(np.float32) / 255. 66 | return data, seg, properties -------------------------------------------------------------------------------- /unetr_pp/utilities/task_name_id_conversion.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | from unetr_pp.paths import nnFormer_raw_data, preprocessing_output_dir, nnFormer_cropped_data, network_training_output_dir 17 | from batchgenerators.utilities.file_and_folder_operations import * 18 | import numpy as np 19 | 20 | 21 | def convert_id_to_task_name(task_id: int): 22 | startswith = "Task%03.0d" % task_id 23 | if preprocessing_output_dir is not None: 24 | candidates_preprocessed = subdirs(preprocessing_output_dir, prefix=startswith, join=False) 25 | else: 26 | candidates_preprocessed = [] 27 | 28 | if nnFormer_raw_data is not None: 29 | candidates_raw = subdirs(nnFormer_raw_data, prefix=startswith, join=False) 30 | else: 31 | candidates_raw = [] 32 | 33 | if nnFormer_cropped_data is not None: 34 | candidates_cropped = subdirs(nnFormer_cropped_data, prefix=startswith, join=False) 35 | else: 36 | candidates_cropped = [] 37 | 38 | candidates_trained_models = [] 39 | if network_training_output_dir is not None: 40 | for m in ['2d', '3d_lowres', '3d_fullres', '3d_cascade_fullres']: 41 | if isdir(join(network_training_output_dir, m)): 42 | candidates_trained_models += subdirs(join(network_training_output_dir, m), prefix=startswith, join=False) 43 | 44 | all_candidates = candidates_cropped + candidates_preprocessed + candidates_raw + candidates_trained_models 45 | unique_candidates = np.unique(all_candidates) 46 | if len(unique_candidates) > 1: 47 | raise RuntimeError("More than one task name found for task id %d. Please correct that. (I looked in the " 48 | "following folders:\n%s\n%s\n%s" % (task_id, nnFormer_raw_data, preprocessing_output_dir, 49 | nnFormer_cropped_data)) 50 | if len(unique_candidates) == 0: 51 | raise RuntimeError("Could not find a task with the ID %d. Make sure the requested task ID exists and that " 52 | "nnU-Net knows where raw and preprocessed data are located (see Documentation - " 53 | "Installation). Here are your currently defined folders:\nunetr_pp_preprocessed=%s\nRESULTS_" 54 | "FOLDER=%s\nunetr_pp_raw_data_base=%s\nIf something is not right, adapt your environemnt " 55 | "variables." % 56 | (task_id, 57 | os.environ.get('unetr_pp_preprocessed') if os.environ.get('unetr_pp_preprocessed') is not None else 'None', 58 | os.environ.get('RESULTS_FOLDER') if os.environ.get('RESULTS_FOLDER') is not None else 'None', 59 | os.environ.get('unetr_pp_raw_data_base') if os.environ.get('unetr_pp_raw_data_base') is not None else 'None', 60 | )) 61 | return unique_candidates[0] 62 | 63 | 64 | def convert_task_name_to_id(task_name: str): 65 | assert task_name.startswith("Task") 66 | task_id = int(task_name[4:7]) 67 | return task_id 68 | -------------------------------------------------------------------------------- /unetr_pp/evaluation/model_selection/collect_all_fold0_results_and_summarize_in_one_csv.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from unetr_pp.evaluation.model_selection.summarize_results_in_one_json import summarize2 16 | from unetr_pp.paths import network_training_output_dir 17 | from batchgenerators.utilities.file_and_folder_operations import * 18 | 19 | if __name__ == "__main__": 20 | summary_output_folder = join(network_training_output_dir, "summary_jsons_fold0_new") 21 | maybe_mkdir_p(summary_output_folder) 22 | summarize2(['all'], output_dir=summary_output_folder, folds=(0,)) 23 | 24 | results_csv = join(network_training_output_dir, "summary_fold0.csv") 25 | 26 | summary_files = subfiles(summary_output_folder, suffix='.json', join=False) 27 | 28 | with open(results_csv, 'w') as f: 29 | for s in summary_files: 30 | if s.find("ensemble") == -1: 31 | task, network, trainer, plans, validation_folder, folds = s.split("__") 32 | else: 33 | n1, n2 = s.split("--") 34 | n1 = n1[n1.find("ensemble_") + len("ensemble_") :] 35 | task = s.split("__")[0] 36 | network = "ensemble" 37 | trainer = n1 38 | plans = n2 39 | validation_folder = "none" 40 | folds = folds[:-len('.json')] 41 | results = load_json(join(summary_output_folder, s)) 42 | results_mean = results['results']['mean']['mean']['Dice'] 43 | results_median = results['results']['median']['mean']['Dice'] 44 | f.write("%s,%s,%s,%s,%s,%02.4f,%02.4f\n" % (task, 45 | network, trainer, validation_folder, plans, results_mean, results_median)) 46 | 47 | summary_output_folder = join(network_training_output_dir, "summary_jsons_new") 48 | maybe_mkdir_p(summary_output_folder) 49 | summarize2(['all'], output_dir=summary_output_folder) 50 | 51 | results_csv = join(network_training_output_dir, "summary_allFolds.csv") 52 | 53 | summary_files = subfiles(summary_output_folder, suffix='.json', join=False) 54 | 55 | with open(results_csv, 'w') as f: 56 | for s in summary_files: 57 | if s.find("ensemble") == -1: 58 | task, network, trainer, plans, validation_folder, folds = s.split("__") 59 | else: 60 | n1, n2 = s.split("--") 61 | n1 = n1[n1.find("ensemble_") + len("ensemble_") :] 62 | task = s.split("__")[0] 63 | network = "ensemble" 64 | trainer = n1 65 | plans = n2 66 | validation_folder = "none" 67 | folds = folds[:-len('.json')] 68 | results = load_json(join(summary_output_folder, s)) 69 | results_mean = results['results']['mean']['mean']['Dice'] 70 | results_median = results['results']['median']['mean']['Dice'] 71 | f.write("%s,%s,%s,%s,%s,%02.4f,%02.4f\n" % (task, 72 | network, trainer, validation_folder, plans, results_mean, results_median)) 73 | 74 | -------------------------------------------------------------------------------- /rotator_viewer.py: -------------------------------------------------------------------------------- 1 | import tkinter as tk 2 | from tkinter import filedialog 3 | from PIL import Image, ImageTk 4 | import os 5 | import PIL 6 | 7 | PIL.Image.MAX_IMAGE_PIXELS = 11881676800 8 | 9 | 10 | class ImageRotatorApp: 11 | def __init__(self, root): 12 | self.root = root 13 | root.title("Image Rotator") 14 | 15 | self.angle = 0 16 | self.image = None 17 | self.photo_image = None 18 | self.image_path = None 19 | self.max_width, self.max_height = 750, 750 # Max dimensions 20 | 21 | self.canvas = tk.Canvas(root, width=self.max_width, height=self.max_height) 22 | self.canvas.pack() 23 | 24 | btn_load = tk.Button(root, text="Load Image", command=self.load_image) 25 | btn_load.pack() 26 | 27 | btn_save = tk.Button(root, text="Save Rotation Angle", command=self.save_rotation_angle) 28 | btn_save.pack() 29 | 30 | self.label_angle = tk.Label(root, text=f"Rotation: {self.angle}°") 31 | self.label_angle.pack() 32 | 33 | # Bind arrow keys to rotation functions 34 | root.bind('', lambda event: self.rotate_image(-1)) 35 | root.bind('', lambda event: self.rotate_image(1)) 36 | 37 | def load_image(self): 38 | file_path = filedialog.askopenfilename() 39 | if file_path: 40 | self.image_path = file_path 41 | image = Image.open(file_path) 42 | self.image = self.resize_image(image) 43 | if os.path.exists(os.path.join('rotations/', 44 | (os.path.basename(self.image_path).split('.')[0]).split("_")[0] + '.txt')): 45 | with open(os.path.join('rotations/', 46 | (os.path.basename(self.image_path).split('.')[0]).split("_")[0] + '.txt'), 47 | 'r') as file: 48 | self.angle = int(file.read()) 49 | else: 50 | self.angle = 0 # Reset the angle when a new image is loaded 51 | self.label_angle.config(text=f"Rotation: {self.angle}°") 52 | self.update_canvas() 53 | 54 | def save_rotation_angle(self): 55 | if self.image_path and self.image: 56 | rotation_directory = './rotations/' 57 | os.makedirs(rotation_directory, exist_ok=True) 58 | filename = (os.path.basename(self.image_path).split('.')[0]).split("_")[0] + '.txt' 59 | with open(os.path.join(rotation_directory, filename), 'w') as file: 60 | file.write(f'{self.angle}') 61 | print(f"Rotation angle saved to {os.path.join(rotation_directory, filename)}") 62 | 63 | def resize_image(self, image): 64 | # Resize the image to fit within the max dimensions while maintaining aspect ratio 65 | ratio = min(self.max_width / image.width, self.max_height / image.height) 66 | new_size = (int(image.width * ratio), int(image.height * ratio)) 67 | return image.resize(new_size, Image.BILINEAR) 68 | 69 | def rotate_image(self, angle): 70 | if self.image: 71 | self.angle = (self.angle + angle) % 360 72 | rotated_image = self.image.rotate(-self.angle, expand=True) 73 | self.update_canvas(image=rotated_image) 74 | self.label_angle.config(text=f"Rotation: {self.angle}°") 75 | 76 | def update_canvas(self, image=None): 77 | if image is None: 78 | image = self.image 79 | 80 | self.photo_image = ImageTk.PhotoImage(image) 81 | self.canvas.config(width=image.width, height=image.height) 82 | self.canvas.create_image(image.width // 2, image.height // 2, image=self.photo_image, anchor=tk.CENTER) 83 | 84 | 85 | def main(): 86 | root = tk.Tk() 87 | app = ImageRotatorApp(root) 88 | root.mainloop() 89 | 90 | 91 | if __name__ == "__main__": 92 | main() 93 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .DS_Store 2 | # Byte-compiled / optimized / DLL files 3 | __pycache__/ 4 | *.py[cod] 5 | *$py.class 6 | 7 | # C extensions 8 | *.so 9 | 10 | # Distribution / packaging 11 | .Python 12 | build/ 13 | develop-eggs/ 14 | dist/ 15 | downloads/ 16 | eggs/ 17 | .eggs/ 18 | lib/ 19 | lib64/ 20 | parts/ 21 | sdist/ 22 | var/ 23 | wheels/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | cover/ 54 | 55 | # Translations 56 | *.mo 57 | *.pot 58 | 59 | # Django stuff: 60 | *.log 61 | local_settings.py 62 | db.sqlite3 63 | db.sqlite3-journal 64 | 65 | # Flask stuff: 66 | instance/ 67 | .webassets-cache 68 | 69 | # Scrapy stuff: 70 | .scrapy 71 | 72 | # Sphinx documentation 73 | docs/_build/ 74 | 75 | # PyBuilder 76 | .pybuilder/ 77 | target/ 78 | 79 | # Jupyter Notebook 80 | .ipynb_checkpoints 81 | 82 | # IPython 83 | profile_default/ 84 | ipython_config.py 85 | 86 | # pyenv 87 | # For a library or package, you might want to ignore these files since the code is 88 | # intended to run in multiple environments; otherwise, check them in: 89 | # .python-version 90 | 91 | # pipenv 92 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 93 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 94 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 95 | # install all needed dependencies. 96 | #Pipfile.lock 97 | 98 | # poetry 99 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 100 | # This is especially recommended for binary packages to ensure reproducibility, and is more 101 | # commonly ignored for libraries. 102 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 103 | #poetry.lock 104 | 105 | # pdm 106 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 107 | #pdm.lock 108 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 109 | # in version control. 110 | # https://pdm.fming.dev/#use-with-ide 111 | .pdm.toml 112 | 113 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 114 | __pypackages__/ 115 | 116 | # Celery stuff 117 | celerybeat-schedule 118 | celerybeat.pid 119 | 120 | # SageMath parsed files 121 | *.sage.py 122 | 123 | # Environments 124 | .env 125 | .venv 126 | env/ 127 | venv/ 128 | ENV/ 129 | env.bak/ 130 | venv.bak/ 131 | 132 | # Spyder project settings 133 | .spyderproject 134 | .spyproject 135 | 136 | # Rope project settings 137 | .ropeproject 138 | 139 | # mkdocs documentation 140 | /site 141 | 142 | # mypy 143 | .mypy_cache/ 144 | .dmypy.json 145 | dmypy.json 146 | 147 | # Pyre type checker 148 | .pyre/ 149 | 150 | # pytype static type analyzer 151 | .pytype/ 152 | 153 | # Cython debug symbols 154 | cython_debug/ 155 | 156 | # PyCharm 157 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 158 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 159 | # and can be added to the global gitignore or merged into this file. For a more nuclear 160 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 161 | #.idea/ 162 | -------------------------------------------------------------------------------- /unetr_pp/experiment_planning/alternative_experiment_planning/target_spacing/experiment_planner_baseline_3DUNet_targetSpacingForAnisoAxis.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import numpy as np 16 | from unetr_pp.experiment_planning.experiment_planner_baseline_3DUNet import ExperimentPlanner 17 | from unetr_pp.paths import * 18 | 19 | 20 | class ExperimentPlannerTargetSpacingForAnisoAxis(ExperimentPlanner): 21 | def __init__(self, folder_with_cropped_data, preprocessed_output_folder): 22 | super().__init__(folder_with_cropped_data, preprocessed_output_folder) 23 | self.data_identifier = "nnFormerData_targetSpacingForAnisoAxis" 24 | self.plans_fname = join(self.preprocessed_output_folder, 25 | "nnFormerPlans" + "targetSpacingForAnisoAxis_plans_3D.pkl") 26 | 27 | def get_target_spacing(self): 28 | """ 29 | per default we use the 50th percentile=median for the target spacing. Higher spacing results in smaller data 30 | and thus faster and easier training. Smaller spacing results in larger data and thus longer and harder training 31 | 32 | For some datasets the median is not a good choice. Those are the datasets where the spacing is very anisotropic 33 | (for example ACDC with (10, 1.5, 1.5)). These datasets still have examples with a pacing of 5 or 6 mm in the low 34 | resolution axis. Choosing the median here will result in bad interpolation artifacts that can substantially 35 | impact performance (due to the low number of slices). 36 | """ 37 | spacings = self.dataset_properties['all_spacings'] 38 | sizes = self.dataset_properties['all_sizes'] 39 | 40 | target = np.percentile(np.vstack(spacings), self.target_spacing_percentile, 0) 41 | target_size = np.percentile(np.vstack(sizes), self.target_spacing_percentile, 0) 42 | target_size_mm = np.array(target) * np.array(target_size) 43 | # we need to identify datasets for which a different target spacing could be beneficial. These datasets have 44 | # the following properties: 45 | # - one axis which much lower resolution than the others 46 | # - the lowres axis has much less voxels than the others 47 | # - (the size in mm of the lowres axis is also reduced) 48 | worst_spacing_axis = np.argmax(target) 49 | other_axes = [i for i in range(len(target)) if i != worst_spacing_axis] 50 | other_spacings = [target[i] for i in other_axes] 51 | other_sizes = [target_size[i] for i in other_axes] 52 | 53 | has_aniso_spacing = target[worst_spacing_axis] > (self.anisotropy_threshold * max(other_spacings)) 54 | has_aniso_voxels = target_size[worst_spacing_axis] * self.anisotropy_threshold < max(other_sizes) 55 | # we don't use the last one for now 56 | #median_size_in_mm = target[target_size_mm] * RESAMPLING_SEPARATE_Z_ANISOTROPY_THRESHOLD < max(target_size_mm) 57 | 58 | if has_aniso_spacing and has_aniso_voxels: 59 | spacings_of_that_axis = np.vstack(spacings)[:, worst_spacing_axis] 60 | target_spacing_of_that_axis = np.percentile(spacings_of_that_axis, 10) 61 | target[worst_spacing_axis] = target_spacing_of_that_axis 62 | return target 63 | 64 | -------------------------------------------------------------------------------- /unetr_pp/experiment_planning/nnFormer_convert_decathlon_task.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | from batchgenerators.utilities.file_and_folder_operations import * 15 | from unetr_pp.configuration import default_num_threads 16 | from unetr_pp.experiment_planning.utils import split_4d 17 | from unetr_pp.utilities.file_endings import remove_trailing_slash 18 | 19 | 20 | def crawl_and_remove_hidden_from_decathlon(folder): 21 | folder = remove_trailing_slash(folder) 22 | assert folder.split('/')[-1].startswith("Task"), "This does not seem to be a decathlon folder. Please give me a " \ 23 | "folder that starts with TaskXX and has the subfolders imagesTr, " \ 24 | "labelsTr and imagesTs" 25 | subf = subfolders(folder, join=False) 26 | assert 'imagesTr' in subf, "This does not seem to be a decathlon folder. Please give me a " \ 27 | "folder that starts with TaskXX and has the subfolders imagesTr, " \ 28 | "labelsTr and imagesTs" 29 | assert 'imagesTs' in subf, "This does not seem to be a decathlon folder. Please give me a " \ 30 | "folder that starts with TaskXX and has the subfolders imagesTr, " \ 31 | "labelsTr and imagesTs" 32 | assert 'labelsTr' in subf, "This does not seem to be a decathlon folder. Please give me a " \ 33 | "folder that starts with TaskXX and has the subfolders imagesTr, " \ 34 | "labelsTr and imagesTs" 35 | _ = [os.remove(i) for i in subfiles(folder, prefix=".")] 36 | _ = [os.remove(i) for i in subfiles(join(folder, 'imagesTr'), prefix=".")] 37 | _ = [os.remove(i) for i in subfiles(join(folder, 'labelsTr'), prefix=".")] 38 | _ = [os.remove(i) for i in subfiles(join(folder, 'imagesTs'), prefix=".")] 39 | 40 | 41 | def main(): 42 | import argparse 43 | parser = argparse.ArgumentParser(description="The MSD provides data as 4D Niftis with the modality being the first" 44 | " dimension. We think this may be cumbersome for some users and " 45 | "therefore expect 3D niftixs instead, with one file per modality. " 46 | "This utility will convert 4D MSD data into the format nnU-Net " 47 | "expects") 48 | parser.add_argument("-i", help="Input folder. Must point to a TaskXX_TASKNAME folder as downloaded from the MSD " 49 | "website", required=False, 50 | default="/home/maaz/PycharmProjects/nnFormer/DATASET/nnFormer_raw/nnFormer_raw_data" 51 | "/Task02_Synapse") 52 | parser.add_argument("-p", required=False, default=default_num_threads, type=int, 53 | help="Use this to specify how many processes are used to run the script. " 54 | "Default is %d" % default_num_threads) 55 | parser.add_argument("-output_task_id", required=False, default=None, type=int, 56 | help="If specified, this will overwrite the task id in the output folder. If unspecified, the " 57 | "task id of the input folder will be used.") 58 | args = parser.parse_args() 59 | 60 | crawl_and_remove_hidden_from_decathlon(args.i) 61 | 62 | split_4d(args.i, args.p, args.output_task_id) 63 | 64 | 65 | if __name__ == "__main__": 66 | main() 67 | -------------------------------------------------------------------------------- /unetr_pp/experiment_planning/summarize_plans.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from batchgenerators.utilities.file_and_folder_operations import * 16 | from unetr_pp.paths import preprocessing_output_dir 17 | 18 | 19 | # This file is intended to double check nnFormers design choices. It is intended to be used for developent purposes only 20 | def summarize_plans(file): 21 | plans = load_pickle(file) 22 | print("num_classes: ", plans['num_classes']) 23 | print("modalities: ", plans['modalities']) 24 | print("use_mask_for_norm", plans['use_mask_for_norm']) 25 | print("keep_only_largest_region", plans['keep_only_largest_region']) 26 | print("min_region_size_per_class", plans['min_region_size_per_class']) 27 | print("min_size_per_class", plans['min_size_per_class']) 28 | print("normalization_schemes", plans['normalization_schemes']) 29 | print("stages...\n") 30 | 31 | for i in range(len(plans['plans_per_stage'])): 32 | print("stage: ", i) 33 | print(plans['plans_per_stage'][i]) 34 | print("") 35 | 36 | 37 | def write_plans_to_file(f, plans_file): 38 | print(plans_file) 39 | a = load_pickle(plans_file) 40 | stages = list(a['plans_per_stage'].keys()) 41 | stages.sort() 42 | for stage in stages: 43 | patch_size_in_mm = [i * j for i, j in zip(a['plans_per_stage'][stages[stage]]['patch_size'], 44 | a['plans_per_stage'][stages[stage]]['current_spacing'])] 45 | median_patient_size_in_mm = [i * j for i, j in zip(a['plans_per_stage'][stages[stage]]['median_patient_size_in_voxels'], 46 | a['plans_per_stage'][stages[stage]]['current_spacing'])] 47 | f.write(plans_file.split("/")[-2]) 48 | f.write(";%s" % plans_file.split("/")[-1]) 49 | f.write(";%d" % stage) 50 | f.write(";%s" % str(a['plans_per_stage'][stages[stage]]['batch_size'])) 51 | f.write(";%s" % str(a['plans_per_stage'][stages[stage]]['num_pool_per_axis'])) 52 | f.write(";%s" % str(a['plans_per_stage'][stages[stage]]['patch_size'])) 53 | f.write(";%s" % str([str("%03.2f" % i) for i in patch_size_in_mm])) 54 | f.write(";%s" % str(a['plans_per_stage'][stages[stage]]['median_patient_size_in_voxels'])) 55 | f.write(";%s" % str([str("%03.2f" % i) for i in median_patient_size_in_mm])) 56 | f.write(";%s" % str([str("%03.2f" % i) for i in a['plans_per_stage'][stages[stage]]['current_spacing']])) 57 | f.write(";%s" % str([str("%03.2f" % i) for i in a['plans_per_stage'][stages[stage]]['original_spacing']])) 58 | f.write(";%s" % str(a['plans_per_stage'][stages[stage]]['pool_op_kernel_sizes'])) 59 | f.write(";%s" % str(a['plans_per_stage'][stages[stage]]['conv_kernel_sizes'])) 60 | f.write(";%s" % str(a['data_identifier'])) 61 | f.write("\n") 62 | 63 | 64 | if __name__ == "__main__": 65 | base_dir = './'#preprocessing_output_dir'' 66 | task_dirs = [i for i in subdirs(base_dir, join=False, prefix="Task") if i.find("BrainTumor") == -1 and i.find("MSSeg") == -1] 67 | print("found %d tasks" % len(task_dirs)) 68 | 69 | with open("2019_02_06_plans_summary.csv", 'w') as f: 70 | f.write("task;plans_file;stage;batch_size;num_pool_per_axis;patch_size;patch_size(mm);median_patient_size_in_voxels;median_patient_size_in_mm;current_spacing;original_spacing;pool_op_kernel_sizes;conv_kernel_sizes\n") 71 | for t in task_dirs: 72 | print(t) 73 | tmp = join(base_dir, t) 74 | plans_files = [i for i in subfiles(tmp, suffix=".pkl", join=False) if i.find("_plans_") != -1 and i.find("Dgx2") == -1] 75 | for p in plans_files: 76 | write_plans_to_file(f, join(tmp, p)) 77 | f.write("\n") 78 | 79 | 80 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Vesuvius Grand Prize Submission 2 | We just want to read the scrolls!! 3 | 4 | ![location](https://github.com/SQMah/Vesuvius-Grand-Prize-Submission/blob/main/location.png?raw=true) 5 | 6 | **Location of segments within the greater scroll.** The ids correspond to scroll segments with their locations here: [http://dl.ash2txt.org/full-scrolls/Scroll1.volpkg/paths/](http://dl.ash2txt.org/full-scrolls/Scroll1.volpkg/paths/) 7 | 8 | ## Methodology 9 | ### Minimum System Specs 10 | - 1 TB drive 11 | - 1 Nvidia GPU with high VRAM (I personally tried with 40GB) 12 | 13 | ### Producing Results 14 | Set up a Linux-based system with CUDA 12.1. 15 | 16 | 1. Change directory into the folder 17 | 2. Run `$ conda env create -f environment.yml`. 18 | 3. Run `$ conda activate Vesuvius-Challenge`. 19 | 4. Run `$ python data_downloader.py`. 20 | 5. Run `$ python data_setup.py`. 21 | 6. Place downloaded `[model.ckpt](https://drive.google.com/file/d/1rh0xGOPhznqPT6QqcK6tbnq86eAM9XiI/view?usp=drive_link)` into `./models`. 22 | 7. Run `$ accelerate launch inference_unetr_pp.py`. 23 | 24 | Results will be saved in the `results/` folder. 25 | 26 | ### Training 27 | After following the above steps to set up and activate the Conda environment: 28 | 29 | 1. Run `$ python data_downloader.py`. 30 | 2. Run `$ python data_setup.py`. 31 | 3. Run `$ python training_unetr_pp.py`. 32 | 33 | Trained models will be saved in the `training/` folder. 34 | 35 | ## Hallucination Mitigation 36 | Hallucinations were mitigated in 4 ways: 37 | 1. Labeled data was created using only 64x64 pixel windowed models. The 256x256 pixel windowed model was used to generate cleaner/more legible results only in the last two iterations. 38 | 2. Including more negative ink labels. The patch extraction technique described below in technical details extracts more negative labels than positive ones, reducing bias towards positive labels. Hence, the model is less likely to hallucinate positive ink labels. The risk of mis-interpreting via hallucination of negative ink labels is far lower. 39 | 3. Strong data augmentation. Distorting augmentations such as optical, grid, and elastic deformations were used during training, greatly reducing the possibility that the models memorize the shape of greek letters instead of learning true ink signals. 40 | 4. Results were generated over a 32 pixel stride instead of the 64 pixel stride used during training. Therefore, the results will be generated on unseen parts of characters even if that part of the scroll was used during training. 41 | 42 | ## Technical Details 43 | I use a custom adaptation of the state of the art [UNETR++ model](https://arxiv.org/abs/2212.04497), a transformer based UNET derivative used in medical imaging as a 3d feature extractor, max pooling over the depth layers, then a final feature extractor based on [Segformer B-5](https://arxiv.org/abs/2105.15203). 44 | 45 | We exclusively ran detections on PHerc Paris 3 (scroll 1), with an ink detection of **256x256** pixels, which corresponds to a ~2.02496mm ink detection window, with a stride of 32 pixels to ensure sufficient training data. Since this is larger than the recommended 64x64 pixel detection window, the ways in which I mitigated hallucinations is discussed below. 46 | 47 | ## Patch Extraction Technique 48 | I propose a patch extraction technique that works well for larger window sizes <=512 pixels as well as allowing the model to have sufficient examples for both positive and negative ink labels. This technique is especially important to learn characters where negative ink labels (negative space) are crucial, for example in distinguishing characters ο, ϲ, and θ, which have very similar ink structures especially when the data is noisy. 49 | 50 | ![patch example](https://github.com/SQMah/Vesuvius-Grand-Prize-Submission/blob/main/patch.png?raw=true) 51 | 52 | The patch extractor works by first identifying all the areas in the manually annotated ink label ground truth data that contain ink, and then only passing the ink area and the surrounding non-ink area that is critical to understanding what character it is. This includes the non-ink labels inside the character itself, which is crucial in distinguishing the aforementioned ο, ϲ, and θ. 53 | 54 | 55 | Example patch extraction from manually annotated ink labels on PHerc Paris 3 segment 20231012184423. The green boxes denote the area of the scroll that the model will be trained. Note that the stride is 64, hence there are many overlapping boxes. 56 | 57 | You can run and visualize the patch extract algorithm yourself using `window_visualizer.py`. 58 | 59 | 60 | -------------------------------------------------------------------------------- /unetr_pp/evaluation/region_based_evaluation.py: -------------------------------------------------------------------------------- 1 | from copy import deepcopy 2 | from multiprocessing.pool import Pool 3 | 4 | from batchgenerators.utilities.file_and_folder_operations import * 5 | from medpy import metric 6 | import SimpleITK as sitk 7 | import numpy as np 8 | from unetr_pp.configuration import default_num_threads 9 | from unetr_pp.postprocessing.consolidate_postprocessing import collect_cv_niftis 10 | 11 | 12 | def get_brats_regions(): 13 | """ 14 | this is only valid for the brats data in here where the labels are 1, 2, and 3. The original brats data have a 15 | different labeling convention! 16 | :return: 17 | """ 18 | regions = { 19 | "whole tumor": (1, 2, 3), 20 | "tumor core": (2, 3), 21 | "enhancing tumor": (3,) 22 | } 23 | return regions 24 | 25 | 26 | def get_KiTS_regions(): 27 | regions = { 28 | "kidney incl tumor": (1, 2), 29 | "tumor": (2,) 30 | } 31 | return regions 32 | 33 | 34 | def create_region_from_mask(mask, join_labels: tuple): 35 | mask_new = np.zeros_like(mask, dtype=np.uint8) 36 | for l in join_labels: 37 | mask_new[mask == l] = 1 38 | return mask_new 39 | 40 | 41 | def evaluate_case(file_pred: str, file_gt: str, regions): 42 | image_gt = sitk.GetArrayFromImage(sitk.ReadImage(file_gt)) 43 | image_pred = sitk.GetArrayFromImage(sitk.ReadImage(file_pred)) 44 | results = [] 45 | for r in regions: 46 | mask_pred = create_region_from_mask(image_pred, r) 47 | mask_gt = create_region_from_mask(image_gt, r) 48 | dc = np.nan if np.sum(mask_gt) == 0 and np.sum(mask_pred) == 0 else metric.dc(mask_pred, mask_gt) 49 | results.append(dc) 50 | return results 51 | 52 | 53 | def evaluate_regions(folder_predicted: str, folder_gt: str, regions: dict, processes=default_num_threads): 54 | region_names = list(regions.keys()) 55 | files_in_pred = subfiles(folder_predicted, suffix='.nii.gz', join=False) 56 | files_in_gt = subfiles(folder_gt, suffix='.nii.gz', join=False) 57 | have_no_gt = [i for i in files_in_pred if i not in files_in_gt] 58 | assert len(have_no_gt) == 0, "Some files in folder_predicted have not ground truth in folder_gt" 59 | have_no_pred = [i for i in files_in_gt if i not in files_in_pred] 60 | if len(have_no_pred) > 0: 61 | print("WARNING! Some files in folder_gt were not predicted (not present in folder_predicted)!") 62 | 63 | files_in_gt.sort() 64 | files_in_pred.sort() 65 | 66 | # run for all cases 67 | full_filenames_gt = [join(folder_gt, i) for i in files_in_pred] 68 | full_filenames_pred = [join(folder_predicted, i) for i in files_in_pred] 69 | 70 | p = Pool(processes) 71 | res = p.starmap(evaluate_case, zip(full_filenames_pred, full_filenames_gt, [list(regions.values())] * len(files_in_gt))) 72 | p.close() 73 | p.join() 74 | 75 | all_results = {r: [] for r in region_names} 76 | with open(join(folder_predicted, 'summary.csv'), 'w') as f: 77 | f.write("casename") 78 | for r in region_names: 79 | f.write(",%s" % r) 80 | f.write("\n") 81 | for i in range(len(files_in_pred)): 82 | f.write(files_in_pred[i][:-7]) 83 | result_here = res[i] 84 | for k, r in enumerate(region_names): 85 | dc = result_here[k] 86 | f.write(",%02.4f" % dc) 87 | all_results[r].append(dc) 88 | f.write("\n") 89 | 90 | f.write('mean') 91 | for r in region_names: 92 | f.write(",%02.4f" % np.nanmean(all_results[r])) 93 | f.write("\n") 94 | f.write('median') 95 | for r in region_names: 96 | f.write(",%02.4f" % np.nanmedian(all_results[r])) 97 | f.write("\n") 98 | 99 | f.write('mean (nan is 1)') 100 | for r in region_names: 101 | tmp = np.array(all_results[r]) 102 | tmp[np.isnan(tmp)] = 1 103 | f.write(",%02.4f" % np.mean(tmp)) 104 | f.write("\n") 105 | f.write('median (nan is 1)') 106 | for r in region_names: 107 | tmp = np.array(all_results[r]) 108 | tmp[np.isnan(tmp)] = 1 109 | f.write(",%02.4f" % np.median(tmp)) 110 | f.write("\n") 111 | 112 | 113 | if __name__ == '__main__': 114 | collect_cv_niftis('./', './cv_niftis') 115 | evaluate_regions('./cv_niftis/', './gt_niftis/', get_brats_regions()) 116 | -------------------------------------------------------------------------------- /unetr_pp/utilities/file_conversions.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple, List, Union 2 | from skimage import io 3 | import SimpleITK as sitk 4 | import numpy as np 5 | import tifffile 6 | 7 | 8 | def convert_2d_image_to_nifti(input_filename: str, output_filename_truncated: str, spacing=(999, 1, 1), 9 | transform=None, is_seg: bool = False) -> None: 10 | """ 11 | Reads an image (must be a format that it recognized by skimage.io.imread) and converts it into a series of niftis. 12 | The image can have an arbitrary number of input channels which will be exported separately (_0000.nii.gz, 13 | _0001.nii.gz, etc for images and only .nii.gz for seg). 14 | Spacing can be ignored most of the time. 15 | !!!2D images are often natural images which do not have a voxel spacing that could be used for resampling. These images 16 | must be resampled by you prior to converting them to nifti!!! 17 | 18 | Datasets converted with this utility can only be used with the 2d U-Net configuration of nnU-Net 19 | 20 | If Transform is not None it will be applied to the image after loading. 21 | 22 | Segmentations will be converted to np.uint32! 23 | 24 | :param is_seg: 25 | :param transform: 26 | :param input_filename: 27 | :param output_filename_truncated: do not use a file ending for this one! Example: output_name='./converted/image1'. This 28 | function will add the suffix (_0000) and file ending (.nii.gz) for you. 29 | :param spacing: 30 | :return: 31 | """ 32 | img = io.imread(input_filename) 33 | 34 | if transform is not None: 35 | img = transform(img) 36 | 37 | if len(img.shape) == 2: # 2d image with no color channels 38 | img = img[None, None] # add dimensions 39 | else: 40 | assert len(img.shape) == 3, "image should be 3d with color channel last but has shape %s" % str(img.shape) 41 | # we assume that the color channel is the last dimension. Transpose it to be in first 42 | img = img.transpose((2, 0, 1)) 43 | # add third dimension 44 | img = img[:, None] 45 | 46 | # image is now (c, x, x, z) where x=1 since it's 2d 47 | if is_seg: 48 | assert img.shape[0] == 1, 'segmentations can only have one color channel, not sure what happened here' 49 | 50 | for j, i in enumerate(img): 51 | 52 | if is_seg: 53 | i = i.astype(np.uint32) 54 | 55 | itk_img = sitk.GetImageFromArray(i) 56 | itk_img.SetSpacing(list(spacing)[::-1]) 57 | if not is_seg: 58 | sitk.WriteImage(itk_img, output_filename_truncated + "_%04.0d.nii.gz" % j) 59 | else: 60 | sitk.WriteImage(itk_img, output_filename_truncated + ".nii.gz") 61 | 62 | 63 | def convert_3d_tiff_to_nifti(filenames: List[str], output_name: str, spacing: Union[tuple, list], transform=None, is_seg=False) -> None: 64 | """ 65 | filenames must be a list of strings, each pointing to a separate 3d tiff file. One file per modality. If your data 66 | only has one imaging modality, simply pass a list with only a single entry 67 | 68 | Files in filenames must be readable with 69 | 70 | Note: we always only pass one file into tifffile.imread, not multiple (even though it supports it). This is because 71 | I am not familiar enough with this functionality and would like to have control over what happens. 72 | 73 | If Transform is not None it will be applied to the image after loading. 74 | 75 | :param transform: 76 | :param filenames: 77 | :param output_name: 78 | :param spacing: 79 | :return: 80 | """ 81 | if is_seg: 82 | assert len(filenames) == 1 83 | 84 | for j, i in enumerate(filenames): 85 | img = tifffile.imread(i) 86 | 87 | if transform is not None: 88 | img = transform(img) 89 | 90 | itk_img = sitk.GetImageFromArray(img) 91 | itk_img.SetSpacing(list(spacing)[::-1]) 92 | 93 | if not is_seg: 94 | sitk.WriteImage(itk_img, output_name + "_%04.0d.nii.gz" % j) 95 | else: 96 | sitk.WriteImage(itk_img, output_name + ".nii.gz") 97 | 98 | 99 | def convert_2d_segmentation_nifti_to_img(nifti_file: str, output_filename: str, transform=None, export_dtype=np.uint8): 100 | img = sitk.GetArrayFromImage(sitk.ReadImage(nifti_file)) 101 | assert img.shape[0] == 1, "This function can only export 2D segmentations!" 102 | img = img[0] 103 | if transform is not None: 104 | img = transform(img) 105 | 106 | io.imsave(output_filename, img.astype(export_dtype), check_contrast=False) 107 | 108 | 109 | def convert_3d_segmentation_nifti_to_tiff(nifti_file: str, output_filename: str, transform=None, export_dtype=np.uint8): 110 | img = sitk.GetArrayFromImage(sitk.ReadImage(nifti_file)) 111 | assert len(img.shape) == 3, "This function can only export 3D segmentations!" 112 | if transform is not None: 113 | img = transform(img) 114 | 115 | tifffile.imsave(output_filename, img.astype(export_dtype)) 116 | -------------------------------------------------------------------------------- /window_visualizer.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | import cv2 4 | import numpy as np 5 | from ink_label_processor import window_is_valid, get_ink_text_bounding_boxes, BoundingBox 6 | import albumentations as A 7 | import matplotlib.pyplot as plt 8 | from custom_augmentations import FourthAugment 9 | 10 | 11 | def process_image(image_data, window_size, stride, white_threshold, visualize=False): 12 | # Load the grayscale image 13 | image = image_data 14 | ink_bounding_boxes = get_ink_text_bounding_boxes(image) 15 | 16 | # Dimensions of the image 17 | height, width = image.shape 18 | 19 | # Count of boxes 20 | box_count = 0 21 | 22 | # Create a copy for drawing boxes 23 | output_image = cv2.cvtColor(image, cv2.COLOR_GRAY2BGR) 24 | 25 | # Convert to binary (white and black) 26 | _, binary_image = cv2.threshold(image, 127, 255, cv2.THRESH_BINARY) 27 | 28 | valid_boxes = [] 29 | 30 | # Iterate over the image 31 | for y in range(0, height - window_size + 1, stride): 32 | for x in range(0, width - window_size + 1, stride): 33 | # Extract the window 34 | window_bounding_box = BoundingBox(x, y, x + window_size, y + window_size) 35 | 36 | # Check if white percentage is above the threshold 37 | if window_is_valid(window_bounding_box, ink_bounding_boxes, white_threshold): 38 | # Draw a box 39 | cv2.rectangle(output_image, (x, y), (x + window_size, y + window_size), (0, 255, 0), 20) 40 | box_count += 1 41 | valid_boxes.append(window_bounding_box) 42 | 43 | print(f"Number of boxes drawn: {box_count}") 44 | # Display the result 45 | if visualize: 46 | cv2.imshow('Boxes on Image', output_image) 47 | cv2.waitKey(0) 48 | cv2.destroyAllWindows() 49 | 50 | return valid_boxes 51 | 52 | 53 | def visualize_augmentations(img_data, size: int, boxes: List[BoundingBox]): 54 | """Use matplotlib to visualize the augmentations in a grid, based on the boxes returned from process_image""" 55 | augmentations_list = [ 56 | # A.ToFloat(max_value=65535.0), 57 | # A.RandomResizedCrop( 58 | # size, size, scale=(0.85, 1.0)), 59 | FourthAugment(p=1.0), 60 | A.Resize(size, size), 61 | A.HorizontalFlip(p=0.5), 62 | A.VerticalFlip(p=0.5), 63 | # A.RandomRotate90(p=0.6), 64 | # A.ChannelShuffle(p=0.5), 65 | A.GridDistortion(p=0.5, num_steps=5, distort_limit=0.3, border_mode=cv2.BORDER_CONSTANT, value=0.0, 66 | normalized=True), 67 | A.ElasticTransform(p=0.5, alpha=1, sigma=50, alpha_affine=50, border_mode=cv2.BORDER_CONSTANT, value=0.0), 68 | A.RandomBrightnessContrast( 69 | brightness_limit=0.3, contrast_limit=0.3, p=0.75 70 | ), 71 | A.ShiftScaleRotate(rotate_limit=360, shift_limit=0.15, scale_limit=0.15, p=0.9, 72 | border_mode=cv2.BORDER_CONSTANT, value=0.0), 73 | A.OneOf([ 74 | A.GaussianBlur(p=0.3), 75 | A.GaussNoise(p=0.3), 76 | # A.OpticalDistortion(p=0.5, border_mode=cv2.BORDER_CONSTANT, value=0.0), 77 | A.PiecewiseAffine(p=0.5), # IAAPiecewiseAffine 78 | A.MotionBlur(), 79 | ], p=0.9), 80 | A.CoarseDropout(max_holes=2, max_width=int(size * 0.2), max_height=int(size * 0.2), 81 | mask_fill_value=0, p=0.5), 82 | ] 83 | augmentations = A.Compose(augmentations_list) 84 | box_data = [box.get_img_from_box(img_data) for box in boxes] 85 | empty_image = np.zeros((size, size), dtype=np.uint8) 86 | augmented_images = [augmentations(image=empty_image, mask=box)[ 87 | 'mask'] for box in box_data] # Apply the augmentations to the images in the boxes 88 | # Display original images using matplotlib 89 | fig = plt.figure(figsize=(8, 8)) 90 | columns = 4 91 | rows = 4 92 | for i in range(1, columns * rows + 1): 93 | fig.add_subplot(rows, columns, i) 94 | plt.imshow(box_data[i - 1]) 95 | plt.show() 96 | # Display grid of augmented images using matplotlib 97 | fig = plt.figure(figsize=(8, 8)) 98 | columns = 4 99 | rows = 4 100 | for i in range(1, columns * rows + 1): 101 | fig.add_subplot(rows, columns, i) 102 | plt.imshow(augmented_images[i - 1]) 103 | plt.show() 104 | 105 | 106 | if __name__ == '__main__': 107 | # Example usage 108 | image_path = './labels/20231012184423_inklabels.png' 109 | window_size = 256 # Example window size 110 | stride = 256 # Example stride length 111 | white_threshold = 0.5 # Fraction threshold of white 112 | visualize_bboxes = True 113 | 114 | # Process the image 115 | img_data = cv2.imread(image_path, cv2.IMREAD_GRAYSCALE) 116 | boxes = process_image(img_data, window_size, stride, white_threshold, visualize=visualize_bboxes) 117 | if not visualize_bboxes: 118 | visualize_augmentations(img_data, window_size, boxes) 119 | -------------------------------------------------------------------------------- /unetr_pp/postprocessing/consolidate_postprocessing.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import shutil 16 | from typing import Tuple 17 | 18 | from batchgenerators.utilities.file_and_folder_operations import * 19 | from unetr_pp.configuration import default_num_threads 20 | from unetr_pp.evaluation.evaluator import aggregate_scores 21 | from unetr_pp.postprocessing.connected_components import determine_postprocessing 22 | import argparse 23 | 24 | 25 | def collect_cv_niftis(cv_folder: str, output_folder: str, validation_folder_name: str = 'validation_raw', 26 | folds: tuple = (0, 1, 2, 3, 4)): 27 | validation_raw_folders = [join(cv_folder, "fold_%d" % i, validation_folder_name) for i in folds] 28 | exist = [isdir(i) for i in validation_raw_folders] 29 | 30 | if not all(exist): 31 | raise RuntimeError("some folds are missing. Please run the full 5-fold cross-validation. " 32 | "The following folds seem to be missing: %s" % 33 | [i for j, i in enumerate(folds) if not exist[j]]) 34 | 35 | # now copy all raw niftis into cv_niftis_raw 36 | maybe_mkdir_p(output_folder) 37 | for f in folds: 38 | niftis = subfiles(validation_raw_folders[f], suffix=".nii.gz") 39 | for n in niftis: 40 | shutil.copy(n, join(output_folder)) 41 | 42 | 43 | def consolidate_folds(output_folder_base, validation_folder_name: str = 'validation_raw', 44 | advanced_postprocessing: bool = False, folds: Tuple[int] = (0, 1, 2, 3, 4)): 45 | """ 46 | Used to determine the postprocessing for an experiment after all five folds have been completed. In the validation of 47 | each fold, the postprocessing can only be determined on the cases within that fold. This can result in different 48 | postprocessing decisions for different folds. In the end, we can only decide for one postprocessing per experiment, 49 | so we have to rerun it 50 | :param folds: 51 | :param advanced_postprocessing: 52 | :param output_folder_base:experiment output folder (fold_0, fold_1, etc must be subfolders of the given folder) 53 | :param validation_folder_name: dont use this 54 | :return: 55 | """ 56 | output_folder_raw = join(output_folder_base, "cv_niftis_raw") 57 | if isdir(output_folder_raw): 58 | shutil.rmtree(output_folder_raw) 59 | 60 | output_folder_gt = join(output_folder_base, "gt_niftis") 61 | collect_cv_niftis(output_folder_base, output_folder_raw, validation_folder_name, 62 | folds) 63 | 64 | num_niftis_gt = len(subfiles(join(output_folder_base, "gt_niftis"), suffix='.nii.gz')) 65 | # count niftis in there 66 | num_niftis = len(subfiles(output_folder_raw, suffix='.nii.gz')) 67 | if num_niftis != num_niftis_gt: 68 | raise AssertionError("If does not seem like you trained all the folds! Train all folds first!") 69 | 70 | # load a summary file so that we can know what class labels to expect 71 | summary_fold0 = load_json(join(output_folder_base, "fold_0", validation_folder_name, "summary.json"))['results'][ 72 | 'mean'] 73 | classes = [int(i) for i in summary_fold0.keys()] 74 | niftis = subfiles(output_folder_raw, join=False, suffix=".nii.gz") 75 | test_pred_pairs = [(join(output_folder_gt, i), join(output_folder_raw, i)) for i in niftis] 76 | 77 | # determine_postprocessing needs a summary.json file in the folder where the raw predictions are. We could compute 78 | # that from the summary files of the five folds but I am feeling lazy today 79 | aggregate_scores(test_pred_pairs, labels=classes, json_output_file=join(output_folder_raw, "summary.json"), 80 | num_threads=default_num_threads) 81 | 82 | determine_postprocessing(output_folder_base, output_folder_gt, 'cv_niftis_raw', 83 | final_subf_name="cv_niftis_postprocessed", processes=default_num_threads, 84 | advanced_postprocessing=advanced_postprocessing) 85 | # determine_postprocessing will create a postprocessing.json file that can be used for inference 86 | 87 | 88 | if __name__ == "__main__": 89 | argparser = argparse.ArgumentParser() 90 | argparser.add_argument("-f", type=str, required=True, help="experiment output folder (fold_0, fold_1, " 91 | "etc must be subfolders of the given folder)") 92 | 93 | args = argparser.parse_args() 94 | 95 | folder = args.f 96 | 97 | consolidate_folds(folder) 98 | -------------------------------------------------------------------------------- /unetr_pp/run/default_configuration.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | import unetr_pp 17 | from unetr_pp.paths import network_training_output_dir, preprocessing_output_dir, default_plans_identifier 18 | from batchgenerators.utilities.file_and_folder_operations import * 19 | from unetr_pp.experiment_planning.summarize_plans import summarize_plans 20 | from unetr_pp.training.model_restore import recursive_find_python_class 21 | import numpy as np 22 | import pickle 23 | 24 | 25 | def get_configuration_from_output_folder(folder): 26 | # split off network_training_output_dir 27 | folder = folder[len(network_training_output_dir):] 28 | if folder.startswith("/"): 29 | folder = folder[1:] 30 | 31 | configuration, task, trainer_and_plans_identifier = folder.split("/") 32 | trainer, plans_identifier = trainer_and_plans_identifier.split("__") 33 | return configuration, task, trainer, plans_identifier 34 | 35 | 36 | def get_default_configuration(network, task, network_trainer, plans_identifier=default_plans_identifier, 37 | search_in=(unetr_pp.__path__[0], "training", "network_training"), 38 | base_module='unetr_pp.training.network_training'): 39 | assert network in ['2d', '3d_lowres', '3d_fullres', '3d_cascade_fullres'], \ 40 | "network can only be one of the following: \'3d\', \'3d_lowres\', \'3d_fullres\', \'3d_cascade_fullres\'" 41 | 42 | dataset_directory = join(preprocessing_output_dir, task) 43 | 44 | if network == '2d': 45 | plans_file = join(preprocessing_output_dir, task, plans_identifier + "_plans_2D.pkl") 46 | else: 47 | plans_file = join(preprocessing_output_dir, task, plans_identifier + "_plans_3D.pkl") 48 | 49 | plans = load_pickle(plans_file) 50 | # Maybe have two kinds of plans,choose the later one 51 | if len(plans['plans_per_stage']) == 2: 52 | Stage = 1 53 | else: 54 | Stage = 0 55 | if task == 'Task001_ACDC': 56 | plans['plans_per_stage'][Stage]['batch_size'] = 4 57 | plans['plans_per_stage'][Stage]['patch_size'] = np.array([16, 160, 160]) 58 | pickle_file = open(plans_file, 'wb') 59 | pickle.dump(plans, pickle_file) 60 | pickle_file.close() 61 | 62 | elif task == 'Task002_Synapse': 63 | plans['plans_per_stage'][Stage]['batch_size'] = 2 64 | plans['plans_per_stage'][Stage]['patch_size'] = np.array([64, 128, 128]) 65 | plans['plans_per_stage'][Stage]['pool_op_kernel_sizes'] = [[2, 2, 2], [2, 2, 2], 66 | [2, 2, 2]] # for deep supervision 67 | pickle_file = open(plans_file, 'wb') 68 | pickle.dump(plans, pickle_file) 69 | pickle_file.close() 70 | possible_stages = list(plans['plans_per_stage'].keys()) 71 | 72 | if (network == '3d_cascade_fullres' or network == "3d_lowres") and len(possible_stages) == 1: 73 | raise RuntimeError("3d_lowres/3d_cascade_fullres only applies if there is more than one stage. This task does " 74 | "not require the cascade. Run 3d_fullres instead") 75 | 76 | if network == '2d' or network == "3d_lowres": 77 | stage = 0 78 | else: 79 | stage = possible_stages[-1] 80 | 81 | trainer_class = recursive_find_python_class([join(*search_in)], network_trainer, 82 | current_module=base_module) 83 | 84 | output_folder_name = join(network_training_output_dir, network, task, network_trainer + "__" + plans_identifier) 85 | 86 | print("###############################################") 87 | print("I am running the following nnFormer: %s" % network) 88 | print("My trainer class is: ", trainer_class) 89 | print("For that I will be using the following configuration:") 90 | summarize_plans(plans_file) 91 | print("I am using stage %d from these plans" % stage) 92 | 93 | if (network == '2d' or len(possible_stages) > 1) and not network == '3d_lowres': 94 | batch_dice = True 95 | print("I am using batch dice + CE loss") 96 | else: 97 | batch_dice = False 98 | print("I am using sample dice + CE loss") 99 | 100 | print("\nI am using data from this folder: ", join(dataset_directory, plans['data_identifier'])) 101 | print("###############################################") 102 | return plans_file, output_folder_name, dataset_directory, batch_dice, stage, trainer_class 103 | -------------------------------------------------------------------------------- /unetr_pp/inference_acdc.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import os 3 | import SimpleITK as sitk 4 | import numpy as np 5 | from medpy.metric import binary 6 | from sklearn.neighbors import KDTree 7 | from scipy import ndimage 8 | import argparse 9 | 10 | 11 | def read_nii(path): 12 | itk_img=sitk.ReadImage(path) 13 | spacing=np.array(itk_img.GetSpacing()) 14 | return sitk.GetArrayFromImage(itk_img),spacing 15 | 16 | def dice(pred, label): 17 | if (pred.sum() + label.sum()) == 0: 18 | return 1 19 | else: 20 | return 2. * np.logical_and(pred, label).sum() / (pred.sum() + label.sum()) 21 | 22 | def process_label(label): 23 | rv = label == 1 24 | myo = label == 2 25 | lv = label == 3 26 | 27 | return rv,myo,lv 28 | ''' 29 | def hd(pred,gt): 30 | pred[pred > 0] = 1 31 | gt[gt > 0] = 1 32 | if pred.sum() > 0 and gt.sum()>0: 33 | dice = binary.dc(pred, gt) 34 | hd95 = binary.hd95(pred, gt) 35 | return dice, hd95 36 | elif pred.sum() > 0 and gt.sum()==0: 37 | return 1, 0 38 | else: 39 | return 0, 0 40 | ''' 41 | 42 | def hd(pred,gt): 43 | #labelPred=sitk.GetImageFromArray(lP.astype(np.float32), isVector=False) 44 | #labelTrue=sitk.GetImageFromArray(lT.astype(np.float32), isVector=False) 45 | #hausdorffcomputer=sitk.HausdorffDistanceImageFilter() 46 | #hausdorffcomputer.Execute(labelTrue>0.5,labelPred>0.5) 47 | #return hausdorffcomputer.GetAverageHausdorffDistance() 48 | if pred.sum() > 0 and gt.sum()>0: 49 | hd95 = binary.hd95(pred, gt) 50 | print(hd95) 51 | return hd95 52 | else: 53 | return 0 54 | 55 | def test(fold): 56 | 57 | label_path = './' 58 | pred_path = '/' 59 | 60 | label_list = sorted(glob.glob(os.path.join(label_path, '*nii.gz'))) 61 | infer_list = sorted(glob.glob(os.path.join(pred_path, '*nii.gz'))) 62 | 63 | print("loading success...") 64 | print(label_list) 65 | print(infer_list) 66 | Dice_rv=[] 67 | Dice_myo=[] 68 | Dice_lv=[] 69 | 70 | hd_rv=[] 71 | hd_myo=[] 72 | hd_lv=[] 73 | 74 | file=path + 'inferTs/'+fold 75 | if not os.path.exists(file): 76 | os.makedirs(file) 77 | fw = open(file+'/dice_pre.txt', 'w') 78 | 79 | for label_path,infer_path in zip(label_list,infer_list): 80 | print(label_path.split('/')[-1]) 81 | print(infer_path.split('/')[-1]) 82 | label,spacing= read_nii(label_path) 83 | infer,spacing= read_nii(infer_path) 84 | label_rv,label_myo,label_lv=process_label(label) 85 | infer_rv,infer_myo,infer_lv=process_label(infer) 86 | 87 | Dice_rv.append(dice(infer_rv,label_rv)) 88 | Dice_myo.append(dice(infer_myo,label_myo)) 89 | Dice_lv.append(dice(infer_lv,label_lv)) 90 | 91 | hd_rv.append(hd(infer_rv,label_rv)) 92 | hd_myo.append(hd(infer_myo,label_myo)) 93 | hd_lv.append(hd(infer_lv,label_lv)) 94 | 95 | fw.write('*'*20+'\n',) 96 | fw.write(infer_path.split('/')[-1]+'\n') 97 | fw.write('hd_rv: {:.4f}\n'.format(hd_rv[-1])) 98 | fw.write('hd_myo: {:.4f}\n'.format(hd_myo[-1])) 99 | fw.write('hd_lv: {:.4f}\n'.format(hd_lv[-1])) 100 | #fw.write('*'*20+'\n') 101 | fw.write('*'*20+'\n',) 102 | fw.write(infer_path.split('/')[-1]+'\n') 103 | fw.write('Dice_rv: {:.4f}\n'.format(Dice_rv[-1])) 104 | fw.write('Dice_myo: {:.4f}\n'.format(Dice_myo[-1])) 105 | fw.write('Dice_lv: {:.4f}\n'.format(Dice_lv[-1])) 106 | fw.write('hd_rv: {:.4f}\n'.format(hd_rv[-1])) 107 | fw.write('hd_myo: {:.4f}\n'.format(hd_myo[-1])) 108 | fw.write('hd_lv: {:.4f}\n'.format(hd_lv[-1])) 109 | fw.write('*'*20+'\n') 110 | 111 | #fw.write('*'*20+'\n') 112 | #fw.write('Mean_hd\n') 113 | #fw.write('hd_rv'+str(np.mean(hd_rv))+'\n') 114 | #fw.write('hd_myo'+str(np.mean(hd_myo))+'\n') 115 | #fw.write('hd_lv'+str(np.mean(hd_lv))+'\n') 116 | #fw.write('*'*20+'\n') 117 | 118 | fw.write('*'*20+'\n') 119 | fw.write('Mean_Dice\n') 120 | fw.write('Dice_rv'+str(np.mean(Dice_rv))+'\n') 121 | fw.write('Dice_myo'+str(np.mean(Dice_myo))+'\n') 122 | fw.write('Dice_lv'+str(np.mean(Dice_lv))+'\n') 123 | fw.write('Mean_HD\n') 124 | fw.write('HD_rv'+str(np.mean(hd_rv))+'\n') 125 | fw.write('HD_myo'+str(np.mean(hd_myo))+'\n') 126 | fw.write('HD_lv'+str(np.mean(hd_lv))+'\n') 127 | fw.write('*'*20+'\n') 128 | 129 | dsc=[] 130 | dsc.append(np.mean(Dice_rv)) 131 | dsc.append(np.mean(Dice_myo)) 132 | dsc.append(np.mean(Dice_lv)) 133 | avg_hd=[] 134 | avg_hd.append(np.mean(hd_rv)) 135 | avg_hd.append(np.mean(hd_myo)) 136 | avg_hd.append(np.mean(hd_lv)) 137 | fw.write('avg_hd:'+str(np.mean(avg_hd))+'\n') 138 | 139 | fw.write('DSC:'+str(np.mean(dsc))+'\n') 140 | fw.write('HD:'+str(np.mean(avg_hd))+'\n') 141 | 142 | print('done') 143 | 144 | if __name__ == '__main__': 145 | parser = argparse.ArgumentParser() 146 | parser.add_argument("fold", help="fold name") 147 | args = parser.parse_args() 148 | fold = args.fold 149 | test(fold) 150 | -------------------------------------------------------------------------------- /unetr_pp/inference_tumor.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import os 3 | import SimpleITK as sitk 4 | import numpy as np 5 | import argparse 6 | from medpy.metric import binary 7 | 8 | def read_nii(path): 9 | return sitk.GetArrayFromImage(sitk.ReadImage(path)) 10 | 11 | def new_dice(pred,label): 12 | tp_hard = np.sum((pred == 1).astype(np.float) * (label == 1).astype(np.float)) 13 | fp_hard = np.sum((pred == 1).astype(np.float) * (label != 1).astype(np.float)) 14 | fn_hard = np.sum((pred != 1).astype(np.float) * (label == 1).astype(np.float)) 15 | return 2*tp_hard/(2*tp_hard+fp_hard+fn_hard) 16 | 17 | def dice(pred, label): 18 | if (pred.sum() + label.sum()) == 0: 19 | return 1 20 | else: 21 | return 2. * np.logical_and(pred, label).sum() / (pred.sum() + label.sum()) 22 | 23 | def hd(pred,gt): 24 | if pred.sum() > 0 and gt.sum()>0: 25 | hd95 = binary.hd95(pred, gt) 26 | return hd95 27 | else: 28 | return 0 29 | 30 | def process_label(label): 31 | net = label == 2 32 | ed = label == 1 33 | et = label == 3 34 | ET=et 35 | TC=net+et 36 | WT=net+et+ed 37 | ED= ed 38 | NET=net 39 | return ET,TC,WT,ED,NET 40 | 41 | def test(fold): 42 | #path='./' 43 | 44 | path = None # Replace None by the full path of : unetr_plus_plus/DATASET_Tumor/unetr_pp_raw/unetr_pp_raw_data/Task03_tumor/" 45 | 46 | label_list=sorted(glob.glob(os.path.join(path,'labelsTs','*nii.gz'))) 47 | 48 | infer_path = None # Replace None by the full path of : unetr_plus_plus/unetr_pp/evaluation/unetr_pp_tumor_checkpoint/" 49 | 50 | infer_list=sorted(glob.glob(os.path.join(infer_path,'inferTs','*nii.gz'))) 51 | print("loading success...") 52 | Dice_et=[] 53 | Dice_tc=[] 54 | Dice_wt=[] 55 | Dice_ed=[] 56 | Dice_net=[] 57 | 58 | HD_et=[] 59 | HD_tc=[] 60 | HD_wt=[] 61 | HD_ed=[] 62 | HD_net=[] 63 | file=infer_path + 'inferTs/'+fold 64 | if not os.path.exists(file): 65 | os.makedirs(file) 66 | fw = open(file+'/dice_five.txt', 'w') 67 | 68 | for label_path,infer_path in zip(label_list,infer_list): 69 | print(label_path.split('/')[-1]) 70 | print(infer_path.split('/')[-1]) 71 | label,infer = read_nii(label_path),read_nii(infer_path) 72 | label_et,label_tc,label_wt,label_ed,label_net=process_label(label) 73 | infer_et,infer_tc,infer_wt,infer_ed,infer_net=process_label(infer) 74 | Dice_et.append(dice(infer_et,label_et)) 75 | Dice_tc.append(dice(infer_tc,label_tc)) 76 | Dice_wt.append(dice(infer_wt,label_wt)) 77 | Dice_ed.append(dice(infer_ed,label_ed)) 78 | Dice_net.append(dice(infer_net,label_net)) 79 | 80 | HD_et.append(hd(infer_et,label_et)) 81 | HD_tc.append(hd(infer_tc,label_tc)) 82 | HD_wt.append(hd(infer_wt,label_wt)) 83 | HD_ed.append(hd(infer_ed,label_ed)) 84 | HD_net.append(hd(infer_net,label_net)) 85 | 86 | 87 | fw.write('*'*20+'\n',) 88 | fw.write(infer_path.split('/')[-1]+'\n') 89 | fw.write('hd_et: {:.4f}\n'.format(HD_et[-1])) 90 | fw.write('hd_tc: {:.4f}\n'.format(HD_tc[-1])) 91 | fw.write('hd_wt: {:.4f}\n'.format(HD_wt[-1])) 92 | fw.write('hd_ed: {:.4f}\n'.format(HD_ed[-1])) 93 | fw.write('hd_net: {:.4f}\n'.format(HD_net[-1])) 94 | fw.write('*'*20+'\n',) 95 | fw.write('Dice_et: {:.4f}\n'.format(Dice_et[-1])) 96 | fw.write('Dice_tc: {:.4f}\n'.format(Dice_tc[-1])) 97 | fw.write('Dice_wt: {:.4f}\n'.format(Dice_wt[-1])) 98 | fw.write('Dice_ed: {:.4f}\n'.format(Dice_ed[-1])) 99 | fw.write('Dice_net: {:.4f}\n'.format(Dice_net[-1])) 100 | 101 | #print('dice_et: {:.4f}'.format(np.mean(Dice_et))) 102 | #print('dice_tc: {:.4f}'.format(np.mean(Dice_tc))) 103 | #print('dice_wt: {:.4f}'.format(np.mean(Dice_wt))) 104 | dsc=[] 105 | avg_hd=[] 106 | dsc.append(np.mean(Dice_et)) 107 | dsc.append(np.mean(Dice_tc)) 108 | dsc.append(np.mean(Dice_wt)) 109 | dsc.append(np.mean(Dice_ed)) 110 | dsc.append(np.mean(Dice_net)) 111 | 112 | 113 | avg_hd.append(np.mean(HD_et)) 114 | avg_hd.append(np.mean(HD_tc)) 115 | avg_hd.append(np.mean(HD_wt)) 116 | avg_hd.append(np.mean(HD_ed)) 117 | avg_hd.append(np.mean(HD_net)) 118 | 119 | fw.write('Dice_et'+str(np.mean(Dice_et))+' '+'\n') 120 | fw.write('Dice_tc'+str(np.mean(Dice_tc))+' '+'\n') 121 | fw.write('Dice_wt'+str(np.mean(Dice_wt))+' '+'\n') 122 | fw.write('Dice_ed'+str(np.mean(Dice_ed))+' '+'\n') 123 | fw.write('Dice_net'+str(np.mean(Dice_net))+' '+'\n') 124 | 125 | fw.write('HD_et'+str(np.mean(HD_et))+' '+'\n') 126 | fw.write('HD_tc'+str(np.mean(HD_tc))+' '+'\n') 127 | fw.write('HD_wt'+str(np.mean(HD_wt))+' '+'\n') 128 | fw.write('HD_ed'+str(np.mean(HD_ed))+' '+'\n') 129 | fw.write('HD_net'+str(np.mean(HD_net))+' '+'\n') 130 | 131 | fw.write('Dice'+str(np.mean(dsc))+' '+'\n') 132 | fw.write('HD'+str(np.mean(avg_hd))+' '+'\n') 133 | #print('Dice'+str(np.mean(dsc))+' '+'\n') 134 | #print('HD'+str(np.mean(avg_hd))+' '+'\n') 135 | 136 | 137 | 138 | if __name__ == '__main__': 139 | parser = argparse.ArgumentParser() 140 | parser.add_argument("fold", help="fold name") 141 | args = parser.parse_args() 142 | fold=args.fold 143 | test(fold) 144 | -------------------------------------------------------------------------------- /data_downloader.py: -------------------------------------------------------------------------------- 1 | import os 2 | import subprocess 3 | from concurrent.futures import ThreadPoolExecutor 4 | from urllib.parse import urlparse, urljoin 5 | 6 | import requests 7 | from bs4 import BeautifulSoup 8 | 9 | data_dir = "./data" 10 | max_concurrent_downloads = 48 11 | 12 | # URLs 13 | scroll_one_base_url = ("scroll1", "http://dl.ash2txt.org/full-scrolls/Scroll1.volpkg/paths/") 14 | scroll_two_base_url = ("scroll2", "http://dl.ash2txt.org/full-scrolls/Scroll2.volpkg/paths/") 15 | 16 | urls = [scroll_one_base_url, scroll_two_base_url] 17 | 18 | username = "registeredusers" 19 | password = "only" 20 | 21 | 22 | def get_url_data_path(data_url): 23 | return os.path.join(data_dir, data_url[0]) 24 | 25 | 26 | def saved_segments(base_data_dir, data_urls): 27 | data_url_results = {} 28 | for data_url in data_urls: 29 | saved_dir = os.path.join(base_data_dir, data_url[0]) 30 | if os.path.exists(saved_dir) and os.path.isdir(saved_dir): 31 | exclude_list = ['.', '..', '.DS_Store', 'Thumbs.db', 'desktop.ini', '.Trashes', 'lost+found'] 32 | data_url_results[data_url[0]] = {item for item in os.listdir(saved_dir) if item not in exclude_list} 33 | else: 34 | raise FileNotFoundError(f"The directory '{saved_dir}' does not exist.") 35 | return data_url_results 36 | 37 | 38 | def relevant_segments_from_url_directories(url_directories): 39 | segments = {url.split('/')[-2] for url in url_directories} 40 | # Ignore if superseded 41 | return {url_segment for url_segment in segments if not is_superseded(url_segment, segments)} 42 | 43 | 44 | def is_superseded(segment_id, all_segment_ids_set): 45 | try: 46 | return "_superseded" in segment_id or (str(int(segment_id) + 1) in all_segment_ids_set) 47 | except ValueError: 48 | return False 49 | 50 | 51 | def get_all_scroll_segments(scroll_url): 52 | response = requests.get(scroll_url, auth=(username, password)) 53 | 54 | if response.status_code == 200: 55 | soup = BeautifulSoup(response.text, "html.parser") 56 | 57 | directories = [] 58 | 59 | for link in soup.find_all("a"): 60 | href = link.get("href") 61 | 62 | if href: 63 | full_url = urljoin(scroll_url, href) 64 | parsed_url = urlparse(full_url) 65 | 66 | # Check if it's a directory (based on the path component of the URL) 67 | if parsed_url.path.endswith('/') and full_url.startswith(scroll_url): 68 | directories.append(full_url) 69 | return directories 70 | else: 71 | print(f"Failed to fetch content from {scroll_url}. Status code: {response.status_code}") 72 | 73 | 74 | def download_segment(base_dir, segment_to_get): 75 | def download_segment_helper(url_to_download): 76 | print(f"Downloading {segment_to_get}, from url {url_to_download}, saving to {base_dir}") 77 | # if os.path.exists(f"{base_dir}/{segment_to_get}"): 78 | # print(f"Skipping {segment_to_get} because it already exists") 79 | # return 80 | url_parts = url_to_download.rstrip('/').split('/') 81 | cut_dirs = len(url_parts) - 3 # Adjust the number of directories to cut as needed 82 | url_to_download = f"{urljoin(url_to_download, segment_to_get)}/" 83 | 84 | wget_command = [ 85 | "wget", 86 | f"--user={username}", 87 | f"--password={password}", 88 | "-r", 89 | "-N", 90 | "--no-parent", 91 | "-A", "*.tif,mask.png", 92 | "-R", "*cellmap*, index.html*", 93 | "-nH", 94 | f"--cut-dirs={cut_dirs}", 95 | "-P", base_dir, 96 | f"{url_to_download}" 97 | ] 98 | result = subprocess.run(wget_command, check=True) 99 | if result.returncode == 0: 100 | directory_to_write = os.path.join(base_dir, segment_to_get) 101 | print(f"Successfully downloaded {segment_to_get} from {url_to_download}") 102 | with open(f"{directory_to_write}/base_url.txt", "w") as f: 103 | f.write(url_to_download) 104 | else: 105 | print(f"Failed to download {segment_to_get} from {url_to_download}") 106 | 107 | return download_segment_helper 108 | 109 | 110 | if __name__ == "__main__": 111 | if not os.path.exists(data_dir): 112 | os.mkdir(data_dir) 113 | for url_tuple in urls: 114 | print(f"Checking if directory {get_url_data_path(url_tuple)} exists") 115 | if not os.path.exists(get_url_data_path(url_tuple)): 116 | print(f"Creating directory {get_url_data_path(url_tuple)}") 117 | os.mkdir(get_url_data_path(url_tuple)) 118 | # downloaded_segments = saved_segments(data_dir, urls) 119 | for url_tuple in urls: 120 | key, url = url_tuple 121 | scroll_segments = get_all_scroll_segments(url) 122 | segments_to_download = relevant_segments_from_url_directories(scroll_segments) 123 | print(f"For key {key}, segments to download: {segments_to_download}") 124 | with ThreadPoolExecutor(max_workers=max_concurrent_downloads) as executor: 125 | for segment in segments_to_download: 126 | executor.submit(download_segment(get_url_data_path(url_tuple), segment), url) 127 | # for segment in ["20230522181603", "20230702185752"]: 128 | # download_segment("./data/scroll1_hari", segment)("http://dl.ash2txt.org/hari-seldon-uploads/team-finished-paths/scroll1/") 129 | -------------------------------------------------------------------------------- /unetr_pp/network_architecture/tumor/transformerblock.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | from unetr_pp.network_architecture.dynunet_block import UnetResBlock 4 | import math 5 | 6 | 7 | class TransformerBlock(nn.Module): 8 | """ 9 | A transformer block, based on: "Shaker et al., 10 | UNETR++: Delving into Efficient and Accurate 3D Medical Image Segmentation" 11 | """ 12 | 13 | def __init__( 14 | self, 15 | input_size: int, 16 | hidden_size: int, 17 | proj_size: int, 18 | num_heads: int, 19 | dropout_rate: float = 0.0, 20 | pos_embed=False, 21 | ) -> None: 22 | """ 23 | Args: 24 | input_size: the size of the input for each stage. 25 | hidden_size: dimension of hidden layer. 26 | proj_size: projection size for keys and values in the spatial attention module. 27 | num_heads: number of attention heads. 28 | dropout_rate: faction of the input units to drop. 29 | pos_embed: bool argument to determine if positional embedding is used. 30 | 31 | """ 32 | 33 | super().__init__() 34 | 35 | if not (0 <= dropout_rate <= 1): 36 | raise ValueError("dropout_rate should be between 0 and 1.") 37 | 38 | if hidden_size % num_heads != 0: 39 | print("Hidden size is ", hidden_size) 40 | print("Num heads is ", num_heads) 41 | raise ValueError("hidden_size should be divisible by num_heads.") 42 | 43 | self.norm = nn.LayerNorm(hidden_size) 44 | self.gamma = nn.Parameter(1e-6 * torch.ones(hidden_size), requires_grad=True) 45 | self.epa_block = EPA(input_size=input_size, hidden_size=hidden_size, proj_size=proj_size, num_heads=num_heads, channel_attn_drop=dropout_rate,spatial_attn_drop=dropout_rate) 46 | self.conv51 = UnetResBlock(3, hidden_size, hidden_size, kernel_size=3, stride=1, norm_name="batch") 47 | self.conv8 = nn.Sequential(nn.Dropout3d(0.1, False), nn.Conv3d(hidden_size, hidden_size, 1)) 48 | 49 | self.pos_embed = None 50 | if pos_embed: 51 | self.pos_embed = nn.Parameter(torch.zeros(1, input_size, hidden_size)) 52 | 53 | def forward(self, x): 54 | B, C, H, W, D = x.shape 55 | 56 | x = x.reshape(B, C, H * W * D).permute(0, 2, 1) 57 | 58 | if self.pos_embed is not None: 59 | x = x + self.pos_embed 60 | attn = x + self.gamma * self.epa_block(self.norm(x)) 61 | 62 | attn_skip = attn.reshape(B, H, W, D, C).permute(0, 4, 1, 2, 3) # (B, C, H, W, D) 63 | attn = self.conv51(attn_skip) 64 | x = attn_skip + self.conv8(attn) 65 | 66 | return x 67 | 68 | 69 | def init_(tensor): 70 | dim = tensor.shape[-1] 71 | std = 1 / math.sqrt(dim) 72 | tensor.uniform_(-std, std) 73 | return tensor 74 | 75 | 76 | class EPA(nn.Module): 77 | """ 78 | Efficient Paired Attention Block, based on: "Shaker et al., 79 | UNETR++: Delving into Efficient and Accurate 3D Medical Image Segmentation" 80 | """ 81 | def __init__(self, input_size, hidden_size, proj_size, num_heads=4, qkv_bias=False, 82 | channel_attn_drop=0.1, spatial_attn_drop=0.1): 83 | super().__init__() 84 | self.num_heads = num_heads 85 | self.temperature = nn.Parameter(torch.ones(num_heads, 1, 1)) 86 | self.temperature2 = nn.Parameter(torch.ones(num_heads, 1, 1)) 87 | 88 | # qkvv are 4 linear layers (query_shared, key_shared, value_spatial, value_channel) 89 | self.qkvv = nn.Linear(hidden_size, hidden_size * 4, bias=qkv_bias) 90 | 91 | # E and F are projection matrices with shared weights used in spatial attention module to project 92 | # keys and values from HWD-dimension to P-dimension 93 | self.EF = nn.Parameter(init_(torch.zeros(input_size, proj_size))) 94 | 95 | self.attn_drop = nn.Dropout(channel_attn_drop) 96 | self.attn_drop_2 = nn.Dropout(spatial_attn_drop) 97 | 98 | def forward(self, x): 99 | B, N, C = x.shape 100 | 101 | qkvv = self.qkvv(x).reshape(B, N, 4, self.num_heads, C // self.num_heads) 102 | qkvv = qkvv.permute(2, 0, 3, 1, 4) 103 | q_shared, k_shared, v_CA, v_SA = qkvv[0], qkvv[1], qkvv[2], qkvv[3] 104 | 105 | q_shared = q_shared.transpose(-2, -1) 106 | k_shared = k_shared.transpose(-2, -1) 107 | v_CA = v_CA.transpose(-2, -1) 108 | v_SA = v_SA.transpose(-2, -1) 109 | 110 | proj_e_f = lambda args: torch.einsum('bhdn,nk->bhdk', *args) 111 | k_shared_projected, v_SA_projected = map(proj_e_f, zip((k_shared, v_SA), (self.EF, self.EF))) 112 | 113 | q_shared = torch.nn.functional.normalize(q_shared, dim=-1) 114 | k_shared = torch.nn.functional.normalize(k_shared, dim=-1) 115 | 116 | attn_CA = (q_shared @ k_shared.transpose(-2, -1)) * self.temperature 117 | attn_CA = attn_CA.softmax(dim=-1) 118 | attn_CA = self.attn_drop(attn_CA) 119 | x_CA = (attn_CA @ v_CA).permute(0, 3, 1, 2).reshape(B, N, C) 120 | 121 | attn_SA = (q_shared.permute(0, 1, 3, 2) @ k_shared_projected) * self.temperature2 122 | attn_SA = attn_SA.softmax(dim=-1) 123 | attn_SA = self.attn_drop_2(attn_SA) 124 | x_SA = (attn_SA @ v_SA_projected.transpose(-2, -1)).permute(0, 3, 1, 2).reshape(B, N, C) 125 | 126 | return x_CA + x_SA 127 | 128 | @torch.jit.ignore 129 | def no_weight_decay(self): 130 | return {'temperature', 'temperature2'} -------------------------------------------------------------------------------- /unetr_pp/network_architecture/synapse/transformerblock.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | from unetr_pp.network_architecture.dynunet_block import UnetResBlock 4 | 5 | 6 | class TransformerBlock(nn.Module): 7 | """ 8 | A transformer block, based on: "Shaker et al., 9 | UNETR++: Delving into Efficient and Accurate 3D Medical Image Segmentation" 10 | """ 11 | 12 | def __init__( 13 | self, 14 | input_size: int, 15 | hidden_size: int, 16 | proj_size: int, 17 | num_heads: int, 18 | dropout_rate: float = 0.0, 19 | pos_embed=False, 20 | ) -> None: 21 | """ 22 | Args: 23 | input_size: the size of the input for each stage. 24 | hidden_size: dimension of hidden layer. 25 | proj_size: projection size for keys and values in the spatial attention module. 26 | num_heads: number of attention heads. 27 | dropout_rate: faction of the input units to drop. 28 | pos_embed: bool argument to determine if positional embedding is used. 29 | 30 | """ 31 | 32 | super().__init__() 33 | 34 | if not (0 <= dropout_rate <= 1): 35 | raise ValueError("dropout_rate should be between 0 and 1.") 36 | 37 | if hidden_size % num_heads != 0: 38 | print("Hidden size is ", hidden_size) 39 | print("Num heads is ", num_heads) 40 | raise ValueError("hidden_size should be divisible by num_heads.") 41 | 42 | self.norm = nn.LayerNorm(hidden_size) 43 | self.gamma = nn.Parameter(1e-6 * torch.ones(hidden_size), requires_grad=True) 44 | self.epa_block = EPA(input_size=input_size, hidden_size=hidden_size, proj_size=proj_size, num_heads=num_heads, channel_attn_drop=dropout_rate,spatial_attn_drop=dropout_rate) 45 | self.conv51 = UnetResBlock(3, hidden_size, hidden_size, kernel_size=3, stride=1, norm_name="batch") 46 | self.conv8 = nn.Sequential(nn.Dropout3d(0.1, False), nn.Conv3d(hidden_size, hidden_size, 1)) 47 | 48 | self.pos_embed = None 49 | if pos_embed: 50 | self.pos_embed = nn.Parameter(torch.zeros(1, input_size, hidden_size)) 51 | 52 | def forward(self, x): 53 | B, C, H, W, D = x.shape 54 | 55 | x = x.reshape(B, C, H * W * D).permute(0, 2, 1) 56 | 57 | if self.pos_embed is not None: 58 | x = x + self.pos_embed 59 | attn = x + self.gamma * self.epa_block(self.norm(x)) 60 | 61 | attn_skip = attn.reshape(B, H, W, D, C).permute(0, 4, 1, 2, 3) # (B, C, H, W, D) 62 | attn = self.conv51(attn_skip) 63 | x = attn_skip + self.conv8(attn) 64 | 65 | return x 66 | 67 | 68 | class EPA(nn.Module): 69 | """ 70 | Efficient Paired Attention Block, based on: "Shaker et al., 71 | UNETR++: Delving into Efficient and Accurate 3D Medical Image Segmentation" 72 | """ 73 | def __init__(self, input_size, hidden_size, proj_size, num_heads=4, qkv_bias=False, 74 | channel_attn_drop=0.1, spatial_attn_drop=0.1): 75 | super().__init__() 76 | self.num_heads = num_heads 77 | self.temperature = nn.Parameter(torch.ones(num_heads, 1, 1)) 78 | self.temperature2 = nn.Parameter(torch.ones(num_heads, 1, 1)) 79 | 80 | # qkvv are 4 linear layers (query_shared, key_shared, value_spatial, value_channel) 81 | self.qkvv = nn.Linear(hidden_size, hidden_size * 4, bias=qkv_bias) 82 | 83 | # E and F are projection matrices with shared weights used in spatial attention module to project 84 | # keys and values from HWD-dimension to P-dimension 85 | self.E = self.F = nn.Linear(input_size, proj_size) 86 | 87 | self.attn_drop = nn.Dropout(channel_attn_drop) 88 | self.attn_drop_2 = nn.Dropout(spatial_attn_drop) 89 | 90 | self.out_proj = nn.Linear(hidden_size, int(hidden_size // 2)) 91 | self.out_proj2 = nn.Linear(hidden_size, int(hidden_size // 2)) 92 | 93 | def forward(self, x): 94 | B, N, C = x.shape 95 | 96 | qkvv = self.qkvv(x).reshape(B, N, 4, self.num_heads, C // self.num_heads) 97 | 98 | qkvv = qkvv.permute(2, 0, 3, 1, 4) 99 | 100 | q_shared, k_shared, v_CA, v_SA = qkvv[0], qkvv[1], qkvv[2], qkvv[3] 101 | 102 | q_shared = q_shared.transpose(-2, -1) 103 | k_shared = k_shared.transpose(-2, -1) 104 | v_CA = v_CA.transpose(-2, -1) 105 | v_SA = v_SA.transpose(-2, -1) 106 | 107 | k_shared_projected = self.E(k_shared) 108 | 109 | v_SA_projected = self.F(v_SA) 110 | 111 | q_shared = torch.nn.functional.normalize(q_shared, dim=-1) 112 | k_shared = torch.nn.functional.normalize(k_shared, dim=-1) 113 | 114 | attn_CA = (q_shared @ k_shared.transpose(-2, -1)) * self.temperature 115 | 116 | attn_CA = attn_CA.softmax(dim=-1) 117 | attn_CA = self.attn_drop(attn_CA) 118 | 119 | x_CA = (attn_CA @ v_CA).permute(0, 3, 1, 2).reshape(B, N, C) 120 | 121 | attn_SA = (q_shared.permute(0, 1, 3, 2) @ k_shared_projected) * self.temperature2 122 | 123 | attn_SA = attn_SA.softmax(dim=-1) 124 | attn_SA = self.attn_drop_2(attn_SA) 125 | 126 | x_SA = (attn_SA @ v_SA_projected.transpose(-2, -1)).permute(0, 3, 1, 2).reshape(B, N, C) 127 | 128 | # Concat fusion 129 | x_SA = self.out_proj(x_SA) 130 | x_CA = self.out_proj2(x_CA) 131 | x = torch.cat((x_SA, x_CA), dim=-1) 132 | return x 133 | 134 | @torch.jit.ignore 135 | def no_weight_decay(self): 136 | return {'temperature', 'temperature2'} 137 | -------------------------------------------------------------------------------- /unetr_pp/network_architecture/tumor/unetr_pp_tumor.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | from typing import Tuple, Union 3 | from unetr_pp.network_architecture.neural_network import SegmentationNetwork 4 | from unetr_pp.network_architecture.dynunet_block import UnetOutBlock, UnetResBlock 5 | from unetr_pp.network_architecture.tumor.model_components import UnetrPPEncoder, UnetrUpBlock 6 | 7 | 8 | class UNETR_PP(SegmentationNetwork): 9 | """ 10 | UNETR++ based on: "Shaker et al., 11 | UNETR++: Delving into Efficient and Accurate 3D Medical Image Segmentation" 12 | """ 13 | def __init__( 14 | self, 15 | in_channels: int, 16 | out_channels: int, 17 | feature_size: int = 16, 18 | hidden_size: int = 256, 19 | num_heads: int = 4, 20 | pos_embed: str = "perceptron", 21 | norm_name: Union[Tuple, str] = "instance", 22 | dropout_rate: float = 0.0, 23 | depths=None, 24 | dims=None, 25 | conv_op=nn.Conv3d, 26 | do_ds=True, 27 | 28 | ) -> None: 29 | """ 30 | Args: 31 | in_channels: dimension of input channels. 32 | out_channels: dimension of output channels. 33 | img_size: dimension of input image. 34 | feature_size: dimension of network feature size. 35 | hidden_size: dimensions of the last encoder. 36 | num_heads: number of attention heads. 37 | pos_embed: position embedding layer type. 38 | norm_name: feature normalization type and arguments. 39 | dropout_rate: faction of the input units to drop. 40 | depths: number of blocks for each stage. 41 | dims: number of channel maps for the stages. 42 | conv_op: type of convolution operation. 43 | do_ds: use deep supervision to compute the loss. 44 | """ 45 | 46 | super().__init__() 47 | if depths is None: 48 | depths = [3, 3, 3, 3] 49 | self.do_ds = do_ds 50 | self.conv_op = conv_op 51 | self.num_classes = out_channels 52 | if not (0 <= dropout_rate <= 1): 53 | raise AssertionError("dropout_rate should be between 0 and 1.") 54 | 55 | if pos_embed not in ["conv", "perceptron"]: 56 | raise KeyError(f"Position embedding layer of type {pos_embed} is not supported.") 57 | 58 | self.feat_size = (4, 4, 4,) 59 | self.hidden_size = hidden_size 60 | 61 | self.unetr_pp_encoder = UnetrPPEncoder(dims=dims, depths=depths, num_heads=num_heads) 62 | 63 | self.encoder1 = UnetResBlock( 64 | spatial_dims=3, 65 | in_channels=in_channels, 66 | out_channels=feature_size, 67 | kernel_size=3, 68 | stride=1, 69 | norm_name=norm_name, 70 | ) 71 | self.decoder5 = UnetrUpBlock( 72 | spatial_dims=3, 73 | in_channels=feature_size * 16, 74 | out_channels=feature_size * 8, 75 | kernel_size=3, 76 | upsample_kernel_size=2, 77 | norm_name=norm_name, 78 | out_size=8*8*8, 79 | ) 80 | self.decoder4 = UnetrUpBlock( 81 | spatial_dims=3, 82 | in_channels=feature_size * 8, 83 | out_channels=feature_size * 4, 84 | kernel_size=3, 85 | upsample_kernel_size=2, 86 | norm_name=norm_name, 87 | out_size=16*16*16, 88 | ) 89 | self.decoder3 = UnetrUpBlock( 90 | spatial_dims=3, 91 | in_channels=feature_size * 4, 92 | out_channels=feature_size * 2, 93 | kernel_size=3, 94 | upsample_kernel_size=2, 95 | norm_name=norm_name, 96 | out_size=32*32*32, 97 | ) 98 | self.decoder2 = UnetrUpBlock( 99 | spatial_dims=3, 100 | in_channels=feature_size * 2, 101 | out_channels=feature_size, 102 | kernel_size=3, 103 | upsample_kernel_size=(4, 4, 4), 104 | norm_name=norm_name, 105 | out_size=128*128*128, 106 | conv_decoder=True, 107 | ) 108 | self.out1 = UnetOutBlock(spatial_dims=3, in_channels=feature_size, out_channels=out_channels) 109 | if self.do_ds: 110 | self.out2 = UnetOutBlock(spatial_dims=3, in_channels=feature_size * 2, out_channels=out_channels) 111 | self.out3 = UnetOutBlock(spatial_dims=3, in_channels=feature_size * 4, out_channels=out_channels) 112 | 113 | def proj_feat(self, x, hidden_size, feat_size): 114 | x = x.view(x.size(0), feat_size[0], feat_size[1], feat_size[2], hidden_size) 115 | x = x.permute(0, 4, 1, 2, 3).contiguous() 116 | return x 117 | 118 | def forward(self, x_in): 119 | #print("###########reached forward network") 120 | #print("XIN",x_in.shape) 121 | x_output, hidden_states = self.unetr_pp_encoder(x_in) 122 | convBlock = self.encoder1(x_in) 123 | 124 | # Four encoders 125 | enc1 = hidden_states[0] 126 | enc2 = hidden_states[1] 127 | enc3 = hidden_states[2] 128 | enc4 = hidden_states[3] 129 | 130 | # Four decoders 131 | dec4 = self.proj_feat(enc4, self.hidden_size, self.feat_size) 132 | dec3 = self.decoder5(dec4, enc3) 133 | dec2 = self.decoder4(dec3, enc2) 134 | dec1 = self.decoder3(dec2, enc1) 135 | 136 | out = self.decoder2(dec1, convBlock) 137 | if self.do_ds: 138 | logits = [self.out1(out), self.out2(dec1), self.out3(dec2)] 139 | else: 140 | logits = self.out1(out) 141 | 142 | return logits 143 | -------------------------------------------------------------------------------- /unetr_pp/network_architecture/acdc/transformerblock.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | from unetr_pp.network_architecture.dynunet_block import UnetResBlock 4 | 5 | 6 | class TransformerBlock(nn.Module): 7 | """ 8 | A transformer block, based on: "Shaker et al., 9 | UNETR++: Delving into Efficient and Accurate 3D Medical Image Segmentation" 10 | """ 11 | 12 | def __init__( 13 | self, 14 | input_size: int, 15 | hidden_size: int, 16 | proj_size: int, 17 | num_heads: int, 18 | dropout_rate: float = 0.0, 19 | pos_embed=False, 20 | ) -> None: 21 | """ 22 | Args: 23 | input_size: the size of the input for each stage. 24 | hidden_size: dimension of hidden layer. 25 | proj_size: projection size for keys and values in the spatial attention module. 26 | num_heads: number of attention heads. 27 | dropout_rate: faction of the input units to drop. 28 | pos_embed: bool argument to determine if positional embedding is used. 29 | 30 | """ 31 | 32 | super().__init__() 33 | 34 | if not (0 <= dropout_rate <= 1): 35 | raise ValueError("dropout_rate should be between 0 and 1.") 36 | 37 | if hidden_size % num_heads != 0: 38 | print("Hidden size is ", hidden_size) 39 | print("Num heads is ", num_heads) 40 | raise ValueError("hidden_size should be divisible by num_heads.") 41 | 42 | self.norm = nn.LayerNorm(hidden_size) 43 | self.gamma = nn.Parameter(1e-6 * torch.ones(hidden_size), requires_grad=True) 44 | self.epa_block = EPA(input_size=input_size, hidden_size=hidden_size, proj_size=proj_size, num_heads=num_heads, 45 | channel_attn_drop=dropout_rate,spatial_attn_drop=dropout_rate) 46 | self.conv51 = UnetResBlock(3, hidden_size, hidden_size, kernel_size=3, stride=1, norm_name="batch") 47 | self.conv52 = UnetResBlock(3, hidden_size, hidden_size, kernel_size=3, stride=1, norm_name="batch") 48 | self.conv8 = nn.Sequential(nn.Dropout3d(0.1, False), nn.Conv3d(hidden_size, hidden_size, 1)) 49 | 50 | self.pos_embed = None 51 | if pos_embed: 52 | self.pos_embed = nn.Parameter(torch.zeros(1, input_size, hidden_size)) 53 | 54 | def forward(self, x): 55 | B, C, H, W, D = x.shape 56 | 57 | x = x.reshape(B, C, H * W * D).permute(0, 2, 1) 58 | 59 | if self.pos_embed is not None: 60 | x = x + self.pos_embed 61 | attn = x + self.gamma * self.epa_block(self.norm(x)) 62 | 63 | attn_skip = attn.reshape(B, H, W, D, C).permute(0, 4, 1, 2, 3) # (B, C, H, W, D) 64 | attn = self.conv51(attn_skip) 65 | attn = self.conv52(attn) 66 | x = attn_skip + self.conv8(attn) 67 | return x 68 | 69 | class EPA(nn.Module): 70 | """ 71 | Efficient Paired Attention Block, based on: "Shaker et al., 72 | UNETR++: Delving into Efficient and Accurate 3D Medical Image Segmentation" 73 | """ 74 | def __init__(self, input_size, hidden_size, proj_size, num_heads=4, qkv_bias=False, channel_attn_drop=0.1, spatial_attn_drop=0.1): 75 | super().__init__() 76 | self.num_heads = num_heads 77 | self.temperature = nn.Parameter(torch.ones(num_heads, 1, 1)) 78 | self.temperature2 = nn.Parameter(torch.ones(num_heads, 1, 1)) 79 | 80 | # qkvv are 4 linear layers (query_shared, key_shared, value_spatial, value_channel) 81 | self.qkvv = nn.Linear(hidden_size, hidden_size * 4, bias=qkv_bias) 82 | 83 | # E and F are projection matrices used in spatial attention module to project keys and values from HWD-dimension to P-dimension 84 | self.E = nn.Linear(input_size, proj_size) 85 | self.F = nn.Linear(input_size, proj_size) 86 | 87 | self.attn_drop = nn.Dropout(channel_attn_drop) 88 | self.attn_drop_2 = nn.Dropout(spatial_attn_drop) 89 | 90 | self.out_proj = nn.Linear(hidden_size, int(hidden_size // 2)) 91 | self.out_proj2 = nn.Linear(hidden_size, int(hidden_size // 2)) 92 | 93 | def forward(self, x): 94 | B, N, C = x.shape 95 | 96 | qkvv = self.qkvv(x).reshape(B, N, 4, self.num_heads, C // self.num_heads) 97 | 98 | qkvv = qkvv.permute(2, 0, 3, 1, 4) 99 | 100 | q_shared, k_shared, v_CA, v_SA = qkvv[0], qkvv[1], qkvv[2], qkvv[3] 101 | 102 | q_shared = q_shared.transpose(-2, -1) 103 | k_shared = k_shared.transpose(-2, -1) 104 | v_CA = v_CA.transpose(-2, -1) 105 | v_SA = v_SA.transpose(-2, -1) 106 | 107 | k_shared_projected = self.E(k_shared) 108 | 109 | v_SA_projected = self.F(v_SA) 110 | 111 | q_shared = torch.nn.functional.normalize(q_shared, dim=-1) 112 | k_shared = torch.nn.functional.normalize(k_shared, dim=-1) 113 | 114 | attn_CA = (q_shared @ k_shared.transpose(-2, -1)) * self.temperature 115 | 116 | attn_CA = attn_CA.softmax(dim=-1) 117 | attn_CA = self.attn_drop(attn_CA) 118 | 119 | x_CA = (attn_CA @ v_CA).permute(0, 3, 1, 2).reshape(B, N, C) 120 | 121 | attn_SA = (q_shared.permute(0, 1, 3, 2) @ k_shared_projected) * self.temperature2 122 | 123 | attn_SA = attn_SA.softmax(dim=-1) 124 | attn_SA = self.attn_drop_2(attn_SA) 125 | 126 | x_SA = (attn_SA @ v_SA_projected.transpose(-2, -1)).permute(0, 3, 1, 2).reshape(B, N, C) 127 | 128 | # Concat fusion 129 | x_SA = self.out_proj(x_SA) 130 | x_CA = self.out_proj2(x_CA) 131 | x = torch.cat((x_SA, x_CA), dim=-1) 132 | return x 133 | 134 | @torch.jit.ignore 135 | def no_weight_decay(self): 136 | return {'temperature', 'temperature2'} 137 | -------------------------------------------------------------------------------- /unetr_pp/network_architecture/lung/unetr_pp_lung.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | from typing import Tuple, Union 3 | from unetr_pp.network_architecture.neural_network import SegmentationNetwork 4 | from unetr_pp.network_architecture.dynunet_block import UnetOutBlock, UnetResBlock 5 | from unetr_pp.network_architecture.lung.model_components import UnetrPPEncoder, UnetrUpBlock 6 | 7 | 8 | class UNETR_PP(SegmentationNetwork): 9 | """ 10 | UNETR++ based on: "Shaker et al., 11 | UNETR++: Delving into Efficient and Accurate 3D Medical Image Segmentation" 12 | """ 13 | def __init__( 14 | self, 15 | in_channels: int, 16 | out_channels: int, 17 | feature_size: int = 16, 18 | hidden_size: int = 256, 19 | num_heads: int = 4, 20 | pos_embed: str = "perceptron", 21 | norm_name: Union[Tuple, str] = "instance", 22 | dropout_rate: float = 0.0, 23 | depths=None, 24 | dims=None, 25 | conv_op=nn.Conv3d, 26 | do_ds=True, 27 | 28 | ) -> None: 29 | """ 30 | Args: 31 | in_channels: dimension of input channels. 32 | out_channels: dimension of output channels. 33 | img_size: dimension of input image. 34 | feature_size: dimension of network feature size. 35 | hidden_size: dimensions of the last encoder. 36 | num_heads: number of attention heads. 37 | pos_embed: position embedding layer type. 38 | norm_name: feature normalization type and arguments. 39 | dropout_rate: faction of the input units to drop. 40 | depths: number of blocks for each stage. 41 | dims: number of channel maps for the stages. 42 | conv_op: type of convolution operation. 43 | do_ds: use deep supervision to compute the loss. 44 | """ 45 | 46 | super().__init__() 47 | if depths is None: 48 | depths = [3, 3, 3, 3] 49 | self.do_ds = do_ds 50 | self.conv_op = conv_op 51 | self.num_classes = out_channels 52 | if not (0 <= dropout_rate <= 1): 53 | raise AssertionError("dropout_rate should be between 0 and 1.") 54 | 55 | if pos_embed not in ["conv", "perceptron"]: 56 | raise KeyError(f"Position embedding layer of type {pos_embed} is not supported.") 57 | 58 | self.feat_size = (4, 6, 6,) 59 | self.hidden_size = hidden_size 60 | 61 | self.unetr_pp_encoder = UnetrPPEncoder(dims=dims, depths=depths, num_heads=num_heads) 62 | 63 | self.encoder1 = UnetResBlock( 64 | spatial_dims=3, 65 | in_channels=in_channels, 66 | out_channels=feature_size, 67 | kernel_size=3, 68 | stride=1, 69 | norm_name=norm_name, 70 | ) 71 | self.decoder5 = UnetrUpBlock( 72 | spatial_dims=3, 73 | in_channels=feature_size * 16, 74 | out_channels=feature_size * 8, 75 | kernel_size=3, 76 | upsample_kernel_size=2, 77 | norm_name=norm_name, 78 | out_size=8*12*12, 79 | ) 80 | self.decoder4 = UnetrUpBlock( 81 | spatial_dims=3, 82 | in_channels=feature_size * 8, 83 | out_channels=feature_size * 4, 84 | kernel_size=3, 85 | upsample_kernel_size=2, 86 | norm_name=norm_name, 87 | out_size=16*24*24, 88 | ) 89 | self.decoder3 = UnetrUpBlock( 90 | spatial_dims=3, 91 | in_channels=feature_size * 4, 92 | out_channels=feature_size * 2, 93 | kernel_size=3, 94 | upsample_kernel_size=2, 95 | norm_name=norm_name, 96 | out_size=32*48*48, 97 | ) 98 | self.decoder2 = UnetrUpBlock( 99 | spatial_dims=3, 100 | in_channels=feature_size * 2, 101 | out_channels=feature_size, 102 | kernel_size=3, 103 | upsample_kernel_size=(1, 4, 4), 104 | norm_name=norm_name, 105 | out_size=32*192*192, 106 | conv_decoder=True, 107 | ) 108 | self.out1 = UnetOutBlock(spatial_dims=3, in_channels=feature_size, out_channels=out_channels) 109 | if self.do_ds: 110 | self.out2 = UnetOutBlock(spatial_dims=3, in_channels=feature_size * 2, out_channels=out_channels) 111 | self.out3 = UnetOutBlock(spatial_dims=3, in_channels=feature_size * 4, out_channels=out_channels) 112 | 113 | def proj_feat(self, x, hidden_size, feat_size): 114 | x = x.view(x.size(0), feat_size[0], feat_size[1], feat_size[2], hidden_size) 115 | x = x.permute(0, 4, 1, 2, 3).contiguous() 116 | return x 117 | 118 | def forward(self, x_in): 119 | #print("#####input_shape:", x_in.shape) 120 | x_output, hidden_states = self.unetr_pp_encoder(x_in) 121 | 122 | convBlock = self.encoder1(x_in) 123 | 124 | # Four encoders 125 | enc1 = hidden_states[0] 126 | #print("ENC1:",enc1.shape) 127 | enc2 = hidden_states[1] 128 | #print("ENC2:",enc2.shape) 129 | enc3 = hidden_states[2] 130 | #print("ENC3:",enc3.shape) 131 | enc4 = hidden_states[3] 132 | #print("ENC4:",enc4.shape) 133 | 134 | # Four decoders 135 | dec4 = self.proj_feat(enc4, self.hidden_size, self.feat_size) 136 | dec3 = self.decoder5(dec4, enc3) 137 | dec2 = self.decoder4(dec3, enc2) 138 | dec1 = self.decoder3(dec2, enc1) 139 | 140 | out = self.decoder2(dec1, convBlock) 141 | if self.do_ds: 142 | logits = [self.out1(out), self.out2(dec1), self.out3(dec2)] 143 | else: 144 | logits = self.out1(out) 145 | 146 | return logits 147 | -------------------------------------------------------------------------------- /unetr_pp/network_architecture/acdc/unetr_pp_acdc.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | from typing import Tuple, Union 3 | from unetr_pp.network_architecture.neural_network import SegmentationNetwork 4 | from unetr_pp.network_architecture.dynunet_block import UnetOutBlock, UnetResBlock 5 | from unetr_pp.network_architecture.acdc.model_components import UnetrPPEncoder, UnetrUpBlock 6 | 7 | class UNETR_PP(SegmentationNetwork): 8 | """ 9 | UNETR++ based on: "Shaker et al., 10 | UNETR++: Delving into Efficient and Accurate 3D Medical Image Segmentation" 11 | """ 12 | def __init__( 13 | self, 14 | in_channels: int, 15 | out_channels: int, 16 | feature_size: int = 16, 17 | hidden_size: int = 256, 18 | num_heads: int = 4, 19 | pos_embed: str = "perceptron", 20 | norm_name: Union[Tuple, str] = "instance", 21 | dropout_rate: float = 0.0, 22 | depths=None, 23 | dims=None, 24 | conv_op=nn.Conv3d, 25 | do_ds=True, 26 | window_size: int = 256, 27 | depth: int = 16, 28 | img_size=() 29 | 30 | ) -> None: 31 | """ 32 | Args: 33 | in_channels: dimension of input channels. 34 | out_channels: dimension of output channels 35 | feature_size: dimension of network feature size. 36 | hidden_size: dimensions of the last encoder. 37 | num_heads: number of attention heads. 38 | pos_embed: position embedding layer type. 39 | norm_name: feature normalization type and arguments. 40 | dropout_rate: faction of the input units to drop. 41 | depths: number of blocks for each stage. 42 | dims: number of channel maps for the stages. 43 | conv_op: type of convolution operation. 44 | do_ds: use deep supervision to compute the loss. 45 | """ 46 | 47 | super().__init__() 48 | if depths is None: 49 | depths = [3, 3, 3, 3] 50 | self.do_ds = do_ds 51 | self.conv_op = conv_op 52 | self.num_classes = out_channels 53 | if not (0 <= dropout_rate <= 1): 54 | raise AssertionError("dropout_rate should be between 0 and 1.") 55 | 56 | if pos_embed not in ["conv", "perceptron"]: 57 | raise KeyError(f"Position embedding layer of type {pos_embed} is not supported.") 58 | 59 | self.feat_size = (depth // 8, window_size // 32, window_size // 32) 60 | self.hidden_size = hidden_size 61 | 62 | self.unetr_pp_encoder = UnetrPPEncoder(dims=dims, depths=depths, num_heads=num_heads, window_size=window_size, depth=depth) 63 | 64 | self.encoder1 = UnetResBlock( 65 | spatial_dims=3, 66 | in_channels=in_channels, 67 | out_channels=feature_size, 68 | kernel_size=3, 69 | stride=1, 70 | norm_name=norm_name, 71 | ) 72 | self.decoder5 = UnetrUpBlock( 73 | spatial_dims=3, 74 | in_channels=feature_size * 16, 75 | out_channels=feature_size * 8, 76 | kernel_size=3, 77 | upsample_kernel_size=2, 78 | norm_name=norm_name, 79 | out_size=(depth // 4) * (window_size // 16) ** 2, 80 | ) 81 | self.decoder4 = UnetrUpBlock( 82 | spatial_dims=3, 83 | in_channels=feature_size * 8, 84 | out_channels=feature_size * 4, 85 | kernel_size=3, 86 | upsample_kernel_size=2, 87 | norm_name=norm_name, 88 | out_size=(depth // 2) * (window_size // 8) ** 2, 89 | ) 90 | self.decoder3 = UnetrUpBlock( 91 | spatial_dims=3, 92 | in_channels=feature_size * 4, 93 | out_channels=feature_size * 2, 94 | kernel_size=3, 95 | upsample_kernel_size=2, 96 | norm_name=norm_name, 97 | out_size=depth * (window_size // 4) ** 2, 98 | ) 99 | self.decoder2 = UnetrUpBlock( 100 | spatial_dims=3, 101 | in_channels=feature_size * 2, 102 | out_channels=feature_size, 103 | kernel_size=3, 104 | upsample_kernel_size=(1, 4, 4), 105 | norm_name=norm_name, 106 | out_size=depth * (window_size // 2) ** 2, 107 | conv_decoder=True, 108 | ) 109 | self.out1 = UnetOutBlock(spatial_dims=3, in_channels=feature_size, out_channels=out_channels) 110 | if self.do_ds: 111 | self.out2 = UnetOutBlock(spatial_dims=3, in_channels=feature_size * 2, out_channels=out_channels) 112 | self.out3 = UnetOutBlock(spatial_dims=3, in_channels=feature_size * 4, out_channels=out_channels) 113 | 114 | def proj_feat(self, x, hidden_size, feat_size): 115 | x = x.view(x.size(0), feat_size[0], feat_size[1], feat_size[2], hidden_size) 116 | x = x.permute(0, 4, 1, 2, 3).contiguous() 117 | return x 118 | 119 | def forward(self, x_in): 120 | x_output, hidden_states = self.unetr_pp_encoder(x_in) 121 | 122 | convBlock = self.encoder1(x_in) 123 | 124 | # Four encoders 125 | enc1 = hidden_states[0] 126 | enc2 = hidden_states[1] 127 | enc3 = hidden_states[2] 128 | enc4 = hidden_states[3] 129 | 130 | # Four decoders 131 | dec4 = self.proj_feat(enc4, self.hidden_size, self.feat_size) 132 | dec3 = self.decoder5(dec4, enc3) 133 | dec2 = self.decoder4(dec3, enc2) 134 | dec1 = self.decoder3(dec2, enc1) 135 | 136 | out = self.decoder2(dec1, convBlock) 137 | if self.do_ds: 138 | logits = [self.out1(out), self.out2(dec1), self.out3(dec2)] 139 | else: 140 | logits = self.out1(out) 141 | 142 | return logits 143 | -------------------------------------------------------------------------------- /unetr_pp/network_architecture/lung/transformerblock.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | from unetr_pp.network_architecture.dynunet_block import UnetResBlock 4 | 5 | 6 | class TransformerBlock(nn.Module): 7 | """ 8 | A transformer block, based on: "Shaker et al., 9 | UNETR++: Delving into Efficient and Accurate 3D Medical Image Segmentation" 10 | """ 11 | 12 | def __init__( 13 | self, 14 | input_size: int, 15 | hidden_size: int, 16 | proj_size: int, 17 | num_heads: int, 18 | dropout_rate: float = 0.0, 19 | pos_embed=False, 20 | ) -> None: 21 | """ 22 | Args: 23 | input_size: the size of the input for each stage. 24 | hidden_size: dimension of hidden layer. 25 | proj_size: projection size for keys and values in the spatial attention module. 26 | num_heads: number of attention heads. 27 | dropout_rate: faction of the input units to drop. 28 | pos_embed: bool argument to determine if positional embedding is used. 29 | 30 | """ 31 | 32 | super().__init__() 33 | 34 | if not (0 <= dropout_rate <= 1): 35 | raise ValueError("dropout_rate should be between 0 and 1.") 36 | 37 | if hidden_size % num_heads != 0: 38 | print("Hidden size is ", hidden_size) 39 | print("Num heads is ", num_heads) 40 | raise ValueError("hidden_size should be divisible by num_heads.") 41 | 42 | self.norm = nn.LayerNorm(hidden_size) 43 | self.gamma = nn.Parameter(1e-6 * torch.ones(hidden_size), requires_grad=True) 44 | self.epa_block = EPA(input_size=input_size, hidden_size=hidden_size, proj_size=proj_size, num_heads=num_heads, 45 | channel_attn_drop=dropout_rate,spatial_attn_drop=dropout_rate) 46 | self.conv51 = UnetResBlock(3, hidden_size, hidden_size, kernel_size=3, stride=1, norm_name="batch") 47 | self.conv52 = UnetResBlock(3, hidden_size, hidden_size, kernel_size=3, stride=1, norm_name="batch") 48 | self.conv8 = nn.Sequential(nn.Dropout3d(0.1, False), nn.Conv3d(hidden_size, hidden_size, 1)) 49 | 50 | self.pos_embed = None 51 | if pos_embed: 52 | #print("input size", input_size) 53 | self.pos_embed = nn.Parameter(torch.zeros(1, input_size, hidden_size)) 54 | 55 | def forward(self, x): 56 | B, C, H, W, D = x.shape 57 | #print("XSHAPE ",x.shape) 58 | x = x.reshape(B, C, H * W * D).permute(0, 2, 1) 59 | 60 | if self.pos_embed is not None: 61 | #print ("x",x.shape) 62 | #print("pos_embed",self.pos_embed.shape) 63 | x = x + self.pos_embed 64 | attn = x + self.gamma * self.epa_block(self.norm(x)) 65 | 66 | attn_skip = attn.reshape(B, H, W, D, C).permute(0, 4, 1, 2, 3) # (B, C, H, W, D) 67 | attn = self.conv51(attn_skip) 68 | attn = self.conv52(attn) 69 | x = attn_skip + self.conv8(attn) 70 | return x 71 | 72 | class EPA(nn.Module): 73 | """ 74 | Efficient Paired Attention Block, based on: "Shaker et al., 75 | UNETR++: Delving into Efficient and Accurate 3D Medical Image Segmentation" 76 | """ 77 | def __init__(self, input_size, hidden_size, proj_size, num_heads=4, qkv_bias=False, channel_attn_drop=0.1, spatial_attn_drop=0.1): 78 | super().__init__() 79 | self.num_heads = num_heads 80 | self.temperature = nn.Parameter(torch.ones(num_heads, 1, 1)) 81 | self.temperature2 = nn.Parameter(torch.ones(num_heads, 1, 1)) 82 | 83 | # qkvv are 4 linear layers (query_shared, key_shared, value_spatial, value_channel) 84 | self.qkvv = nn.Linear(hidden_size, hidden_size * 4, bias=qkv_bias) 85 | 86 | # E and F are projection matrices used in spatial attention module to project keys and values from HWD-dimension to P-dimension 87 | self.E = nn.Linear(input_size, proj_size) 88 | self.F = nn.Linear(input_size, proj_size) 89 | 90 | self.attn_drop = nn.Dropout(channel_attn_drop) 91 | self.attn_drop_2 = nn.Dropout(spatial_attn_drop) 92 | 93 | self.out_proj = nn.Linear(hidden_size, int(hidden_size // 2)) 94 | self.out_proj2 = nn.Linear(hidden_size, int(hidden_size // 2)) 95 | 96 | def forward(self, x): 97 | B, N, C = x.shape 98 | #print("The shape in EPA ", self.E.shape) 99 | 100 | qkvv = self.qkvv(x).reshape(B, N, 4, self.num_heads, C // self.num_heads) 101 | 102 | qkvv = qkvv.permute(2, 0, 3, 1, 4) 103 | 104 | q_shared, k_shared, v_CA, v_SA = qkvv[0], qkvv[1], qkvv[2], qkvv[3] 105 | 106 | q_shared = q_shared.transpose(-2, -1) 107 | k_shared = k_shared.transpose(-2, -1) 108 | v_CA = v_CA.transpose(-2, -1) 109 | v_SA = v_SA.transpose(-2, -1) 110 | 111 | k_shared_projected = self.E(k_shared) 112 | 113 | v_SA_projected = self.F(v_SA) 114 | 115 | q_shared = torch.nn.functional.normalize(q_shared, dim=-1) 116 | k_shared = torch.nn.functional.normalize(k_shared, dim=-1) 117 | 118 | attn_CA = (q_shared @ k_shared.transpose(-2, -1)) * self.temperature 119 | 120 | attn_CA = attn_CA.softmax(dim=-1) 121 | attn_CA = self.attn_drop(attn_CA) 122 | 123 | x_CA = (attn_CA @ v_CA).permute(0, 3, 1, 2).reshape(B, N, C) 124 | 125 | attn_SA = (q_shared.permute(0, 1, 3, 2) @ k_shared_projected) * self.temperature2 126 | 127 | attn_SA = attn_SA.softmax(dim=-1) 128 | attn_SA = self.attn_drop_2(attn_SA) 129 | 130 | x_SA = (attn_SA @ v_SA_projected.transpose(-2, -1)).permute(0, 3, 1, 2).reshape(B, N, C) 131 | 132 | # Concat fusion 133 | x_SA = self.out_proj(x_SA) 134 | x_CA = self.out_proj2(x_CA) 135 | x = torch.cat((x_SA, x_CA), dim=-1) 136 | return x 137 | 138 | @torch.jit.ignore 139 | def no_weight_decay(self): 140 | return {'temperature', 'temperature2'} 141 | -------------------------------------------------------------------------------- /unetr_pp/experiment_planning/experiment_planner_baseline_2DUNet_v21.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | from copy import deepcopy 15 | 16 | from unetr_pp.experiment_planning.common_utils import get_pool_and_conv_props 17 | from unetr_pp.experiment_planning.experiment_planner_baseline_2DUNet import ExperimentPlanner2D 18 | from unetr_pp.network_architecture.generic_UNet import Generic_UNet 19 | from unetr_pp.paths import * 20 | import numpy as np 21 | 22 | 23 | class ExperimentPlanner2D_v21(ExperimentPlanner2D): 24 | def __init__(self, folder_with_cropped_data, preprocessed_output_folder): 25 | super(ExperimentPlanner2D_v21, self).__init__(folder_with_cropped_data, preprocessed_output_folder) 26 | self.data_identifier = "nnFormerData_plans_v2.1_2D" 27 | self.plans_fname = join(self.preprocessed_output_folder, 28 | "nnFormerPlansv2.1_plans_2D.pkl") 29 | self.unet_base_num_features = 32 30 | 31 | def get_properties_for_stage(self, current_spacing, original_spacing, original_shape, num_cases, 32 | num_modalities, num_classes): 33 | 34 | new_median_shape = np.round(original_spacing / current_spacing * original_shape).astype(int) 35 | 36 | dataset_num_voxels = np.prod(new_median_shape, dtype=np.int64) * num_cases 37 | input_patch_size = new_median_shape[1:] 38 | 39 | network_num_pool_per_axis, pool_op_kernel_sizes, conv_kernel_sizes, new_shp, \ 40 | shape_must_be_divisible_by = get_pool_and_conv_props(current_spacing[1:], input_patch_size, 41 | self.unet_featuremap_min_edge_length, 42 | self.unet_max_numpool) 43 | 44 | # we pretend to use 30 feature maps. This will yield the same configuration as in V1. The larger memory 45 | # footpring of 32 vs 30 is mor ethan offset by the fp16 training. We make fp16 training default 46 | # Reason for 32 vs 30 feature maps is that 32 is faster in fp16 training (because multiple of 8) 47 | ref = Generic_UNet.use_this_for_batch_size_computation_2D * Generic_UNet.DEFAULT_BATCH_SIZE_2D / 2 # for batch size 2 48 | here = Generic_UNet.compute_approx_vram_consumption(new_shp, 49 | network_num_pool_per_axis, 50 | 30, 51 | self.unet_max_num_filters, 52 | num_modalities, num_classes, 53 | pool_op_kernel_sizes, 54 | conv_per_stage=self.conv_per_stage) 55 | while here > ref: 56 | axis_to_be_reduced = np.argsort(new_shp / new_median_shape[1:])[-1] 57 | 58 | tmp = deepcopy(new_shp) 59 | tmp[axis_to_be_reduced] -= shape_must_be_divisible_by[axis_to_be_reduced] 60 | _, _, _, _, shape_must_be_divisible_by_new = \ 61 | get_pool_and_conv_props(current_spacing[1:], tmp, self.unet_featuremap_min_edge_length, 62 | self.unet_max_numpool) 63 | new_shp[axis_to_be_reduced] -= shape_must_be_divisible_by_new[axis_to_be_reduced] 64 | 65 | # we have to recompute numpool now: 66 | network_num_pool_per_axis, pool_op_kernel_sizes, conv_kernel_sizes, new_shp, \ 67 | shape_must_be_divisible_by = get_pool_and_conv_props(current_spacing[1:], new_shp, 68 | self.unet_featuremap_min_edge_length, 69 | self.unet_max_numpool) 70 | 71 | here = Generic_UNet.compute_approx_vram_consumption(new_shp, network_num_pool_per_axis, 72 | self.unet_base_num_features, 73 | self.unet_max_num_filters, num_modalities, 74 | num_classes, pool_op_kernel_sizes, 75 | conv_per_stage=self.conv_per_stage) 76 | # print(new_shp) 77 | 78 | batch_size = int(np.floor(ref / here) * 2) 79 | input_patch_size = new_shp 80 | 81 | if batch_size < self.unet_min_batch_size: 82 | raise RuntimeError("This should not happen") 83 | 84 | # check if batch size is too large (more than 5 % of dataset) 85 | max_batch_size = np.round(self.batch_size_covers_max_percent_of_dataset * dataset_num_voxels / 86 | np.prod(input_patch_size, dtype=np.int64)).astype(int) 87 | batch_size = max(1, min(batch_size, max_batch_size)) 88 | 89 | plan = { 90 | 'batch_size': batch_size, 91 | 'num_pool_per_axis': network_num_pool_per_axis, 92 | 'patch_size': input_patch_size, 93 | 'median_patient_size_in_voxels': new_median_shape, 94 | 'current_spacing': current_spacing, 95 | 'original_spacing': original_spacing, 96 | 'pool_op_kernel_sizes': pool_op_kernel_sizes, 97 | 'conv_kernel_sizes': conv_kernel_sizes, 98 | 'do_dummy_2D_data_aug': False 99 | } 100 | return plan 101 | -------------------------------------------------------------------------------- /unetr_pp/evaluation/model_selection/summarize_results_with_plans.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | from batchgenerators.utilities.file_and_folder_operations import * 17 | import os 18 | from unetr_pp.evaluation.model_selection.summarize_results_in_one_json import summarize 19 | from unetr_pp.paths import network_training_output_dir 20 | import numpy as np 21 | 22 | 23 | def list_to_string(l, delim=","): 24 | st = "%03.3f" % l[0] 25 | for i in l[1:]: 26 | st += delim + "%03.3f" % i 27 | return st 28 | 29 | 30 | def write_plans_to_file(f, plans_file, stage=0, do_linebreak_at_end=True, override_name=None): 31 | a = load_pickle(plans_file) 32 | stages = list(a['plans_per_stage'].keys()) 33 | stages.sort() 34 | patch_size_in_mm = [i * j for i, j in zip(a['plans_per_stage'][stages[stage]]['patch_size'], 35 | a['plans_per_stage'][stages[stage]]['current_spacing'])] 36 | median_patient_size_in_mm = [i * j for i, j in zip(a['plans_per_stage'][stages[stage]]['median_patient_size_in_voxels'], 37 | a['plans_per_stage'][stages[stage]]['current_spacing'])] 38 | if override_name is None: 39 | f.write(plans_file.split("/")[-2] + "__" + plans_file.split("/")[-1]) 40 | else: 41 | f.write(override_name) 42 | f.write(";%d" % stage) 43 | f.write(";%s" % str(a['plans_per_stage'][stages[stage]]['batch_size'])) 44 | f.write(";%s" % str(a['plans_per_stage'][stages[stage]]['num_pool_per_axis'])) 45 | f.write(";%s" % str(a['plans_per_stage'][stages[stage]]['patch_size'])) 46 | f.write(";%s" % list_to_string(patch_size_in_mm)) 47 | f.write(";%s" % str(a['plans_per_stage'][stages[stage]]['median_patient_size_in_voxels'])) 48 | f.write(";%s" % list_to_string(median_patient_size_in_mm)) 49 | f.write(";%s" % list_to_string(a['plans_per_stage'][stages[stage]]['current_spacing'])) 50 | f.write(";%s" % list_to_string(a['plans_per_stage'][stages[stage]]['original_spacing'])) 51 | f.write(";%s" % str(a['plans_per_stage'][stages[stage]]['pool_op_kernel_sizes'])) 52 | f.write(";%s" % str(a['plans_per_stage'][stages[stage]]['conv_kernel_sizes'])) 53 | if do_linebreak_at_end: 54 | f.write("\n") 55 | 56 | 57 | if __name__ == "__main__": 58 | summarize((1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 24, 27), output_dir=join(network_training_output_dir, "summary_fold0"), folds=(0,)) 59 | base_dir = os.environ['RESULTS_FOLDER'] 60 | nnformers = ['nnFormerV2', 'nnFormerV2_zspacing'] 61 | task_ids = list(range(99)) 62 | with open("summary.csv", 'w') as f: 63 | f.write("identifier;stage;batch_size;num_pool_per_axis;patch_size;patch_size(mm);median_patient_size_in_voxels;median_patient_size_in_mm;current_spacing;original_spacing;pool_op_kernel_sizes;conv_kernel_sizes;patient_dc;global_dc\n") 64 | for i in task_ids: 65 | for nnformer in nnformers: 66 | try: 67 | summary_folder = join(base_dir, nnformer, "summary_fold0") 68 | if isdir(summary_folder): 69 | summary_files = subfiles(summary_folder, join=False, prefix="Task%03.0d_" % i, suffix=".json", sort=True) 70 | for s in summary_files: 71 | tmp = s.split("__") 72 | trainer = tmp[2] 73 | 74 | expected_output_folder = join(base_dir, nnformer, tmp[1], tmp[0], tmp[2].split(".")[0]) 75 | name = tmp[0] + "__" + nnformer + "__" + tmp[1] + "__" + tmp[2].split(".")[0] 76 | global_dice_json = join(base_dir, nnformer, tmp[1], tmp[0], tmp[2].split(".")[0], "fold_0", "validation_tiledTrue_doMirror_True", "global_dice.json") 77 | 78 | if not isdir(expected_output_folder) or len(tmp) > 3: 79 | if len(tmp) == 2: 80 | continue 81 | expected_output_folder = join(base_dir, nnformer, tmp[1], tmp[0], tmp[2] + "__" + tmp[3].split(".")[0]) 82 | name = tmp[0] + "__" + nnformer + "__" + tmp[1] + "__" + tmp[2] + "__" + tmp[3].split(".")[0] 83 | global_dice_json = join(base_dir, nnformer, tmp[1], tmp[0], tmp[2] + "__" + tmp[3].split(".")[0], "fold_0", "validation_tiledTrue_doMirror_True", "global_dice.json") 84 | 85 | assert isdir(expected_output_folder), "expected output dir not found" 86 | plans_file = join(expected_output_folder, "plans.pkl") 87 | assert isfile(plans_file) 88 | 89 | plans = load_pickle(plans_file) 90 | num_stages = len(plans['plans_per_stage']) 91 | if num_stages > 1 and tmp[1] == "3d_fullres": 92 | stage = 1 93 | elif (num_stages == 1 and tmp[1] == "3d_fullres") or tmp[1] == "3d_lowres": 94 | stage = 0 95 | else: 96 | print("skipping", s) 97 | continue 98 | 99 | g_dc = load_json(global_dice_json) 100 | mn_glob_dc = np.mean(list(g_dc.values())) 101 | 102 | write_plans_to_file(f, plans_file, stage, False, name) 103 | # now read and add result to end of line 104 | results = load_json(join(summary_folder, s)) 105 | mean_dc = results['results']['mean']['mean']['Dice'] 106 | f.write(";%03.3f" % mean_dc) 107 | f.write(";%03.3f\n" % mn_glob_dc) 108 | print(name, mean_dc) 109 | except Exception as e: 110 | print(e) 111 | -------------------------------------------------------------------------------- /unetr_pp/network_architecture/synapse/unetr_pp_synapse.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | from typing import Tuple, Union 3 | from unetr_pp.network_architecture.neural_network import SegmentationNetwork 4 | from unetr_pp.network_architecture.dynunet_block import UnetOutBlock, UnetResBlock 5 | from unetr_pp.network_architecture.synapse.model_components import UnetrPPEncoder, UnetrUpBlock 6 | 7 | 8 | class UNETR_PP(SegmentationNetwork): 9 | """ 10 | UNETR++ based on: "Shaker et al., 11 | UNETR++: Delving into Efficient and Accurate 3D Medical Image Segmentation" 12 | """ 13 | 14 | def __init__( 15 | self, 16 | in_channels: int, 17 | out_channels: int, 18 | img_size: [64, 128, 128], 19 | feature_size: int = 16, 20 | hidden_size: int = 256, 21 | num_heads: int = 4, 22 | pos_embed: str = "perceptron", # TODO: Remove the argument 23 | norm_name: Union[Tuple, str] = "instance", 24 | dropout_rate: float = 0.0, 25 | depths=None, 26 | dims=None, 27 | conv_op=nn.Conv3d, 28 | do_ds=True, 29 | 30 | ) -> None: 31 | """ 32 | Args: 33 | in_channels: dimension of input channels. 34 | out_channels: dimension of output channels. 35 | img_size: dimension of input image. 36 | feature_size: dimension of network feature size. 37 | hidden_size: dimension of the last encoder. 38 | num_heads: number of attention heads. 39 | pos_embed: position embedding layer type. 40 | norm_name: feature normalization type and arguments. 41 | dropout_rate: faction of the input units to drop. 42 | depths: number of blocks for each stage. 43 | dims: number of channel maps for the stages. 44 | conv_op: type of convolution operation. 45 | do_ds: use deep supervision to compute the loss. 46 | 47 | Examples:: 48 | 49 | # for single channel input 4-channel output with patch size of (64, 128, 128), feature size of 16, batch 50 | norm and depths of [3, 3, 3, 3] with output channels [32, 64, 128, 256], 4 heads, and 14 classes with 51 | deep supervision: 52 | >>> net = UNETR_PP(in_channels=1, out_channels=14, img_size=(64, 128, 128), feature_size=16, num_heads=4, 53 | >>> norm_name='batch', depths=[3, 3, 3, 3], dims=[32, 64, 128, 256], do_ds=True) 54 | """ 55 | 56 | super().__init__() 57 | if depths is None: 58 | depths = [3, 3, 3, 3] 59 | self.do_ds = do_ds 60 | self.conv_op = conv_op 61 | self.num_classes = out_channels 62 | if not (0 <= dropout_rate <= 1): 63 | raise AssertionError("dropout_rate should be between 0 and 1.") 64 | 65 | if pos_embed not in ["conv", "perceptron"]: 66 | raise KeyError(f"Position embedding layer of type {pos_embed} is not supported.") 67 | 68 | self.patch_size = (2, 4, 4) 69 | self.feat_size = ( 70 | img_size[0] // self.patch_size[0] // 8, # 8 is the downsampling happened through the four encoders stages 71 | img_size[1] // self.patch_size[1] // 8, # 8 is the downsampling happened through the four encoders stages 72 | img_size[2] // self.patch_size[2] // 8, # 8 is the downsampling happened through the four encoders stages 73 | ) 74 | self.hidden_size = hidden_size 75 | 76 | self.unetr_pp_encoder = UnetrPPEncoder(dims=dims, depths=depths, num_heads=num_heads) 77 | 78 | self.encoder1 = UnetResBlock( 79 | spatial_dims=3, 80 | in_channels=in_channels, 81 | out_channels=feature_size, 82 | kernel_size=3, 83 | stride=1, 84 | norm_name=norm_name, 85 | ) 86 | self.decoder5 = UnetrUpBlock( 87 | spatial_dims=3, 88 | in_channels=feature_size * 16, 89 | out_channels=feature_size * 8, 90 | kernel_size=3, 91 | upsample_kernel_size=2, 92 | norm_name=norm_name, 93 | out_size=8 * 8 * 8, 94 | ) 95 | self.decoder4 = UnetrUpBlock( 96 | spatial_dims=3, 97 | in_channels=feature_size * 8, 98 | out_channels=feature_size * 4, 99 | kernel_size=3, 100 | upsample_kernel_size=2, 101 | norm_name=norm_name, 102 | out_size=16 * 16 * 16, 103 | ) 104 | self.decoder3 = UnetrUpBlock( 105 | spatial_dims=3, 106 | in_channels=feature_size * 4, 107 | out_channels=feature_size * 2, 108 | kernel_size=3, 109 | upsample_kernel_size=2, 110 | norm_name=norm_name, 111 | out_size=32 * 32 * 32, 112 | ) 113 | self.decoder2 = UnetrUpBlock( 114 | spatial_dims=3, 115 | in_channels=feature_size * 2, 116 | out_channels=feature_size, 117 | kernel_size=3, 118 | upsample_kernel_size=(2, 4, 4), 119 | norm_name=norm_name, 120 | out_size=64 * 128 * 128, 121 | conv_decoder=True, 122 | ) 123 | self.out1 = UnetOutBlock(spatial_dims=3, in_channels=feature_size, out_channels=out_channels) 124 | if self.do_ds: 125 | self.out2 = UnetOutBlock(spatial_dims=3, in_channels=feature_size * 2, out_channels=out_channels) 126 | self.out3 = UnetOutBlock(spatial_dims=3, in_channels=feature_size * 4, out_channels=out_channels) 127 | 128 | def proj_feat(self, x, hidden_size, feat_size): 129 | x = x.view(x.size(0), feat_size[0], feat_size[1], feat_size[2], hidden_size) 130 | x = x.permute(0, 4, 1, 2, 3).contiguous() 131 | return x 132 | 133 | def forward(self, x_in): 134 | x_output, hidden_states = self.unetr_pp_encoder(x_in) 135 | 136 | convBlock = self.encoder1(x_in) 137 | 138 | # Four encoders 139 | enc1 = hidden_states[0] 140 | enc2 = hidden_states[1] 141 | enc3 = hidden_states[2] 142 | enc4 = hidden_states[3] 143 | 144 | # Four decoders 145 | dec4 = self.proj_feat(enc4, self.hidden_size, self.feat_size) 146 | dec3 = self.decoder5(dec4, enc3) 147 | dec2 = self.decoder4(dec3, enc2) 148 | dec1 = self.decoder3(dec2, enc1) 149 | 150 | out = self.decoder2(dec1, convBlock) 151 | if self.do_ds: 152 | logits = [self.out1(out), self.out2(dec1), self.out3(dec2)] 153 | else: 154 | logits = self.out1(out) 155 | 156 | return logits 157 | --------------------------------------------------------------------------------