├── src ├── dafne │ ├── bin │ │ ├── __init__.py │ │ ├── batch_validate_ui.py │ │ ├── calc_transforms.py │ │ ├── edit_config.py │ │ ├── batch_validate.py │ │ └── dafne.py │ ├── resources │ │ ├── __init__.py │ │ ├── circle.png │ │ ├── square.png │ │ ├── dafne_anim.gif │ │ ├── dafne_logo.png │ │ └── runtime_dependencies.cfg │ ├── config │ │ ├── version.py │ │ └── __init__.py │ ├── MedSAM │ │ ├── segment_anything │ │ │ ├── utils │ │ │ │ ├── __init__.py │ │ │ │ ├── transforms.py │ │ │ │ └── onnx.py │ │ │ ├── modeling │ │ │ │ ├── __init__.py │ │ │ │ ├── common.py │ │ │ │ └── mask_decoder.py │ │ │ ├── __init__.py │ │ │ └── build_sam.py │ │ ├── utils │ │ │ ├── ckpt_convert.py │ │ │ ├── pre_grey_rgb.py │ │ │ ├── format_convert.py │ │ │ ├── split.py │ │ │ ├── README.md │ │ │ └── pre_CT_MR.py │ │ └── MedSAM_Inference.py │ ├── __init__.py │ ├── utils │ │ ├── open_folder.py │ │ ├── __init__.py │ │ ├── resource_utils.py │ │ ├── log.py │ │ ├── compressed_pickle.py │ │ ├── ThreadHelpers.py │ │ ├── mask_utils.py │ │ └── polyToMask.py │ └── ui │ │ ├── __init__.py │ │ ├── LogWindow.py │ │ ├── LogWindowUI.ui │ │ ├── ModelBrowser.ui │ │ ├── ModelBrowserUI.py │ │ ├── LogWindowUI.py │ │ ├── CalcTransformsUI.ui │ │ ├── ContourPainter.py │ │ ├── CalcTransformsUI.py │ │ ├── WhatsNew.py │ │ ├── BatchCalcTransforms.py │ │ ├── BrushPatches.py │ │ ├── ValidateUI.ui │ │ ├── ValidateUI.py │ │ └── Viewer3D.py └── __init__.py ├── icons ├── dafne_icon.icns ├── dafne_icon.ico ├── dafne_icon.png ├── dafne_icon1024.png ├── mac_installer_bg.png ├── calctransform_ico.ico ├── calctransform_ico.png ├── mac_installer_bg.svg └── dafne_icon.svg ├── pyproject.toml ├── use_local_directories_no ├── .gitmodules ├── install_scripts ├── create_windows_installer.bat ├── create_linux_installer.sh ├── entitlements.plist ├── make_mac_icons.sh ├── update_version.py ├── dafne_linux.spec ├── dafne_win.iss ├── dafne_win.spec ├── create_mac_installer.sh ├── dafne_mac.spec └── fix_app_bundle_for_mac.py ├── pyinstaller_hooks ├── hook-pydicom.py ├── hook-dosma.py ├── hook-dafne_dl.py └── hook-dafne.py ├── requirements.txt ├── create_new_version.sh ├── batch_validate.py ├── calc_transforms.py ├── batch_validate_ui.py ├── edit_config.py ├── dafne ├── test ├── __init__.py ├── test_seg.py ├── plotSegmentations.py ├── testDL.py ├── testILearn.py ├── testILearn_AA.py ├── testILearn_split_AA.py ├── validation_ILearn.py └── validation_split_ILearn.py ├── setup.cfg ├── .gitignore └── README.md /src/dafne/bin/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022 Dafne-Imaging Team 2 | -------------------------------------------------------------------------------- /src/dafne/resources/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022 Dafne-Imaging Team 2 | -------------------------------------------------------------------------------- /src/dafne/config/version.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022 Dafne-Imaging Team 2 | VERSION='1.9-alpha' -------------------------------------------------------------------------------- /icons/dafne_icon.icns: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dafne-imaging/dafne/HEAD/icons/dafne_icon.icns -------------------------------------------------------------------------------- /icons/dafne_icon.ico: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dafne-imaging/dafne/HEAD/icons/dafne_icon.ico -------------------------------------------------------------------------------- /icons/dafne_icon.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dafne-imaging/dafne/HEAD/icons/dafne_icon.png -------------------------------------------------------------------------------- /icons/dafne_icon1024.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dafne-imaging/dafne/HEAD/icons/dafne_icon1024.png -------------------------------------------------------------------------------- /src/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022 Dafne-Imaging Team 2 | 3 | from .config.version import VERSION -------------------------------------------------------------------------------- /icons/mac_installer_bg.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dafne-imaging/dafne/HEAD/icons/mac_installer_bg.png -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["setuptools>=42"] 3 | build-backend = "setuptools.build_meta" -------------------------------------------------------------------------------- /icons/calctransform_ico.ico: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dafne-imaging/dafne/HEAD/icons/calctransform_ico.ico -------------------------------------------------------------------------------- /icons/calctransform_ico.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dafne-imaging/dafne/HEAD/icons/calctransform_ico.png -------------------------------------------------------------------------------- /src/dafne/resources/circle.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dafne-imaging/dafne/HEAD/src/dafne/resources/circle.png -------------------------------------------------------------------------------- /src/dafne/resources/square.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dafne-imaging/dafne/HEAD/src/dafne/resources/square.png -------------------------------------------------------------------------------- /src/dafne/resources/dafne_anim.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dafne-imaging/dafne/HEAD/src/dafne/resources/dafne_anim.gif -------------------------------------------------------------------------------- /src/dafne/resources/dafne_logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dafne-imaging/dafne/HEAD/src/dafne/resources/dafne_logo.png -------------------------------------------------------------------------------- /use_local_directories_no: -------------------------------------------------------------------------------- 1 | #This is a dummy file. If present, the configuration will use the local directories instead of the system-defined ones. -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "dl"] 2 | path = dl 3 | url = https://github.com/dafne-imaging/dafne-dl.git 4 | branch = master 5 | update = merge 6 | -------------------------------------------------------------------------------- /install_scripts/create_windows_installer.bat: -------------------------------------------------------------------------------- 1 | python -V 2 | python update_version.py 3 | pyinstaller dafne_win.spec --noconfirm 4 | "C:\Program Files (x86)\Inno Setup 6\Compil32.exe" /cc dafne_win.iss -------------------------------------------------------------------------------- /pyinstaller_hooks/hook-pydicom.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022 Dafne-Imaging Team 2 | # Hook file for pyinstaller 3 | 4 | from PyInstaller.utils.hooks import collect_submodules 5 | 6 | hiddenimports=collect_submodules('pydicom') -------------------------------------------------------------------------------- /pyinstaller_hooks/hook-dosma.py: -------------------------------------------------------------------------------- 1 | # Hook file for pyinstaller 2 | 3 | from PyInstaller.utils.hooks import collect_submodules, collect_data_files 4 | 5 | hiddenimports=collect_submodules('dosma') 6 | datas = collect_data_files('dosma') -------------------------------------------------------------------------------- /pyinstaller_hooks/hook-dafne_dl.py: -------------------------------------------------------------------------------- 1 | # Hook file for pyinstaller 2 | 3 | from PyInstaller.utils.hooks import collect_submodules 4 | 5 | hiddenimports=collect_submodules('dafne_dl') + \ 6 | collect_submodules('skimage.filters') -------------------------------------------------------------------------------- /src/dafne/MedSAM/segment_anything/utils/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # All rights reserved. 4 | 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | -------------------------------------------------------------------------------- /install_scripts/create_linux_installer.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # 3 | # Copyright (c) 2022 Dafne-Imaging Team 4 | # 5 | 6 | VERSION=`python update_version.py | tail -n1` 7 | echo $VERSION 8 | ../venv_system/bin/pyinstaller dafne_linux.spec --noconfirm 9 | cd dist 10 | mv dafne "dafne_linux_$VERSION" -------------------------------------------------------------------------------- /pyinstaller_hooks/hook-dafne.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022 Dafne-Imaging Team 2 | # Copyright (c) 2022 Dafne-Imaging Team 3 | # Hook file for pyinstaller 4 | 5 | from PyInstaller.utils.hooks import collect_submodules 6 | 7 | hiddenimports=collect_submodules('dafne') + \ 8 | collect_submodules('skimage.segmentation') -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | h5py 2 | numpy 3 | scipy 4 | matplotlib 5 | pyqt5 6 | nibabel 7 | pydicom 8 | dill 9 | progress 10 | appdirs 11 | requests 12 | scikit-image 13 | ormir-pyvoxel 14 | importlib_resources ; python_version < "3.10" 15 | dafne-dl >= 1.4a2 16 | flexidep>=0.0.6 17 | pyvistaqt 18 | torch 19 | torchvision 20 | dafne-dicomUtils -------------------------------------------------------------------------------- /install_scripts/entitlements.plist: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | com.apple.security.cs.allow-jit 6 | 7 | com.apple.security.cs.allow-unsigned-executable-memory 8 | 9 | 10 | -------------------------------------------------------------------------------- /src/dafne/MedSAM/segment_anything/modeling/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # All rights reserved. 4 | 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | from .sam import Sam 9 | from .image_encoder import ImageEncoderViT 10 | from .mask_decoder import MaskDecoder 11 | from .prompt_encoder import PromptEncoder 12 | from .transformer import TwoWayTransformer 13 | -------------------------------------------------------------------------------- /src/dafne/MedSAM/segment_anything/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # All rights reserved. 4 | 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | from .build_sam import ( 9 | build_sam, 10 | build_sam_vit_h, 11 | build_sam_vit_l, 12 | build_sam_vit_b, 13 | sam_model_registry, 14 | ) 15 | from .predictor import SamPredictor 16 | from .automatic_mask_generator import SamAutomaticMaskGenerator 17 | -------------------------------------------------------------------------------- /src/dafne/MedSAM/utils/ckpt_convert.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import torch 3 | 4 | # %% convert medsam model checkpoint to sam checkpoint format for convenient inference 5 | sam_ckpt_path = "" 6 | medsam_ckpt_path = "" 7 | save_path = "" 8 | multi_gpu_ckpt = True # set as True if the model is trained with multi-gpu 9 | 10 | sam_ckpt = torch.load(sam_ckpt_path) 11 | medsam_ckpt = torch.load(medsam_ckpt_path) 12 | sam_keys = sam_ckpt.keys() 13 | for key in sam_keys: 14 | if not multi_gpu_ckpt: 15 | sam_ckpt[key] = medsam_ckpt["model"][key] 16 | else: 17 | sam_ckpt[key] = medsam_ckpt["model"]["module." + key] 18 | 19 | torch.save(sam_ckpt, save_path) 20 | -------------------------------------------------------------------------------- /create_new_version.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # 3 | # Copyright (c) 2022 Dafne-Imaging Team 4 | # 5 | 6 | new_version=$1 7 | version_file=src/dafne/config/version.py 8 | current_version=`tail -n 1 $version_file | sed "s/VERSION='\(.*\)'/\1/"` 9 | 10 | if [ "$new_version" == "" ]; then 11 | echo "Usage: $0 " 12 | echo "Current version: $current_version" 13 | exit 1 14 | fi 15 | 16 | echo "# Copyright (c) 2022 Dafne-Imaging Team" > $version_file 17 | echo "# This file was auto-generated. Any changes might be overwritten" >> $version_file 18 | echo "VERSION='$new_version'" >> $version_file 19 | 20 | rm dist/* 21 | python -m build --sdist --wheel 22 | python -m twine upload dist/* 23 | -------------------------------------------------------------------------------- /src/dafne/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022 Dafne-Imaging Team 2 | 3 | from .config.version import VERSION 4 | __version__ = VERSION 5 | 6 | from . import resources 7 | 8 | import sys 9 | import flexidep 10 | 11 | assert sys.version_info.major == 3, "This software is only compatible with Python 3.x" 12 | 13 | if sys.version_info.minor < 10: 14 | import importlib_resources as pkg_resources 15 | else: 16 | import importlib.resources as pkg_resources 17 | 18 | # install the required resources 19 | if not flexidep.is_frozen(): 20 | with pkg_resources.files(resources).joinpath('runtime_dependencies.cfg').open() as f: 21 | dm = flexidep.DependencyManager(config_file=f) 22 | dm.install_interactive() 23 | -------------------------------------------------------------------------------- /src/dafne/utils/open_folder.py: -------------------------------------------------------------------------------- 1 | import subprocess 2 | import sys 3 | import os 4 | 5 | 6 | def open_folder(folder_path): 7 | if not os.path.isdir(folder_path): 8 | folder_path = os.path.dirname(folder_path) 9 | 10 | print('Opening folder:', folder_path) 11 | 12 | if sys.platform == 'win32': 13 | proc_name = 'explorer' 14 | elif sys.platform == 'darwin': 15 | proc_name = 'open' 16 | elif sys.platform.startswith('linux'): 17 | proc_name = 'xdg-open' 18 | else: 19 | raise NotImplementedError('Unsupported platform') 20 | 21 | try: 22 | subprocess.run([proc_name, folder_path]) 23 | except Exception as e: 24 | print(f'Error while opening folder: {e}') 25 | -------------------------------------------------------------------------------- /batch_validate.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022 Dafne-Imaging Team 2 | # generic stub for executable scripts residing in bin. 3 | # This code will execute the main function of a script residing in bin having the same name as the script. 4 | # The main function must be named "main" and must be in the global scope. 5 | 6 | import os 7 | import sys 8 | import importlib 9 | 10 | src_path = os.path.abspath(os.path.join(os.path.dirname(__file__), 'src')) 11 | if src_path not in sys.path: 12 | sys.path.append(src_path) 13 | 14 | this_script = os.path.basename(__file__) 15 | this_script_name = os.path.splitext(this_script)[0] 16 | 17 | import_module_name = f'dafne.bin.{this_script_name}' 18 | 19 | i = importlib.import_module(import_module_name) 20 | 21 | if __name__ == '__main__': 22 | i.main() -------------------------------------------------------------------------------- /calc_transforms.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022 Dafne-Imaging Team 2 | # generic stub for executable scripts residing in bin. 3 | # This code will execute the main function of a script residing in bin having the same name as the script. 4 | # The main function must be named "main" and must be in the global scope. 5 | 6 | import os 7 | import sys 8 | import importlib 9 | 10 | src_path = os.path.abspath(os.path.join(os.path.dirname(__file__), 'src')) 11 | if src_path not in sys.path: 12 | sys.path.append(src_path) 13 | 14 | this_script = os.path.basename(__file__) 15 | this_script_name = os.path.splitext(this_script)[0] 16 | 17 | import_module_name = f'dafne.bin.{this_script_name}' 18 | 19 | i = importlib.import_module(import_module_name) 20 | 21 | if __name__ == '__main__': 22 | i.main() -------------------------------------------------------------------------------- /batch_validate_ui.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022 Dafne-Imaging Team 2 | # generic stub for executable scripts residing in bin. 3 | # This code will execute the main function of a script residing in bin having the same name as the script. 4 | # The main function must be named "main" and must be in the global scope. 5 | 6 | import os 7 | import sys 8 | import importlib 9 | 10 | src_path = os.path.abspath(os.path.join(os.path.dirname(__file__), 'src')) 11 | if src_path not in sys.path: 12 | sys.path.append(src_path) 13 | 14 | this_script = os.path.basename(__file__) 15 | this_script_name = os.path.splitext(this_script)[0] 16 | 17 | import_module_name = f'dafne.bin.{this_script_name}' 18 | 19 | i = importlib.import_module(import_module_name) 20 | 21 | if __name__ == '__main__': 22 | i.main() -------------------------------------------------------------------------------- /edit_config.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022 Dafne-Imaging Team 2 | # generic stub for executable scripts residing in bin. 3 | # This code will execute the main function of a script residing in bin having the same name as the script. 4 | # The main function must be named "main" and must be in the global scope. 5 | 6 | import os 7 | import sys 8 | import importlib 9 | 10 | src_path = os.path.abspath(os.path.join(os.path.dirname(__file__), 'src')) 11 | if src_path not in sys.path: 12 | sys.path.append(src_path) 13 | 14 | 15 | this_script = os.path.basename(__file__) 16 | this_script_name = os.path.splitext(this_script)[0] 17 | 18 | import_module_name = f'dafne.bin.{this_script_name}' 19 | 20 | i = importlib.import_module(import_module_name) 21 | 22 | if __name__ == '__main__': 23 | i.main() -------------------------------------------------------------------------------- /dafne: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # Copyright (c) 2022 Dafne-Imaging Team 3 | # generic stub for executable scripts residing in bin. 4 | # This code will execute the main function of a script residing in bin having the same name as the script. 5 | # The main function must be named "main" and must be in the global scope. 6 | 7 | import os 8 | import sys 9 | import importlib 10 | 11 | src_path = os.path.abspath(os.path.join(os.path.dirname(__file__), 'src')) 12 | if src_path not in sys.path: 13 | sys.path.append(src_path) 14 | 15 | this_script = os.path.basename(__file__) 16 | this_script_name = os.path.splitext(this_script)[0] 17 | 18 | import_module_name = f'dafne.bin.{this_script_name}' 19 | 20 | i = importlib.import_module(import_module_name) 21 | 22 | if __name__ == '__main__': 23 | i.main() -------------------------------------------------------------------------------- /test/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021 Dafne-Imaging Team 2 | # 3 | # This program is free software: you can redistribute it and/or modify 4 | # it under the terms of the GNU General Public License as published by 5 | # the Free Software Foundation, either version 3 of the License, or 6 | # (at your option) any later version. 7 | # 8 | # This program is distributed in the hope that it will be useful, 9 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 10 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 11 | # GNU General Public License for more details. 12 | # 13 | # You should have received a copy of the GNU General Public License 14 | # along with this program. If not, see . 15 | 16 | #!/usr/bin/env python3 17 | # -*- coding: utf-8 -*- 18 | -------------------------------------------------------------------------------- /src/dafne/config/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021 Dafne-Imaging Team 2 | # 3 | # This program is free software: you can redistribute it and/or modify 4 | # it under the terms of the GNU General Public License as published by 5 | # the Free Software Foundation, either version 3 of the License, or 6 | # (at your option) any later version. 7 | # 8 | # This program is distributed in the hope that it will be useful, 9 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 10 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 11 | # GNU General Public License for more details. 12 | # 13 | # You should have received a copy of the GNU General Public License 14 | # along with this program. If not, see . 15 | 16 | from .config import * 17 | from .version import VERSION -------------------------------------------------------------------------------- /src/dafne/ui/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021 Dafne-Imaging Team 2 | # 3 | # This program is free software: you can redistribute it and/or modify 4 | # it under the terms of the GNU General Public License as published by 5 | # the Free Software Foundation, either version 3 of the License, or 6 | # (at your option) any later version. 7 | # 8 | # This program is distributed in the hope that it will be useful, 9 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 10 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 11 | # GNU General Public License for more details. 12 | # 13 | # You should have received a copy of the GNU General Public License 14 | # along with this program. If not, see . 15 | 16 | #!/usr/bin/env python3 17 | # -*- coding: utf-8 -*- 18 | -------------------------------------------------------------------------------- /src/dafne/utils/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021 Dafne-Imaging Team 2 | # 3 | # This program is free software: you can redistribute it and/or modify 4 | # it under the terms of the GNU General Public License as published by 5 | # the Free Software Foundation, either version 3 of the License, or 6 | # (at your option) any later version. 7 | # 8 | # This program is distributed in the hope that it will be useful, 9 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 10 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 11 | # GNU General Public License for more details. 12 | # 13 | # You should have received a copy of the GNU General Public License 14 | # along with this program. If not, see . 15 | 16 | #!/usr/bin/env python3 17 | # -*- coding: utf-8 -*- 18 | -------------------------------------------------------------------------------- /src/dafne/utils/resource_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022 Dafne-Imaging Team 2 | import os 3 | import sys 4 | 5 | assert sys.version_info.major == 3, "This software is only compatible with Python 3.x" 6 | 7 | if sys.version_info.minor < 10: 8 | import importlib_resources as pkg_resources 9 | else: 10 | import importlib.resources as pkg_resources 11 | 12 | from contextlib import contextmanager 13 | 14 | from .. import resources 15 | 16 | @contextmanager 17 | def get_resource_path(resource_name): 18 | if getattr(sys, '_MEIPASS', None): 19 | yield os.path.join(sys._MEIPASS, 'resources', resource_name) # PyInstaller support. If _MEIPASS is set, we are in a Pyinstaller environment 20 | else: 21 | with pkg_resources.as_file(pkg_resources.files(resources).joinpath(resource_name)) as resource: 22 | yield str(resource) 23 | -------------------------------------------------------------------------------- /src/dafne/resources/runtime_dependencies.cfg: -------------------------------------------------------------------------------- 1 | [Global] 2 | interactive initialization = False 3 | use gui = Yes 4 | local install = False 5 | package manager = pip 6 | id = network.dafne.dafne 7 | optional packages = 8 | radiomics 9 | priority = radiomics 10 | 11 | [Packages] 12 | # in macos, install tensorflow-metal after tensorflow 13 | tensorflow = 14 | tensorflow ; platform_machine == 'x86_64' 15 | tensorflow_macos ++tensorflow-metal ; sys_platform == 'darwin' and platform_machine == 'arm64' and python_version <= "3.11" 16 | tensorflow_macos ; sys_platform == 'darwin' and platform_machine == 'arm64' and python_version > "3.11" 17 | 18 | SimpleITK = 19 | SimpleITK-SimpleElastix 20 | SimpleITK 21 | 22 | # uninstall SimpleITK after radiomics, as the user should choose to install SimpleITK-SimpleElastix instead 23 | radiomics = 24 | pyradiomics --SimpleITK -------------------------------------------------------------------------------- /src/dafne/bin/batch_validate_ui.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | # Copyright (c) 2021 Dafne-Imaging Team 4 | # 5 | # This program is free software: you can redistribute it and/or modify 6 | # it under the terms of the GNU General Public License as published by 7 | # the Free Software Foundation, either version 3 of the License, or 8 | # (at your option) any later version. 9 | # 10 | # This program is distributed in the hope that it will be useful, 11 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 12 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 13 | # GNU General Public License for more details. 14 | # 15 | # You should have received a copy of the GNU General Public License 16 | # along with this program. If not, see . 17 | 18 | from ..ui.BatchValidateWindow import run 19 | 20 | def main(): 21 | run() -------------------------------------------------------------------------------- /src/dafne/bin/calc_transforms.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | 4 | # Copyright (c) 2021 Dafne-Imaging Team 5 | # 6 | # This program is free software: you can redistribute it and/or modify 7 | # it under the terms of the GNU General Public License as published by 8 | # the Free Software Foundation, either version 3 of the License, or 9 | # (at your option) any later version. 10 | # 11 | # This program is distributed in the hope that it will be useful, 12 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 13 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 | # GNU General Public License for more details. 15 | # 16 | # You should have received a copy of the GNU General Public License 17 | # along with this program. If not, see . 18 | 19 | from ..ui.BatchCalcTransforms import run 20 | 21 | def main(): 22 | run() -------------------------------------------------------------------------------- /install_scripts/make_mac_icons.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | mkdir dafne_icon.iconset 4 | sips -z 16 16 dafne_icon1024.png --out dafne_icon.iconset/icon_16x16.png 5 | sips -z 32 32 dafne_icon1024.png --out dafne_icon.iconset/icon_16x16@2x.png 6 | sips -z 32 32 dafne_icon1024.png --out dafne_icon.iconset/icon_32x32.png 7 | sips -z 64 64 dafne_icon1024.png --out dafne_icon.iconset/icon_32x32@2x.png 8 | sips -z 128 128 dafne_icon1024.png --out dafne_icon.iconset/icon_128x128.png 9 | sips -z 256 256 dafne_icon1024.png --out dafne_icon.iconset/icon_128x128@2x.png 10 | sips -z 256 256 dafne_icon1024.png --out dafne_icon.iconset/icon_256x256.png 11 | sips -z 512 512 dafne_icon1024.png --out dafne_icon.iconset/icon_256x256@2x.png 12 | sips -z 512 512 dafne_icon1024.png --out dafne_icon.iconset/icon_512x512.png 13 | cp dafne_icon1024.png dafne_icon.iconset/icon_512x512@2x.png 14 | iconutil -c icns dafne_icon.iconset 15 | rm -R dafne_icon.iconset -------------------------------------------------------------------------------- /src/dafne/utils/log.py: -------------------------------------------------------------------------------- 1 | from PyQt5.QtCore import QObject, pyqtSignal 2 | 3 | 4 | class LogStream(QObject): 5 | 6 | updated = pyqtSignal(str) 7 | 8 | def __init__(self, file, old_descriptor = None, parent=None): 9 | super(LogStream, self).__init__(parent) 10 | self.fdesc = open(file, 'w') 11 | self.old_descriptor = old_descriptor 12 | self.data = '' 13 | 14 | def write(self, data): 15 | self.fdesc.write(data) 16 | self.fdesc.flush() 17 | if self.old_descriptor is not None: 18 | self.old_descriptor.write(data) 19 | self.old_descriptor.flush() 20 | self.data += data 21 | self.updated.emit(data) 22 | 23 | def writelines(self, lines): 24 | for line in lines: 25 | self.write(line) 26 | 27 | def __getattr__(self, item): 28 | return getattr(self.fdesc, item) 29 | 30 | def get_data(self): 31 | return self.data 32 | 33 | def close(self): 34 | self.fdesc.close() 35 | 36 | log_objects = {} 37 | -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [metadata] 2 | name = dafne 3 | version = attr: dafne.config.version.VERSION 4 | author = Francesco Santini 5 | author_email = francesco.santini@unibas.ch 6 | description =Dafne - Deep Anatomical Federated Network 7 | long_description = file: README.md 8 | long_description_content_type = text/markdown 9 | url = https://github.com/dafne-imaging/dafne 10 | project_urls = 11 | Bug Tracker = https://github.com/dafne-imaging/dafne/issues 12 | classifiers = 13 | Programming Language :: Python :: 3 14 | License :: OSI Approved :: GNU General Public License v3 or later (GPLv3+) 15 | Operating System :: OS Independent 16 | 17 | [options] 18 | package_dir = 19 | = src 20 | packages = find: 21 | include_package_data = True 22 | python_requires = >=3.6 23 | install_requires = file: requirements.txt 24 | 25 | [options.packages.find] 26 | where = src 27 | 28 | [options.package_data] 29 | dafne = resources/* 30 | 31 | [options.entry_points] 32 | console_scripts = 33 | dafne = dafne.bin.dafne:main 34 | dafne_calc_transforms = dafne.bin.calc_transforms:main 35 | dafne_edit_config = dafne.bin.edit_config:main -------------------------------------------------------------------------------- /test/test_seg.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | 4 | # Copyright (c) 2021 Dafne-Imaging Team 5 | # 6 | # This program is free software: you can redistribute it and/or modify 7 | # it under the terms of the GNU General Public License as published by 8 | # the Free Software Foundation, either version 3 of the License, or 9 | # (at your option) any later version. 10 | # 11 | # This program is distributed in the hope that it will be useful, 12 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 13 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 | # GNU General Public License for more details. 15 | # 16 | # You should have received a copy of the GNU General Public License 17 | # along with this program. If not, see . 18 | 19 | from test.testDL import testSegmentation 20 | from test.testDL import testClassification 21 | 22 | #testSegmentation('models/thigh.model', 'testImages/thigh_test.dcm') 23 | #testSegmentation('models/leg.model', 'testImages/leg_test.dcm') 24 | 25 | testClassification('models/classifier.model', 'testImages/thigh_test.dcm') 26 | testClassification('models/classifier.model', 'testImages/leg_test.dcm') -------------------------------------------------------------------------------- /install_scripts/update_version.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) 2022 Dafne-Imaging Team 3 | import shutil 4 | import sys 5 | import os 6 | 7 | sys.path.append(os.path.abspath(os.path.join('..', 'src'))) 8 | 9 | from dafne.config.version import VERSION 10 | 11 | shutil.move('dafne_win.iss', 'dafne_win.iss.old') 12 | 13 | with open('dafne_win.iss.old', 'r') as orig_file: 14 | with open('dafne_win.iss', 'w') as new_file: 15 | for line in orig_file: 16 | if line.startswith('#define MyAppVersion'): 17 | new_file.write(f'#define MyAppVersion "{VERSION}"\n') 18 | elif line.startswith('OutputBaseFilename='): 19 | new_file.write(f'OutputBaseFilename=dafne_windows_setup_{VERSION}\n') 20 | else: 21 | new_file.write(line) 22 | 23 | shutil.move('dafne_mac.spec', 'dafne_mac.spec.old') 24 | with open('dafne_mac.spec.old', 'r') as orig_file: 25 | with open('dafne_mac.spec', 'w') as new_file: 26 | for line in orig_file: 27 | if 'version=' in line: 28 | new_file.write(f" version='{VERSION}')\n") 29 | else: 30 | new_file.write(line) 31 | 32 | print(VERSION) -------------------------------------------------------------------------------- /test/plotSegmentations.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | 4 | # Copyright (c) 2021 Dafne-Imaging Team 5 | # 6 | # This program is free software: you can redistribute it and/or modify 7 | # it under the terms of the GNU General Public License as published by 8 | # the Free Software Foundation, either version 3 of the License, or 9 | # (at your option) any later version. 10 | # 11 | # This program is distributed in the hope that it will be useful, 12 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 13 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 | # GNU General Public License for more details. 15 | # 16 | # You should have received a copy of the GNU General Public License 17 | # along with this program. If not, see . 18 | 19 | import numpy as np 20 | import matplotlib.pyplot as plt 21 | 22 | def plotSegmentations(ima, segmentations): 23 | for label, mask in segmentations.items(): 24 | plt.figure() 25 | imaRGB = np.stack([ima, ima, ima], axis = -1) 26 | imaRGB = imaRGB / imaRGB.max() * 0.6 27 | imaRGB[:,:,0] = imaRGB[:,:,0] + 0.4 * mask 28 | plt.imshow(imaRGB) 29 | plt.title(label) -------------------------------------------------------------------------------- /src/dafne/bin/edit_config.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | 4 | # Copyright (c) 2022 Dafne-Imaging Team 5 | # 6 | # This program is free software: you can redistribute it and/or modify 7 | # it under the terms of the GNU General Public License as published by 8 | # the Free Software Foundation, either version 3 of the License, or 9 | # (at your option) any later version. 10 | # 11 | # This program is distributed in the hope that it will be useful, 12 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 13 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 | # GNU General Public License for more details. 15 | # 16 | # You should have received a copy of the GNU General Public License 17 | # along with this program. If not, see . 18 | 19 | from ..config.config import show_config_dialog, save_config, load_config 20 | 21 | import sys 22 | from PyQt5.QtWidgets import QApplication 23 | 24 | def main(): 25 | app = QApplication(sys.argv) 26 | app.setQuitOnLastWindowClosed(True) 27 | 28 | load_config() 29 | accepted = show_config_dialog(None, True) 30 | if accepted: 31 | save_config() 32 | print('Configuration saved') 33 | else: 34 | print('Aborted') -------------------------------------------------------------------------------- /src/dafne/utils/compressed_pickle.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021 Dafne-Imaging Team 2 | 3 | import pickle 4 | import bz2 5 | 6 | COMPRESS_LEVEL = 9 7 | 8 | 9 | def compressed_dump(obj, file, **kwargs): 10 | with bz2.BZ2File(file, 'wb', compresslevel=COMPRESS_LEVEL) as f: 11 | pickle.dump(obj, f, **kwargs) 12 | 13 | 14 | def compressed_dumps(obj, **kwargs): 15 | return bz2.compress(pickle.dumps(obj, **kwargs), compresslevel=COMPRESS_LEVEL) 16 | 17 | 18 | def compressed_load(file, **kwargs): 19 | with bz2.BZ2File(file, 'rb') as f: 20 | return pickle.load(f, **kwargs) 21 | 22 | 23 | def compressed_loads(compressed_bytes, **kwargs): 24 | return pickle.loads(bz2.decompress(compressed_bytes), **kwargs) 25 | 26 | 27 | def loads(byte_array, **kwargs): 28 | """ 29 | Generic loads to replace pickle load 30 | """ 31 | try: 32 | return compressed_loads(byte_array, **kwargs) 33 | except OSError: 34 | print("Loading uncompressed pickle") 35 | return pickle.loads(byte_array, **kwargs) 36 | 37 | 38 | def load(file, **kwargs): 39 | """ 40 | Generic load 41 | """ 42 | try: 43 | return compressed_load(file, **kwargs) 44 | except OSError: 45 | print("Loading uncompressed pickle") 46 | try: 47 | file.seek(0) 48 | except: 49 | pass 50 | return pickle.load(file, **kwargs) 51 | 52 | 53 | dump = compressed_dump 54 | dumps = compressed_dumps 55 | -------------------------------------------------------------------------------- /install_scripts/dafne_linux.spec: -------------------------------------------------------------------------------- 1 | # -*- mode: python ; coding: utf-8 -*- 2 | import sys ; sys.setrecursionlimit(sys.getrecursionlimit() * 5) 3 | 4 | block_cipher = None 5 | 6 | 7 | a = Analysis( 8 | ['../dafne'], 9 | pathex=['../src'], 10 | binaries=[], 11 | datas=[('../LICENSE', '.'), ('../src/dafne/resources/*', 'resources/')], 12 | hiddenimports=[ 13 | 'dafne', 14 | 'pydicom', 15 | 'SimpleITK', 16 | 'tensorflow', 17 | 'skimage', 18 | 'nibabel', 19 | 'dafne_dl', 20 | 'cmath', 21 | 'ormir-pyvoxel', 22 | 'pyvistaqt', 23 | 'pyvista', 24 | 'vtk', 25 | 'torch', 26 | 'torchvision'], 27 | hookspath=['../pyinstaller_hooks'], 28 | hooksconfig={}, 29 | runtime_hooks=[], 30 | excludes=[], 31 | win_no_prefer_redirects=False, 32 | win_private_assemblies=False, 33 | cipher=block_cipher, 34 | noarchive=False, 35 | ) 36 | pyz = PYZ(a.pure, a.zipped_data, cipher=block_cipher) 37 | 38 | exe = EXE( 39 | pyz, 40 | a.scripts, 41 | a.binaries, 42 | a.zipfiles, 43 | a.datas, 44 | [], 45 | name='dafne', 46 | debug=False, 47 | bootloader_ignore_signals=False, 48 | strip=False, 49 | upx=True, 50 | upx_exclude=[], 51 | runtime_tmpdir=None, 52 | console=False, 53 | disable_windowed_traceback=False, 54 | argv_emulation=False, 55 | target_arch=None, 56 | codesign_identity=None, 57 | entitlements_file=None, 58 | ) 59 | -------------------------------------------------------------------------------- /test/testDL.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | 4 | # Copyright (c) 2021 Dafne-Imaging Team 5 | # 6 | # This program is free software: you can redistribute it and/or modify 7 | # it under the terms of the GNU General Public License as published by 8 | # the Free Software Foundation, either version 3 of the License, or 9 | # (at your option) any later version. 10 | # 11 | # This program is distributed in the hope that it will be useful, 12 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 13 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 | # GNU General Public License for more details. 15 | # 16 | # You should have received a copy of the GNU General Public License 17 | # along with this program. If not, see . 18 | 19 | from src.dafne_dl import DynamicDLModel 20 | from dicomUtils import loadDicomFile 21 | #import numpy as np 22 | from test.plotSegmentations import plotSegmentations 23 | import matplotlib.pyplot as plt 24 | 25 | def testSegmentation(modelPath, dicomPath): 26 | thighModel = DynamicDLModel.Load(open(modelPath, 'rb')) 27 | 28 | ima, info = loadDicomFile(dicomPath) 29 | 30 | resolution = info.PixelSpacing 31 | out = thighModel({'image': ima, 'resolution': resolution}) 32 | 33 | plotSegmentations(ima, out) 34 | 35 | plt.show() 36 | 37 | def testClassification(modelPath, dicomPath): 38 | classModel = DynamicDLModel.Load(open(modelPath, 'rb')) 39 | 40 | ima, info = loadDicomFile(dicomPath) 41 | 42 | resolution = info.PixelSpacing 43 | out = classModel({'image': ima, 'resolution': resolution}) 44 | print(out) -------------------------------------------------------------------------------- /src/dafne/utils/ThreadHelpers.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021 Dafne-Imaging Team 2 | # 3 | # This program is free software: you can redistribute it and/or modify 4 | # it under the terms of the GNU General Public License as published by 5 | # the Free Software Foundation, either version 3 of the License, or 6 | # (at your option) any later version. 7 | # 8 | # This program is distributed in the hope that it will be useful, 9 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 10 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 11 | # GNU General Public License for more details. 12 | # 13 | # You should have received a copy of the GNU General Public License 14 | # along with this program. If not, see . 15 | 16 | from PyQt5.QtCore import QRunnable, pyqtSlot, QThreadPool 17 | from functools import wraps 18 | import traceback 19 | 20 | threadpool = QThreadPool() 21 | 22 | class Runner(QRunnable): 23 | 24 | def __init__(self, func, *args, **kwargs): 25 | QRunnable.__init__(self) 26 | self.func = func 27 | self.args = args 28 | self.kwargs = kwargs 29 | 30 | @pyqtSlot() 31 | def run(self): 32 | try: 33 | setattr(self.args[0], 'separate_thread_running', True) 34 | except: 35 | pass 36 | self.func(*self.args, **self.kwargs) 37 | try: 38 | setattr(self.args[0], 'separate_thread_running', False) 39 | except: 40 | pass 41 | 42 | def separate_thread_decorator(func): 43 | @wraps(func) 44 | def run_wrapper(*args, **kwargs): 45 | runner = Runner(func, *args, **kwargs) 46 | threadpool.start(runner) 47 | return run_wrapper -------------------------------------------------------------------------------- /src/dafne/MedSAM/segment_anything/modeling/common.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # All rights reserved. 4 | 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | import torch 9 | import torch.nn as nn 10 | 11 | from typing import Type 12 | 13 | 14 | class MLPBlock(nn.Module): 15 | def __init__( 16 | self, 17 | embedding_dim: int, 18 | mlp_dim: int, 19 | act: Type[nn.Module] = nn.GELU, 20 | ) -> None: 21 | super().__init__() 22 | self.lin1 = nn.Linear(embedding_dim, mlp_dim) 23 | self.lin2 = nn.Linear(mlp_dim, embedding_dim) 24 | self.act = act() 25 | 26 | def forward(self, x: torch.Tensor) -> torch.Tensor: 27 | return self.lin2(self.act(self.lin1(x))) 28 | 29 | 30 | # From https://github.com/facebookresearch/detectron2/blob/main/detectron2/layers/batch_norm.py # noqa 31 | # Itself from https://github.com/facebookresearch/ConvNeXt/blob/d1fa8f6fef0a165b27399986cc2bdacc92777e40/models/convnext.py#L119 # noqa 32 | class LayerNorm2d(nn.Module): 33 | def __init__(self, num_channels: int, eps: float = 1e-6) -> None: 34 | super().__init__() 35 | self.weight = nn.Parameter(torch.ones(num_channels)) 36 | self.bias = nn.Parameter(torch.zeros(num_channels)) 37 | self.eps = eps 38 | 39 | def forward(self, x: torch.Tensor) -> torch.Tensor: 40 | u = x.mean(1, keepdim=True) 41 | s = (x - u).pow(2).mean(1, keepdim=True) 42 | x = (x - u) / torch.sqrt(s + self.eps) 43 | x = self.weight[:, None, None] * x + self.bias[:, None, None] 44 | return x 45 | -------------------------------------------------------------------------------- /src/dafne/ui/LogWindow.py: -------------------------------------------------------------------------------- 1 | from PyQt5 import Qt 2 | from PyQt5.QtCore import pyqtSlot 3 | from PyQt5.QtWidgets import QDialog 4 | 5 | from .LogWindowUI import Ui_LogWindow 6 | 7 | from ..utils.log import log_objects 8 | 9 | 10 | class LogWindow(QDialog, Ui_LogWindow): 11 | def __init__(self, parent=None): 12 | super(LogWindow, self).__init__(parent) 13 | self.setupUi(self) 14 | self.setWindowTitle("Log") 15 | self.refresh_btn.clicked.connect(self.refresh) 16 | self.resize(1024, 768) 17 | self.refresh() 18 | try: 19 | log_objects['stdout'].updated.connect(self.append_output) 20 | except KeyError: 21 | pass 22 | 23 | try: 24 | log_objects['stderr'].updated.connect(self.append_error) 25 | except KeyError: 26 | pass 27 | 28 | @pyqtSlot(str) 29 | def append_output(self, data): 30 | self.output_text.moveCursor(Qt.QTextCursor.End) 31 | self.output_text.insertPlainText(data) 32 | self.output_text.moveCursor(Qt.QTextCursor.End) 33 | 34 | @pyqtSlot(str) 35 | def append_error(self, data): 36 | self.error_text.moveCursor(Qt.QTextCursor.End) 37 | self.error_text.insertPlainText(data) 38 | self.error_text.moveCursor(Qt.QTextCursor.End) 39 | 40 | def refresh(self): 41 | self.output_text.clear() 42 | self.error_text.clear() 43 | 44 | if 'stdout' not in log_objects: 45 | self.output_text.appendPlainText("Not available") 46 | else: 47 | self.output_text.appendPlainText(log_objects['stdout'].get_data()) 48 | 49 | if 'stderr' not in log_objects: 50 | self.error_text.appendPlainText("Not available") 51 | else: 52 | self.error_text.appendPlainText(log_objects['stderr'].get_data()) 53 | 54 | self.output_text.moveCursor(Qt.QTextCursor.End) 55 | self.error_text.moveCursor(Qt.QTextCursor.End) 56 | 57 | 58 | -------------------------------------------------------------------------------- /test/testILearn.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # Copyright (c) 2021 Dafne-Imaging Team 4 | # 5 | # This program is free software: you can redistribute it and/or modify 6 | # it under the terms of the GNU General Public License as published by 7 | # the Free Software Foundation, either version 3 of the License, or 8 | # (at your option) any later version. 9 | # 10 | # This program is distributed in the hope that it will be useful, 11 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 12 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 13 | # GNU General Public License for more details. 14 | # 15 | # You should have received a copy of the GNU General Public License 16 | # along with this program. If not, see . 17 | 18 | from src.dafne_dl import LocalModelProvider 19 | import numpy as np 20 | import pickle 21 | 22 | GENERATE_PICKLE = False 23 | 24 | model_provider = LocalModelProvider('models') 25 | segmenter = model_provider.load_model('Thigh') 26 | 27 | data_in = np.load('testImages/test_data.npy') 28 | segment_in = np.load('testImages/test_segment.npz') 29 | 30 | n_slices = data_in.shape[2] 31 | resolution = np.array([1.037037, 1.037037]) 32 | 33 | if GENERATE_PICKLE: 34 | print("Generating data") 35 | slice_range = range(4,n_slices-5) 36 | image_list = [] 37 | seg_list = [] 38 | 39 | for slice in slice_range: 40 | image_list.append(data_in[:,:,slice].squeeze()) 41 | seg_dict = {} 42 | for roi_name in segment_in: 43 | seg_dict[roi_name] = segment_in[roi_name][:,:,slice] 44 | seg_list.append(seg_dict) 45 | 46 | pickle.dump(seg_list, open('testImages/test_segment.pickle', 'wb')) 47 | pickle.dump(image_list, open('testImages/test_data.pickle', 'wb')) 48 | else: 49 | seg_list = pickle.load(open('testImages/test_segment.pickle', 'rb')) 50 | image_list = pickle.load(open('testImages/test_data.pickle', 'rb')) 51 | 52 | print('Performing incremental learning') 53 | segmenter.incremental_learn({'image_list': image_list, 'resolution': resolution}, seg_list) 54 | -------------------------------------------------------------------------------- /install_scripts/dafne_win.iss: -------------------------------------------------------------------------------- 1 | ; Script generated by the Inno Setup Script Wizard. 2 | ; SEE THE DOCUMENTATION FOR DETAILS ON CREATING INNO SETUP SCRIPT FILES! 3 | 4 | #define MyAppName "Dafne" 5 | #define MyAppVersion "1.5-alpha" 6 | #define MyAppPublisher "Dafne-imaging" 7 | #define MyAppURL "https://dafne.network/" 8 | #define MyAppExeName "dafne.exe" 9 | 10 | [Setup] 11 | ; NOTE: The value of AppId uniquely identifies this application. Do not use the same AppId value in installers for other applications. 12 | ; (To generate a new GUID, click Tools | Generate GUID inside the IDE.) 13 | AppId={{451322B2-10C5-4BA0-88DC-BB8933F78678} 14 | AppName={#MyAppName} 15 | AppVersion={#MyAppVersion} 16 | ;AppVerName={#MyAppName} {#MyAppVersion} 17 | AppPublisher={#MyAppPublisher} 18 | AppPublisherURL={#MyAppURL} 19 | AppSupportURL={#MyAppURL} 20 | AppUpdatesURL={#MyAppURL} 21 | ArchitecturesAllowed=x64 22 | ArchitecturesInstallIn64BitMode=x64 23 | DefaultDirName={autopf}\{#MyAppName} 24 | DisableProgramGroupPage=auto 25 | DefaultGroupName={#MyAppName} 26 | LicenseFile=dist\dafne\LICENSE 27 | ; Uncomment the following line to run in non administrative install mode (install for current user only.) 28 | ;PrivilegesRequired=lowest 29 | PrivilegesRequiredOverridesAllowed=dialog 30 | OutputDir=C:\dafne 31 | OutputBaseFilename=dafne_windows_setup_1.5-alpha 32 | SetupIconFile=..\icons\dafne_icon.ico 33 | Compression=lzma 34 | SolidCompression=yes 35 | WizardStyle=modern 36 | 37 | [Languages] 38 | Name: "english"; MessagesFile: "compiler:Default.isl" 39 | 40 | [Tasks] 41 | Name: "desktopicon"; Description: "{cm:CreateDesktopIcon}"; GroupDescription: "{cm:AdditionalIcons}"; Flags: unchecked 42 | 43 | [Files] 44 | Source: "dist\dafne\{#MyAppExeName}"; DestDir: "{app}"; Flags: ignoreversion 45 | Source: "dist\dafne\*"; DestDir: "{app}"; Flags: ignoreversion recursesubdirs createallsubdirs 46 | ; NOTE: Don't use "Flags: ignoreversion" on any shared system files 47 | 48 | [Icons] 49 | Name: "{group}\{#MyAppName}"; Filename: "{app}\{#MyAppExeName}" 50 | Name: "{autodesktop}\{#MyAppName}"; Filename: "{app}\{#MyAppExeName}"; Tasks: desktopicon 51 | 52 | [Run] 53 | Filename: "{app}\{#MyAppExeName}"; Description: "{cm:LaunchProgram,{#StringChange(MyAppName, '&', '&&')}}"; Flags: nowait postinstall skipifsilent 54 | 55 | -------------------------------------------------------------------------------- /src/dafne/ui/LogWindowUI.ui: -------------------------------------------------------------------------------- 1 | 2 | 3 | LogWindow 4 | 5 | 6 | 7 | 0 8 | 0 9 | 885 10 | 519 11 | 12 | 13 | 14 | Form 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | Output: 25 | 26 | 27 | 28 | 29 | 30 | 31 | false 32 | 33 | 34 | QPlainTextEdit::NoWrap 35 | 36 | 37 | true 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | 46 | 47 | 48 | Error: 49 | 50 | 51 | 52 | 53 | 54 | 55 | false 56 | 57 | 58 | QPlainTextEdit::NoWrap 59 | 60 | 61 | true 62 | 63 | 64 | 65 | 66 | 67 | 68 | 69 | 70 | 71 | 72 | Refresh 73 | 74 | 75 | 76 | 77 | 78 | 79 | 80 | 81 | -------------------------------------------------------------------------------- /src/dafne/bin/batch_validate.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | # Copyright (c) 2021 Dafne-Imaging Team 4 | # 5 | # This program is free software: you can redistribute it and/or modify 6 | # it under the terms of the GNU General Public License as published by 7 | # the Free Software Foundation, either version 3 of the License, or 8 | # (at your option) any later version. 9 | # 10 | # This program is distributed in the hope that it will be useful, 11 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 12 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 13 | # GNU General Public License for more details. 14 | # 15 | # You should have received a copy of the GNU General Public License 16 | # along with this program. If not, see . 17 | 18 | import argparse 19 | 20 | 21 | def main(): 22 | parser = argparse.ArgumentParser(description='Batch model validation') 23 | parser.add_argument('--classification', type=str, help='Classification to use') 24 | parser.add_argument('--timestamp_start', type=int, help='Timestamp start') 25 | parser.add_argument('--timestamp_end', type=int, help='Timestamp end') 26 | parser.add_argument('--upload_stats', type=bool, help='Upload stats') 27 | parser.add_argument('--save_local', type=bool, help='Save local') 28 | parser.add_argument('--local_filename', type=str, help='Local filename') 29 | parser.add_argument('--roi', type=str, help='Reference ROI file') 30 | parser.add_argument('--masks', type=str, help='Reference Mask dataset') 31 | parser.add_argument('--comment', type=str, help='Comment for logging') 32 | parser.add_argument('dataset', type=str, help='Dataset for validation') 33 | 34 | args = parser.parse_args() 35 | args_to_pass = ['classification', 'timestamp_start', 'timestamp_end', 'upload_stats', 'save_local', 'local_filename'] 36 | args_dict = {k: v for k,v in vars(args).items() if v is not None and k in args_to_pass} 37 | dataset = args.dataset 38 | roi = args.roi 39 | masks = args.masks 40 | 41 | from ..utils.BatchValidator import BatchValidator 42 | validator = BatchValidator(**args_dict) 43 | validator.load_directory(dataset) 44 | if roi: 45 | validator.loadROIPickle(roi) 46 | elif masks: 47 | validator.mask_import(masks) 48 | 49 | assert validator.mask_list, 'No masks found' 50 | validator.calculate(args.comment) 51 | 52 | -------------------------------------------------------------------------------- /test/testILearn_AA.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # Copyright (c) 2021 Dafne-Imaging Team 4 | # 5 | # This program is free software: you can redistribute it and/or modify 6 | # it under the terms of the GNU General Public License as published by 7 | # the Free Software Foundation, either version 3 of the License, or 8 | # (at your option) any later version. 9 | # 10 | # This program is distributed in the hope that it will be useful, 11 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 12 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 13 | # GNU General Public License for more details. 14 | # 15 | # You should have received a copy of the GNU General Public License 16 | # along with this program. If not, see . 17 | 18 | from src.dafne_dl import LocalModelProvider 19 | import numpy as np 20 | import pickle 21 | 22 | GENERATE_PICKLE = True 23 | 24 | model_provider = LocalModelProvider('models') 25 | segmenter = model_provider.load_model('Thigh') 26 | 27 | data_in = np.load('testImages/test_data.npy') 28 | segment_in = np.load('testImages/test_segment.npy',allow_pickle=True) 29 | 30 | n_slices = data_in.shape[2] 31 | resolution = np.array([1.037037, 1.037037]) 32 | 33 | LABELS_DICT = { 34 | 1: 'VL', 35 | 2: 'VM', 36 | 3: 'VI', 37 | 4: 'RF', 38 | 5: 'SAR', 39 | 6: 'GRA', 40 | 7: 'AM', 41 | 8: 'SM', 42 | 9: 'ST', 43 | 10: 'BFL', 44 | 11: 'BFS', 45 | 12: 'AL' 46 | } 47 | 48 | if GENERATE_PICKLE: 49 | print("Generating data") 50 | slice_range = range(4,n_slices-5) # range(15,25) 51 | image_list = [] 52 | seg_list = [] 53 | 54 | for slice in slice_range: 55 | image_list.append(data_in[:,:,slice].squeeze()) 56 | seg_dict = {} 57 | for k, v in LABELS_DICT.items(): 58 | seg_dict[v] = segment_in[slice][v][:,:] # segment_in[roi_name][:,:,slice] 59 | seg_list.append(seg_dict) 60 | 61 | pickle.dump(seg_list, open('testImages/test_segment.pickle', 'wb')) 62 | pickle.dump(image_list, open('testImages/test_data.pickle', 'wb')) 63 | else: 64 | seg_list = pickle.load(open('testImages/test_segment.pickle', 'rb')) 65 | image_list = pickle.load(open('testImages/test_data.pickle', 'rb')) 66 | 67 | print('Performing incremental learning') 68 | segmenter.incremental_learn({'image_list': image_list, 'resolution': resolution}, seg_list) 69 | -------------------------------------------------------------------------------- /test/testILearn_split_AA.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # Copyright (c) 2021 Dafne-Imaging Team 4 | # 5 | # This program is free software: you can redistribute it and/or modify 6 | # it under the terms of the GNU General Public License as published by 7 | # the Free Software Foundation, either version 3 of the License, or 8 | # (at your option) any later version. 9 | # 10 | # This program is distributed in the hope that it will be useful, 11 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 12 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 13 | # GNU General Public License for more details. 14 | # 15 | # You should have received a copy of the GNU General Public License 16 | # along with this program. If not, see . 17 | 18 | from src.dafne_dl import LocalModelProvider 19 | import numpy as np 20 | import pickle 21 | 22 | GENERATE_PICKLE = False 23 | 24 | model_provider = LocalModelProvider('models') 25 | segmenter = model_provider.load_model('Thigh-Split') 26 | 27 | data_in = np.load('testImages/test_data.npy') 28 | segment_in = np.load('testImages/test_segment.npy',allow_pickle=True) 29 | 30 | n_slices = data_in.shape[2] 31 | resolution = np.array([1.037037, 1.037037]) 32 | 33 | LABELS_DICT = { 34 | 1: 'VL', 35 | 2: 'VM', 36 | 3: 'VI', 37 | 4: 'RF', 38 | 5: 'SAR', 39 | 6: 'GRA', 40 | 7: 'AM', 41 | 8: 'SM', 42 | 9: 'ST', 43 | 10: 'BFL', 44 | 11: 'BFS', 45 | 12: 'AL' 46 | } 47 | 48 | if GENERATE_PICKLE: 49 | print("Generating data") 50 | slice_range = range(4,n_slices-5) # range(15,25) 51 | image_list = [] 52 | seg_list = [] 53 | 54 | for slice in slice_range: 55 | image_list.append(data_in[:,:,slice].squeeze()) 56 | seg_dict = {} 57 | for k, v in LABELS_DICT.items(): 58 | seg_dict[v] = segment_in[slice][v][:,:] # segment_in[roi_name][:,:,slice] 59 | seg_list.append(seg_dict) 60 | 61 | pickle.dump(seg_list, open('testImages/test_segment.pickle', 'wb')) 62 | pickle.dump(image_list, open('testImages/test_data.pickle', 'wb')) 63 | else: 64 | seg_list = pickle.load(open('testImages/test_segment.pickle', 'rb')) 65 | image_list = pickle.load(open('testImages/test_data.pickle', 'rb')) 66 | 67 | print('Performing incremental learning') 68 | segmenter.incremental_learn({'image_list': image_list, 'resolution': resolution}, seg_list) 69 | -------------------------------------------------------------------------------- /src/dafne/ui/ModelBrowser.ui: -------------------------------------------------------------------------------- 1 | 2 | 3 | ModelBrowser 4 | 5 | 6 | 7 | 0 8 | 0 9 | 1024 10 | 768 11 | 12 | 13 | 14 | Dafne Model Browser 15 | 16 | 17 | 18 | 19 | 20 | 21 | Models 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | true 30 | 31 | 32 | QAbstractItemView::NoEditTriggers 33 | 34 | 35 | true 36 | 37 | 38 | false 39 | 40 | 41 | 42 | Name 43 | 44 | 45 | 46 | 47 | Info 48 | 49 | 50 | 51 | 52 | 53 | 54 | 55 | Qt::Horizontal 56 | 57 | 58 | QDialogButtonBox::Cancel|QDialogButtonBox::Ok 59 | 60 | 61 | 62 | 63 | 64 | 65 | 66 | 67 | buttonBox 68 | accepted() 69 | ModelBrowser 70 | accept() 71 | 72 | 73 | 248 74 | 254 75 | 76 | 77 | 157 78 | 274 79 | 80 | 81 | 82 | 83 | buttonBox 84 | rejected() 85 | ModelBrowser 86 | reject() 87 | 88 | 89 | 316 90 | 260 91 | 92 | 93 | 286 94 | 274 95 | 96 | 97 | 98 | 99 | 100 | -------------------------------------------------------------------------------- /src/dafne/ui/ModelBrowserUI.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | # Form implementation generated from reading ui file 'ModelBrowser.ui' 4 | # 5 | # Created by: PyQt5 UI code generator 5.15.7 6 | # 7 | # WARNING: Any manual changes made to this file will be lost when pyuic5 is 8 | # run again. Do not edit this file unless you know what you are doing. 9 | 10 | 11 | from PyQt5 import QtCore, QtGui, QtWidgets 12 | 13 | 14 | class Ui_ModelBrowser(object): 15 | def setupUi(self, ModelBrowser): 16 | ModelBrowser.setObjectName("ModelBrowser") 17 | ModelBrowser.resize(1024, 768) 18 | self.verticalLayout = QtWidgets.QVBoxLayout(ModelBrowser) 19 | self.verticalLayout.setObjectName("verticalLayout") 20 | self.model_tree = QtWidgets.QTreeWidget(ModelBrowser) 21 | self.model_tree.setObjectName("model_tree") 22 | self.verticalLayout.addWidget(self.model_tree) 23 | self.details_table = QtWidgets.QTableWidget(ModelBrowser) 24 | self.details_table.setEnabled(True) 25 | self.details_table.setEditTriggers(QtWidgets.QAbstractItemView.NoEditTriggers) 26 | self.details_table.setObjectName("details_table") 27 | self.details_table.setColumnCount(2) 28 | self.details_table.setRowCount(0) 29 | item = QtWidgets.QTableWidgetItem() 30 | self.details_table.setHorizontalHeaderItem(0, item) 31 | item = QtWidgets.QTableWidgetItem() 32 | self.details_table.setHorizontalHeaderItem(1, item) 33 | self.details_table.horizontalHeader().setStretchLastSection(True) 34 | self.details_table.verticalHeader().setVisible(False) 35 | self.verticalLayout.addWidget(self.details_table) 36 | self.buttonBox = QtWidgets.QDialogButtonBox(ModelBrowser) 37 | self.buttonBox.setOrientation(QtCore.Qt.Horizontal) 38 | self.buttonBox.setStandardButtons(QtWidgets.QDialogButtonBox.Cancel|QtWidgets.QDialogButtonBox.Ok) 39 | self.buttonBox.setObjectName("buttonBox") 40 | self.verticalLayout.addWidget(self.buttonBox) 41 | 42 | self.retranslateUi(ModelBrowser) 43 | self.buttonBox.accepted.connect(ModelBrowser.accept) # type: ignore 44 | self.buttonBox.rejected.connect(ModelBrowser.reject) # type: ignore 45 | QtCore.QMetaObject.connectSlotsByName(ModelBrowser) 46 | 47 | def retranslateUi(self, ModelBrowser): 48 | _translate = QtCore.QCoreApplication.translate 49 | ModelBrowser.setWindowTitle(_translate("ModelBrowser", "Dafne Model Browser")) 50 | self.model_tree.headerItem().setText(0, _translate("ModelBrowser", "Models")) 51 | item = self.details_table.horizontalHeaderItem(0) 52 | item.setText(_translate("ModelBrowser", "Name")) 53 | item = self.details_table.horizontalHeaderItem(1) 54 | item.setText(_translate("ModelBrowser", "Info")) 55 | -------------------------------------------------------------------------------- /src/dafne/ui/LogWindowUI.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | # Form implementation generated from reading ui file 'LogWindowUI.ui' 4 | # 5 | # Created by: PyQt5 UI code generator 5.15.4 6 | # 7 | # WARNING: Any manual changes made to this file will be lost when pyuic5 is 8 | # run again. Do not edit this file unless you know what you are doing. 9 | 10 | 11 | from PyQt5 import QtCore, QtGui, QtWidgets 12 | 13 | 14 | class Ui_LogWindow(object): 15 | def setupUi(self, LogWindow): 16 | LogWindow.setObjectName("LogWindow") 17 | LogWindow.resize(885, 519) 18 | self.verticalLayout_3 = QtWidgets.QVBoxLayout(LogWindow) 19 | self.verticalLayout_3.setObjectName("verticalLayout_3") 20 | self.horizontalLayout = QtWidgets.QHBoxLayout() 21 | self.horizontalLayout.setObjectName("horizontalLayout") 22 | self.verticalLayout = QtWidgets.QVBoxLayout() 23 | self.verticalLayout.setObjectName("verticalLayout") 24 | self.label = QtWidgets.QLabel(LogWindow) 25 | self.label.setObjectName("label") 26 | self.verticalLayout.addWidget(self.label) 27 | self.output_text = QtWidgets.QPlainTextEdit(LogWindow) 28 | self.output_text.setUndoRedoEnabled(False) 29 | self.output_text.setLineWrapMode(QtWidgets.QPlainTextEdit.NoWrap) 30 | self.output_text.setReadOnly(True) 31 | self.output_text.setObjectName("output_text") 32 | self.verticalLayout.addWidget(self.output_text) 33 | self.horizontalLayout.addLayout(self.verticalLayout) 34 | self.verticalLayout_2 = QtWidgets.QVBoxLayout() 35 | self.verticalLayout_2.setObjectName("verticalLayout_2") 36 | self.label_2 = QtWidgets.QLabel(LogWindow) 37 | self.label_2.setObjectName("label_2") 38 | self.verticalLayout_2.addWidget(self.label_2) 39 | self.error_text = QtWidgets.QPlainTextEdit(LogWindow) 40 | self.error_text.setUndoRedoEnabled(False) 41 | self.error_text.setLineWrapMode(QtWidgets.QPlainTextEdit.NoWrap) 42 | self.error_text.setReadOnly(True) 43 | self.error_text.setObjectName("error_text") 44 | self.verticalLayout_2.addWidget(self.error_text) 45 | self.horizontalLayout.addLayout(self.verticalLayout_2) 46 | self.verticalLayout_3.addLayout(self.horizontalLayout) 47 | self.refresh_btn = QtWidgets.QPushButton(LogWindow) 48 | self.refresh_btn.setObjectName("refresh_btn") 49 | self.verticalLayout_3.addWidget(self.refresh_btn) 50 | 51 | self.retranslateUi(LogWindow) 52 | QtCore.QMetaObject.connectSlotsByName(LogWindow) 53 | 54 | def retranslateUi(self, LogWindow): 55 | _translate = QtCore.QCoreApplication.translate 56 | LogWindow.setWindowTitle(_translate("LogWindow", "Form")) 57 | self.label.setText(_translate("LogWindow", "Output:")) 58 | self.label_2.setText(_translate("LogWindow", "Error:")) 59 | self.refresh_btn.setText(_translate("LogWindow", "Refresh")) 60 | -------------------------------------------------------------------------------- /install_scripts/dafne_win.spec: -------------------------------------------------------------------------------- 1 | # -*- mode: python ; coding: utf-8 -*- 2 | 3 | block_cipher = None 4 | 5 | a_dafne = Analysis(['..\\dafne'], 6 | pathex=['..\\src'], 7 | binaries=[], 8 | datas=[('..\\LICENSE', '.'), ('..\\src\\dafne\\resources\\*', 'resources\\')], 9 | hiddenimports = ['pydicom', 10 | 'dafne', 11 | 'SimpleITK', 12 | 'tensorflow', 13 | 'skimage', 14 | 'nibabel', 15 | 'dafne_dl', 16 | 'pyvistaqt', 17 | 'pyvista', 18 | 'vtk', 19 | 'torch', 20 | 'torchvision'], 21 | hookspath=['..\\pyinstaller_hooks'], 22 | runtime_hooks=[], 23 | excludes=[], 24 | win_no_prefer_redirects=False, 25 | win_private_assemblies=False, 26 | cipher=block_cipher, 27 | noarchive=False) 28 | 29 | a_calc_tra = Analysis(['..\\calc_transforms.py'], 30 | pathex=['..\\src'], 31 | binaries=[], 32 | datas=[], 33 | hiddenimports=[ 34 | 'dafne', 35 | 'pydicom', 36 | 'SimpleITK'], 37 | hookspath=[], 38 | runtime_hooks=[], 39 | excludes=[], 40 | win_no_prefer_redirects=False, 41 | win_private_assemblies=False, 42 | cipher=block_cipher, 43 | noarchive=False) 44 | 45 | MERGE( (a_dafne, 'dafne', 'dafne'), (a_calc_tra, 'calc_transforms', 'calc_transforms') ) 46 | 47 | pyz_dafne = PYZ(a_dafne.pure, a_dafne.zipped_data, 48 | cipher=block_cipher) 49 | exe_dafne = EXE(pyz_dafne, 50 | a_dafne.scripts, 51 | [], 52 | exclude_binaries=True, 53 | name='dafne', 54 | debug=False, 55 | bootloader_ignore_signals=False, 56 | strip=False, 57 | upx=True, 58 | icon='..\\icons\\dafne_icon.ico', 59 | console=True) 60 | coll_dafne = COLLECT(exe_dafne, 61 | a_dafne.binaries, 62 | a_dafne.zipfiles, 63 | a_dafne.datas, 64 | strip=False, 65 | upx=True, 66 | upx_exclude=[], 67 | name='dafne') 68 | 69 | 70 | pyz_calc_tra = PYZ(a_calc_tra.pure, a_calc_tra.zipped_data, 71 | cipher=block_cipher) 72 | exe_calc_tra = EXE(pyz_calc_tra, 73 | a_calc_tra.scripts, 74 | [], 75 | exclude_binaries=True, 76 | name='calc_transforms', 77 | debug=False, 78 | bootloader_ignore_signals=False, 79 | strip=False, 80 | upx=True, 81 | icon='..\\icons\\calctransform_ico.ico', 82 | console=True ) 83 | coll_calc_tra = COLLECT(exe_calc_tra, 84 | a_calc_tra.binaries, 85 | a_calc_tra.zipfiles, 86 | a_calc_tra.datas, 87 | strip=False, 88 | upx=True, 89 | upx_exclude=[], 90 | name='calc_transforms') 91 | 92 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # avoid sending user configuration 2 | config.pickle 3 | 4 | # avoid sending data 5 | *.dcm 6 | *.model 7 | *.hdf5 8 | 9 | # Byte-compiled / optimized / DLL files 10 | __pycache__/ 11 | *.py[cod] 12 | *$py.class 13 | 14 | # C extensions 15 | *.so 16 | 17 | # Distribution / packaging 18 | .Python 19 | build/ 20 | develop-eggs/ 21 | dist/ 22 | downloads/ 23 | eggs/ 24 | .eggs/ 25 | lib/ 26 | lib64/ 27 | parts/ 28 | sdist/ 29 | var/ 30 | wheels/ 31 | share/python-wheels/ 32 | *.egg-info/ 33 | .installed.cfg 34 | *.egg 35 | MANIFEST 36 | 37 | # PyInstaller 38 | # Usually these files are written by a python script from a template 39 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 40 | *.manifest 41 | #*.spec 42 | 43 | # Installer logs 44 | pip-log.txt 45 | pip-delete-this-directory.txt 46 | 47 | # Unit test / coverage reports 48 | htmlcov/ 49 | .tox/ 50 | .nox/ 51 | .coverage 52 | .coverage.* 53 | .cache 54 | nosetests.xml 55 | coverage.xml 56 | *.cover 57 | *.py,cover 58 | .hypothesis/ 59 | .pytest_cache/ 60 | cover/ 61 | 62 | # Translations 63 | *.mo 64 | *.pot 65 | 66 | # Django stuff: 67 | *.log 68 | local_settings.py 69 | db.sqlite3 70 | db.sqlite3-journal 71 | 72 | # Flask stuff: 73 | instance/ 74 | .webassets-cache 75 | 76 | # Scrapy stuff: 77 | .scrapy 78 | 79 | # Sphinx documentation 80 | docs/_build/ 81 | 82 | # PyBuilder 83 | .pybuilder/ 84 | target/ 85 | 86 | # Jupyter Notebook 87 | .ipynb_checkpoints 88 | 89 | # IPython 90 | profile_default/ 91 | ipython_config.py 92 | 93 | # pyenv 94 | # For a library or package, you might want to ignore these files since the code is 95 | # intended to run in multiple environments; otherwise, check them in: 96 | # .python-version 97 | 98 | # pipenv 99 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 100 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 101 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 102 | # install all needed dependencies. 103 | #Pipfile.lock 104 | 105 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 106 | __pypackages__/ 107 | 108 | # Celery stuff 109 | celerybeat-schedule 110 | celerybeat.pid 111 | 112 | # SageMath parsed files 113 | *.sage.py 114 | 115 | # Environments 116 | .env 117 | .venv 118 | env/ 119 | venv/ 120 | ENV/ 121 | env.bak/ 122 | venv.bak/ 123 | 124 | # Spyder project settings 125 | .spyderproject 126 | .spyproject 127 | 128 | # Rope project settings 129 | .ropeproject 130 | 131 | # mkdocs documentation 132 | /site 133 | 134 | # mypy 135 | .mypy_cache/ 136 | .dmypy.json 137 | dmypy.json 138 | 139 | # Pyre type checker 140 | .pyre/ 141 | 142 | # pytype static type analyzer 143 | .pytype/ 144 | 145 | # Cython debug symbols 146 | cython_debug/ 147 | 148 | .vscode 149 | .DS_Store 150 | 151 | tmp 152 | config_localhost.pickle 153 | config_server.pickle 154 | -------------------------------------------------------------------------------- /test/validation_ILearn.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding: utf-8 3 | 4 | # Copyright (c) 2021 Dafne-Imaging Team 5 | # 6 | # This program is free software: you can redistribute it and/or modify 7 | # it under the terms of the GNU General Public License as published by 8 | # the Free Software Foundation, either version 3 of the License, or 9 | # (at your option) any later version. 10 | # 11 | # This program is distributed in the hope that it will be useful, 12 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 13 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 | # GNU General Public License for more details. 15 | # 16 | # You should have received a copy of the GNU General Public License 17 | # along with this program. If not, see . 18 | 19 | import numpy as np 20 | import pickle 21 | from generate_thigh_model import coscia_unet as unet 22 | import src.dafne_dl.common.preprocess_train as pretrain 23 | 24 | model=unet() 25 | #model.load_weights('weights/weights_coscia.hdf5') ## old 26 | model.load_weights('Weights_incremental/thigh/weights - 5 - 86965.35.hdf5') ## incremental 27 | 28 | seg_list = pickle.load(open('testImages/test_segment.pickle', 'rb')) 29 | image_list = pickle.load(open('testImages/test_data.pickle', 'rb')) 30 | 31 | LABELS_DICT = { 32 | 1: 'VL', 33 | 2: 'VM', 34 | 3: 'VI', 35 | 4: 'RF', 36 | 5: 'SAR', 37 | 6: 'GRA', 38 | 7: 'AM', 39 | 8: 'SM', 40 | 9: 'ST', 41 | 10: 'BFL', 42 | 11: 'BFS', 43 | 12: 'AL' 44 | } 45 | ''' 46 | LABELS_DICT = { 47 | 1: 'SOL', 48 | 2: 'GM', 49 | 3: 'GL', 50 | 4: 'TA', 51 | 5: 'ELD', 52 | 6: 'PE', 53 | } 54 | ''' 55 | MODEL_RESOLUTION = np.array([1.037037, 1.037037]) 56 | MODEL_SIZE = (432, 432) 57 | 58 | image_list, mask_list = pretrain.common_input_process(LABELS_DICT, MODEL_RESOLUTION, MODEL_SIZE, {'image_list': image_list, 'resolution': MODEL_RESOLUTION}, seg_list) 59 | 60 | ch = mask_list[0].shape[2] 61 | aggregated_masks = [] 62 | mask_list_no_overlap = [] 63 | for masks in mask_list: 64 | agg, new_masks = pretrain.calc_aggregated_masks_and_remove_overlap(masks) 65 | aggregated_masks.append(agg) 66 | mask_list_no_overlap.append(new_masks) 67 | 68 | for slice_number in range(len(image_list)): 69 | img = image_list[slice_number] 70 | segmentation = model.predict(np.expand_dims(np.stack([img,np.zeros(MODEL_SIZE)],axis=-1),axis=0)) 71 | segmentationnum = np.argmax(np.squeeze(segmentation[0,:,:,:ch]), axis=2) 72 | cateseg=np.zeros((432,432,ch),dtype='float32') 73 | for i in range(432): 74 | for j in range(432): 75 | cateseg[i,j,int(segmentationnum[i,j])]=1.0 76 | acc=0 77 | y_pred=cateseg 78 | y_true=mask_list_no_overlap[slice_number] 79 | for j in range(ch): ## Dice 80 | elements_per_class=y_true[:,:,j].sum() 81 | predicted_per_class=y_pred[:,:,j].sum() 82 | intersection=(np.multiply(y_pred[:,:,j],y_true[:,:,j])).sum() 83 | intersection=2.0*intersection 84 | union=elements_per_class+predicted_per_class 85 | acc+=intersection/(union+0.000001) 86 | acc=acc/ch 87 | print(str(slice_number)+'__'+str(acc)) 88 | -------------------------------------------------------------------------------- /test/validation_split_ILearn.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding: utf-8 3 | 4 | # Copyright (c) 2021 Dafne-Imaging Team 5 | # 6 | # This program is free software: you can redistribute it and/or modify 7 | # it under the terms of the GNU General Public License as published by 8 | # the Free Software Foundation, either version 3 of the License, or 9 | # (at your option) any later version. 10 | # 11 | # This program is distributed in the hope that it will be useful, 12 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 13 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 | # GNU General Public License for more details. 15 | # 16 | # You should have received a copy of the GNU General Public License 17 | # along with this program. If not, see . 18 | 19 | import numpy as np 20 | import pickle 21 | from generate_thigh_split_model import coscia_unet as unet 22 | import src.dafne_dl.common.preprocess_train as pretrain 23 | 24 | model=unet() 25 | #model.load_weights('weights/weights_coscia_split.hdf5') ## old 26 | model.load_weights('Weights_incremental_split/thigh/weights - 5 - 47699.60.hdf5') ## incremental 27 | 28 | seg_list = pickle.load(open('testImages/test_segment.pickle', 'rb')) 29 | image_list = pickle.load(open('testImages/test_data.pickle', 'rb')) 30 | 31 | LABELS_DICT = { 32 | 1: 'VL', 33 | 2: 'VM', 34 | 3: 'VI', 35 | 4: 'RF', 36 | 5: 'SAR', 37 | 6: 'GRA', 38 | 7: 'AM', 39 | 8: 'SM', 40 | 9: 'ST', 41 | 10: 'BFL', 42 | 11: 'BFS', 43 | 12: 'AL' 44 | } 45 | ''' 46 | LABELS_DICT = { 47 | 1: 'SOL', 48 | 2: 'GM', 49 | 3: 'GL', 50 | 4: 'TA', 51 | 5: 'ELD', 52 | 6: 'PE', 53 | } 54 | ''' 55 | MODEL_RESOLUTION = np.array([1.037037, 1.037037]) 56 | MODEL_SIZE = (432, 432) 57 | MODEL_SIZE_SPLIT = (250, 250) 58 | 59 | image_list, mask_list = pretrain.common_input_process_split(LABELS_DICT, MODEL_RESOLUTION, MODEL_SIZE, MODEL_SIZE_SPLIT, {'image_list': image_list, 'resolution': MODEL_RESOLUTION}, seg_list) 60 | 61 | ch = mask_list[0].shape[2] 62 | aggregated_masks = [] 63 | mask_list_no_overlap = [] 64 | for masks in mask_list: 65 | agg, new_masks = pretrain.calc_aggregated_masks_and_remove_overlap(masks) 66 | aggregated_masks.append(agg) 67 | mask_list_no_overlap.append(new_masks) 68 | 69 | for slice_number in range(len(image_list)): 70 | img = image_list[slice_number] 71 | segmentation = model.predict(np.expand_dims(np.stack([img,np.zeros(MODEL_SIZE_SPLIT)],axis=-1),axis=0)) 72 | segmentationnum = np.argmax(np.squeeze(segmentation[0,:,:,:ch]), axis=2) 73 | cateseg=np.zeros((MODEL_SIZE_SPLIT[0],MODEL_SIZE_SPLIT[1],ch),dtype='float32') 74 | for i in range(MODEL_SIZE_SPLIT[0]): 75 | for j in range(MODEL_SIZE_SPLIT[1]): 76 | cateseg[i,j,int(segmentationnum[i,j])]=1.0 77 | acc=0 78 | y_pred=cateseg 79 | y_true=mask_list_no_overlap[slice_number] 80 | for j in range(ch): ## Dice 81 | elements_per_class=y_true[:,:,j].sum() 82 | predicted_per_class=y_pred[:,:,j].sum() 83 | intersection=(np.multiply(y_pred[:,:,j],y_true[:,:,j])).sum() 84 | intersection=2.0*intersection 85 | union=elements_per_class+predicted_per_class 86 | acc+=intersection/(union+0.000001) 87 | acc=acc/ch 88 | print(str(slice_number)+'__'+str(acc)) 89 | -------------------------------------------------------------------------------- /install_scripts/create_mac_installer.sh: -------------------------------------------------------------------------------- 1 | #!/bin/zsh 2 | # 3 | # Copyright (c) 2022 Dafne-Imaging Team 4 | # 5 | 6 | # note: checking signature/notarization: spctl -a -vvv 7 | 8 | git pull 9 | git checkout master 10 | 11 | APPNAME=Dafne 12 | VERSION=$(python update_version.py | tail -n 1) 13 | ARCH=$(uname -a | sed -E -n 's/.*(arm64|x86_64)$/\1/p') 14 | DMG_NAME=dafne_mac_${VERSION}_$ARCH.dmg 15 | CODESIGN_IDENTITY="Francesco Santini" 16 | USE_ALTOOL=False 17 | 18 | 19 | NOTARYTOOL() { 20 | if [ -f /Library/Developer/CommandLineTools/usr/bin/notarytool ] 21 | then 22 | /Library/Developer/CommandLineTools/usr/bin/notarytool "$@" 23 | else 24 | xcrun notarytool "$@" 25 | fi 26 | } 27 | 28 | echo $VERSION 29 | pyinstaller dafne_mac.spec --noconfirm 30 | cd dist 31 | echo "Fixing app bundle" 32 | python ../fix_app_bundle_for_mac.py $APPNAME.app 33 | echo "Signing code" 34 | # Sign code outside MacOS 35 | find $APPNAME.app/Contents/Resources -name '*.dylib' | xargs codesign --force -v -s "$CODESIGN_IDENTITY" 36 | find $APPNAME.app/Contents/Resources -name '*.so' | xargs codesign --force -v -s "$CODESIGN_IDENTITY" 37 | # sign the app 38 | codesign --deep --force -v -s "$CODESIGN_IDENTITY" $APPNAME.app 39 | find $APPNAME.app/Contents -path '*bin/*' | xargs codesign --force -o runtime --timestamp --entitlements ../entitlements.plist -v -s "$CODESIGN_IDENTITY" 40 | 41 | # Resign the app with the correct entitlement 42 | codesign --force -o runtime --timestamp --entitlements ../entitlements.plist -v -s "$CODESIGN_IDENTITY" $APPNAME.app 43 | 44 | echo "Creating DMG" 45 | # Create-dmg at some point stopped working because of a newline prepended to the mount point. If this fails, check for this bug. 46 | create-dmg --volname "Dafne" --volicon $APPNAME.app/Contents/Resources/dafne_icon.icns \ 47 | --eula $APPNAME.app/Contents/Resources/LICENSE --background ../../icons/mac_installer_bg.png \ 48 | --window-size 420 220 --icon-size 64 --icon $APPNAME.app 46 31 \ 49 | --app-drop-link 236 90 "$DMG_NAME" $APPNAME.app 50 | codesign -s "$CODESIGN_IDENTITY" "$DMG_NAME" 51 | echo "Notarizing app" 52 | 53 | if [ "$USE_ALTOOL" = "True" ] 54 | then 55 | # make sure that the credentials are stored in keychain with 56 | # xcrun altool --store-password-in-keychain-item "AC_PASSWORD" -u "" -p "" 57 | # Note: password is a "app-specific password" created in the appleID site to bypass 2FA 58 | xcrun altool --notarize-app \ 59 | --primary-bundle-id "network.dafne.dafne" --password "@keychain:AC_PASSWORD" \ 60 | --file "$DMG_NAME" 61 | 62 | echo "Check the status in 1 hour or so with:" 63 | echo 'xcrun altool --notarization-history 0 -p "@keychain:AC_PASSWORD"' 64 | echo 'If there are any errors check' 65 | echo 'xcrun altool --notarization-info -p "@keychain:AC_PASSWORD"' 66 | echo 'if successful, staple the ticket running' 67 | echo "xcrun stapler staple $DMG_NAME" 68 | else 69 | # alternative with notarytool 70 | # store credentials: 71 | # xcrun notarytool store-credentials "AC_PASSWORD" --apple-id --password --team-id 72 | echo "This can take up to 1 hour" 73 | NOTARYTOOL submit "$DMG_NAME" --wait --keychain-profile "AC_PASSWORD" 74 | # This will wait for the notarization to complete 75 | echo 'If failed, save log file with:' 76 | echo 'xcrun notarytool log --keychain-profile "AC_PASSWORD" notarization.log' 77 | xcrun stapler staple "$DMG_NAME" 78 | fi 79 | -------------------------------------------------------------------------------- /src/dafne/bin/dafne.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | 4 | # Copyright (c) 2022 Dafne-Imaging Team 5 | # 6 | # This program is free software: you can redistribute it and/or modify 7 | # it under the terms of the GNU General Public License as published by 8 | # the Free Software Foundation, either version 3 of the License, or 9 | # (at your option) any later version. 10 | # 11 | # This program is distributed in the hope that it will be useful, 12 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 13 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 | # GNU General Public License for more details. 15 | # 16 | # You should have received a copy of the GNU General Public License 17 | # along with this program. If not, see . 18 | 19 | import os 20 | 21 | from ..ui.WhatsNew import show_news 22 | # Hide tensorflow warnings; set to 1 to see warnings 23 | from ..utils import log 24 | 25 | os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' # or any {'0', '1', '2', '3'} 26 | 27 | from ..ui.MuscleSegmentation import MuscleSegmentation 28 | from ..config import GlobalConfig, load_config 29 | 30 | import matplotlib 31 | matplotlib.use("Qt5Agg") 32 | import argparse 33 | import matplotlib.pyplot as plt 34 | 35 | MODELS_DIR = 'models_old' 36 | 37 | 38 | def main(): 39 | parser = argparse.ArgumentParser(description="Muscle segmentation tool.") 40 | parser.add_argument('path', nargs='?', type=str) 41 | parser.add_argument('-c', '--class', dest='dl_class', type=str, help='Specify the deep learning model to use for the dataset') 42 | parser.add_argument('-r', '--register', action='store_true', help='Perform the registration after loading.') 43 | parser.add_argument('-m', '--save-masks', action='store_true', help='Convert saved ROIs to masks.') 44 | parser.add_argument('-d', '--save-dicoms', action='store_true', help='Save ROIs as dicoms in addition to numpy') 45 | parser.add_argument('-q', '--quit', action='store_true', help='Quit after loading the dataset (useful with -r or -q options).') 46 | parser.add_argument('-rm', '--remote-model', action='store_true', help='Force remote model') 47 | parser.add_argument('-lm', '--local-model', action='store_true', help='Force local model') 48 | 49 | args = parser.parse_args() 50 | 51 | load_config() 52 | 53 | if args.remote_model: 54 | GlobalConfig['MODEL_PROVIDER'] = 'Remote' 55 | 56 | if args.local_model: 57 | GlobalConfig['MODEL_PROVIDER'] = 'Local' 58 | 59 | if GlobalConfig['REDIRECT_OUTPUT']: 60 | import sys 61 | 62 | log.log_objects['stdout'] = log.LogStream(GlobalConfig['OUTPUT_LOG_FILE'], sys.stdout if GlobalConfig['ECHO_OUTPUT'] else None) 63 | log.log_objects['stderr'] = log.LogStream(GlobalConfig['ERROR_LOG_FILE'], sys.stderr if GlobalConfig['ECHO_OUTPUT'] else None) 64 | 65 | sys.stdout = log.log_objects['stdout'] 66 | sys.stderr = log.log_objects['stderr'] 67 | 68 | imFig = MuscleSegmentation() 69 | 70 | dl_class = None 71 | 72 | if args.dl_class: 73 | dl_class = args.dl_class 74 | 75 | if args.path: 76 | imFig.loadDirectory(args.path, dl_class) 77 | 78 | if args.save_dicoms: 79 | imFig.saveDicom = True 80 | 81 | if args.register: 82 | imFig.calcTransforms() 83 | 84 | if args.save_masks: 85 | imFig.saveResults() 86 | 87 | if not args.quit: 88 | plt.show() 89 | 90 | 91 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | [![PyPI version](https://badge.fury.io/py/dafne.svg)](https://badge.fury.io/py/dafne) 2 | [![PDF Documentation](https://img.shields.io/badge/Docs-pdf-brightgreen)](https://www.dafne.network/files/documentation.pdf) 3 | [![HTML Documentation](https://img.shields.io/badge/Docs-html-brightgreen)](https://www.dafne.network/documentation/) 4 | 5 | # Dafne 6 | Deep Anatomical Federated Network is a program for the segmentation of medical images. It relies on a server to provide deep learning models to aid the segmentation, and incremental learning is used to improve the performance. See https://www.dafne.network/ for documentation and user information. 7 | 8 | ## Windows binary installation 9 | Please install the Visual Studio Redistributable Package under windows: https://aka.ms/vs/16/release/vc_redist.x64.exe 10 | Then, run the provided installer 11 | 12 | ## Mac binary installation 13 | Install the Dafne App from the downloaded .dmg file as usual. Make sure to download the archive appropriate for your architecture (x86 or arm). 14 | 15 | ## Linux binary installation 16 | The Linux distribution is a self-contained executable file. Simply download it, make it executable, and run it. 17 | 18 | ## pip installation 19 | Dafne can also be installed with pip 20 | `pip install dafne` 21 | 22 | # Citing 23 | If you are writing a scientific paper, and you used Dafne for your data evaluation, please cite the following paper: 24 | 25 | > Santini F, Wasserthal J, Agosti A, et al. *Deep Anatomical Federated Network (Dafne): an open client/server framework for the continuous collaborative improvement of deep-learning-based medical image segmentation*. 2023 doi: [10.48550/arXiv.2302.06352](https://doi.org/10.48550/arXiv.2302.06352). 26 | 27 | 28 | # Notes for developers 29 | 30 | ## dafne 31 | 32 | Run: 33 | `python dafne.py ` 34 | 35 | 36 | ## Notes for the DL models 37 | 38 | ### Apply functions 39 | The input of the apply function is: 40 | ``` 41 | dict({ 42 | 'image': np.array (2D image) 43 | 'resolution': sequence with two elements (image resolution in mm) 44 | 'split_laterality': True/False (indicates whether the ROIs should be split in L/R if applicable) 45 | 'classification': str - The classification tag of the image (optional, to identify model variants) 46 | }) 47 | ``` 48 | 49 | The output of the classifier is a string. 50 | The output of the segmenters is: 51 | ``` 52 | dict({ 53 | roi_name_1: np.array (2D mask), 54 | roi_name_2: ... 55 | }) 56 | ``` 57 | 58 | ### Incremental learn functions 59 | The input of the incremental learn functions are: 60 | ``` 61 | training data: dict({ 62 | 'resolution': sequence (see above) 63 | 'classification': str (see above) 64 | 'image_list': list([ 65 | - np.array (2D image) 66 | - np.array (2D image) 67 | - ... 68 | ]) 69 | }) 70 | 71 | training outputs: list([ 72 | - dict({ 73 | roi_name_1: np.array (2D mask) 74 | roi_name_2: ... 75 | }) 76 | - dict... 77 | ``` 78 | 79 | Every entry in the training outputs list corresponds to an entry in the image_list inside the training data. 80 | So `len(training_data['image_list']) == len(training_outputs)`. 81 | 82 | # Acknowledgments 83 | Input/Output is based on [DOSMA](https://github.com/ad12/DOSMA) - GPLv3 license 84 | 85 | This software includes the [Segment Anything Model (SAM)](https://github.com/facebookresearch/segment-anything) - Apache 2.0 license 86 | 87 | Other packages required for this project are listed in requirements.txt 88 | -------------------------------------------------------------------------------- /install_scripts/dafne_mac.spec: -------------------------------------------------------------------------------- 1 | # -*- mode: python ; coding: utf-8 -*- 2 | 3 | block_cipher = None 4 | import sys ; sys.setrecursionlimit(sys.getrecursionlimit() * 5) 5 | 6 | a_dafne = Analysis(['../run_dafne.py'], 7 | pathex=['../src'], 8 | binaries=[], 9 | datas=[('../LICENSE', '.'), ('../src/dafne/resources/*', 'resources/')], 10 | hiddenimports = ['dafne', 11 | 'pydicom', 12 | 'SimpleITK', 13 | 'tensorflow', 14 | 'skimage', 15 | 'nibabel', 16 | 'dafne_dl', 17 | 'cmath', 18 | 'ormir-pyvoxel', 19 | 'pyvistaqt', 20 | 'pyvista', 21 | 'vtk', 22 | 'torch', 23 | 'torchvision'], 24 | hookspath=['../pyinstaller_hooks'], 25 | runtime_hooks=[], 26 | excludes=[], 27 | win_no_prefer_redirects=False, 28 | win_private_assemblies=False, 29 | cipher=block_cipher, 30 | noarchive=False) 31 | 32 | a_calc_tra = Analysis(['../calc_transforms.py'], 33 | pathex=['../src'], 34 | binaries=[], 35 | datas=[], 36 | hiddenimports=[ 37 | 'dafne', 38 | 'pydicom', 39 | 'SimpleITK', 40 | 'ormir-pyvoxel'], 41 | hookspath=[], 42 | runtime_hooks=[], 43 | excludes=[], 44 | win_no_prefer_redirects=False, 45 | win_private_assemblies=False, 46 | cipher=block_cipher, 47 | noarchive=False) 48 | 49 | MERGE( (a_dafne, 'dafne', 'dafne'), (a_calc_tra, 'calc_transforms', 'calc_transforms') ) 50 | 51 | pyz_dafne = PYZ(a_dafne.pure, a_dafne.zipped_data, 52 | cipher=block_cipher) 53 | exe_dafne = EXE(pyz_dafne, 54 | a_dafne.scripts, 55 | [], 56 | exclude_binaries=True, 57 | name='dafne', 58 | debug=False, 59 | icon='../icons/dafne_icon.icns', 60 | bootloader_ignore_signals=False, 61 | strip=False, 62 | upx=True, 63 | console=False ) 64 | coll_dafne = COLLECT(exe_dafne, 65 | a_dafne.binaries, 66 | a_dafne.zipfiles, 67 | a_dafne.datas, 68 | strip=False, 69 | upx=True, 70 | upx_exclude=[], 71 | name='dafne') 72 | 73 | pyz_calc_tra = PYZ(a_calc_tra.pure, a_calc_tra.zipped_data, 74 | cipher=block_cipher) 75 | exe_calc_tra = EXE(pyz_calc_tra, 76 | a_calc_tra.scripts, 77 | [], 78 | exclude_binaries=True, 79 | name='calc_transforms', 80 | debug=False, 81 | bootloader_ignore_signals=False, 82 | strip=False, 83 | upx=True, 84 | icon='../icons/calctransform_ico.ico', 85 | console=True ) 86 | coll_calc_tra = COLLECT(exe_calc_tra, 87 | a_calc_tra.binaries, 88 | a_calc_tra.zipfiles, 89 | a_calc_tra.datas, 90 | strip=False, 91 | upx=True, 92 | upx_exclude=[], 93 | name='calc_transforms') 94 | 95 | app = BUNDLE(coll_dafne, 96 | name='Dafne.app', 97 | icon='../icons/dafne_icon.icns', 98 | bundle_identifier='network.dafne.dafne', 99 | version='1.8-alpha3') 100 | -------------------------------------------------------------------------------- /src/dafne/ui/CalcTransformsUI.ui: -------------------------------------------------------------------------------- 1 | 2 | 3 | CalcTransformsUI 4 | 5 | 6 | 7 | 0 8 | 0 9 | 412 10 | 218 11 | 12 | 13 | 14 | Form 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 0 24 | 0 25 | 26 | 27 | 28 | Location: 29 | 30 | 31 | 32 | 33 | 34 | 35 | false 36 | 37 | 38 | 39 | 40 | 41 | 42 | Choose... 43 | 44 | 45 | 46 | 47 | 48 | 49 | 50 | 51 | false 52 | 53 | 54 | 0 55 | 56 | 57 | 58 | 59 | 60 | 61 | Orientation 62 | 63 | 64 | 65 | 66 | 67 | Original 68 | 69 | 70 | true 71 | 72 | 73 | 74 | 75 | 76 | 77 | Axial 78 | 79 | 80 | 81 | 82 | 83 | 84 | Sagittal 85 | 86 | 87 | 88 | 89 | 90 | 91 | Coronal 92 | 93 | 94 | 95 | 96 | 97 | 98 | 99 | 100 | 101 | Qt::Vertical 102 | 103 | 104 | QSizePolicy::Expanding 105 | 106 | 107 | 108 | 20 109 | 45 110 | 111 | 112 | 113 | 114 | 115 | 116 | 117 | false 118 | 119 | 120 | Calculate Transforms 121 | 122 | 123 | 124 | 125 | 126 | 127 | 128 | 129 | -------------------------------------------------------------------------------- /src/dafne/MedSAM/utils/pre_grey_rgb.py: -------------------------------------------------------------------------------- 1 | #%% import packages 2 | import numpy as np 3 | import os 4 | join = os.path.join 5 | from skimage import io, transform 6 | from tqdm import tqdm 7 | 8 | # convert 2D data to npy files, including images and corresponding masks 9 | modality = 'dd' # e.g., 'Dermoscopy 10 | anatomy = 'dd' # e.g., 'SkinCancer' 11 | img_name_suffix = '.png' 12 | gt_name_suffix = '.png' 13 | prefix = modality + '_' + anatomy + '_' 14 | save_suffix = '.npy' 15 | image_size = 1024 16 | img_path = 'path to /images' # path to the images 17 | gt_path = 'path to/labels' # path to the corresponding annotations 18 | npy_path = 'path to/data/npy/' + prefix[:-1] # save npy path e.g., MedSAM/data/npy/; don't miss the `/` 19 | os.makedirs(join(npy_path, "gts"), exist_ok=True) 20 | os.makedirs(join(npy_path, "imgs"), exist_ok=True) 21 | names = sorted(os.listdir(gt_path)) 22 | print(f'ori \# files {len(names)=}') 23 | 24 | # set label ids that are excluded 25 | remove_label_ids = [] 26 | tumor_id = None # only set this when there are multiple tumors in one image; convert semantic masks to instance masks 27 | label_id_offset = 0 28 | do_intensity_cutoff = False # True for grey images 29 | #%% save preprocessed images and masks as npz files 30 | for name in tqdm(names): 31 | image_name = name.split(gt_name_suffix)[0] + img_name_suffix 32 | gt_name = name 33 | npy_save_name = prefix + gt_name.split(gt_name_suffix)[0]+save_suffix 34 | gt_data_ori = np.uint8(io.imread(join(gt_path, gt_name))) 35 | # remove label ids 36 | for remove_label_id in remove_label_ids: 37 | gt_data_ori[gt_data_ori==remove_label_id] = 0 38 | # label tumor masks as instances and remove from gt_data_ori 39 | if tumor_id is not None: 40 | tumor_bw = np.uint8(gt_data_ori==tumor_id) 41 | gt_data_ori[tumor_bw>0] = 0 42 | # label tumor masks as instances 43 | tumor_inst, tumor_n = cc3d.connected_components(tumor_bw, connectivity=26, return_N=True) 44 | # put the tumor instances back to gt_data_ori 45 | gt_data_ori[tumor_inst>0] = tumor_inst[tumor_inst>0] + label_id_offset + 1 46 | 47 | # crop the ground truth with non-zero slices 48 | image_data = io.imread(join(img_path, image_name)) 49 | if np.max(image_data) > 255.0: 50 | image_data = np.uint8((image_data-image_data.min()) / (np.max(image_data)-np.min(image_data))*255.0) 51 | if len(image_data.shape) == 2: 52 | image_data = np.repeat(np.expand_dims(image_data, -1), 3, -1) 53 | assert len(image_data.shape) == 3, 'image data is not three channels: img shape:' + str(image_data.shape) + image_name 54 | # convert three channel to one channel 55 | if image_data.shape[-1] > 3: 56 | image_data = image_data[:,:,:3] 57 | # image preprocess start 58 | if do_intensity_cutoff: 59 | lower_bound, upper_bound = np.percentile(image_data[image_data>0], 0.5), np.percentile(image_data[image_data>0], 99.5) 60 | image_data_pre = np.clip(image_data, lower_bound, upper_bound) 61 | image_data_pre = (image_data_pre - np.min(image_data_pre))/(np.max(image_data_pre)-np.min(image_data_pre))*255.0 62 | image_data_pre[image_data==0] = 0 63 | image_data_pre = np.uint8(image_data_pre) 64 | else: 65 | # print('no intensity cutoff') 66 | image_data_pre = image_data.copy() 67 | np.savez_compressed(join(npy_path, prefix + gt_name.split(gt_name_suffix)[0]+'.npz'), imgs=image_data_pre, gts=gt_data_ori) 68 | resize_img = transform.resize(image_data_pre, (image_size, image_size), order=3, mode='constant', preserve_range=True, anti_aliasing=True) 69 | resize_img01 = resize_img/255.0 70 | resize_gt = transform.resize(gt_data_ori, (image_size, image_size), order=0, mode='constant', preserve_range=True, anti_aliasing=False) 71 | # save resize img and gt as npy 72 | np.save(join(npy_path, "imgs", npy_save_name), resize_img01) 73 | np.save(join(npy_path, "gts", npy_save_name), resize_gt.astype(np.uint8)) 74 | 75 | -------------------------------------------------------------------------------- /src/dafne/ui/ContourPainter.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021 Dafne-Imaging Team 2 | # 3 | # This program is free software: you can redistribute it and/or modify 4 | # it under the terms of the GNU General Public License as published by 5 | # the Free Software Foundation, either version 3 of the License, or 6 | # (at your option) any later version. 7 | # 8 | # This program is distributed in the hope that it will be useful, 9 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 10 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 11 | # GNU General Public License for more details. 12 | # 13 | # You should have received a copy of the GNU General Public License 14 | # along with this program. If not, see . 15 | 16 | from matplotlib.patches import Circle, Polygon 17 | from ..utils.pySplineInterp import SplineInterpROIClass 18 | 19 | MIN_KNOT_RADIUS = 0.5 # if radius is smaller than this, knots are not painted 20 | 21 | 22 | class ContourPainter: 23 | """ 24 | Class to paint a series of ROIs on a pyplot axes object 25 | """ 26 | def __init__(self, color, knot_radius): 27 | self._knots = [] 28 | self._curves = [] 29 | self.rois = [] 30 | self.color = color 31 | self.knot_radius = knot_radius 32 | self.painted = False 33 | 34 | def set_color(self, color): 35 | self.color = color 36 | self.recalculate_patches() 37 | 38 | def set_radius(self, radius): 39 | self.radius = radius 40 | self.recalculate_patches() 41 | 42 | def clear_patches(self, axes = None): 43 | if not self.painted: return 44 | self.painted = False 45 | if axes: 46 | # axes.patches = [] # Error with new matplotlib! 47 | try: 48 | for patch in axes.patches: 49 | patch.remove() 50 | except Exception as e: 51 | print("Error removing patches", e) 52 | return 53 | for knot in self._knots: 54 | try: 55 | knot.set_visible(False) 56 | except: 57 | pass 58 | try: 59 | knot.remove() 60 | except: 61 | pass 62 | for curve in self._curves: 63 | try: 64 | curve.set_visible(False) 65 | except: 66 | pass 67 | try: 68 | curve.remove() 69 | except: 70 | pass 71 | 72 | def recalculate_patches(self): 73 | self.clear_patches() 74 | self._knots = [] 75 | self._curves = [] 76 | for roi in self.rois: 77 | if self.knot_radius >= MIN_KNOT_RADIUS: 78 | for knot in roi.knots: 79 | self._knots.append(Circle(knot, 80 | self.knot_radius, 81 | facecolor='none', 82 | edgecolor=self.color, 83 | linewidth=1.0)) 84 | try: 85 | self._curves.append(Polygon(roi.getCurve(), 86 | facecolor = 'none', 87 | edgecolor = self.color, 88 | zorder=1)) 89 | except: 90 | pass 91 | 92 | def add_roi(self, roi: SplineInterpROIClass): 93 | self.rois.append(roi) 94 | self.recalculate_patches() 95 | 96 | def clear_rois(self, axes = None): 97 | self.clear_patches(axes) 98 | self._knots = [] 99 | self._curves = [] 100 | self.rois = [] 101 | 102 | def draw(self, axes, clear_first=False): 103 | # print("Calling Contourpainter draw") 104 | if clear_first: 105 | self.clear_patches() 106 | for knot in self._knots: 107 | #if not self.painted: 108 | axes.add_patch(knot) 109 | axes.draw_artist(knot) 110 | self.painted = True 111 | for curve in self._curves: 112 | #if not self.painted: 113 | axes.add_patch(curve) 114 | axes.draw_artist(curve) 115 | self.painted = True 116 | # print("Painted?", self.painted) 117 | 118 | 119 | 120 | -------------------------------------------------------------------------------- /src/dafne/MedSAM/segment_anything/utils/transforms.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # All rights reserved. 4 | 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | import numpy as np 9 | import torch 10 | from torch.nn import functional as F 11 | from torchvision.transforms.functional import resize, to_pil_image # type: ignore 12 | 13 | from copy import deepcopy 14 | from typing import Tuple 15 | 16 | 17 | class ResizeLongestSide: 18 | """ 19 | Resizes images to the longest side 'target_length', as well as provides 20 | methods for resizing coordinates and boxes. Provides methods for 21 | transforming both numpy array and batched torch tensors. 22 | """ 23 | 24 | def __init__(self, target_length: int) -> None: 25 | self.target_length = target_length 26 | 27 | def apply_image(self, image: np.ndarray) -> np.ndarray: 28 | """ 29 | Expects a numpy array with shape HxWxC in uint8 format. 30 | """ 31 | target_size = self.get_preprocess_shape( 32 | image.shape[0], image.shape[1], self.target_length 33 | ) 34 | return np.array(resize(to_pil_image(image), target_size)) 35 | 36 | def apply_coords( 37 | self, coords: np.ndarray, original_size: Tuple[int, ...] 38 | ) -> np.ndarray: 39 | """ 40 | Expects a numpy array of length 2 in the final dimension. Requires the 41 | original image size in (H, W) format. 42 | """ 43 | old_h, old_w = original_size 44 | new_h, new_w = self.get_preprocess_shape(old_h, old_w, self.target_length) 45 | new_coords = np.empty_like(coords) 46 | new_coords[..., 0] = coords[..., 0] * (new_w / old_w) 47 | new_coords[..., 1] = coords[..., 1] * (new_h / old_h) 48 | return new_coords 49 | 50 | def apply_boxes( 51 | self, boxes: np.ndarray, original_size: Tuple[int, ...] 52 | ) -> np.ndarray: 53 | """ 54 | Expects a numpy array shape Bx4. Requires the original image size 55 | in (H, W) format. 56 | """ 57 | boxes = self.apply_coords(boxes.reshape(-1, 2, 2), original_size) 58 | return boxes.reshape(-1, 4) 59 | 60 | def apply_image_torch(self, image: torch.Tensor) -> torch.Tensor: 61 | """ 62 | Expects batched images with shape BxCxHxW and float format. This 63 | transformation may not exactly match apply_image. apply_image is 64 | the transformation expected by the model. 65 | """ 66 | # Expects an image in BCHW format. May not exactly match apply_image. 67 | target_size = self.get_preprocess_shape( 68 | image.shape[2], image.shape[3], self.target_length 69 | ) 70 | return F.interpolate( 71 | image, target_size, mode="bilinear", align_corners=False, antialias=True 72 | ) 73 | 74 | def apply_coords_torch( 75 | self, coords: torch.Tensor, original_size: Tuple[int, ...] 76 | ) -> torch.Tensor: 77 | """ 78 | Expects a torch tensor with length 2 in the last dimension. Requires the 79 | original image size in (H, W) format. 80 | """ 81 | old_h, old_w = original_size 82 | new_h, new_w = self.get_preprocess_shape( 83 | original_size[0], original_size[1], self.target_length 84 | ) 85 | coords = deepcopy(coords).to(torch.float) 86 | coords[..., 0] = coords[..., 0] * (new_w / old_w) 87 | coords[..., 1] = coords[..., 1] * (new_h / old_h) 88 | return coords 89 | 90 | def apply_boxes_torch( 91 | self, boxes: torch.Tensor, original_size: Tuple[int, ...] 92 | ) -> torch.Tensor: 93 | """ 94 | Expects a torch tensor with shape Bx4. Requires the original image 95 | size in (H, W) format. 96 | """ 97 | boxes = self.apply_coords_torch(boxes.reshape(-1, 2, 2), original_size) 98 | return boxes.reshape(-1, 4) 99 | 100 | @staticmethod 101 | def get_preprocess_shape( 102 | oldh: int, oldw: int, long_side_length: int 103 | ) -> Tuple[int, int]: 104 | """ 105 | Compute the output size given input size and target long side length. 106 | """ 107 | scale = long_side_length * 1.0 / max(oldh, oldw) 108 | newh, neww = oldh * scale, oldw * scale 109 | neww = int(neww + 0.5) 110 | newh = int(newh + 0.5) 111 | return (newh, neww) 112 | -------------------------------------------------------------------------------- /src/dafne/MedSAM/utils/format_convert.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import os 3 | join = os.path.join 4 | import random 5 | import numpy as np 6 | from skimage import io 7 | import SimpleITK as sitk 8 | 9 | 10 | def dcm2nii(dcm_path, nii_path): 11 | """ 12 | Convert dicom files to nii files 13 | """ 14 | reader = sitk.ImageSeriesReader() 15 | dicom_names = reader.GetGDCMSeriesFileNames(dcm_path) 16 | reader.SetFileNames(dicom_names) 17 | image = reader.Execute() 18 | sitk.WriteImage(image, nii_path) 19 | 20 | def mhd2nii(mhd_path, nii_path): 21 | """ 22 | Convert mhd files to nii files 23 | """ 24 | image = sitk.ReadImage(mhd_path) 25 | sitk.WriteImage(image, nii_path) 26 | 27 | def nii2nii(nii_path, nii_gz_path): 28 | """ 29 | Convert nii files to nii.gz files, which can reduce the file size 30 | """ 31 | image = sitk.ReadImage(nii_path) 32 | sitk.WriteImage(image, nii_gz_path) 33 | 34 | def nrrd2nii(nrrd_path, nii_path): 35 | """ 36 | Convert nrrd files to nii files 37 | """ 38 | image = sitk.ReadImage(nrrd_path) 39 | sitk.WriteImage(image, nii_path) 40 | 41 | def jpg2png(jpg_path, png_path): 42 | """ 43 | Convert jpg files to png files 44 | """ 45 | image = io.imread(jpg_path) 46 | io.imsave(png_path, image) 47 | 48 | def patchfy(img, mask, outpath, basename): 49 | """ 50 | Patchfy the image and mask into 1024x1024 patches 51 | """ 52 | image_patch_dir = join(outpath, "images") 53 | mask_patch_dir = join(outpath, "labels") 54 | os.makedirs(image_patch_dir, exist_ok=True) 55 | os.makedirs(mask_patch_dir, exist_ok=True) 56 | assert img.shape[:2] == mask.shape 57 | patch_height = 1024 58 | patch_width = 1024 59 | 60 | img_height, img_width = img.shape[:2] 61 | mask_height, mask_width = mask.shape 62 | 63 | if img_height % patch_height != 0: 64 | img = np.pad(img, ((0, patch_height - img_height % patch_height), (0, 0), (0, 0)), mode="constant") 65 | if img_width % patch_width != 0: 66 | img = np.pad(img, ((0, 0), (0, patch_width - img_width % patch_width), (0, 0)), mode="constant") 67 | if mask_height % patch_height != 0: 68 | mask = np.pad(mask, ((0, patch_height - mask_height % patch_height), (0, 0)), mode="constant") 69 | if mask_width % patch_width != 0: 70 | mask = np.pad(mask, ((0, 0), (0, patch_width - mask_width % patch_width)), mode="constant") 71 | 72 | assert img.shape[:2] == mask.shape 73 | assert img.shape[0] % patch_height == 0 74 | assert img.shape[1] % patch_width == 0 75 | assert mask.shape[0] % patch_height == 0 76 | assert mask.shape[1] % patch_width == 0 77 | 78 | height_steps = (img_height // patch_height) if img_height % patch_height == 0 else (img_height // patch_height + 1) 79 | width_steps = (img_width // patch_width) if img_width % patch_width == 0 else (img_width // patch_width + 1) 80 | 81 | for i in range(height_steps): 82 | for j in range(width_steps): 83 | img_patch = img[i * patch_height:(i + 1) * patch_height, j * patch_width:(j + 1) * patch_width, :] 84 | mask_patch = mask[i * patch_height:(i + 1) * patch_height, j * patch_width:(j + 1) * patch_width] 85 | assert img_patch.shape[:2] == mask_patch.shape 86 | assert img_patch.shape[0] == patch_height 87 | assert img_patch.shape[1] == patch_width 88 | print(f"img_patch.shape: {img_patch.shape}, mask_patch.shape: {mask_patch.shape}") 89 | img_patch_path = join(image_patch_dir, f"{basename}_{i}_{j}.png") 90 | mask_patch_path = join(mask_patch_dir, f"{basename}_{i}_{j}.png") 91 | io.imsave(img_patch_path, img_patch) 92 | io.imsave(mask_patch_path, mask_patch) 93 | 94 | 95 | def rle_decode(mask_rle, img_shape): 96 | """ 97 | #functions to convert encoding to mask and mask to encoding 98 | mask_rle: run-length as string formated (start length) 99 | shape: (height,width) of array to return 100 | Returns numpy array, 1 - mask, 0 - background 101 | """ 102 | seq = mask_rle.split() 103 | starts = np.array(list(map(int, seq[0::2]))) 104 | lengths = np.array(list(map(int, seq[1::2]))) 105 | assert len(starts) == len(lengths) 106 | ends = starts + lengths 107 | img = np.zeros((np.product(img_shape),), dtype=np.uint8) 108 | for begin, end in zip(starts, ends): 109 | img[begin:end] = 255 110 | # https://stackoverflow.com/a/46574906/4521646 111 | img.shape = img_shape 112 | return img.T -------------------------------------------------------------------------------- /src/dafne/MedSAM/utils/split.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import os 3 | join = os.path.join 4 | import random 5 | 6 | path_nii = '' # please complete path; two subfolders: images and labels 7 | path_video = None # or specify the path 8 | path_2d = None # or specify the path 9 | 10 | #%% split 3D nii data 11 | if path_nii is not None: 12 | img_path = join(path_nii, 'images') 13 | gt_path = join(path_nii, 'labels') 14 | gt_names = sorted(os.listdir(gt_path)) 15 | img_suffix = '_0000.nii.gz' 16 | gt_suffix = '.nii.gz' 17 | # split 20% data for validation and testing 18 | validation_path = join(path_nii, 'validation') 19 | os.makedirs(join(validation_path, 'images'), exist_ok=True) 20 | os.makedirs(join(validation_path, 'labels'), exist_ok=True) 21 | testing_path = join(path_nii, 'testing') 22 | os.makedirs(join(testing_path, 'images'), exist_ok=True) 23 | os.makedirs(join(testing_path, 'labels'), exist_ok=True) 24 | candidates = random.sample(gt_names, int(len(gt_names)*0.2)) 25 | # split half of test names for validation 26 | validation_names = random.sample(candidates, int(len(candidates)*0.5)) 27 | test_names = [name for name in candidates if name not in validation_names] 28 | # move validation and testing data to corresponding folders 29 | for name in validation_names: 30 | img_name = name.split(gt_suffix)[0] + img_suffix 31 | os.rename(join(img_path, img_name), join(validation_path, 'images', img_name)) 32 | os.rename(join(gt_path, name), join(validation_path, 'labels', name)) 33 | for name in test_names: 34 | img_name = name.split(gt_suffix)[0] + img_suffix 35 | os.rename(join(img_path, img_name), join(testing_path, 'images', img_name)) 36 | os.rename(join(gt_path, name), join(testing_path, 'labels', name)) 37 | 38 | 39 | ##% split 2D images 40 | if path_2d is not None: 41 | img_path = join(path_2d, 'images') 42 | gt_path = join(path_2d, 'labels') 43 | gt_names = sorted(os.listdir(gt_path)) 44 | img_suffix = '.png' 45 | gt_suffix = '.png' 46 | # split 20% data for validation and testing 47 | validation_path = join(path_2d, 'validation') 48 | os.makedirs(join(validation_path, 'images'), exist_ok=True) 49 | os.makedirs(join(validation_path, 'labels'), exist_ok=True) 50 | testing_path = join(path_2d, 'testing') 51 | os.makedirs(join(testing_path, 'images'), exist_ok=True) 52 | os.makedirs(join(testing_path, 'labels'), exist_ok=True) 53 | candidates = random.sample(gt_names, int(len(gt_names)*0.2)) 54 | # split half of test names for validation 55 | validation_names = random.sample(candidates, int(len(candidates)*0.5)) 56 | test_names = [name for name in candidates if name not in validation_names] 57 | # move validation and testing data to corresponding folders 58 | for name in validation_names: 59 | img_name = name.split(gt_suffix)[0] + img_suffix 60 | os.rename(join(img_path, img_name), join(validation_path, 'images', img_name)) 61 | os.rename(join(gt_path, name), join(validation_path, 'labels', name)) 62 | 63 | for name in test_names: 64 | img_name = name.split(gt_suffix)[0] + img_suffix 65 | os.rename(join(img_path, img_name), join(testing_path, 'images', img_name)) 66 | os.rename(join(gt_path, name), join(testing_path, 'labels', name)) 67 | 68 | #%% split video data 69 | if path_video is not None: 70 | img_path = join(path_video, 'images') 71 | gt_path = join(path_video, 'labels') 72 | gt_folders = sorted(os.listdir(gt_path)) 73 | # split 20% videos for validation and testing 74 | validation_path = join(path_video, 'validation') 75 | os.makedirs(join(validation_path, 'images'), exist_ok=True) 76 | os.makedirs(join(validation_path, 'labels'), exist_ok=True) 77 | testing_path = join(path_video, 'testing') 78 | os.makedirs(join(testing_path, 'images'), exist_ok=True) 79 | os.makedirs(join(testing_path, 'labels'), exist_ok=True) 80 | candidates = random.sample(gt_folders, int(len(gt_folders)*0.2)) 81 | # split half of test names for validation 82 | validation_names = random.sample(candidates, int(len(candidates)*0.5)) 83 | test_names = [name for name in candidates if name not in validation_names] 84 | # move validation and testing data to corresponding folders 85 | for name in validation_names: 86 | os.rename(join(img_path, name), join(validation_path, 'images', name)) 87 | os.rename(join(gt_path, name), join(validation_path, 'labels', name)) 88 | for name in test_names: 89 | os.rename(join(img_path, name), join(testing_path, 'images', name)) 90 | os.rename(join(gt_path, name), join(testing_path, 'labels', name)) 91 | -------------------------------------------------------------------------------- /src/dafne/ui/CalcTransformsUI.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | # Form implementation generated from reading ui file 'CalcTransformsUI.ui' 4 | # 5 | # Created by: PyQt5 UI code generator 5.15.7 6 | # 7 | # WARNING: Any manual changes made to this file will be lost when pyuic5 is 8 | # run again. Do not edit this file unless you know what you are doing. 9 | 10 | 11 | from PyQt5 import QtCore, QtGui, QtWidgets 12 | 13 | 14 | class Ui_CalcTransformsUI(object): 15 | def setupUi(self, CalcTransformsUI): 16 | CalcTransformsUI.setObjectName("CalcTransformsUI") 17 | CalcTransformsUI.resize(412, 218) 18 | self.verticalLayout = QtWidgets.QVBoxLayout(CalcTransformsUI) 19 | self.verticalLayout.setObjectName("verticalLayout") 20 | self.horizontalLayout = QtWidgets.QHBoxLayout() 21 | self.horizontalLayout.setObjectName("horizontalLayout") 22 | self.label = QtWidgets.QLabel(CalcTransformsUI) 23 | sizePolicy = QtWidgets.QSizePolicy(QtWidgets.QSizePolicy.Preferred, QtWidgets.QSizePolicy.Fixed) 24 | sizePolicy.setHorizontalStretch(0) 25 | sizePolicy.setVerticalStretch(0) 26 | sizePolicy.setHeightForWidth(self.label.sizePolicy().hasHeightForWidth()) 27 | self.label.setSizePolicy(sizePolicy) 28 | self.label.setObjectName("label") 29 | self.horizontalLayout.addWidget(self.label) 30 | self.location_Text = QtWidgets.QLineEdit(CalcTransformsUI) 31 | self.location_Text.setEnabled(False) 32 | self.location_Text.setObjectName("location_Text") 33 | self.horizontalLayout.addWidget(self.location_Text) 34 | self.choose_Button = QtWidgets.QPushButton(CalcTransformsUI) 35 | self.choose_Button.setObjectName("choose_Button") 36 | self.horizontalLayout.addWidget(self.choose_Button) 37 | self.verticalLayout.addLayout(self.horizontalLayout) 38 | self.progressBar = QtWidgets.QProgressBar(CalcTransformsUI) 39 | self.progressBar.setEnabled(False) 40 | self.progressBar.setProperty("value", 0) 41 | self.progressBar.setObjectName("progressBar") 42 | self.verticalLayout.addWidget(self.progressBar) 43 | self.orientationBox = QtWidgets.QGroupBox(CalcTransformsUI) 44 | self.orientationBox.setObjectName("orientationBox") 45 | self.horizontalLayout_2 = QtWidgets.QHBoxLayout(self.orientationBox) 46 | self.horizontalLayout_2.setObjectName("horizontalLayout_2") 47 | self.original_radio = QtWidgets.QRadioButton(self.orientationBox) 48 | self.original_radio.setChecked(True) 49 | self.original_radio.setObjectName("original_radio") 50 | self.horizontalLayout_2.addWidget(self.original_radio) 51 | self.axial_radio = QtWidgets.QRadioButton(self.orientationBox) 52 | self.axial_radio.setObjectName("axial_radio") 53 | self.horizontalLayout_2.addWidget(self.axial_radio) 54 | self.sagittal_radio = QtWidgets.QRadioButton(self.orientationBox) 55 | self.sagittal_radio.setObjectName("sagittal_radio") 56 | self.horizontalLayout_2.addWidget(self.sagittal_radio) 57 | self.coronal_radio = QtWidgets.QRadioButton(self.orientationBox) 58 | self.coronal_radio.setObjectName("coronal_radio") 59 | self.horizontalLayout_2.addWidget(self.coronal_radio) 60 | self.verticalLayout.addWidget(self.orientationBox) 61 | spacerItem = QtWidgets.QSpacerItem(20, 45, QtWidgets.QSizePolicy.Minimum, QtWidgets.QSizePolicy.Expanding) 62 | self.verticalLayout.addItem(spacerItem) 63 | self.calculate_button = QtWidgets.QPushButton(CalcTransformsUI) 64 | self.calculate_button.setEnabled(False) 65 | self.calculate_button.setObjectName("calculate_button") 66 | self.verticalLayout.addWidget(self.calculate_button) 67 | 68 | self.retranslateUi(CalcTransformsUI) 69 | QtCore.QMetaObject.connectSlotsByName(CalcTransformsUI) 70 | 71 | def retranslateUi(self, CalcTransformsUI): 72 | _translate = QtCore.QCoreApplication.translate 73 | CalcTransformsUI.setWindowTitle(_translate("CalcTransformsUI", "Form")) 74 | self.label.setText(_translate("CalcTransformsUI", "Location:")) 75 | self.choose_Button.setText(_translate("CalcTransformsUI", "Choose...")) 76 | self.orientationBox.setTitle(_translate("CalcTransformsUI", "Orientation")) 77 | self.original_radio.setText(_translate("CalcTransformsUI", "Original")) 78 | self.axial_radio.setText(_translate("CalcTransformsUI", "Axial")) 79 | self.sagittal_radio.setText(_translate("CalcTransformsUI", "Sagittal")) 80 | self.coronal_radio.setText(_translate("CalcTransformsUI", "Coronal")) 81 | self.calculate_button.setText(_translate("CalcTransformsUI", "Calculate Transforms")) 82 | -------------------------------------------------------------------------------- /install_scripts/fix_app_bundle_for_mac.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Code from https://github.com/pyinstaller/pyinstaller/wiki/Recipe-OSX-Code-Signing-Qt 3 | # 4 | 5 | import os 6 | import shutil 7 | import sys 8 | from pathlib import Path 9 | from typing import Generator, List, Optional 10 | 11 | from macholib.MachO import MachO 12 | 13 | 14 | def create_symlink(folder: Path) -> None: 15 | """Create the appropriate symlink in the MacOS folder 16 | pointing to the Resources folder. 17 | """ 18 | sibbling = Path(str(folder).replace("MacOS", "")) 19 | 20 | # PyQt5/Qt/qml/QtQml/Models.2 21 | root = str(sibbling).partition("Contents")[2].lstrip("/") 22 | # ../../../../ 23 | backward = "../" * (root.count("/") + 1) 24 | # ../../../../Resources/PyQt5/Qt/qml/QtQml/Models.2 25 | good_path = f"{backward}Resources/{root}" 26 | 27 | folder.symlink_to(good_path) 28 | 29 | 30 | def fix_dll(dll: Path) -> None: 31 | """Fix the DLL lookup paths to use relative ones for Qt dependencies. 32 | Inspiration: PyInstaller/depend/dylib.py:mac_set_relative_dylib_deps() 33 | Currently one header is pointing to (we are in the Resources folder): 34 | @loader_path/../../../../QtCore (it is referencing to the old MacOS folder) 35 | It will be converted to: 36 | @loader_path/../../../../../../MacOS/QtCore 37 | """ 38 | 39 | def match_func(pth: str) -> Optional[str]: 40 | """Callback function for MachO.rewriteLoadCommands() that is 41 | called on every lookup path setted in the DLL headers. 42 | By returning None for system libraries, it changes nothing. 43 | Else we return a relative path pointing to the good file 44 | in the MacOS folder. 45 | """ 46 | basename = os.path.basename(pth) 47 | if not basename.startswith("Qt"): 48 | return None 49 | return f"@loader_path{good_path}/{basename}" 50 | 51 | # Resources/PyQt5/Qt/qml/QtQuick/Controls.2/Fusion 52 | root = str(dll.parent).partition("Contents")[2][1:] 53 | # /../../../../../../.. 54 | backward = "/.." * (root.count("/") + 1) 55 | # /../../../../../../../MacOS 56 | good_path = f"{backward}/MacOS" 57 | 58 | # Rewrite Mach headers with corrected @loader_path 59 | dll = MachO(dll) 60 | dll.rewriteLoadCommands(match_func) 61 | with open(dll.filename, "rb+") as f: 62 | for header in dll.headers: 63 | f.seek(0) 64 | dll.write(f) 65 | f.seek(0, 2) 66 | f.flush() 67 | 68 | 69 | def find_problematic_folders(folder: Path) -> Generator[Path, None, None]: 70 | """Recursively yields problematic folders (containing a dot in their name).""" 71 | for path in folder.iterdir(): 72 | if not path.is_dir() or path.is_symlink(): 73 | # Skip simlinks as they are allowed (even with a dot) 74 | continue 75 | if "." in path.name: 76 | yield path 77 | else: 78 | yield from find_problematic_folders(path) 79 | 80 | 81 | def move_contents_to_resources(folder: Path) -> Generator[Path, None, None]: 82 | """Recursively move any non symlink file from a problematic folder 83 | to the sibbling one in Resources. 84 | """ 85 | for path in folder.iterdir(): 86 | if path.is_symlink(): 87 | continue 88 | if path.name == "qml": 89 | yield from move_contents_to_resources(path) 90 | else: 91 | sibbling = Path(str(path).replace("MacOS", "Resources")) 92 | sibbling.parent.mkdir(parents=True, exist_ok=True) 93 | shutil.move(path, sibbling) 94 | yield sibbling 95 | 96 | 97 | def main(args: List[str]) -> int: 98 | """ 99 | Fix the application to allow codesign (NXDRIVE-1301). 100 | Take one or more .app as arguments: "Nuxeo Drive.app". 101 | To overall process will: 102 | - move problematic folders from MacOS to Resources 103 | - fix the DLLs lookup paths 104 | - create the appropriate symbolic link 105 | """ 106 | for app in args: 107 | name = os.path.basename(app) 108 | print(f">>> [{name}] Fixing Qt folder names") 109 | path = Path(app) / "Contents" / "MacOS" 110 | for folder in find_problematic_folders(path): 111 | for file in move_contents_to_resources(folder): 112 | try: 113 | fix_dll(file) 114 | except (ValueError, IsADirectoryError): 115 | continue 116 | shutil.rmtree(folder) 117 | create_symlink(folder) 118 | print(f" !! Fixed {folder}") 119 | print(f">>> [{name}] Application fixed.") 120 | 121 | 122 | if __name__ == "__main__": 123 | sys.exit(main(sys.argv[1:])) -------------------------------------------------------------------------------- /src/dafne/MedSAM/MedSAM_Inference.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # %% load environment 3 | import numpy as np 4 | import matplotlib.pyplot as plt 5 | import os 6 | 7 | join = os.path.join 8 | import torch 9 | from segment_anything import sam_model_registry 10 | from skimage import io, transform 11 | import torch.nn.functional as F 12 | import argparse 13 | 14 | 15 | # visualization functions 16 | # source: https://github.com/facebookresearch/segment-anything/blob/main/notebooks/predictor_example.ipynb 17 | # change color to avoid red and green 18 | def show_mask(mask, ax, random_color=False): 19 | if random_color: 20 | color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0) 21 | else: 22 | color = np.array([251 / 255, 252 / 255, 30 / 255, 0.6]) 23 | h, w = mask.shape[-2:] 24 | mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1) 25 | ax.imshow(mask_image) 26 | 27 | 28 | def show_box(box, ax): 29 | x0, y0 = box[0], box[1] 30 | w, h = box[2] - box[0], box[3] - box[1] 31 | ax.add_patch( 32 | plt.Rectangle((x0, y0), w, h, edgecolor="blue", facecolor=(0, 0, 0, 0), lw=2) 33 | ) 34 | 35 | 36 | @torch.no_grad() 37 | def medsam_inference(medsam_model, img_embed, box_1024, H, W): 38 | box_torch = torch.as_tensor(box_1024, dtype=torch.float, device=img_embed.device) 39 | if len(box_torch.shape) == 2: 40 | box_torch = box_torch[:, None, :] # (B, 1, 4) 41 | 42 | sparse_embeddings, dense_embeddings = medsam_model.prompt_encoder( 43 | points=None, 44 | boxes=box_torch, 45 | masks=None, 46 | ) 47 | low_res_logits, _ = medsam_model.mask_decoder( 48 | image_embeddings=img_embed, # (B, 256, 64, 64) 49 | image_pe=medsam_model.prompt_encoder.get_dense_pe(), # (1, 256, 64, 64) 50 | sparse_prompt_embeddings=sparse_embeddings, # (B, 2, 256) 51 | dense_prompt_embeddings=dense_embeddings, # (B, 256, 64, 64) 52 | multimask_output=False, 53 | ) 54 | 55 | low_res_pred = torch.sigmoid(low_res_logits) # (1, 1, 256, 256) 56 | 57 | low_res_pred = F.interpolate( 58 | low_res_pred, 59 | size=(H, W), 60 | mode="bilinear", 61 | align_corners=False, 62 | ) # (1, 1, gt.shape) 63 | low_res_pred = low_res_pred.squeeze().cpu().numpy() # (256, 256) 64 | medsam_seg = (low_res_pred > 0.5).astype(np.uint8) 65 | return medsam_seg 66 | 67 | 68 | # %% load model and image 69 | parser = argparse.ArgumentParser( 70 | description="run inference on testing set based on MedSAM" 71 | ) 72 | parser.add_argument( 73 | "-i", 74 | "--data_path", 75 | type=str, 76 | default="assets/img_demo.png", 77 | help="path to the data folder", 78 | ) 79 | parser.add_argument( 80 | "-o", 81 | "--seg_path", 82 | type=str, 83 | default="assets/", 84 | help="path to the segmentation folder", 85 | ) 86 | parser.add_argument( 87 | "--box", 88 | type=list, 89 | default=[95, 255, 190, 350], 90 | help="bounding box of the segmentation target", 91 | ) 92 | parser.add_argument("--device", type=str, default="cuda:0", help="device") 93 | parser.add_argument( 94 | "-chk", 95 | "--checkpoint", 96 | type=str, 97 | default="work_dir/MedSAM/medsam_vit_b.pth", 98 | help="path to the trained model", 99 | ) 100 | args = parser.parse_args() 101 | 102 | device = args.device 103 | medsam_model = sam_model_registry["vit_b"](checkpoint=args.checkpoint) 104 | medsam_model = medsam_model.to(device) 105 | medsam_model.eval() 106 | 107 | img_np = io.imread(args.data_path) 108 | if len(img_np.shape) == 2: 109 | img_3c = np.repeat(img_np[:, :, None], 3, axis=-1) 110 | else: 111 | img_3c = img_np 112 | H, W, _ = img_3c.shape 113 | # %% image preprocessing 114 | img_1024 = transform.resize( 115 | img_3c, (1024, 1024), order=3, preserve_range=True, anti_aliasing=True 116 | ).astype(np.uint8) 117 | img_1024 = (img_1024 - img_1024.min()) / np.clip( 118 | img_1024.max() - img_1024.min(), a_min=1e-8, a_max=None 119 | ) # normalize to [0, 1], (H, W, 3) 120 | # convert the shape to (3, H, W) 121 | img_1024_tensor = ( 122 | torch.tensor(img_1024).float().permute(2, 0, 1).unsqueeze(0).to(device) 123 | ) 124 | 125 | box_np = np.array([args.box]) 126 | # transfer box_np t0 1024x1024 scale 127 | box_1024 = box_np / np.array([W, H, W, H]) * 1024 128 | with torch.no_grad(): 129 | image_embedding = medsam_model.image_encoder(img_1024_tensor) # (1, 256, 64, 64) 130 | 131 | medsam_seg = medsam_inference(medsam_model, image_embedding, box_1024, H, W) 132 | io.imsave( 133 | join(args.seg_path, "seg_" + os.path.basename(args.data_path)), 134 | medsam_seg, 135 | check_contrast=False, 136 | ) 137 | 138 | # %% visualize results 139 | fig, ax = plt.subplots(1, 2, figsize=(10, 5)) 140 | ax[0].imshow(img_3c) 141 | show_box(box_np[0], ax[0]) 142 | ax[0].set_title("Input Image and Bounding Box") 143 | ax[1].imshow(img_3c) 144 | show_mask(medsam_seg, ax[1]) 145 | show_box(box_np[0], ax[1]) 146 | ax[1].set_title("MedSAM Segmentation") 147 | plt.show() 148 | -------------------------------------------------------------------------------- /src/dafne/MedSAM/segment_anything/build_sam.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # All rights reserved. 4 | 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | from functools import partial 8 | from pathlib import Path 9 | import urllib.request 10 | import torch 11 | 12 | from .modeling import ( 13 | ImageEncoderViT, 14 | MaskDecoder, 15 | PromptEncoder, 16 | Sam, 17 | TwoWayTransformer, 18 | ) 19 | 20 | 21 | def build_sam_vit_h(checkpoint=None): 22 | return _build_sam( 23 | encoder_embed_dim=1280, 24 | encoder_depth=32, 25 | encoder_num_heads=16, 26 | encoder_global_attn_indexes=[7, 15, 23, 31], 27 | checkpoint=checkpoint, 28 | ) 29 | 30 | 31 | build_sam = build_sam_vit_h 32 | 33 | 34 | def build_sam_vit_l(checkpoint=None): 35 | return _build_sam( 36 | encoder_embed_dim=1024, 37 | encoder_depth=24, 38 | encoder_num_heads=16, 39 | encoder_global_attn_indexes=[5, 11, 17, 23], 40 | checkpoint=checkpoint, 41 | ) 42 | 43 | 44 | def build_sam_vit_b(checkpoint=None): 45 | return _build_sam( 46 | encoder_embed_dim=768, 47 | encoder_depth=12, 48 | encoder_num_heads=12, 49 | encoder_global_attn_indexes=[2, 5, 8, 11], 50 | checkpoint=checkpoint, 51 | ) 52 | 53 | 54 | sam_model_registry = { 55 | "default": build_sam_vit_h, 56 | "vit_h": build_sam_vit_h, 57 | "vit_l": build_sam_vit_l, 58 | "vit_b": build_sam_vit_b, 59 | } 60 | 61 | 62 | def _build_sam( 63 | encoder_embed_dim, 64 | encoder_depth, 65 | encoder_num_heads, 66 | encoder_global_attn_indexes, 67 | checkpoint=None, 68 | ): 69 | prompt_embed_dim = 256 70 | image_size = 1024 71 | vit_patch_size = 16 72 | image_embedding_size = image_size // vit_patch_size 73 | sam = Sam( 74 | image_encoder=ImageEncoderViT( 75 | depth=encoder_depth, 76 | embed_dim=encoder_embed_dim, 77 | img_size=image_size, 78 | mlp_ratio=4, 79 | norm_layer=partial(torch.nn.LayerNorm, eps=1e-6), 80 | num_heads=encoder_num_heads, 81 | patch_size=vit_patch_size, 82 | qkv_bias=True, 83 | use_rel_pos=True, 84 | global_attn_indexes=encoder_global_attn_indexes, 85 | window_size=14, 86 | out_chans=prompt_embed_dim, 87 | ), 88 | prompt_encoder=PromptEncoder( 89 | embed_dim=prompt_embed_dim, 90 | image_embedding_size=(image_embedding_size, image_embedding_size), 91 | input_image_size=(image_size, image_size), 92 | mask_in_chans=16, 93 | ), 94 | mask_decoder=MaskDecoder( 95 | num_multimask_outputs=3, 96 | transformer=TwoWayTransformer( 97 | depth=2, 98 | embedding_dim=prompt_embed_dim, 99 | mlp_dim=2048, 100 | num_heads=8, 101 | ), 102 | transformer_dim=prompt_embed_dim, 103 | iou_head_depth=3, 104 | iou_head_hidden_dim=256, 105 | ), 106 | pixel_mean=[123.675, 116.28, 103.53], 107 | pixel_std=[58.395, 57.12, 57.375], 108 | ) 109 | sam.eval() 110 | checkpoint = Path(checkpoint) 111 | if checkpoint.name == "sam_vit_b_01ec64.pth" and not checkpoint.exists(): 112 | cmd = input("Download sam_vit_b_01ec64.pth from facebook AI? [y]/n: ") 113 | if len(cmd) == 0 or cmd.lower() == "y": 114 | checkpoint.parent.mkdir(parents=True, exist_ok=True) 115 | print("Downloading SAM ViT-B checkpoint...") 116 | urllib.request.urlretrieve( 117 | "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth", 118 | checkpoint, 119 | ) 120 | print(checkpoint.name, " is downloaded!") 121 | elif checkpoint.name == "sam_vit_h_4b8939.pth" and not checkpoint.exists(): 122 | cmd = input("Download sam_vit_h_4b8939.pth from facebook AI? [y]/n: ") 123 | if len(cmd) == 0 or cmd.lower() == "y": 124 | checkpoint.parent.mkdir(parents=True, exist_ok=True) 125 | print("Downloading SAM ViT-H checkpoint...") 126 | urllib.request.urlretrieve( 127 | "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth", 128 | checkpoint, 129 | ) 130 | print(checkpoint.name, " is downloaded!") 131 | elif checkpoint.name == "sam_vit_l_0b3195.pth" and not checkpoint.exists(): 132 | cmd = input("Download sam_vit_l_0b3195.pth from facebook AI? [y]/n: ") 133 | if len(cmd) == 0 or cmd.lower() == "y": 134 | checkpoint.parent.mkdir(parents=True, exist_ok=True) 135 | print("Downloading SAM ViT-L checkpoint...") 136 | urllib.request.urlretrieve( 137 | "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_l_0b3195.pth", 138 | checkpoint, 139 | ) 140 | print(checkpoint.name, " is downloaded!") 141 | 142 | if checkpoint is not None: 143 | with open(checkpoint, "rb") as f: 144 | state_dict = torch.load(f, map_location=torch.device('cpu')) 145 | sam.load_state_dict(state_dict) 146 | return sam 147 | -------------------------------------------------------------------------------- /src/dafne/MedSAM/utils/README.md: -------------------------------------------------------------------------------- 1 | # Introduction 2 | This folder contains the code for data organization, splitting, preprocessing, and checkpoint converting. 3 | 4 | 5 | ## Data Organization 6 | Since the orginal data formats and folder structures vary greatly across different dataset, we need to organize them as unified structures, allowing to use the same functions for data preprocessing. 7 | The expected folder structures are as follows: 8 | 9 | 3D nii data 10 | ``` 11 | ----dataset_name 12 | --------images 13 | ------------xxxx_0000.nii.gz 14 | ------------xxxx_0000.nii.gz 15 | --------labels 16 | ------------xxxx.nii.gz 17 | ------------xxxx.nii.gz 18 | ``` 19 | Note: you can also use different suffix for images and labels. Please change them in the following preprocessing scripts as well. 20 | 21 | 2D data 22 | ``` 23 | ----dataset_name 24 | --------images 25 | ------------xxxx.jpg/png 26 | ------------xxxx.jpg/png 27 | --------labels 28 | ------------xxxx.png 29 | ------------xxxx.jpg/png 30 | ``` 31 | 32 | video data 33 | ``` 34 | ----dataset_name 35 | --------images 36 | ------------video1 37 | ----------------xxxx.png 38 | ------------video2 39 | ----------------xxxx.png 40 | --------labels 41 | ------------video1 42 | ----------------xxxx.png 43 | ------------video2 44 | ----------------xxxx.png 45 | ``` 46 | 47 | Unfortunately, it is impossible to have one script to finish all the data organization. We manually organized with commonly used data format converting functions, including `dcm2nii`, `mhd2nii`, `nii2nii`, `nrrd2nii`, `jpg2png`, `tif2png`, `rle_decode`. These functions are available at `format_convert.py` 48 | 49 | ## Data Splitting 50 | For common 2D images (e.g., skin cancer demoscopy, chest X-Ray), they can be directly separated into 80%/10%/10% for training, parameter tuning, and internal validation, respectively. For 3D images (e.g., all the MRI/CT scans) and video data, they should be split in the case/video level rather than 2D slice/frame level. For 2D whole-slide images, the splitting is in the whole-slide level. Since they cannot be directly sent to the model because of the high resolution, we divided them into patches with a fixed size of `1024x1024` after data splitting. 51 | 52 | After finishing the data organization, the data splitting can be easily done by running 53 | ```bash 54 | python split.py 55 | ``` 56 | Please set the proper data path in the script. The expected folder structures (e.g., 3D images) are 57 | 58 | ``` 59 | ----dataset_name 60 | --------images 61 | ------------xxxx_0000.nii.gz 62 | ------------xxxx_0000.nii.gz 63 | --------labels 64 | ------------xxxx.nii.gz 65 | ------------xxxx.nii.gz 66 | --------validation 67 | ------------images 68 | ----------------xxxx_0000.nii.gz 69 | ----------------xxxx_0000.nii.gz 70 | ------------labels 71 | ----------------xxxx.nii.gz 72 | ----------------xxxx.nii.gz 73 | --------testing 74 | ------------images 75 | ----------------xxxx_0000.nii.gz 76 | ----------------xxxx_0000.nii.gz 77 | ------------labels 78 | ----------------xxxx.nii.gz 79 | ----------------xxxx.nii.gz 80 | ``` 81 | 82 | ## Data Preprocessing and Ensembling 83 | 84 | All the images will be preprocessed as `npy` files. There are two main reasons for choosing this format. First, it allows fast data loading (main reason). We learned this point from [nnU-Net](https://github.com/MIC-DKFZ/nnUNet). Second, numpy file is a universal data interface to unify all the different data formats. For the convenience of debugging and inference, we also saved the original images and labels as `npz` files. Spacing information is also saved for CT and MR images. 85 | 86 | The following steps are applied to all images 87 | - max-min normalization 88 | - resample image size to 1024x2014 89 | - save the pre-processed images and labels as npy files 90 | 91 | Different modalities also have their own additional pre-process steps based on the data features. 92 | 93 | For CT images, we fist adjust the window level and width following the [common practice](https://radiopaedia.org/articles/windowing-ct). 94 | - Soft tissue window level (40) and width (400) 95 | - Chest window level (-600) and width (1500) 96 | - Brain window level (40) and width (80) 97 | 98 | For MR and ultrasound, mammography, and Optical Coherence Tomography (OCT) images (i.e., ultrasound), we apply intensity cut-off with 0.5 and 99.5 percentiles of the foreground voxels. Regarding RGB images (e.g., endoscopy, dermoscopy, fundus, and pathology images), if they are already within the expected intensity range of [0, 255], their intensities remained unchanged. However, if they fell outside this range, max-min normalization was applited to rescale the intensity values to [0, 255]. 99 | 100 | Preprocess for CT/MR images: 101 | ```bash 102 | python pre_CT_MR.py 103 | ``` 104 | 105 | Preprocess for grey and RGB images: 106 | ```bash 107 | python pre_grey_rgb.py 108 | ``` 109 | 110 | Note: Please set the corresponding folder path and molidaty information. We provided an example in the script. 111 | 112 | Data ensembling of different training datasets is very simple. Since all the training data are converted into `npy` files during preprocessing, you just need to merge them into one folder. 113 | 114 | 115 | ## Checkpoint Converting 116 | If the model is trained with multiple GPUs, please use the script `ckpt_convert.py` to convert the format since users only use one GPU for model inference in real practice. 117 | 118 | Set the path to `sam_ckpt_path`, `medsam_ckpt_path`, and `save_path` and run 119 | 120 | ```bash 121 | python ckpt_convert.py 122 | ``` 123 | 124 | -------------------------------------------------------------------------------- /icons/mac_installer_bg.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 19 | 42 | 44 | 52 | 57 | 58 | 66 | 71 | 72 | 80 | 85 | 86 | 87 | 92 | 102 | 109 | 115 | 117 | 121 | 125 | 126 | 127 | 128 | 129 | -------------------------------------------------------------------------------- /src/dafne/ui/WhatsNew.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022 Dafne-Imaging Team 2 | from datetime import datetime 3 | import xml.etree.ElementTree as ET 4 | from urllib.parse import urlparse 5 | 6 | import requests 7 | from PyQt5.QtCore import Qt, QObject, pyqtSignal 8 | from PyQt5.QtWidgets import QDialog, QVBoxLayout, QPushButton, QLabel, QSizePolicy 9 | 10 | from ..config import GlobalConfig, save_config 11 | from ..utils.ThreadHelpers import separate_thread_decorator 12 | 13 | MAX_NEWS_ITEMS = 3 14 | BLOG_INDEX = '/blog/index/' 15 | BLOG_ADDRESS = '/blog/' 16 | 17 | 18 | class WhatsNewDialog(QDialog): 19 | def __init__(self, news_list, index_address, *args, **kwargs): 20 | QDialog.__init__(self, *args, **kwargs) 21 | my_layout = QVBoxLayout(self) 22 | self.setLayout(my_layout) 23 | self.setWindowTitle(f"Dafne News") 24 | self.setWindowModality(Qt.ApplicationModal) 25 | 26 | for news in news_list[:MAX_NEWS_ITEMS]: 27 | title_label = QLabel(f'

{news["title"]}

') 28 | title_label.setOpenExternalLinks(True) 29 | title_label.setTextInteractionFlags(Qt.LinksAccessibleByMouse) 30 | title_label.sizePolicy().setVerticalStretch(0) 31 | title_label.setWordWrap(True) 32 | my_layout.addWidget(title_label) 33 | date_label = QLabel(f'{news["date"]}') 34 | date_label.sizePolicy().setVerticalStretch(0) 35 | my_layout.addWidget(date_label) 36 | body_label = QLabel(news["excerpt"]) 37 | body_label.setWordWrap(True) 38 | size_policy = QSizePolicy(QSizePolicy.Preferred, QSizePolicy.Preferred) 39 | size_policy.setHorizontalStretch(0) 40 | size_policy.setVerticalStretch(1) 41 | size_policy.setHeightForWidth(body_label.sizePolicy().hasHeightForWidth()) 42 | body_label.setAlignment(Qt.AlignLeading | Qt.AlignLeft | Qt.AlignTop) 43 | body_label.setSizePolicy(size_policy) 44 | my_layout.addWidget(body_label) 45 | 46 | more_news_label = QLabel(f'All news...') 47 | more_news_label.setOpenExternalLinks(True) 48 | more_news_label.setTextInteractionFlags(Qt.LinksAccessibleByMouse) 49 | more_news_label.sizePolicy().setVerticalStretch(0) 50 | my_layout.addWidget(more_news_label) 51 | 52 | n_news = min(MAX_NEWS_ITEMS, len(news_list)) 53 | 54 | btn = QPushButton("Close") 55 | btn.clicked.connect(self.close) 56 | my_layout.addWidget(btn) 57 | self.resize(300, 110 * n_news + 60) 58 | self.show() 59 | 60 | 61 | def xml_timestamp_to_datetime(timestamp): 62 | return datetime.strptime(timestamp, '%Y-%m-%dT%H:%M:%S%z') 63 | 64 | 65 | def datetime_to_xml_timestamp(dt): 66 | return dt.strftime('%Y-%m-%dT%H:%M:%S%z') 67 | 68 | 69 | def check_for_updates(): 70 | last_news_time = xml_timestamp_to_datetime(GlobalConfig['LAST_NEWS']) 71 | # last_news_time = xml_timestamp_to_datetime('2010-11-10T00:00:00+00:00') 72 | try: 73 | r = requests.get(GlobalConfig['NEWS_URL'], timeout=(1, None)) 74 | except requests.exceptions.ConnectionError: 75 | return [], [] 76 | except requests.exceptions.Timeout: 77 | return [], [] 78 | 79 | if r.status_code != 200: 80 | return [], [] 81 | try: 82 | feed = ET.fromstring(r.text) 83 | except ET.ParseError: 84 | print("Error parsing news feed") 85 | return [], [] 86 | 87 | parsed_uri = urlparse(GlobalConfig['NEWS_URL']) 88 | base_url = f'{parsed_uri.scheme}://{parsed_uri.netloc}' 89 | 90 | xml_ns = {'atom': 'http://www.w3.org/2005/Atom'} 91 | 92 | news_list = [] 93 | newest_time = last_news_time 94 | for entry in feed.findall('atom:entry', xml_ns): 95 | link = entry.find('atom:link', xml_ns).attrib['href'] 96 | # skip the index 97 | if link == BLOG_INDEX: 98 | continue 99 | 100 | news_time = xml_timestamp_to_datetime(entry.find('atom:updated', xml_ns).text) 101 | if news_time > last_news_time: 102 | news_list.append({'date': news_time.strftime('%Y-%m-%d'), 103 | 'link': base_url + link, 104 | 'title': entry.find('atom:title', xml_ns).text, 105 | 'excerpt': entry.find('atom:summary', xml_ns).text}) 106 | if news_time > newest_time: 107 | newest_time = news_time 108 | 109 | GlobalConfig['LAST_NEWS'] = datetime_to_xml_timestamp(newest_time) 110 | news_list.sort(key=lambda x: x['date'], reverse=True) 111 | save_config() 112 | return news_list, base_url + BLOG_ADDRESS 113 | 114 | 115 | class NewsChecker(QObject): 116 | news_ready = pyqtSignal(list, str) 117 | 118 | def __init__(self, *args, **kwargs): 119 | QObject.__init__(self, *args, **kwargs) 120 | 121 | @separate_thread_decorator 122 | def check_news(self): 123 | news_list, index_address = check_for_updates() 124 | if news_list: 125 | self.news_ready.emit(news_list, index_address) 126 | 127 | 128 | def show_news(): 129 | news_list, index_address = check_for_updates() 130 | if news_list: 131 | d = WhatsNewDialog(news_list, index_address) 132 | d.exec() 133 | 134 | 135 | def main(): 136 | import sys 137 | from PyQt5.QtWidgets import QApplication 138 | app = QApplication(sys.argv) 139 | app.setQuitOnLastWindowClosed(True) 140 | show_news() 141 | -------------------------------------------------------------------------------- /src/dafne/ui/BatchCalcTransforms.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | 4 | # Copyright (c) 2021 Dafne-Imaging Team 5 | # 6 | # This program is free software: you can redistribute it and/or modify 7 | # it under the terms of the GNU General Public License as published by 8 | # the Free Software Foundation, either version 3 of the License, or 9 | # (at your option) any later version. 10 | # 11 | # This program is distributed in the hope that it will be useful, 12 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 13 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 | # GNU General Public License for more details. 15 | # 16 | # You should have received a copy of the GNU General Public License 17 | # along with this program. If not, see . 18 | 19 | from PyQt5.QtWidgets import QWidget, QMainWindow, QFileDialog, QMessageBox, QApplication 20 | from PyQt5.QtCore import pyqtSlot, pyqtSignal 21 | from .CalcTransformsUI import Ui_CalcTransformsUI 22 | import os 23 | import numpy as np 24 | from ..utils.RegistrationManager import RegistrationManager 25 | from dicomUtils.misc import dosma_volume_from_path, get_nifti_orientation 26 | from ..utils.ThreadHelpers import separate_thread_decorator 27 | from ..config import GlobalConfig 28 | import sys 29 | 30 | 31 | class CalcTransformWindow(QWidget, Ui_CalcTransformsUI): 32 | update_progress = pyqtSignal(int) 33 | success = pyqtSignal() 34 | 35 | def __init__(self): 36 | QWidget.__init__(self) 37 | self.setupUi(self) 38 | self.setWindowTitle("Dicom transform calculator") 39 | self.registrationManager = None 40 | self.update_progress.connect(self.set_progress) 41 | self.choose_Button.clicked.connect(self.load_data) 42 | self.calculate_button.clicked.connect(self.calculate) 43 | self.success.connect(self.show_success_box) 44 | self.data = None 45 | self.basepath = '' 46 | self.basename = '' 47 | 48 | @pyqtSlot() 49 | def load_data(self): 50 | if GlobalConfig['ENABLE_NIFTI']: 51 | filter = 'Image files (*.dcm *.ima *.nii *.nii.gz);; Dicom files (*.dcm *.ima);;Nifti files (*.nii *.nii.gz);;All files (*.*)' 52 | else: 53 | filter = 'Dicom files (*.dcm *.ima);;All files (*.*)' 54 | 55 | dataFile, _ = QFileDialog.getOpenFileName(self, caption='Select dataset to import', 56 | filter=filter) 57 | 58 | path = os.path.abspath(dataFile) 59 | print(path) 60 | _, ext = os.path.splitext(path) 61 | dataset_name = os.path.basename(path) 62 | 63 | containing_dir = os.path.dirname(path) 64 | 65 | if ext.lower() not in ['.nii', '.gz']: 66 | path = containing_dir 67 | 68 | 69 | medical_volume = None 70 | basename = '' 71 | try: 72 | medical_volume, affine_valid, title, self.basepath, self.basename = dosma_volume_from_path(path, self, sort=GlobalConfig['DICOM_SORT']) 73 | self.data = medical_volume.volume 74 | except: 75 | pass 76 | print("Basepath", self.basepath) 77 | 78 | if self.data is None: 79 | self.progressBar.setValue(0) 80 | self.progressBar.setEnabled(False) 81 | self.calculate_button.setEnabled(False) 82 | QMessageBox.warning(self, 'Warning', 'Invalid dataset!') 83 | return 84 | 85 | self.data = medical_volume 86 | 87 | self.progressBar.setEnabled(True) 88 | self.registrationManager = None 89 | self.location_Text.setText(containing_dir if not basename else basename) 90 | self.calculate_button.setEnabled(True) 91 | 92 | @pyqtSlot(int) 93 | def set_progress(self, value): 94 | self.progressBar.setValue(value) 95 | 96 | @pyqtSlot() 97 | def show_success_box(self): 98 | QMessageBox.information(self, 'Done', 'Done!') 99 | 100 | @pyqtSlot() 101 | @separate_thread_decorator 102 | def calculate(self): 103 | self.choose_Button.setEnabled(False) 104 | self.calculate_button.setEnabled(False) 105 | 106 | if self.axial_radio.isChecked(): 107 | self.data = self.data.reformat(get_nifti_orientation('axial')) 108 | elif self.sagittal_radio.isChecked(): 109 | self.data = self.data.reformat(get_nifti_orientation('sagittal')) 110 | elif self.coronal_radio.isChecked(): 111 | self.data = self.data.reformat(get_nifti_orientation('coronal')) 112 | 113 | 114 | data = list(np.transpose(self.data.volume, [2, 0, 1])) 115 | self.progressBar.setMaximum(len(data)) 116 | 117 | self.registrationManager = RegistrationManager(data, None, os.getcwd(), 118 | GlobalConfig['TEMP_DIR']) 119 | self.registrationManager.set_standard_transforms_name(self.basepath, self.basename) 120 | 121 | self.registrationManager.calc_transforms(lambda value: self.update_progress.emit(value)) 122 | self.choose_Button.setEnabled(True) 123 | self.calculate_button.setEnabled(False) 124 | self.update_progress.emit(0) 125 | self.success.emit() 126 | 127 | 128 | def run(): 129 | app = QApplication(sys.argv) 130 | window = QMainWindow() 131 | widget = CalcTransformWindow() 132 | window.setCentralWidget(widget) 133 | window.setWindowTitle("Dicom transform calculator") 134 | window.show() 135 | sys.exit(app.exec_()) 136 | -------------------------------------------------------------------------------- /src/dafne/ui/BrushPatches.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021 Dafne-Imaging Team 2 | # 3 | # This program is free software: you can redistribute it and/or modify 4 | # it under the terms of the GNU General Public License as published by 5 | # the Free Software Foundation, either version 3 of the License, or 6 | # (at your option) any later version. 7 | # 8 | # This program is distributed in the hope that it will be useful, 9 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 10 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 11 | # GNU General Public License for more details. 12 | # 13 | # You should have received a copy of the GNU General Public License 14 | # along with this program. If not, see . 15 | 16 | from matplotlib.patches import Polygon, Rectangle 17 | import numpy as np 18 | import math 19 | from scipy.ndimage import shift 20 | 21 | 22 | class SquareBrush(Rectangle): 23 | 24 | def __init__(self, *args, **kwargs): 25 | Rectangle.__init__(self, *args, **kwargs) 26 | 27 | def to_mask(self, shape): 28 | """ 29 | convert the brush to a binary mask of the size "shape" 30 | """ 31 | mask = np.zeros(shape, dtype=np.uint8) 32 | 33 | xy = self.get_xy() 34 | h = self.get_height() 35 | w = self.get_width() 36 | # x and y are inverted 37 | x0 = int(np.round(xy[1] + 0.5)) 38 | y0 = int(np.round(xy[0] + 0.5)) 39 | x1 = int(x0 + np.round(h)) 40 | y1 = int(y0 + np.round(w)) 41 | 42 | mask[x0:x1, y0:y1] = 1 43 | return mask 44 | 45 | 46 | class PixelatedCircleBrush(Polygon): 47 | 48 | def __init__(self, center, radius, **kwargs): 49 | self.point_array = None 50 | self.kwargs = kwargs 51 | self.center = None 52 | self.radius = None 53 | self.base_mask = None 54 | Polygon.__init__(self, np.array([[0,0],[1,1]]), **kwargs) 55 | self.center = np.array(center).ravel() # make sure it's a row vector 56 | self.set_radius(radius) 57 | 58 | def get_center(self): 59 | return self.center 60 | 61 | def get_radius(self): 62 | return self.radius 63 | 64 | def set_center(self, center): 65 | self.center = np.array(center).ravel() # make sure it's a row vector 66 | self._recalculate_xy() 67 | 68 | def set_radius(self, radius): 69 | if radius != self.radius: 70 | self.radius = radius 71 | self._recalculate_vertices() 72 | self._recalculate_mask() 73 | self._recalculate_xy() 74 | 75 | def to_mask(self, shape): 76 | mask = np.zeros(shape, dtype=np.uint8) 77 | mask[0:self.base_mask.shape[0], 0:self.base_mask.shape[1]] = self.base_mask 78 | mask = shift(mask, (self.center[1] - self.radius, self.center[0] - self.radius), order=0, prefilter=False) 79 | return mask 80 | 81 | def _recalculate_xy(self): 82 | xy = self.point_array + self.center 83 | self.set_xy(xy) 84 | 85 | def _recalculate_vertices(self): 86 | if self.radius == 0: 87 | self.point_array = np.array([[0,0],[1,0],[1,1],[0,1]]) - 0.5 88 | return 89 | 90 | radius = self.radius + 0.5 91 | 92 | # midpoint circle algorithm 93 | x = radius 94 | y = 0 95 | P = 1 - radius 96 | 97 | octant_point_array = [(x, y-0.5)] 98 | 99 | 100 | while x > y: 101 | 102 | y += 1 103 | 104 | # Mid-point inside or on the perimeter 105 | if P <= 0: 106 | P = P + 2 * y + 1 107 | 108 | # Mid-point outside the perimeter 109 | else: 110 | octant_point_array.append((x, y-0.5)) 111 | x -= 1 112 | octant_point_array.append((x, y-0.5)) 113 | P = P + 2 * y - 2 * x + 1 114 | 115 | if x < y: 116 | break 117 | 118 | # assemble the octants 119 | quarter_point_array = octant_point_array[:] 120 | quarter_point_array.extend([(y,x) for x,y in octant_point_array[::-1]]) 121 | point_array = quarter_point_array[:] 122 | point_array.extend([(-x,y) for x,y in quarter_point_array[::-1]]) 123 | point_array.extend([(-x,-y) for x,y in quarter_point_array]) 124 | point_array.extend(([(x,-y) for x,y in quarter_point_array[::-1]])) 125 | 126 | self.point_array = np.array(point_array) 127 | 128 | def _recalculate_mask(self): 129 | if self.radius == 0: 130 | self.base_mask = np.ones((1, 1), dtype=np.uint8) 131 | return 132 | 133 | radius = self.radius 134 | 135 | # midpoint circle algorithm 136 | x = radius 137 | y = 0 138 | P = 1 - radius 139 | 140 | self.base_mask = np.zeros((self.radius * 2 + 1, self.radius * 2 + 1), dtype=np.uint8) 141 | 142 | def fill_mask_line(x, y): 143 | r = radius 144 | self.base_mask[int(r - x): int(r + x + 1), int(r + y)] = 1 145 | self.base_mask[int(r - x):int(r + x + 1), int(r - y)] = 1 146 | 147 | fill_mask_line(x, y) 148 | fill_mask_line(y, x) 149 | 150 | while x > y: 151 | y += 1 152 | 153 | # Mid-point inside or on the perimeter 154 | if P <= 0: 155 | P = P + 2 * y + 1 156 | 157 | # Mid-point outside the perimeter 158 | else: 159 | x -= 1 160 | P = P + 2 * y - 2 * x + 1 161 | 162 | fill_mask_line(x, y) 163 | fill_mask_line(y, x) 164 | 165 | if (x < y): 166 | break -------------------------------------------------------------------------------- /src/dafne/utils/mask_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021 Dafne-Imaging Team 2 | # 3 | # This program is free software: you can redistribute it and/or modify 4 | # it under the terms of the GNU General Public License as published by 5 | # the Free Software Foundation, either version 3 of the License, or 6 | # (at your option) any later version. 7 | # 8 | # This program is distributed in the hope that it will be useful, 9 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 10 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 11 | # GNU General Public License for more details. 12 | # 13 | # You should have received a copy of the GNU General Public License 14 | # along with this program. If not, see . 15 | 16 | import os 17 | import numpy as np 18 | from voxel import DicomWriter, NiftiWriter, MedicalVolume 19 | from matplotlib import pyplot as plt 20 | from scipy import ndimage 21 | 22 | 23 | def save_dicom_masks(base_path: str, mask_dict: dict, affine, dicom_headers: list): 24 | dicom_writer = DicomWriter(num_workers=0) 25 | for name, mask in mask_dict.items(): 26 | n = name.strip() 27 | if n == '': n = '_' 28 | dicom_path = os.path.join(base_path, n) 29 | try: 30 | os.makedirs(dicom_path) 31 | except OSError: 32 | pass 33 | medical_volume = MedicalVolume(mask.astype(np.uint16), affine, dicom_headers) 34 | dicom_writer.save(medical_volume, dicom_path, fname_fmt='image%04d.dcm') 35 | 36 | 37 | def save_npy_masks(base_path, mask_dict): 38 | for name, mask in mask_dict.items(): 39 | npy_name = os.path.join(base_path, name + '.npy') 40 | np.save(npy_name, mask) 41 | 42 | 43 | def save_npz_masks(filename, mask_dict): 44 | np.savez_compressed(filename, **mask_dict) 45 | 46 | 47 | def save_nifti_masks(base_path, mask_dict, affine): 48 | nifti_writer = NiftiWriter() 49 | for name, mask in mask_dict.items(): 50 | nifti_name = os.path.join(base_path, name + '.nii.gz') 51 | medical_volume = MedicalVolume(mask.astype(np.uint16), affine) 52 | nifti_writer.save(medical_volume, nifti_name) 53 | 54 | 55 | def make_accumulated_mask(mask_dict): 56 | accumulated_mask = None 57 | current_value = 1 58 | name_list = [] 59 | for name, mask in mask_dict.items(): 60 | name_list.append(name) 61 | if accumulated_mask is None: 62 | accumulated_mask = 1*(mask>0) 63 | else: 64 | accumulated_mask += current_value * (mask>0) 65 | current_value += 1 66 | return accumulated_mask, name_list 67 | 68 | 69 | def write_legend(filename, name_list): 70 | with open(filename, 'w') as f: 71 | f.write('Value,Label\n') 72 | for index, name in enumerate(name_list): 73 | f.write(f'{index+1},{name}\n') 74 | 75 | 76 | def write_itksnap_legend(filename, name_list): 77 | PREAMBLE = """################################################ 78 | # ITK-SnAP Label Description File 79 | # File format: 80 | # IDX -R- -G- -B- -A-- VIS MSH LABEL 81 | # Fields: 82 | # IDX: Zero-based index 83 | # -R-: Red color component (0..255) 84 | # -G-: Green color component (0..255) 85 | # -B-: Blue color component (0..255) 86 | # -A-: Label transparency (0.00 .. 1.00) 87 | # VIS: Label visibility (0 or 1) 88 | # MSH: Label mesh visibility (0 or 1) 89 | # LABEL: Label description 90 | ################################################\n""" 91 | 92 | line_format = '{id:>5d}{red:>6d}{green:>5d}{blue:>5d}{alpha:>9.2g}{vis:>3d}{mesh:>3d} "{label}"\n' 93 | 94 | cmap = plt.get_cmap('hsv') 95 | nLabels = len(name_list) 96 | nColors = cmap.N 97 | 98 | step = int(nColors/nLabels) 99 | 100 | with open(filename, 'w') as f: 101 | f.write(PREAMBLE) 102 | f.write(line_format.format(id=0, red=0, green=0, blue=0, alpha=0, vis=0, mesh=0, label='Clear Label')) 103 | for index, name in enumerate(name_list): 104 | color = cmap(index*step) 105 | f.write(line_format.format(id=index+1, red=int(color[0]*255), green=int(color[1]*255), blue=int(color[2]*255), 106 | alpha=1, vis=1, mesh=1, label=name)) 107 | 108 | 109 | def save_single_nifti(filename, mask_dict, affine): 110 | nifti_writer = NiftiWriter() 111 | accumulated_mask, name_list = make_accumulated_mask(mask_dict) 112 | legend_name = filename + '.csv' 113 | snap_legend_name = filename + '_itk-snap.txt' 114 | medical_volume = MedicalVolume(accumulated_mask.astype(np.uint16), affine) 115 | nifti_writer.save(medical_volume, filename) 116 | write_legend(legend_name, name_list) 117 | write_itksnap_legend(snap_legend_name, name_list) 118 | 119 | 120 | def save_single_dicom_dataset(base_path, mask_dict, affine, dicom_headers: list): 121 | dicom_writer = DicomWriter(num_workers=0) 122 | accumulated_mask, name_list = make_accumulated_mask(mask_dict) 123 | medical_volume = MedicalVolume(accumulated_mask.astype(np.uint16), affine, dicom_headers) 124 | try: 125 | os.makedirs(base_path) 126 | except OSError: 127 | pass 128 | dicom_writer.save(medical_volume, base_path, fname_fmt='image%04d.dcm') 129 | legend_name = os.path.join(base_path, 'legend.csv') 130 | snap_legend_name = os.path.join(base_path, 'legend_itk-snap.txt') 131 | write_legend(legend_name, name_list) 132 | write_itksnap_legend(snap_legend_name, name_list) 133 | 134 | 135 | def distance_mask(mask): 136 | mask = mask.astype(np.uint8) 137 | internal_distance = ndimage.distance_transform_edt(mask) 138 | external_distance = ndimage.distance_transform_edt(1-mask) 139 | return internal_distance - external_distance -------------------------------------------------------------------------------- /src/dafne/ui/ValidateUI.ui: -------------------------------------------------------------------------------- 1 | 2 | 3 | ValidateUI 4 | 5 | 6 | 7 | 0 8 | 0 9 | 724 10 | 348 11 | 12 | 13 | 14 | Form 15 | 16 | 17 | 18 | 19 | 20 | Configure... 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 0 31 | 0 32 | 33 | 34 | 35 | Data Location: 36 | 37 | 38 | 39 | 40 | 41 | 42 | false 43 | 44 | 45 | 46 | 47 | 48 | 49 | Choose... 50 | 51 | 52 | 53 | 54 | 55 | 56 | 57 | 58 | 59 | 60 | 61 | 0 62 | 0 63 | 64 | 65 | 66 | Masks Location: 67 | 68 | 69 | 70 | 71 | 72 | 73 | false 74 | 75 | 76 | 77 | 78 | 79 | 80 | false 81 | 82 | 83 | Choose folder... 84 | 85 | 86 | 87 | 88 | 89 | 90 | false 91 | 92 | 93 | Choose ROI... 94 | 95 | 96 | 97 | 98 | 99 | 100 | 101 | 102 | 103 | 104 | Current eval: 105 | 106 | 107 | 108 | 109 | 110 | 111 | false 112 | 113 | 114 | 115 | 1 116 | 0 117 | 118 | 119 | 120 | 0 121 | 122 | 123 | 124 | 125 | 126 | 127 | 128 | 129 | 130 | 131 | Overall progress: 132 | 133 | 134 | 135 | 136 | 137 | 138 | false 139 | 140 | 141 | 142 | 1 143 | 0 144 | 145 | 146 | 147 | 0 148 | 149 | 150 | 151 | 152 | 153 | 154 | 155 | 156 | 157 | 158 | 159 | 160 | 161 | 162 | 163 | Qt::Vertical 164 | 165 | 166 | QSizePolicy::Expanding 167 | 168 | 169 | 170 | 20 171 | 45 172 | 173 | 174 | 175 | 176 | 177 | 178 | 179 | false 180 | 181 | 182 | Evaluate 183 | 184 | 185 | 186 | 187 | 188 | 189 | 190 | 191 | -------------------------------------------------------------------------------- /src/dafne/utils/polyToMask.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | # Copyright (c) 2021 Dafne-Imaging Team 4 | # 5 | # This program is free software: you can redistribute it and/or modify 6 | # it under the terms of the GNU General Public License as published by 7 | # the Free Software Foundation, either version 3 of the License, or 8 | # (at your option) any later version. 9 | # 10 | # This program is distributed in the hope that it will be useful, 11 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 12 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 13 | # GNU General Public License for more details. 14 | # 15 | # You should have received a copy of the GNU General Public License 16 | # along with this program. If not, see . 17 | 18 | import numpy as np 19 | from collections import deque 20 | 21 | # recursively flood a mask in place 22 | def flood(seedX, seedY, mask): 23 | # check if we are on countour or out of bounds 24 | sz = mask.shape 25 | q = deque() 26 | if (mask[seedX][seedY] == 1): 27 | return 28 | q.append((seedX, seedY)) 29 | 30 | # function to determine if a point is out of bound 31 | isOutOfBound = lambda p: p[0] < 0 or p[1] < 0 or p[0] >= sz[0] or p[1] >= sz[1] 32 | while q: # iterate until empty 33 | currentNode = q.popleft() 34 | if isOutOfBound(currentNode) or mask[currentNode[0]][currentNode[1]] == 1: 35 | continue 36 | # travel right (east) 37 | for e in range(currentNode[0], sz[0]): 38 | # exit if we reached a countour 39 | if mask[e][currentNode[1]] == 1: 40 | break 41 | # change color 42 | mask[e][currentNode[1]] = 1 43 | # add north and south to the queue 44 | q.append((e, currentNode[1]-1)) 45 | q.append((e, currentNode[1]+1)) 46 | # travel left (west) 47 | for w in range(currentNode[0]-1, -1, -1): 48 | # exit if we reached a countour 49 | if mask[w][currentNode[1]] == 1: 50 | break 51 | # change color 52 | mask[w][currentNode[1]] = 1 53 | # add north and south to the queue 54 | q.append((w, currentNode[1]-1)) 55 | q.append((w, currentNode[1]+1)) 56 | 57 | 58 | # old recursive implementation that crashed python :( 59 | # 60 | # if seedX < 0 or seedX >= sz[0] or seedY < 0 or seedY >= sz[1] or mask[seedX][seedY] == 1: 61 | # return 62 | # mask[seedX][seedY] = 1 63 | # # recursively flood 64 | # flood(seedX+1, seedY, mask) 65 | # flood(seedX-1, seedY, mask) 66 | # flood(seedX, seedY+1, mask) 67 | # flood(seedX, seedY-1, mask) 68 | 69 | def intround(x): 70 | return int(round(x)) 71 | 72 | #converts this spline to mask of a defined size. Note! At the moment this will not work properly if the contour touches the edges! 73 | def polyToMask(points, size): 74 | size = (size[1], size[0]) # x is rows and y is columns 75 | mask = np.zeros((size[0]+1, size[1])) # create a mask that is 1 larger than needed 76 | for i in range(0, len(points)): 77 | curPoint = points[i] 78 | try: 79 | nextPoint = points[i+1] 80 | except IndexError: 81 | nextPoint = points[0] # close the polygon 82 | 83 | #print curPoint, nextPoint 84 | if (curPoint[0] == nextPoint[0]) and (curPoint[1] == nextPoint[1]): 85 | continue 86 | if curPoint[0] < 0: curPoint[0] = 0 87 | if curPoint[1] < 0: curPoint[1] = 0 88 | if nextPoint[0] < 0: nextPoint[0] = 0 89 | if nextPoint[1] < 0: nextPoint[1] = 0 90 | 91 | if curPoint[0] > size[0]-1: curPoint[0] = size[0]-1 92 | if curPoint[1] > size[1]-1: curPoint[1] = size[1]-1 93 | if nextPoint[0] > size[0]-1: nextPoint[0] = size[0]-1 94 | if nextPoint[1] > size[1]-1: nextPoint[1] = size[1]-1 95 | 96 | mask[intround(curPoint[0])][intround(curPoint[1])] = 1 # set initial point to 1 97 | # special case for vertical line 98 | if nextPoint[0] == curPoint[0]: 99 | # order start and end 100 | if curPoint[1] < nextPoint[1]: 101 | startY = curPoint[1] 102 | endY = nextPoint[1] 103 | else: 104 | startY = nextPoint[1] 105 | endY = curPoint[1] 106 | for y in range(intround(startY), intround(endY+1)): # how stupid is this? 107 | mask[intround(curPoint[0])][intround(y)] = 1 108 | else: 109 | # not a vertical line 110 | slope = (nextPoint[1]-curPoint[1])/(nextPoint[0]-curPoint[0]) 111 | if abs(slope) < 1: 112 | #travel along x because line is "flat" 113 | if curPoint[0] < nextPoint[0]: 114 | startX = curPoint[0] 115 | endX = nextPoint[0] 116 | startY = curPoint[1] 117 | else: 118 | startX = nextPoint[0] 119 | endX = curPoint[0] 120 | startY = nextPoint[1] 121 | for x in range(intround(startX), intround(endX+1)): 122 | nextY = startY + (x-startX)*slope 123 | mask[intround(x)][intround(nextY)] = 1 124 | else: 125 | # travel along y because line is "steep" 126 | if curPoint[1] < nextPoint[1]: 127 | startY = curPoint[1] 128 | endY = nextPoint[1] 129 | startX = curPoint[0] 130 | else: 131 | startY = nextPoint[1] 132 | endY = curPoint[1] 133 | startX = nextPoint[0] 134 | for y in range(intround(startY), intround(endY+1)): 135 | nextX = startX + (y-startY)/slope 136 | mask[intround(nextX)][intround(y)] = 1 137 | 138 | contour = np.copy(mask) # save this for later 139 | flood(mask.shape[0]-1, 0, mask) # flood the outside of the contour, by starting on a point that is for sure outside 140 | # now the mask is actually the inverted mask and the contour 141 | # invert the mask and add the contour back 142 | mask = np.logical_or(contour, np.logical_not(mask)) 143 | #mask = np.logical_not(mask) # do not add the contour! Change this for production, this is only for the abstract 144 | return np.transpose(mask[0:len(mask)-1]) # return the mask with the original desired size - transposed because x is rows and y is columns 145 | -------------------------------------------------------------------------------- /src/dafne/MedSAM/segment_anything/utils/onnx.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # All rights reserved. 4 | 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | import torch 9 | import torch.nn as nn 10 | from torch.nn import functional as F 11 | 12 | from typing import Tuple 13 | 14 | from ..modeling import Sam 15 | from .amg import calculate_stability_score 16 | 17 | 18 | class SamOnnxModel(nn.Module): 19 | """ 20 | This model should not be called directly, but is used in ONNX export. 21 | It combines the prompt encoder, mask decoder, and mask postprocessing of Sam, 22 | with some functions modified to enable model tracing. Also supports extra 23 | options controlling what information. See the ONNX export script for details. 24 | """ 25 | 26 | def __init__( 27 | self, 28 | model: Sam, 29 | return_single_mask: bool, 30 | use_stability_score: bool = False, 31 | return_extra_metrics: bool = False, 32 | ) -> None: 33 | super().__init__() 34 | self.mask_decoder = model.mask_decoder 35 | self.model = model 36 | self.img_size = model.image_encoder.img_size 37 | self.return_single_mask = return_single_mask 38 | self.use_stability_score = use_stability_score 39 | self.stability_score_offset = 1.0 40 | self.return_extra_metrics = return_extra_metrics 41 | 42 | @staticmethod 43 | def resize_longest_image_size( 44 | input_image_size: torch.Tensor, longest_side: int 45 | ) -> torch.Tensor: 46 | input_image_size = input_image_size.to(torch.float32) 47 | scale = longest_side / torch.max(input_image_size) 48 | transformed_size = scale * input_image_size 49 | transformed_size = torch.floor(transformed_size + 0.5).to(torch.int64) 50 | return transformed_size 51 | 52 | def _embed_points( 53 | self, point_coords: torch.Tensor, point_labels: torch.Tensor 54 | ) -> torch.Tensor: 55 | point_coords = point_coords + 0.5 56 | point_coords = point_coords / self.img_size 57 | point_embedding = self.model.prompt_encoder.pe_layer._pe_encoding(point_coords) 58 | point_labels = point_labels.unsqueeze(-1).expand_as(point_embedding) 59 | 60 | point_embedding = point_embedding * (point_labels != -1) 61 | point_embedding = ( 62 | point_embedding 63 | + self.model.prompt_encoder.not_a_point_embed.weight * (point_labels == -1) 64 | ) 65 | 66 | for i in range(self.model.prompt_encoder.num_point_embeddings): 67 | point_embedding = ( 68 | point_embedding 69 | + self.model.prompt_encoder.point_embeddings[i].weight 70 | * (point_labels == i) 71 | ) 72 | 73 | return point_embedding 74 | 75 | def _embed_masks( 76 | self, input_mask: torch.Tensor, has_mask_input: torch.Tensor 77 | ) -> torch.Tensor: 78 | mask_embedding = has_mask_input * self.model.prompt_encoder.mask_downscaling( 79 | input_mask 80 | ) 81 | mask_embedding = mask_embedding + ( 82 | 1 - has_mask_input 83 | ) * self.model.prompt_encoder.no_mask_embed.weight.reshape(1, -1, 1, 1) 84 | return mask_embedding 85 | 86 | def mask_postprocessing( 87 | self, masks: torch.Tensor, orig_im_size: torch.Tensor 88 | ) -> torch.Tensor: 89 | masks = F.interpolate( 90 | masks, 91 | size=(self.img_size, self.img_size), 92 | mode="bilinear", 93 | align_corners=False, 94 | ) 95 | 96 | prepadded_size = self.resize_longest_image_size(orig_im_size, self.img_size).to( 97 | torch.int64 98 | ) 99 | masks = masks[..., : prepadded_size[0], : prepadded_size[1]] # type: ignore 100 | 101 | orig_im_size = orig_im_size.to(torch.int64) 102 | h, w = orig_im_size[0], orig_im_size[1] 103 | masks = F.interpolate(masks, size=(h, w), mode="bilinear", align_corners=False) 104 | return masks 105 | 106 | def select_masks( 107 | self, masks: torch.Tensor, iou_preds: torch.Tensor, num_points: int 108 | ) -> Tuple[torch.Tensor, torch.Tensor]: 109 | # Determine if we should return the multiclick mask or not from the number of points. 110 | # The reweighting is used to avoid control flow. 111 | score_reweight = torch.tensor( 112 | [[1000] + [0] * (self.model.mask_decoder.num_mask_tokens - 1)] 113 | ).to(iou_preds.device) 114 | score = iou_preds + (num_points - 2.5) * score_reweight 115 | best_idx = torch.argmax(score, dim=1) 116 | masks = masks[torch.arange(masks.shape[0]), best_idx, :, :].unsqueeze(1) 117 | iou_preds = iou_preds[torch.arange(masks.shape[0]), best_idx].unsqueeze(1) 118 | 119 | return masks, iou_preds 120 | 121 | @torch.no_grad() 122 | def forward( 123 | self, 124 | image_embeddings: torch.Tensor, 125 | point_coords: torch.Tensor, 126 | point_labels: torch.Tensor, 127 | mask_input: torch.Tensor, 128 | has_mask_input: torch.Tensor, 129 | orig_im_size: torch.Tensor, 130 | ): 131 | sparse_embedding = self._embed_points(point_coords, point_labels) 132 | dense_embedding = self._embed_masks(mask_input, has_mask_input) 133 | 134 | masks, scores = self.model.mask_decoder.predict_masks( 135 | image_embeddings=image_embeddings, 136 | image_pe=self.model.prompt_encoder.get_dense_pe(), 137 | sparse_prompt_embeddings=sparse_embedding, 138 | dense_prompt_embeddings=dense_embedding, 139 | ) 140 | 141 | if self.use_stability_score: 142 | scores = calculate_stability_score( 143 | masks, self.model.mask_threshold, self.stability_score_offset 144 | ) 145 | 146 | if self.return_single_mask: 147 | masks, scores = self.select_masks(masks, scores, point_coords.shape[1]) 148 | 149 | upscaled_masks = self.mask_postprocessing(masks, orig_im_size) 150 | 151 | if self.return_extra_metrics: 152 | stability_scores = calculate_stability_score( 153 | upscaled_masks, self.model.mask_threshold, self.stability_score_offset 154 | ) 155 | areas = (upscaled_masks > self.model.mask_threshold).sum(-1).sum(-1) 156 | return upscaled_masks, scores, stability_scores, areas, masks 157 | 158 | return upscaled_masks, scores, masks 159 | -------------------------------------------------------------------------------- /src/dafne/ui/ValidateUI.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | # Form implementation generated from reading ui file 'ValidateUI.ui' 4 | # 5 | # Created by: PyQt5 UI code generator 5.15.6 6 | # 7 | # WARNING: Any manual changes made to this file will be lost when pyuic5 is 8 | # run again. Do not edit this file unless you know what you are doing. 9 | 10 | 11 | from PyQt5 import QtCore, QtGui, QtWidgets 12 | 13 | 14 | class Ui_ValidateUI(object): 15 | def setupUi(self, ValidateUI): 16 | ValidateUI.setObjectName("ValidateUI") 17 | ValidateUI.resize(724, 348) 18 | self.verticalLayout = QtWidgets.QVBoxLayout(ValidateUI) 19 | self.verticalLayout.setObjectName("verticalLayout") 20 | self.configure_button = QtWidgets.QPushButton(ValidateUI) 21 | self.configure_button.setObjectName("configure_button") 22 | self.verticalLayout.addWidget(self.configure_button) 23 | self.horizontalLayout = QtWidgets.QHBoxLayout() 24 | self.horizontalLayout.setObjectName("horizontalLayout") 25 | self.label = QtWidgets.QLabel(ValidateUI) 26 | sizePolicy = QtWidgets.QSizePolicy(QtWidgets.QSizePolicy.Preferred, QtWidgets.QSizePolicy.Fixed) 27 | sizePolicy.setHorizontalStretch(0) 28 | sizePolicy.setVerticalStretch(0) 29 | sizePolicy.setHeightForWidth(self.label.sizePolicy().hasHeightForWidth()) 30 | self.label.setSizePolicy(sizePolicy) 31 | self.label.setObjectName("label") 32 | self.horizontalLayout.addWidget(self.label) 33 | self.data_location_Text = QtWidgets.QLineEdit(ValidateUI) 34 | self.data_location_Text.setEnabled(False) 35 | self.data_location_Text.setObjectName("data_location_Text") 36 | self.horizontalLayout.addWidget(self.data_location_Text) 37 | self.data_choose_Button = QtWidgets.QPushButton(ValidateUI) 38 | self.data_choose_Button.setObjectName("data_choose_Button") 39 | self.horizontalLayout.addWidget(self.data_choose_Button) 40 | self.verticalLayout.addLayout(self.horizontalLayout) 41 | self.horizontalLayout_2 = QtWidgets.QHBoxLayout() 42 | self.horizontalLayout_2.setObjectName("horizontalLayout_2") 43 | self.label_2 = QtWidgets.QLabel(ValidateUI) 44 | sizePolicy = QtWidgets.QSizePolicy(QtWidgets.QSizePolicy.Preferred, QtWidgets.QSizePolicy.Fixed) 45 | sizePolicy.setHorizontalStretch(0) 46 | sizePolicy.setVerticalStretch(0) 47 | sizePolicy.setHeightForWidth(self.label_2.sizePolicy().hasHeightForWidth()) 48 | self.label_2.setSizePolicy(sizePolicy) 49 | self.label_2.setObjectName("label_2") 50 | self.horizontalLayout_2.addWidget(self.label_2) 51 | self.mask_location_Text = QtWidgets.QLineEdit(ValidateUI) 52 | self.mask_location_Text.setEnabled(False) 53 | self.mask_location_Text.setObjectName("mask_location_Text") 54 | self.horizontalLayout_2.addWidget(self.mask_location_Text) 55 | self.mask_choose_Button = QtWidgets.QPushButton(ValidateUI) 56 | self.mask_choose_Button.setEnabled(False) 57 | self.mask_choose_Button.setObjectName("mask_choose_Button") 58 | self.horizontalLayout_2.addWidget(self.mask_choose_Button) 59 | self.roi_choose_Button = QtWidgets.QPushButton(ValidateUI) 60 | self.roi_choose_Button.setEnabled(False) 61 | self.roi_choose_Button.setObjectName("roi_choose_Button") 62 | self.horizontalLayout_2.addWidget(self.roi_choose_Button) 63 | self.verticalLayout.addLayout(self.horizontalLayout_2) 64 | self.horizontalLayout_3 = QtWidgets.QHBoxLayout() 65 | self.horizontalLayout_3.setObjectName("horizontalLayout_3") 66 | self.currentProgress_Label = QtWidgets.QLabel(ValidateUI) 67 | self.currentProgress_Label.setObjectName("currentProgress_Label") 68 | self.horizontalLayout_3.addWidget(self.currentProgress_Label) 69 | self.slice_progressBar = QtWidgets.QProgressBar(ValidateUI) 70 | self.slice_progressBar.setEnabled(False) 71 | sizePolicy = QtWidgets.QSizePolicy(QtWidgets.QSizePolicy.Expanding, QtWidgets.QSizePolicy.Fixed) 72 | sizePolicy.setHorizontalStretch(1) 73 | sizePolicy.setVerticalStretch(0) 74 | sizePolicy.setHeightForWidth(self.slice_progressBar.sizePolicy().hasHeightForWidth()) 75 | self.slice_progressBar.setSizePolicy(sizePolicy) 76 | self.slice_progressBar.setProperty("value", 0) 77 | self.slice_progressBar.setObjectName("slice_progressBar") 78 | self.horizontalLayout_3.addWidget(self.slice_progressBar) 79 | self.verticalLayout.addLayout(self.horizontalLayout_3) 80 | self.horizontalLayout_4 = QtWidgets.QHBoxLayout() 81 | self.horizontalLayout_4.setObjectName("horizontalLayout_4") 82 | self.label_4 = QtWidgets.QLabel(ValidateUI) 83 | self.label_4.setObjectName("label_4") 84 | self.horizontalLayout_4.addWidget(self.label_4) 85 | self.overall_progressBar = QtWidgets.QProgressBar(ValidateUI) 86 | self.overall_progressBar.setEnabled(False) 87 | sizePolicy = QtWidgets.QSizePolicy(QtWidgets.QSizePolicy.Expanding, QtWidgets.QSizePolicy.Fixed) 88 | sizePolicy.setHorizontalStretch(1) 89 | sizePolicy.setVerticalStretch(0) 90 | sizePolicy.setHeightForWidth(self.overall_progressBar.sizePolicy().hasHeightForWidth()) 91 | self.overall_progressBar.setSizePolicy(sizePolicy) 92 | self.overall_progressBar.setProperty("value", 0) 93 | self.overall_progressBar.setObjectName("overall_progressBar") 94 | self.horizontalLayout_4.addWidget(self.overall_progressBar) 95 | self.verticalLayout.addLayout(self.horizontalLayout_4) 96 | self.status_Label = QtWidgets.QLabel(ValidateUI) 97 | self.status_Label.setText("") 98 | self.status_Label.setObjectName("status_Label") 99 | self.verticalLayout.addWidget(self.status_Label) 100 | spacerItem = QtWidgets.QSpacerItem(20, 45, QtWidgets.QSizePolicy.Minimum, QtWidgets.QSizePolicy.Expanding) 101 | self.verticalLayout.addItem(spacerItem) 102 | self.evaluate_Button = QtWidgets.QPushButton(ValidateUI) 103 | self.evaluate_Button.setEnabled(False) 104 | self.evaluate_Button.setObjectName("evaluate_Button") 105 | self.verticalLayout.addWidget(self.evaluate_Button) 106 | 107 | self.retranslateUi(ValidateUI) 108 | QtCore.QMetaObject.connectSlotsByName(ValidateUI) 109 | 110 | def retranslateUi(self, ValidateUI): 111 | _translate = QtCore.QCoreApplication.translate 112 | ValidateUI.setWindowTitle(_translate("ValidateUI", "Form")) 113 | self.configure_button.setText(_translate("ValidateUI", "Configure...")) 114 | self.label.setText(_translate("ValidateUI", "Data Location:")) 115 | self.data_choose_Button.setText(_translate("ValidateUI", "Choose...")) 116 | self.label_2.setText(_translate("ValidateUI", "Masks Location:")) 117 | self.mask_choose_Button.setText(_translate("ValidateUI", "Choose folder...")) 118 | self.roi_choose_Button.setText(_translate("ValidateUI", "Choose ROI...")) 119 | self.currentProgress_Label.setText(_translate("ValidateUI", "Current eval:")) 120 | self.label_4.setText(_translate("ValidateUI", "Overall progress:")) 121 | self.evaluate_Button.setText(_translate("ValidateUI", "Evaluate")) 122 | -------------------------------------------------------------------------------- /src/dafne/ui/Viewer3D.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023 Dafne-Imaging Team 2 | # Part of this code are based on "wezel": https://github.com/QIB-Sheffield/wezel/ 3 | 4 | import os 5 | import numpy as np 6 | from PyQt5.QtCore import pyqtSlot, pyqtSignal, Qt 7 | from PyQt5.QtWidgets import QWidget, QVBoxLayout, QApplication, QPushButton, QHBoxLayout, QLabel, QSlider 8 | from ..config.config import GlobalConfig 9 | 10 | os.environ["QT_API"] = "pyqt5" 11 | import pyvista as pv 12 | from pyvistaqt import QtInteractor 13 | 14 | WIDTH = 400 15 | HEIGHT = 400 16 | 17 | OPACITY_ARRAY = np.array([0, 0.2, 0.3, 0.6, 0.2, 0]) 18 | COLORMAP = 'bone' 19 | 20 | class Viewer3D(QWidget): 21 | 22 | hide_signal = pyqtSignal() 23 | 24 | def __init__(self): 25 | super().__init__() 26 | self.plotter = QtInteractor(self) 27 | self.plotter.background_color = 'black' 28 | self.plotter.add_camera_orientation_widget() 29 | self.spacing = (1.0, 1.0, 1.0) 30 | self.data = None 31 | self.anatomy = None 32 | self.actor_roi = None 33 | self.actor_anatomy = None 34 | self.anatomy_opacity = 0.2 35 | # Layout 36 | layout = QVBoxLayout() 37 | layout.setContentsMargins(0, 0, 0, 0) 38 | layout.setSpacing(0) 39 | layout.addWidget(self.plotter) 40 | 41 | fit_button = QPushButton("Fit to scene") 42 | layout.addWidget(fit_button) 43 | fit_button.clicked.connect(self.plotter.reset_camera) 44 | 45 | opacity_widget = QWidget() 46 | opacity_widget_layout = QHBoxLayout() 47 | opacity_widget.setLayout(opacity_widget_layout) 48 | opacity_widget_layout.addWidget(QLabel('Anatomy opacity:')) 49 | self.anat_opacity_slider = QSlider(Qt.Horizontal) 50 | self.anat_opacity_slider.setMinimum(0) 51 | self.anat_opacity_slider.setMaximum(100) 52 | self.anat_opacity_slider.setValue(20) 53 | self.anat_opacity_slider.valueChanged.connect(self.set_global_anat_opacity) 54 | opacity_widget_layout.addWidget(self.anat_opacity_slider) 55 | layout.addWidget(opacity_widget) 56 | 57 | 58 | self.setLayout(layout) 59 | self.setWindowTitle("3D Viewer") 60 | screen_width = QApplication.desktop().screenGeometry().width() 61 | self.setGeometry(screen_width - WIDTH, 0, WIDTH, HEIGHT) 62 | self.real_close_flag = False 63 | 64 | @pyqtSlot(list, np.ndarray) 65 | def set_spacing_and_anatomy(self, spacing, anatomy): 66 | self.spacing = spacing 67 | self.anatomy = anatomy.astype(np.uint16) 68 | self.visualize_anatomy() 69 | 70 | @pyqtSlot(int) 71 | def set_global_anat_opacity(self, value): 72 | self.anatomy_opacity = float(value)/100 73 | if self.actor_anatomy is None: 74 | return 75 | 76 | lut = pv.LookupTable(cmap=COLORMAP) 77 | lut.apply_opacity(OPACITY_ARRAY * self.anatomy_opacity) 78 | self.actor_anatomy.prop.apply_lookup_table(lut) 79 | self.plotter.render() 80 | 81 | def visualize_anatomy(self): 82 | self.plotter.remove_actor(self.actor_anatomy, render=False) 83 | if self.anatomy is None or self.spacing is None or self.anatomy_opacity == 0: 84 | self.plotter.render() 85 | return 86 | vol = pv.ImageData(dimensions=np.array(self.anatomy.shape)+1, spacing=self.spacing) 87 | vol.cell_data['values'] = self.anatomy.flatten(order='F') 88 | 89 | opacity = (OPACITY_ARRAY* self.anatomy_opacity) 90 | 91 | self.actor_anatomy = self.plotter.add_volume(vol, 92 | scalars='values', 93 | clim=[self.anatomy.min(), self.anatomy.max()], 94 | opacity=opacity, 95 | cmap=COLORMAP, 96 | show_scalar_bar=False) 97 | self.plotter.show() 98 | 99 | @pyqtSlot(list, np.ndarray) 100 | def set_spacing_and_data(self, spacing, data): 101 | """ 102 | Set the data and spacing. 103 | """ 104 | self.spacing = spacing 105 | self.data = data 106 | self.update_data() 107 | 108 | @pyqtSlot(list) 109 | def set_spacing(self, spacing): 110 | """ 111 | Set the affine transformation matrix. 112 | """ 113 | self.spacing = spacing 114 | 115 | @pyqtSlot(np.ndarray) 116 | def set_affine(self, affine): 117 | """ 118 | Set the affine transformation matrix. 119 | """ 120 | column_spacing = np.linalg.norm(affine[:3, 0]) 121 | row_spacing = np.linalg.norm(affine[:3, 1]) 122 | slice_spacing = np.linalg.norm(affine[:3, 2]) 123 | self.spacing = (column_spacing, row_spacing, slice_spacing) # mm 124 | self.data = None 125 | 126 | def update_data(self): 127 | if not self.isVisible(): 128 | return 129 | 130 | camera_position = self.plotter.camera_position 131 | #self.plotter.clear() 132 | self.plotter.remove_actor(self.actor_roi, reset_camera=False, render=False) 133 | if self.data is None or self.spacing is None or not np.any(self.data): 134 | print("No data to plot") 135 | self.plotter.render() 136 | return 137 | 138 | 139 | #grid = pv.UniformGrid(dimensions=self.data.shape, spacing=self.spacing) 140 | grid = pv.ImageData(dimensions=self.data.shape, spacing=self.spacing) 141 | surf = grid.contour([0.5], self.data.flatten(order="F"), method='marching_cubes') 142 | color = GlobalConfig['ROI_COLOR'] 143 | color = [color[0], color[1], color[2]] 144 | self.actor_roi = self.plotter.add_mesh(surf, 145 | color=color, 146 | opacity=0.8, 147 | show_edges=False, 148 | smooth_shading=True, 149 | specular=0.5, 150 | show_scalar_bar=False, 151 | render=False 152 | ) 153 | 154 | #restore camera position if it's not the default, which is too narrow 155 | if np.max(np.abs(camera_position[0])) > 1: 156 | self.plotter.camera_position = camera_position 157 | self.plotter.render() 158 | 159 | def real_close(self): 160 | self.real_close_flag = True 161 | self.close() 162 | 163 | def closeEvent(self, event): 164 | if self.real_close_flag: # if the window is closed by the user 165 | event.accept() 166 | self.hide_signal.emit() 167 | event.ignore() 168 | self.hide() 169 | 170 | @pyqtSlot(np.ndarray) 171 | def set_data(self, data): 172 | """ 173 | Set the data to be plotted. 174 | """ 175 | self.data = data 176 | self.update_data() 177 | 178 | @pyqtSlot(int, np.ndarray) 179 | def set_slice(self, slice_number, slice_data): 180 | """ 181 | Set the slice to be plotted. 182 | """ 183 | if self.data is None: 184 | return 185 | self.data[:, :, slice_number] = slice_data 186 | self.update_data() 187 | -------------------------------------------------------------------------------- /icons/dafne_icon.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 16 | 37 | 39 | 43 | 46 | 52 | 55 | 60 | 65 | 70 | 75 | 80 | 85 | 90 | 91 | 97 | 103 | 109 | 115 | 116 | 124 | 125 | 126 | -------------------------------------------------------------------------------- /src/dafne/MedSAM/segment_anything/modeling/mask_decoder.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # All rights reserved. 4 | 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | import torch 9 | from torch import nn 10 | from torch.nn import functional as F 11 | 12 | from typing import List, Tuple, Type 13 | 14 | from .common import LayerNorm2d 15 | 16 | 17 | class MaskDecoder(nn.Module): 18 | def __init__( 19 | self, 20 | *, 21 | transformer_dim: int, 22 | transformer: nn.Module, 23 | num_multimask_outputs: int = 3, 24 | activation: Type[nn.Module] = nn.GELU, 25 | iou_head_depth: int = 3, 26 | iou_head_hidden_dim: int = 256, 27 | ) -> None: 28 | """ 29 | Predicts masks given an image and prompt embeddings, using a 30 | transformer architecture. 31 | 32 | Arguments: 33 | transformer_dim (int): the channel dimension of the transformer 34 | transformer (nn.Module): the transformer used to predict masks 35 | num_multimask_outputs (int): the number of masks to predict 36 | when disambiguating masks 37 | activation (nn.Module): the type of activation to use when 38 | upscaling masks 39 | iou_head_depth (int): the depth of the MLP used to predict 40 | mask quality 41 | iou_head_hidden_dim (int): the hidden dimension of the MLP 42 | used to predict mask quality 43 | """ 44 | super().__init__() 45 | self.transformer_dim = transformer_dim 46 | self.transformer = transformer 47 | 48 | self.num_multimask_outputs = num_multimask_outputs 49 | 50 | self.iou_token = nn.Embedding(1, transformer_dim) 51 | self.num_mask_tokens = num_multimask_outputs + 1 52 | self.mask_tokens = nn.Embedding(self.num_mask_tokens, transformer_dim) 53 | 54 | self.output_upscaling = nn.Sequential( 55 | nn.ConvTranspose2d( 56 | transformer_dim, transformer_dim // 4, kernel_size=2, stride=2 57 | ), 58 | LayerNorm2d(transformer_dim // 4), 59 | activation(), 60 | nn.ConvTranspose2d( 61 | transformer_dim // 4, transformer_dim // 8, kernel_size=2, stride=2 62 | ), 63 | activation(), 64 | ) 65 | self.output_hypernetworks_mlps = nn.ModuleList( 66 | [ 67 | MLP(transformer_dim, transformer_dim, transformer_dim // 8, 3) 68 | for i in range(self.num_mask_tokens) 69 | ] 70 | ) 71 | 72 | self.iou_prediction_head = MLP( 73 | transformer_dim, iou_head_hidden_dim, self.num_mask_tokens, iou_head_depth 74 | ) 75 | 76 | def forward( 77 | self, 78 | image_embeddings: torch.Tensor, 79 | image_pe: torch.Tensor, 80 | sparse_prompt_embeddings: torch.Tensor, 81 | dense_prompt_embeddings: torch.Tensor, 82 | multimask_output: bool, 83 | ) -> Tuple[torch.Tensor, torch.Tensor]: 84 | """ 85 | Predict masks given image and prompt embeddings. 86 | 87 | Arguments: 88 | image_embeddings (torch.Tensor): the embeddings from the image encoder 89 | image_pe (torch.Tensor): positional encoding with the shape of image_embeddings 90 | sparse_prompt_embeddings (torch.Tensor): the embeddings of the points and boxes 91 | dense_prompt_embeddings (torch.Tensor): the embeddings of the mask inputs 92 | multimask_output (bool): Whether to return multiple masks or a single 93 | mask. 94 | 95 | Returns: 96 | torch.Tensor: batched predicted masks 97 | torch.Tensor: batched predictions of mask quality 98 | """ 99 | masks, iou_pred = self.predict_masks( 100 | image_embeddings=image_embeddings, 101 | image_pe=image_pe, 102 | sparse_prompt_embeddings=sparse_prompt_embeddings, 103 | dense_prompt_embeddings=dense_prompt_embeddings, 104 | ) 105 | 106 | # Select the correct mask or masks for output 107 | if multimask_output: 108 | mask_slice = slice(1, None) 109 | else: 110 | mask_slice = slice(0, 1) 111 | masks = masks[:, mask_slice, :, :] 112 | iou_pred = iou_pred[:, mask_slice] 113 | 114 | # Prepare output 115 | return masks, iou_pred 116 | 117 | def predict_masks( 118 | self, 119 | image_embeddings: torch.Tensor, 120 | image_pe: torch.Tensor, 121 | sparse_prompt_embeddings: torch.Tensor, 122 | dense_prompt_embeddings: torch.Tensor, 123 | ) -> Tuple[torch.Tensor, torch.Tensor]: 124 | """Predicts masks. See 'forward' for more details.""" 125 | # Concatenate output tokens 126 | output_tokens = torch.cat( 127 | [self.iou_token.weight, self.mask_tokens.weight], dim=0 128 | ) 129 | output_tokens = output_tokens.unsqueeze(0).expand( 130 | sparse_prompt_embeddings.size(0), -1, -1 131 | ) 132 | tokens = torch.cat((output_tokens, sparse_prompt_embeddings), dim=1) 133 | 134 | # Expand per-image data in batch direction to be per-mask 135 | if image_embeddings.shape[0] != tokens.shape[0]: 136 | src = torch.repeat_interleave(image_embeddings, tokens.shape[0], dim=0) 137 | else: 138 | src = image_embeddings 139 | src = src + dense_prompt_embeddings 140 | pos_src = torch.repeat_interleave(image_pe, tokens.shape[0], dim=0) 141 | b, c, h, w = src.shape 142 | 143 | # Run the transformer 144 | hs, src = self.transformer(src, pos_src, tokens) 145 | iou_token_out = hs[:, 0, :] 146 | mask_tokens_out = hs[:, 1 : (1 + self.num_mask_tokens), :] 147 | 148 | # Upscale mask embeddings and predict masks using the mask tokens 149 | src = src.transpose(1, 2).view(b, c, h, w) 150 | upscaled_embedding = self.output_upscaling(src) 151 | hyper_in_list: List[torch.Tensor] = [] 152 | for i in range(self.num_mask_tokens): 153 | hyper_in_list.append( 154 | self.output_hypernetworks_mlps[i](mask_tokens_out[:, i, :]) 155 | ) 156 | hyper_in = torch.stack(hyper_in_list, dim=1) 157 | b, c, h, w = upscaled_embedding.shape 158 | masks = (hyper_in @ upscaled_embedding.view(b, c, h * w)).view(b, -1, h, w) 159 | 160 | # Generate mask quality predictions 161 | iou_pred = self.iou_prediction_head(iou_token_out) 162 | 163 | return masks, iou_pred 164 | 165 | 166 | # Lightly adapted from 167 | # https://github.com/facebookresearch/MaskFormer/blob/main/mask_former/modeling/transformer/transformer_predictor.py # noqa 168 | class MLP(nn.Module): 169 | def __init__( 170 | self, 171 | input_dim: int, 172 | hidden_dim: int, 173 | output_dim: int, 174 | num_layers: int, 175 | sigmoid_output: bool = False, 176 | ) -> None: 177 | super().__init__() 178 | self.num_layers = num_layers 179 | h = [hidden_dim] * (num_layers - 1) 180 | self.layers = nn.ModuleList( 181 | nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]) 182 | ) 183 | self.sigmoid_output = sigmoid_output 184 | 185 | def forward(self, x): 186 | for i, layer in enumerate(self.layers): 187 | x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x) 188 | if self.sigmoid_output: 189 | x = F.sigmoid(x) 190 | return x 191 | -------------------------------------------------------------------------------- /src/dafne/MedSAM/utils/pre_CT_MR.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # %% import packages 3 | # pip install connected-components-3d 4 | import numpy as np 5 | 6 | # import nibabel as nib 7 | import SimpleITK as sitk 8 | import os 9 | 10 | join = os.path.join 11 | from skimage import transform 12 | from tqdm import tqdm 13 | import cc3d 14 | 15 | # convert nii image to npz files, including original image and corresponding masks 16 | modality = "CT" 17 | anatomy = "Abd" # anantomy + dataset name 18 | img_name_suffix = "_0000.nii.gz" 19 | gt_name_suffix = ".nii.gz" 20 | prefix = modality + "_" + anatomy + "_" 21 | 22 | nii_path = "data/FLARE22Train/images" # path to the nii images 23 | gt_path = "data/FLARE22Train/labels" # path to the ground truth 24 | npy_path = "data/npy/" + prefix[:-1] 25 | os.makedirs(join(npy_path, "gts"), exist_ok=True) 26 | os.makedirs(join(npy_path, "imgs"), exist_ok=True) 27 | 28 | image_size = 1024 29 | voxel_num_thre2d = 100 30 | voxel_num_thre3d = 1000 31 | 32 | names = sorted(os.listdir(gt_path)) 33 | print(f"ori \# files {len(names)=}") 34 | names = [ 35 | name 36 | for name in names 37 | if os.path.exists(join(nii_path, name.split(gt_name_suffix)[0] + img_name_suffix)) 38 | ] 39 | print(f"after sanity check \# files {len(names)=}") 40 | 41 | # set label ids that are excluded 42 | remove_label_ids = [ 43 | 12 44 | ] # remove deodenum since it is scattered in the image, which is hard to specify with the bounding box 45 | tumor_id = None # only set this when there are multiple tumors; convert semantic masks to instance masks 46 | # set window level and width 47 | # https://radiopaedia.org/articles/windowing-ct 48 | WINDOW_LEVEL = 40 # only for CT images 49 | WINDOW_WIDTH = 400 # only for CT images 50 | 51 | # %% save preprocessed images and masks as npz files 52 | for name in tqdm(names[:40]): # use the remaining 10 cases for validation 53 | image_name = name.split(gt_name_suffix)[0] + img_name_suffix 54 | gt_name = name 55 | gt_sitk = sitk.ReadImage(join(gt_path, gt_name)) 56 | gt_data_ori = np.uint8(sitk.GetArrayFromImage(gt_sitk)) 57 | # remove label ids 58 | for remove_label_id in remove_label_ids: 59 | gt_data_ori[gt_data_ori == remove_label_id] = 0 60 | # label tumor masks as instances and remove from gt_data_ori 61 | if tumor_id is not None: 62 | tumor_bw = np.uint8(gt_data_ori == tumor_id) 63 | gt_data_ori[tumor_bw > 0] = 0 64 | # label tumor masks as instances 65 | tumor_inst, tumor_n = cc3d.connected_components( 66 | tumor_bw, connectivity=26, return_N=True 67 | ) 68 | # put the tumor instances back to gt_data_ori 69 | gt_data_ori[tumor_inst > 0] = ( 70 | tumor_inst[tumor_inst > 0] + np.max(gt_data_ori) + 1 71 | ) 72 | 73 | # exclude the objects with less than 1000 pixels in 3D 74 | gt_data_ori = cc3d.dust( 75 | gt_data_ori, threshold=voxel_num_thre3d, connectivity=26, in_place=True 76 | ) 77 | # remove small objects with less than 100 pixels in 2D slices 78 | 79 | for slice_i in range(gt_data_ori.shape[0]): 80 | gt_i = gt_data_ori[slice_i, :, :] 81 | # remove small objects with less than 100 pixels 82 | # reason: fro such small objects, the main challenge is detection rather than segmentation 83 | gt_data_ori[slice_i, :, :] = cc3d.dust( 84 | gt_i, threshold=voxel_num_thre2d, connectivity=8, in_place=True 85 | ) 86 | # find non-zero slices 87 | z_index, _, _ = np.where(gt_data_ori > 0) 88 | z_index = np.unique(z_index) 89 | 90 | if len(z_index) > 0: 91 | # crop the ground truth with non-zero slices 92 | gt_roi = gt_data_ori[z_index, :, :] 93 | # load image and preprocess 94 | img_sitk = sitk.ReadImage(join(nii_path, image_name)) 95 | image_data = sitk.GetArrayFromImage(img_sitk) 96 | # nii preprocess start 97 | if modality == "CT": 98 | lower_bound = WINDOW_LEVEL - WINDOW_WIDTH / 2 99 | upper_bound = WINDOW_LEVEL + WINDOW_WIDTH / 2 100 | image_data_pre = np.clip(image_data, lower_bound, upper_bound) 101 | image_data_pre = ( 102 | (image_data_pre - np.min(image_data_pre)) 103 | / (np.max(image_data_pre) - np.min(image_data_pre)) 104 | * 255.0 105 | ) 106 | else: 107 | lower_bound, upper_bound = np.percentile( 108 | image_data[image_data > 0], 0.5 109 | ), np.percentile(image_data[image_data > 0], 99.5) 110 | image_data_pre = np.clip(image_data, lower_bound, upper_bound) 111 | image_data_pre = ( 112 | (image_data_pre - np.min(image_data_pre)) 113 | / (np.max(image_data_pre) - np.min(image_data_pre)) 114 | * 255.0 115 | ) 116 | image_data_pre[image_data == 0] = 0 117 | 118 | image_data_pre = np.uint8(image_data_pre) 119 | img_roi = image_data_pre[z_index, :, :] 120 | np.savez_compressed(join(npy_path, prefix + gt_name.split(gt_name_suffix)[0]+'.npz'), imgs=img_roi, gts=gt_roi, spacing=img_sitk.GetSpacing()) 121 | # save the image and ground truth as nii files for sanity check; 122 | # they can be removed 123 | img_roi_sitk = sitk.GetImageFromArray(img_roi) 124 | gt_roi_sitk = sitk.GetImageFromArray(gt_roi) 125 | sitk.WriteImage( 126 | img_roi_sitk, 127 | join(npy_path, prefix + gt_name.split(gt_name_suffix)[0] + "_img.nii.gz"), 128 | ) 129 | sitk.WriteImage( 130 | gt_roi_sitk, 131 | join(npy_path, prefix + gt_name.split(gt_name_suffix)[0] + "_gt.nii.gz"), 132 | ) 133 | # save the each CT image as npy file 134 | for i in range(img_roi.shape[0]): 135 | img_i = img_roi[i, :, :] 136 | img_3c = np.repeat(img_i[:, :, None], 3, axis=-1) 137 | resize_img_skimg = transform.resize( 138 | img_3c, 139 | (image_size, image_size), 140 | order=3, 141 | preserve_range=True, 142 | mode="constant", 143 | anti_aliasing=True, 144 | ) 145 | resize_img_skimg_01 = (resize_img_skimg - resize_img_skimg.min()) / np.clip( 146 | resize_img_skimg.max() - resize_img_skimg.min(), a_min=1e-8, a_max=None 147 | ) # normalize to [0, 1], (H, W, 3) 148 | gt_i = gt_roi[i, :, :] 149 | resize_gt_skimg = transform.resize( 150 | gt_i, 151 | (image_size, image_size), 152 | order=0, 153 | preserve_range=True, 154 | mode="constant", 155 | anti_aliasing=False, 156 | ) 157 | resize_gt_skimg = np.uint8(resize_gt_skimg) 158 | assert resize_img_skimg_01.shape[:2] == resize_gt_skimg.shape 159 | np.save( 160 | join( 161 | npy_path, 162 | "imgs", 163 | prefix 164 | + gt_name.split(gt_name_suffix)[0] 165 | + "-" 166 | + str(i).zfill(3) 167 | + ".npy", 168 | ), 169 | resize_img_skimg_01, 170 | ) 171 | np.save( 172 | join( 173 | npy_path, 174 | "gts", 175 | prefix 176 | + gt_name.split(gt_name_suffix)[0] 177 | + "-" 178 | + str(i).zfill(3) 179 | + ".npy", 180 | ), 181 | resize_gt_skimg, 182 | ) 183 | --------------------------------------------------------------------------------