├── src
├── dafne
│ ├── bin
│ │ ├── __init__.py
│ │ ├── batch_validate_ui.py
│ │ ├── calc_transforms.py
│ │ ├── edit_config.py
│ │ ├── batch_validate.py
│ │ └── dafne.py
│ ├── resources
│ │ ├── __init__.py
│ │ ├── circle.png
│ │ ├── square.png
│ │ ├── dafne_anim.gif
│ │ ├── dafne_logo.png
│ │ └── runtime_dependencies.cfg
│ ├── config
│ │ ├── version.py
│ │ └── __init__.py
│ ├── MedSAM
│ │ ├── segment_anything
│ │ │ ├── utils
│ │ │ │ ├── __init__.py
│ │ │ │ ├── transforms.py
│ │ │ │ └── onnx.py
│ │ │ ├── modeling
│ │ │ │ ├── __init__.py
│ │ │ │ ├── common.py
│ │ │ │ └── mask_decoder.py
│ │ │ ├── __init__.py
│ │ │ └── build_sam.py
│ │ ├── utils
│ │ │ ├── ckpt_convert.py
│ │ │ ├── pre_grey_rgb.py
│ │ │ ├── format_convert.py
│ │ │ ├── split.py
│ │ │ ├── README.md
│ │ │ └── pre_CT_MR.py
│ │ └── MedSAM_Inference.py
│ ├── __init__.py
│ ├── utils
│ │ ├── open_folder.py
│ │ ├── __init__.py
│ │ ├── resource_utils.py
│ │ ├── log.py
│ │ ├── compressed_pickle.py
│ │ ├── ThreadHelpers.py
│ │ ├── mask_utils.py
│ │ └── polyToMask.py
│ └── ui
│ │ ├── __init__.py
│ │ ├── LogWindow.py
│ │ ├── LogWindowUI.ui
│ │ ├── ModelBrowser.ui
│ │ ├── ModelBrowserUI.py
│ │ ├── LogWindowUI.py
│ │ ├── CalcTransformsUI.ui
│ │ ├── ContourPainter.py
│ │ ├── CalcTransformsUI.py
│ │ ├── WhatsNew.py
│ │ ├── BatchCalcTransforms.py
│ │ ├── BrushPatches.py
│ │ ├── ValidateUI.ui
│ │ ├── ValidateUI.py
│ │ └── Viewer3D.py
└── __init__.py
├── icons
├── dafne_icon.icns
├── dafne_icon.ico
├── dafne_icon.png
├── dafne_icon1024.png
├── mac_installer_bg.png
├── calctransform_ico.ico
├── calctransform_ico.png
├── mac_installer_bg.svg
└── dafne_icon.svg
├── pyproject.toml
├── use_local_directories_no
├── .gitmodules
├── install_scripts
├── create_windows_installer.bat
├── create_linux_installer.sh
├── entitlements.plist
├── make_mac_icons.sh
├── update_version.py
├── dafne_linux.spec
├── dafne_win.iss
├── dafne_win.spec
├── create_mac_installer.sh
├── dafne_mac.spec
└── fix_app_bundle_for_mac.py
├── pyinstaller_hooks
├── hook-pydicom.py
├── hook-dosma.py
├── hook-dafne_dl.py
└── hook-dafne.py
├── requirements.txt
├── create_new_version.sh
├── batch_validate.py
├── calc_transforms.py
├── batch_validate_ui.py
├── edit_config.py
├── dafne
├── test
├── __init__.py
├── test_seg.py
├── plotSegmentations.py
├── testDL.py
├── testILearn.py
├── testILearn_AA.py
├── testILearn_split_AA.py
├── validation_ILearn.py
└── validation_split_ILearn.py
├── setup.cfg
├── .gitignore
└── README.md
/src/dafne/bin/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2022 Dafne-Imaging Team
2 |
--------------------------------------------------------------------------------
/src/dafne/resources/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2022 Dafne-Imaging Team
2 |
--------------------------------------------------------------------------------
/src/dafne/config/version.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2022 Dafne-Imaging Team
2 | VERSION='1.9-alpha'
--------------------------------------------------------------------------------
/icons/dafne_icon.icns:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/dafne-imaging/dafne/HEAD/icons/dafne_icon.icns
--------------------------------------------------------------------------------
/icons/dafne_icon.ico:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/dafne-imaging/dafne/HEAD/icons/dafne_icon.ico
--------------------------------------------------------------------------------
/icons/dafne_icon.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/dafne-imaging/dafne/HEAD/icons/dafne_icon.png
--------------------------------------------------------------------------------
/icons/dafne_icon1024.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/dafne-imaging/dafne/HEAD/icons/dafne_icon1024.png
--------------------------------------------------------------------------------
/src/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2022 Dafne-Imaging Team
2 |
3 | from .config.version import VERSION
--------------------------------------------------------------------------------
/icons/mac_installer_bg.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/dafne-imaging/dafne/HEAD/icons/mac_installer_bg.png
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [build-system]
2 | requires = ["setuptools>=42"]
3 | build-backend = "setuptools.build_meta"
--------------------------------------------------------------------------------
/icons/calctransform_ico.ico:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/dafne-imaging/dafne/HEAD/icons/calctransform_ico.ico
--------------------------------------------------------------------------------
/icons/calctransform_ico.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/dafne-imaging/dafne/HEAD/icons/calctransform_ico.png
--------------------------------------------------------------------------------
/src/dafne/resources/circle.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/dafne-imaging/dafne/HEAD/src/dafne/resources/circle.png
--------------------------------------------------------------------------------
/src/dafne/resources/square.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/dafne-imaging/dafne/HEAD/src/dafne/resources/square.png
--------------------------------------------------------------------------------
/src/dafne/resources/dafne_anim.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/dafne-imaging/dafne/HEAD/src/dafne/resources/dafne_anim.gif
--------------------------------------------------------------------------------
/src/dafne/resources/dafne_logo.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/dafne-imaging/dafne/HEAD/src/dafne/resources/dafne_logo.png
--------------------------------------------------------------------------------
/use_local_directories_no:
--------------------------------------------------------------------------------
1 | #This is a dummy file. If present, the configuration will use the local directories instead of the system-defined ones.
--------------------------------------------------------------------------------
/.gitmodules:
--------------------------------------------------------------------------------
1 | [submodule "dl"]
2 | path = dl
3 | url = https://github.com/dafne-imaging/dafne-dl.git
4 | branch = master
5 | update = merge
6 |
--------------------------------------------------------------------------------
/install_scripts/create_windows_installer.bat:
--------------------------------------------------------------------------------
1 | python -V
2 | python update_version.py
3 | pyinstaller dafne_win.spec --noconfirm
4 | "C:\Program Files (x86)\Inno Setup 6\Compil32.exe" /cc dafne_win.iss
--------------------------------------------------------------------------------
/pyinstaller_hooks/hook-pydicom.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2022 Dafne-Imaging Team
2 | # Hook file for pyinstaller
3 |
4 | from PyInstaller.utils.hooks import collect_submodules
5 |
6 | hiddenimports=collect_submodules('pydicom')
--------------------------------------------------------------------------------
/pyinstaller_hooks/hook-dosma.py:
--------------------------------------------------------------------------------
1 | # Hook file for pyinstaller
2 |
3 | from PyInstaller.utils.hooks import collect_submodules, collect_data_files
4 |
5 | hiddenimports=collect_submodules('dosma')
6 | datas = collect_data_files('dosma')
--------------------------------------------------------------------------------
/pyinstaller_hooks/hook-dafne_dl.py:
--------------------------------------------------------------------------------
1 | # Hook file for pyinstaller
2 |
3 | from PyInstaller.utils.hooks import collect_submodules
4 |
5 | hiddenimports=collect_submodules('dafne_dl') + \
6 | collect_submodules('skimage.filters')
--------------------------------------------------------------------------------
/src/dafne/MedSAM/segment_anything/utils/__init__.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # Copyright (c) Meta Platforms, Inc. and affiliates.
3 | # All rights reserved.
4 |
5 | # This source code is licensed under the license found in the
6 | # LICENSE file in the root directory of this source tree.
7 |
--------------------------------------------------------------------------------
/install_scripts/create_linux_installer.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | #
3 | # Copyright (c) 2022 Dafne-Imaging Team
4 | #
5 |
6 | VERSION=`python update_version.py | tail -n1`
7 | echo $VERSION
8 | ../venv_system/bin/pyinstaller dafne_linux.spec --noconfirm
9 | cd dist
10 | mv dafne "dafne_linux_$VERSION"
--------------------------------------------------------------------------------
/pyinstaller_hooks/hook-dafne.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2022 Dafne-Imaging Team
2 | # Copyright (c) 2022 Dafne-Imaging Team
3 | # Hook file for pyinstaller
4 |
5 | from PyInstaller.utils.hooks import collect_submodules
6 |
7 | hiddenimports=collect_submodules('dafne') + \
8 | collect_submodules('skimage.segmentation')
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | h5py
2 | numpy
3 | scipy
4 | matplotlib
5 | pyqt5
6 | nibabel
7 | pydicom
8 | dill
9 | progress
10 | appdirs
11 | requests
12 | scikit-image
13 | ormir-pyvoxel
14 | importlib_resources ; python_version < "3.10"
15 | dafne-dl >= 1.4a2
16 | flexidep>=0.0.6
17 | pyvistaqt
18 | torch
19 | torchvision
20 | dafne-dicomUtils
--------------------------------------------------------------------------------
/install_scripts/entitlements.plist:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 | com.apple.security.cs.allow-jit
6 |
7 | com.apple.security.cs.allow-unsigned-executable-memory
8 |
9 |
10 |
--------------------------------------------------------------------------------
/src/dafne/MedSAM/segment_anything/modeling/__init__.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # Copyright (c) Meta Platforms, Inc. and affiliates.
3 | # All rights reserved.
4 |
5 | # This source code is licensed under the license found in the
6 | # LICENSE file in the root directory of this source tree.
7 |
8 | from .sam import Sam
9 | from .image_encoder import ImageEncoderViT
10 | from .mask_decoder import MaskDecoder
11 | from .prompt_encoder import PromptEncoder
12 | from .transformer import TwoWayTransformer
13 |
--------------------------------------------------------------------------------
/src/dafne/MedSAM/segment_anything/__init__.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # Copyright (c) Meta Platforms, Inc. and affiliates.
3 | # All rights reserved.
4 |
5 | # This source code is licensed under the license found in the
6 | # LICENSE file in the root directory of this source tree.
7 |
8 | from .build_sam import (
9 | build_sam,
10 | build_sam_vit_h,
11 | build_sam_vit_l,
12 | build_sam_vit_b,
13 | sam_model_registry,
14 | )
15 | from .predictor import SamPredictor
16 | from .automatic_mask_generator import SamAutomaticMaskGenerator
17 |
--------------------------------------------------------------------------------
/src/dafne/MedSAM/utils/ckpt_convert.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | import torch
3 |
4 | # %% convert medsam model checkpoint to sam checkpoint format for convenient inference
5 | sam_ckpt_path = ""
6 | medsam_ckpt_path = ""
7 | save_path = ""
8 | multi_gpu_ckpt = True # set as True if the model is trained with multi-gpu
9 |
10 | sam_ckpt = torch.load(sam_ckpt_path)
11 | medsam_ckpt = torch.load(medsam_ckpt_path)
12 | sam_keys = sam_ckpt.keys()
13 | for key in sam_keys:
14 | if not multi_gpu_ckpt:
15 | sam_ckpt[key] = medsam_ckpt["model"][key]
16 | else:
17 | sam_ckpt[key] = medsam_ckpt["model"]["module." + key]
18 |
19 | torch.save(sam_ckpt, save_path)
20 |
--------------------------------------------------------------------------------
/create_new_version.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | #
3 | # Copyright (c) 2022 Dafne-Imaging Team
4 | #
5 |
6 | new_version=$1
7 | version_file=src/dafne/config/version.py
8 | current_version=`tail -n 1 $version_file | sed "s/VERSION='\(.*\)'/\1/"`
9 |
10 | if [ "$new_version" == "" ]; then
11 | echo "Usage: $0 "
12 | echo "Current version: $current_version"
13 | exit 1
14 | fi
15 |
16 | echo "# Copyright (c) 2022 Dafne-Imaging Team" > $version_file
17 | echo "# This file was auto-generated. Any changes might be overwritten" >> $version_file
18 | echo "VERSION='$new_version'" >> $version_file
19 |
20 | rm dist/*
21 | python -m build --sdist --wheel
22 | python -m twine upload dist/*
23 |
--------------------------------------------------------------------------------
/src/dafne/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2022 Dafne-Imaging Team
2 |
3 | from .config.version import VERSION
4 | __version__ = VERSION
5 |
6 | from . import resources
7 |
8 | import sys
9 | import flexidep
10 |
11 | assert sys.version_info.major == 3, "This software is only compatible with Python 3.x"
12 |
13 | if sys.version_info.minor < 10:
14 | import importlib_resources as pkg_resources
15 | else:
16 | import importlib.resources as pkg_resources
17 |
18 | # install the required resources
19 | if not flexidep.is_frozen():
20 | with pkg_resources.files(resources).joinpath('runtime_dependencies.cfg').open() as f:
21 | dm = flexidep.DependencyManager(config_file=f)
22 | dm.install_interactive()
23 |
--------------------------------------------------------------------------------
/src/dafne/utils/open_folder.py:
--------------------------------------------------------------------------------
1 | import subprocess
2 | import sys
3 | import os
4 |
5 |
6 | def open_folder(folder_path):
7 | if not os.path.isdir(folder_path):
8 | folder_path = os.path.dirname(folder_path)
9 |
10 | print('Opening folder:', folder_path)
11 |
12 | if sys.platform == 'win32':
13 | proc_name = 'explorer'
14 | elif sys.platform == 'darwin':
15 | proc_name = 'open'
16 | elif sys.platform.startswith('linux'):
17 | proc_name = 'xdg-open'
18 | else:
19 | raise NotImplementedError('Unsupported platform')
20 |
21 | try:
22 | subprocess.run([proc_name, folder_path])
23 | except Exception as e:
24 | print(f'Error while opening folder: {e}')
25 |
--------------------------------------------------------------------------------
/batch_validate.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2022 Dafne-Imaging Team
2 | # generic stub for executable scripts residing in bin.
3 | # This code will execute the main function of a script residing in bin having the same name as the script.
4 | # The main function must be named "main" and must be in the global scope.
5 |
6 | import os
7 | import sys
8 | import importlib
9 |
10 | src_path = os.path.abspath(os.path.join(os.path.dirname(__file__), 'src'))
11 | if src_path not in sys.path:
12 | sys.path.append(src_path)
13 |
14 | this_script = os.path.basename(__file__)
15 | this_script_name = os.path.splitext(this_script)[0]
16 |
17 | import_module_name = f'dafne.bin.{this_script_name}'
18 |
19 | i = importlib.import_module(import_module_name)
20 |
21 | if __name__ == '__main__':
22 | i.main()
--------------------------------------------------------------------------------
/calc_transforms.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2022 Dafne-Imaging Team
2 | # generic stub for executable scripts residing in bin.
3 | # This code will execute the main function of a script residing in bin having the same name as the script.
4 | # The main function must be named "main" and must be in the global scope.
5 |
6 | import os
7 | import sys
8 | import importlib
9 |
10 | src_path = os.path.abspath(os.path.join(os.path.dirname(__file__), 'src'))
11 | if src_path not in sys.path:
12 | sys.path.append(src_path)
13 |
14 | this_script = os.path.basename(__file__)
15 | this_script_name = os.path.splitext(this_script)[0]
16 |
17 | import_module_name = f'dafne.bin.{this_script_name}'
18 |
19 | i = importlib.import_module(import_module_name)
20 |
21 | if __name__ == '__main__':
22 | i.main()
--------------------------------------------------------------------------------
/batch_validate_ui.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2022 Dafne-Imaging Team
2 | # generic stub for executable scripts residing in bin.
3 | # This code will execute the main function of a script residing in bin having the same name as the script.
4 | # The main function must be named "main" and must be in the global scope.
5 |
6 | import os
7 | import sys
8 | import importlib
9 |
10 | src_path = os.path.abspath(os.path.join(os.path.dirname(__file__), 'src'))
11 | if src_path not in sys.path:
12 | sys.path.append(src_path)
13 |
14 | this_script = os.path.basename(__file__)
15 | this_script_name = os.path.splitext(this_script)[0]
16 |
17 | import_module_name = f'dafne.bin.{this_script_name}'
18 |
19 | i = importlib.import_module(import_module_name)
20 |
21 | if __name__ == '__main__':
22 | i.main()
--------------------------------------------------------------------------------
/edit_config.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2022 Dafne-Imaging Team
2 | # generic stub for executable scripts residing in bin.
3 | # This code will execute the main function of a script residing in bin having the same name as the script.
4 | # The main function must be named "main" and must be in the global scope.
5 |
6 | import os
7 | import sys
8 | import importlib
9 |
10 | src_path = os.path.abspath(os.path.join(os.path.dirname(__file__), 'src'))
11 | if src_path not in sys.path:
12 | sys.path.append(src_path)
13 |
14 |
15 | this_script = os.path.basename(__file__)
16 | this_script_name = os.path.splitext(this_script)[0]
17 |
18 | import_module_name = f'dafne.bin.{this_script_name}'
19 |
20 | i = importlib.import_module(import_module_name)
21 |
22 | if __name__ == '__main__':
23 | i.main()
--------------------------------------------------------------------------------
/dafne:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # Copyright (c) 2022 Dafne-Imaging Team
3 | # generic stub for executable scripts residing in bin.
4 | # This code will execute the main function of a script residing in bin having the same name as the script.
5 | # The main function must be named "main" and must be in the global scope.
6 |
7 | import os
8 | import sys
9 | import importlib
10 |
11 | src_path = os.path.abspath(os.path.join(os.path.dirname(__file__), 'src'))
12 | if src_path not in sys.path:
13 | sys.path.append(src_path)
14 |
15 | this_script = os.path.basename(__file__)
16 | this_script_name = os.path.splitext(this_script)[0]
17 |
18 | import_module_name = f'dafne.bin.{this_script_name}'
19 |
20 | i = importlib.import_module(import_module_name)
21 |
22 | if __name__ == '__main__':
23 | i.main()
--------------------------------------------------------------------------------
/test/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2021 Dafne-Imaging Team
2 | #
3 | # This program is free software: you can redistribute it and/or modify
4 | # it under the terms of the GNU General Public License as published by
5 | # the Free Software Foundation, either version 3 of the License, or
6 | # (at your option) any later version.
7 | #
8 | # This program is distributed in the hope that it will be useful,
9 | # but WITHOUT ANY WARRANTY; without even the implied warranty of
10 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11 | # GNU General Public License for more details.
12 | #
13 | # You should have received a copy of the GNU General Public License
14 | # along with this program. If not, see .
15 |
16 | #!/usr/bin/env python3
17 | # -*- coding: utf-8 -*-
18 |
--------------------------------------------------------------------------------
/src/dafne/config/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2021 Dafne-Imaging Team
2 | #
3 | # This program is free software: you can redistribute it and/or modify
4 | # it under the terms of the GNU General Public License as published by
5 | # the Free Software Foundation, either version 3 of the License, or
6 | # (at your option) any later version.
7 | #
8 | # This program is distributed in the hope that it will be useful,
9 | # but WITHOUT ANY WARRANTY; without even the implied warranty of
10 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11 | # GNU General Public License for more details.
12 | #
13 | # You should have received a copy of the GNU General Public License
14 | # along with this program. If not, see .
15 |
16 | from .config import *
17 | from .version import VERSION
--------------------------------------------------------------------------------
/src/dafne/ui/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2021 Dafne-Imaging Team
2 | #
3 | # This program is free software: you can redistribute it and/or modify
4 | # it under the terms of the GNU General Public License as published by
5 | # the Free Software Foundation, either version 3 of the License, or
6 | # (at your option) any later version.
7 | #
8 | # This program is distributed in the hope that it will be useful,
9 | # but WITHOUT ANY WARRANTY; without even the implied warranty of
10 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11 | # GNU General Public License for more details.
12 | #
13 | # You should have received a copy of the GNU General Public License
14 | # along with this program. If not, see .
15 |
16 | #!/usr/bin/env python3
17 | # -*- coding: utf-8 -*-
18 |
--------------------------------------------------------------------------------
/src/dafne/utils/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2021 Dafne-Imaging Team
2 | #
3 | # This program is free software: you can redistribute it and/or modify
4 | # it under the terms of the GNU General Public License as published by
5 | # the Free Software Foundation, either version 3 of the License, or
6 | # (at your option) any later version.
7 | #
8 | # This program is distributed in the hope that it will be useful,
9 | # but WITHOUT ANY WARRANTY; without even the implied warranty of
10 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11 | # GNU General Public License for more details.
12 | #
13 | # You should have received a copy of the GNU General Public License
14 | # along with this program. If not, see .
15 |
16 | #!/usr/bin/env python3
17 | # -*- coding: utf-8 -*-
18 |
--------------------------------------------------------------------------------
/src/dafne/utils/resource_utils.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2022 Dafne-Imaging Team
2 | import os
3 | import sys
4 |
5 | assert sys.version_info.major == 3, "This software is only compatible with Python 3.x"
6 |
7 | if sys.version_info.minor < 10:
8 | import importlib_resources as pkg_resources
9 | else:
10 | import importlib.resources as pkg_resources
11 |
12 | from contextlib import contextmanager
13 |
14 | from .. import resources
15 |
16 | @contextmanager
17 | def get_resource_path(resource_name):
18 | if getattr(sys, '_MEIPASS', None):
19 | yield os.path.join(sys._MEIPASS, 'resources', resource_name) # PyInstaller support. If _MEIPASS is set, we are in a Pyinstaller environment
20 | else:
21 | with pkg_resources.as_file(pkg_resources.files(resources).joinpath(resource_name)) as resource:
22 | yield str(resource)
23 |
--------------------------------------------------------------------------------
/src/dafne/resources/runtime_dependencies.cfg:
--------------------------------------------------------------------------------
1 | [Global]
2 | interactive initialization = False
3 | use gui = Yes
4 | local install = False
5 | package manager = pip
6 | id = network.dafne.dafne
7 | optional packages =
8 | radiomics
9 | priority = radiomics
10 |
11 | [Packages]
12 | # in macos, install tensorflow-metal after tensorflow
13 | tensorflow =
14 | tensorflow ; platform_machine == 'x86_64'
15 | tensorflow_macos ++tensorflow-metal ; sys_platform == 'darwin' and platform_machine == 'arm64' and python_version <= "3.11"
16 | tensorflow_macos ; sys_platform == 'darwin' and platform_machine == 'arm64' and python_version > "3.11"
17 |
18 | SimpleITK =
19 | SimpleITK-SimpleElastix
20 | SimpleITK
21 |
22 | # uninstall SimpleITK after radiomics, as the user should choose to install SimpleITK-SimpleElastix instead
23 | radiomics =
24 | pyradiomics --SimpleITK
--------------------------------------------------------------------------------
/src/dafne/bin/batch_validate_ui.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 | # Copyright (c) 2021 Dafne-Imaging Team
4 | #
5 | # This program is free software: you can redistribute it and/or modify
6 | # it under the terms of the GNU General Public License as published by
7 | # the Free Software Foundation, either version 3 of the License, or
8 | # (at your option) any later version.
9 | #
10 | # This program is distributed in the hope that it will be useful,
11 | # but WITHOUT ANY WARRANTY; without even the implied warranty of
12 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
13 | # GNU General Public License for more details.
14 | #
15 | # You should have received a copy of the GNU General Public License
16 | # along with this program. If not, see .
17 |
18 | from ..ui.BatchValidateWindow import run
19 |
20 | def main():
21 | run()
--------------------------------------------------------------------------------
/src/dafne/bin/calc_transforms.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 |
4 | # Copyright (c) 2021 Dafne-Imaging Team
5 | #
6 | # This program is free software: you can redistribute it and/or modify
7 | # it under the terms of the GNU General Public License as published by
8 | # the Free Software Foundation, either version 3 of the License, or
9 | # (at your option) any later version.
10 | #
11 | # This program is distributed in the hope that it will be useful,
12 | # but WITHOUT ANY WARRANTY; without even the implied warranty of
13 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14 | # GNU General Public License for more details.
15 | #
16 | # You should have received a copy of the GNU General Public License
17 | # along with this program. If not, see .
18 |
19 | from ..ui.BatchCalcTransforms import run
20 |
21 | def main():
22 | run()
--------------------------------------------------------------------------------
/install_scripts/make_mac_icons.sh:
--------------------------------------------------------------------------------
1 | #!/bin/sh
2 |
3 | mkdir dafne_icon.iconset
4 | sips -z 16 16 dafne_icon1024.png --out dafne_icon.iconset/icon_16x16.png
5 | sips -z 32 32 dafne_icon1024.png --out dafne_icon.iconset/icon_16x16@2x.png
6 | sips -z 32 32 dafne_icon1024.png --out dafne_icon.iconset/icon_32x32.png
7 | sips -z 64 64 dafne_icon1024.png --out dafne_icon.iconset/icon_32x32@2x.png
8 | sips -z 128 128 dafne_icon1024.png --out dafne_icon.iconset/icon_128x128.png
9 | sips -z 256 256 dafne_icon1024.png --out dafne_icon.iconset/icon_128x128@2x.png
10 | sips -z 256 256 dafne_icon1024.png --out dafne_icon.iconset/icon_256x256.png
11 | sips -z 512 512 dafne_icon1024.png --out dafne_icon.iconset/icon_256x256@2x.png
12 | sips -z 512 512 dafne_icon1024.png --out dafne_icon.iconset/icon_512x512.png
13 | cp dafne_icon1024.png dafne_icon.iconset/icon_512x512@2x.png
14 | iconutil -c icns dafne_icon.iconset
15 | rm -R dafne_icon.iconset
--------------------------------------------------------------------------------
/src/dafne/utils/log.py:
--------------------------------------------------------------------------------
1 | from PyQt5.QtCore import QObject, pyqtSignal
2 |
3 |
4 | class LogStream(QObject):
5 |
6 | updated = pyqtSignal(str)
7 |
8 | def __init__(self, file, old_descriptor = None, parent=None):
9 | super(LogStream, self).__init__(parent)
10 | self.fdesc = open(file, 'w')
11 | self.old_descriptor = old_descriptor
12 | self.data = ''
13 |
14 | def write(self, data):
15 | self.fdesc.write(data)
16 | self.fdesc.flush()
17 | if self.old_descriptor is not None:
18 | self.old_descriptor.write(data)
19 | self.old_descriptor.flush()
20 | self.data += data
21 | self.updated.emit(data)
22 |
23 | def writelines(self, lines):
24 | for line in lines:
25 | self.write(line)
26 |
27 | def __getattr__(self, item):
28 | return getattr(self.fdesc, item)
29 |
30 | def get_data(self):
31 | return self.data
32 |
33 | def close(self):
34 | self.fdesc.close()
35 |
36 | log_objects = {}
37 |
--------------------------------------------------------------------------------
/setup.cfg:
--------------------------------------------------------------------------------
1 | [metadata]
2 | name = dafne
3 | version = attr: dafne.config.version.VERSION
4 | author = Francesco Santini
5 | author_email = francesco.santini@unibas.ch
6 | description =Dafne - Deep Anatomical Federated Network
7 | long_description = file: README.md
8 | long_description_content_type = text/markdown
9 | url = https://github.com/dafne-imaging/dafne
10 | project_urls =
11 | Bug Tracker = https://github.com/dafne-imaging/dafne/issues
12 | classifiers =
13 | Programming Language :: Python :: 3
14 | License :: OSI Approved :: GNU General Public License v3 or later (GPLv3+)
15 | Operating System :: OS Independent
16 |
17 | [options]
18 | package_dir =
19 | = src
20 | packages = find:
21 | include_package_data = True
22 | python_requires = >=3.6
23 | install_requires = file: requirements.txt
24 |
25 | [options.packages.find]
26 | where = src
27 |
28 | [options.package_data]
29 | dafne = resources/*
30 |
31 | [options.entry_points]
32 | console_scripts =
33 | dafne = dafne.bin.dafne:main
34 | dafne_calc_transforms = dafne.bin.calc_transforms:main
35 | dafne_edit_config = dafne.bin.edit_config:main
--------------------------------------------------------------------------------
/test/test_seg.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 |
4 | # Copyright (c) 2021 Dafne-Imaging Team
5 | #
6 | # This program is free software: you can redistribute it and/or modify
7 | # it under the terms of the GNU General Public License as published by
8 | # the Free Software Foundation, either version 3 of the License, or
9 | # (at your option) any later version.
10 | #
11 | # This program is distributed in the hope that it will be useful,
12 | # but WITHOUT ANY WARRANTY; without even the implied warranty of
13 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14 | # GNU General Public License for more details.
15 | #
16 | # You should have received a copy of the GNU General Public License
17 | # along with this program. If not, see .
18 |
19 | from test.testDL import testSegmentation
20 | from test.testDL import testClassification
21 |
22 | #testSegmentation('models/thigh.model', 'testImages/thigh_test.dcm')
23 | #testSegmentation('models/leg.model', 'testImages/leg_test.dcm')
24 |
25 | testClassification('models/classifier.model', 'testImages/thigh_test.dcm')
26 | testClassification('models/classifier.model', 'testImages/leg_test.dcm')
--------------------------------------------------------------------------------
/install_scripts/update_version.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Copyright (c) 2022 Dafne-Imaging Team
3 | import shutil
4 | import sys
5 | import os
6 |
7 | sys.path.append(os.path.abspath(os.path.join('..', 'src')))
8 |
9 | from dafne.config.version import VERSION
10 |
11 | shutil.move('dafne_win.iss', 'dafne_win.iss.old')
12 |
13 | with open('dafne_win.iss.old', 'r') as orig_file:
14 | with open('dafne_win.iss', 'w') as new_file:
15 | for line in orig_file:
16 | if line.startswith('#define MyAppVersion'):
17 | new_file.write(f'#define MyAppVersion "{VERSION}"\n')
18 | elif line.startswith('OutputBaseFilename='):
19 | new_file.write(f'OutputBaseFilename=dafne_windows_setup_{VERSION}\n')
20 | else:
21 | new_file.write(line)
22 |
23 | shutil.move('dafne_mac.spec', 'dafne_mac.spec.old')
24 | with open('dafne_mac.spec.old', 'r') as orig_file:
25 | with open('dafne_mac.spec', 'w') as new_file:
26 | for line in orig_file:
27 | if 'version=' in line:
28 | new_file.write(f" version='{VERSION}')\n")
29 | else:
30 | new_file.write(line)
31 |
32 | print(VERSION)
--------------------------------------------------------------------------------
/test/plotSegmentations.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 |
4 | # Copyright (c) 2021 Dafne-Imaging Team
5 | #
6 | # This program is free software: you can redistribute it and/or modify
7 | # it under the terms of the GNU General Public License as published by
8 | # the Free Software Foundation, either version 3 of the License, or
9 | # (at your option) any later version.
10 | #
11 | # This program is distributed in the hope that it will be useful,
12 | # but WITHOUT ANY WARRANTY; without even the implied warranty of
13 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14 | # GNU General Public License for more details.
15 | #
16 | # You should have received a copy of the GNU General Public License
17 | # along with this program. If not, see .
18 |
19 | import numpy as np
20 | import matplotlib.pyplot as plt
21 |
22 | def plotSegmentations(ima, segmentations):
23 | for label, mask in segmentations.items():
24 | plt.figure()
25 | imaRGB = np.stack([ima, ima, ima], axis = -1)
26 | imaRGB = imaRGB / imaRGB.max() * 0.6
27 | imaRGB[:,:,0] = imaRGB[:,:,0] + 0.4 * mask
28 | plt.imshow(imaRGB)
29 | plt.title(label)
--------------------------------------------------------------------------------
/src/dafne/bin/edit_config.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 |
4 | # Copyright (c) 2022 Dafne-Imaging Team
5 | #
6 | # This program is free software: you can redistribute it and/or modify
7 | # it under the terms of the GNU General Public License as published by
8 | # the Free Software Foundation, either version 3 of the License, or
9 | # (at your option) any later version.
10 | #
11 | # This program is distributed in the hope that it will be useful,
12 | # but WITHOUT ANY WARRANTY; without even the implied warranty of
13 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14 | # GNU General Public License for more details.
15 | #
16 | # You should have received a copy of the GNU General Public License
17 | # along with this program. If not, see .
18 |
19 | from ..config.config import show_config_dialog, save_config, load_config
20 |
21 | import sys
22 | from PyQt5.QtWidgets import QApplication
23 |
24 | def main():
25 | app = QApplication(sys.argv)
26 | app.setQuitOnLastWindowClosed(True)
27 |
28 | load_config()
29 | accepted = show_config_dialog(None, True)
30 | if accepted:
31 | save_config()
32 | print('Configuration saved')
33 | else:
34 | print('Aborted')
--------------------------------------------------------------------------------
/src/dafne/utils/compressed_pickle.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2021 Dafne-Imaging Team
2 |
3 | import pickle
4 | import bz2
5 |
6 | COMPRESS_LEVEL = 9
7 |
8 |
9 | def compressed_dump(obj, file, **kwargs):
10 | with bz2.BZ2File(file, 'wb', compresslevel=COMPRESS_LEVEL) as f:
11 | pickle.dump(obj, f, **kwargs)
12 |
13 |
14 | def compressed_dumps(obj, **kwargs):
15 | return bz2.compress(pickle.dumps(obj, **kwargs), compresslevel=COMPRESS_LEVEL)
16 |
17 |
18 | def compressed_load(file, **kwargs):
19 | with bz2.BZ2File(file, 'rb') as f:
20 | return pickle.load(f, **kwargs)
21 |
22 |
23 | def compressed_loads(compressed_bytes, **kwargs):
24 | return pickle.loads(bz2.decompress(compressed_bytes), **kwargs)
25 |
26 |
27 | def loads(byte_array, **kwargs):
28 | """
29 | Generic loads to replace pickle load
30 | """
31 | try:
32 | return compressed_loads(byte_array, **kwargs)
33 | except OSError:
34 | print("Loading uncompressed pickle")
35 | return pickle.loads(byte_array, **kwargs)
36 |
37 |
38 | def load(file, **kwargs):
39 | """
40 | Generic load
41 | """
42 | try:
43 | return compressed_load(file, **kwargs)
44 | except OSError:
45 | print("Loading uncompressed pickle")
46 | try:
47 | file.seek(0)
48 | except:
49 | pass
50 | return pickle.load(file, **kwargs)
51 |
52 |
53 | dump = compressed_dump
54 | dumps = compressed_dumps
55 |
--------------------------------------------------------------------------------
/install_scripts/dafne_linux.spec:
--------------------------------------------------------------------------------
1 | # -*- mode: python ; coding: utf-8 -*-
2 | import sys ; sys.setrecursionlimit(sys.getrecursionlimit() * 5)
3 |
4 | block_cipher = None
5 |
6 |
7 | a = Analysis(
8 | ['../dafne'],
9 | pathex=['../src'],
10 | binaries=[],
11 | datas=[('../LICENSE', '.'), ('../src/dafne/resources/*', 'resources/')],
12 | hiddenimports=[
13 | 'dafne',
14 | 'pydicom',
15 | 'SimpleITK',
16 | 'tensorflow',
17 | 'skimage',
18 | 'nibabel',
19 | 'dafne_dl',
20 | 'cmath',
21 | 'ormir-pyvoxel',
22 | 'pyvistaqt',
23 | 'pyvista',
24 | 'vtk',
25 | 'torch',
26 | 'torchvision'],
27 | hookspath=['../pyinstaller_hooks'],
28 | hooksconfig={},
29 | runtime_hooks=[],
30 | excludes=[],
31 | win_no_prefer_redirects=False,
32 | win_private_assemblies=False,
33 | cipher=block_cipher,
34 | noarchive=False,
35 | )
36 | pyz = PYZ(a.pure, a.zipped_data, cipher=block_cipher)
37 |
38 | exe = EXE(
39 | pyz,
40 | a.scripts,
41 | a.binaries,
42 | a.zipfiles,
43 | a.datas,
44 | [],
45 | name='dafne',
46 | debug=False,
47 | bootloader_ignore_signals=False,
48 | strip=False,
49 | upx=True,
50 | upx_exclude=[],
51 | runtime_tmpdir=None,
52 | console=False,
53 | disable_windowed_traceback=False,
54 | argv_emulation=False,
55 | target_arch=None,
56 | codesign_identity=None,
57 | entitlements_file=None,
58 | )
59 |
--------------------------------------------------------------------------------
/test/testDL.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 |
4 | # Copyright (c) 2021 Dafne-Imaging Team
5 | #
6 | # This program is free software: you can redistribute it and/or modify
7 | # it under the terms of the GNU General Public License as published by
8 | # the Free Software Foundation, either version 3 of the License, or
9 | # (at your option) any later version.
10 | #
11 | # This program is distributed in the hope that it will be useful,
12 | # but WITHOUT ANY WARRANTY; without even the implied warranty of
13 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14 | # GNU General Public License for more details.
15 | #
16 | # You should have received a copy of the GNU General Public License
17 | # along with this program. If not, see .
18 |
19 | from src.dafne_dl import DynamicDLModel
20 | from dicomUtils import loadDicomFile
21 | #import numpy as np
22 | from test.plotSegmentations import plotSegmentations
23 | import matplotlib.pyplot as plt
24 |
25 | def testSegmentation(modelPath, dicomPath):
26 | thighModel = DynamicDLModel.Load(open(modelPath, 'rb'))
27 |
28 | ima, info = loadDicomFile(dicomPath)
29 |
30 | resolution = info.PixelSpacing
31 | out = thighModel({'image': ima, 'resolution': resolution})
32 |
33 | plotSegmentations(ima, out)
34 |
35 | plt.show()
36 |
37 | def testClassification(modelPath, dicomPath):
38 | classModel = DynamicDLModel.Load(open(modelPath, 'rb'))
39 |
40 | ima, info = loadDicomFile(dicomPath)
41 |
42 | resolution = info.PixelSpacing
43 | out = classModel({'image': ima, 'resolution': resolution})
44 | print(out)
--------------------------------------------------------------------------------
/src/dafne/utils/ThreadHelpers.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2021 Dafne-Imaging Team
2 | #
3 | # This program is free software: you can redistribute it and/or modify
4 | # it under the terms of the GNU General Public License as published by
5 | # the Free Software Foundation, either version 3 of the License, or
6 | # (at your option) any later version.
7 | #
8 | # This program is distributed in the hope that it will be useful,
9 | # but WITHOUT ANY WARRANTY; without even the implied warranty of
10 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11 | # GNU General Public License for more details.
12 | #
13 | # You should have received a copy of the GNU General Public License
14 | # along with this program. If not, see .
15 |
16 | from PyQt5.QtCore import QRunnable, pyqtSlot, QThreadPool
17 | from functools import wraps
18 | import traceback
19 |
20 | threadpool = QThreadPool()
21 |
22 | class Runner(QRunnable):
23 |
24 | def __init__(self, func, *args, **kwargs):
25 | QRunnable.__init__(self)
26 | self.func = func
27 | self.args = args
28 | self.kwargs = kwargs
29 |
30 | @pyqtSlot()
31 | def run(self):
32 | try:
33 | setattr(self.args[0], 'separate_thread_running', True)
34 | except:
35 | pass
36 | self.func(*self.args, **self.kwargs)
37 | try:
38 | setattr(self.args[0], 'separate_thread_running', False)
39 | except:
40 | pass
41 |
42 | def separate_thread_decorator(func):
43 | @wraps(func)
44 | def run_wrapper(*args, **kwargs):
45 | runner = Runner(func, *args, **kwargs)
46 | threadpool.start(runner)
47 | return run_wrapper
--------------------------------------------------------------------------------
/src/dafne/MedSAM/segment_anything/modeling/common.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # Copyright (c) Meta Platforms, Inc. and affiliates.
3 | # All rights reserved.
4 |
5 | # This source code is licensed under the license found in the
6 | # LICENSE file in the root directory of this source tree.
7 |
8 | import torch
9 | import torch.nn as nn
10 |
11 | from typing import Type
12 |
13 |
14 | class MLPBlock(nn.Module):
15 | def __init__(
16 | self,
17 | embedding_dim: int,
18 | mlp_dim: int,
19 | act: Type[nn.Module] = nn.GELU,
20 | ) -> None:
21 | super().__init__()
22 | self.lin1 = nn.Linear(embedding_dim, mlp_dim)
23 | self.lin2 = nn.Linear(mlp_dim, embedding_dim)
24 | self.act = act()
25 |
26 | def forward(self, x: torch.Tensor) -> torch.Tensor:
27 | return self.lin2(self.act(self.lin1(x)))
28 |
29 |
30 | # From https://github.com/facebookresearch/detectron2/blob/main/detectron2/layers/batch_norm.py # noqa
31 | # Itself from https://github.com/facebookresearch/ConvNeXt/blob/d1fa8f6fef0a165b27399986cc2bdacc92777e40/models/convnext.py#L119 # noqa
32 | class LayerNorm2d(nn.Module):
33 | def __init__(self, num_channels: int, eps: float = 1e-6) -> None:
34 | super().__init__()
35 | self.weight = nn.Parameter(torch.ones(num_channels))
36 | self.bias = nn.Parameter(torch.zeros(num_channels))
37 | self.eps = eps
38 |
39 | def forward(self, x: torch.Tensor) -> torch.Tensor:
40 | u = x.mean(1, keepdim=True)
41 | s = (x - u).pow(2).mean(1, keepdim=True)
42 | x = (x - u) / torch.sqrt(s + self.eps)
43 | x = self.weight[:, None, None] * x + self.bias[:, None, None]
44 | return x
45 |
--------------------------------------------------------------------------------
/src/dafne/ui/LogWindow.py:
--------------------------------------------------------------------------------
1 | from PyQt5 import Qt
2 | from PyQt5.QtCore import pyqtSlot
3 | from PyQt5.QtWidgets import QDialog
4 |
5 | from .LogWindowUI import Ui_LogWindow
6 |
7 | from ..utils.log import log_objects
8 |
9 |
10 | class LogWindow(QDialog, Ui_LogWindow):
11 | def __init__(self, parent=None):
12 | super(LogWindow, self).__init__(parent)
13 | self.setupUi(self)
14 | self.setWindowTitle("Log")
15 | self.refresh_btn.clicked.connect(self.refresh)
16 | self.resize(1024, 768)
17 | self.refresh()
18 | try:
19 | log_objects['stdout'].updated.connect(self.append_output)
20 | except KeyError:
21 | pass
22 |
23 | try:
24 | log_objects['stderr'].updated.connect(self.append_error)
25 | except KeyError:
26 | pass
27 |
28 | @pyqtSlot(str)
29 | def append_output(self, data):
30 | self.output_text.moveCursor(Qt.QTextCursor.End)
31 | self.output_text.insertPlainText(data)
32 | self.output_text.moveCursor(Qt.QTextCursor.End)
33 |
34 | @pyqtSlot(str)
35 | def append_error(self, data):
36 | self.error_text.moveCursor(Qt.QTextCursor.End)
37 | self.error_text.insertPlainText(data)
38 | self.error_text.moveCursor(Qt.QTextCursor.End)
39 |
40 | def refresh(self):
41 | self.output_text.clear()
42 | self.error_text.clear()
43 |
44 | if 'stdout' not in log_objects:
45 | self.output_text.appendPlainText("Not available")
46 | else:
47 | self.output_text.appendPlainText(log_objects['stdout'].get_data())
48 |
49 | if 'stderr' not in log_objects:
50 | self.error_text.appendPlainText("Not available")
51 | else:
52 | self.error_text.appendPlainText(log_objects['stderr'].get_data())
53 |
54 | self.output_text.moveCursor(Qt.QTextCursor.End)
55 | self.error_text.moveCursor(Qt.QTextCursor.End)
56 |
57 |
58 |
--------------------------------------------------------------------------------
/test/testILearn.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 |
3 | # Copyright (c) 2021 Dafne-Imaging Team
4 | #
5 | # This program is free software: you can redistribute it and/or modify
6 | # it under the terms of the GNU General Public License as published by
7 | # the Free Software Foundation, either version 3 of the License, or
8 | # (at your option) any later version.
9 | #
10 | # This program is distributed in the hope that it will be useful,
11 | # but WITHOUT ANY WARRANTY; without even the implied warranty of
12 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
13 | # GNU General Public License for more details.
14 | #
15 | # You should have received a copy of the GNU General Public License
16 | # along with this program. If not, see .
17 |
18 | from src.dafne_dl import LocalModelProvider
19 | import numpy as np
20 | import pickle
21 |
22 | GENERATE_PICKLE = False
23 |
24 | model_provider = LocalModelProvider('models')
25 | segmenter = model_provider.load_model('Thigh')
26 |
27 | data_in = np.load('testImages/test_data.npy')
28 | segment_in = np.load('testImages/test_segment.npz')
29 |
30 | n_slices = data_in.shape[2]
31 | resolution = np.array([1.037037, 1.037037])
32 |
33 | if GENERATE_PICKLE:
34 | print("Generating data")
35 | slice_range = range(4,n_slices-5)
36 | image_list = []
37 | seg_list = []
38 |
39 | for slice in slice_range:
40 | image_list.append(data_in[:,:,slice].squeeze())
41 | seg_dict = {}
42 | for roi_name in segment_in:
43 | seg_dict[roi_name] = segment_in[roi_name][:,:,slice]
44 | seg_list.append(seg_dict)
45 |
46 | pickle.dump(seg_list, open('testImages/test_segment.pickle', 'wb'))
47 | pickle.dump(image_list, open('testImages/test_data.pickle', 'wb'))
48 | else:
49 | seg_list = pickle.load(open('testImages/test_segment.pickle', 'rb'))
50 | image_list = pickle.load(open('testImages/test_data.pickle', 'rb'))
51 |
52 | print('Performing incremental learning')
53 | segmenter.incremental_learn({'image_list': image_list, 'resolution': resolution}, seg_list)
54 |
--------------------------------------------------------------------------------
/install_scripts/dafne_win.iss:
--------------------------------------------------------------------------------
1 | ; Script generated by the Inno Setup Script Wizard.
2 | ; SEE THE DOCUMENTATION FOR DETAILS ON CREATING INNO SETUP SCRIPT FILES!
3 |
4 | #define MyAppName "Dafne"
5 | #define MyAppVersion "1.5-alpha"
6 | #define MyAppPublisher "Dafne-imaging"
7 | #define MyAppURL "https://dafne.network/"
8 | #define MyAppExeName "dafne.exe"
9 |
10 | [Setup]
11 | ; NOTE: The value of AppId uniquely identifies this application. Do not use the same AppId value in installers for other applications.
12 | ; (To generate a new GUID, click Tools | Generate GUID inside the IDE.)
13 | AppId={{451322B2-10C5-4BA0-88DC-BB8933F78678}
14 | AppName={#MyAppName}
15 | AppVersion={#MyAppVersion}
16 | ;AppVerName={#MyAppName} {#MyAppVersion}
17 | AppPublisher={#MyAppPublisher}
18 | AppPublisherURL={#MyAppURL}
19 | AppSupportURL={#MyAppURL}
20 | AppUpdatesURL={#MyAppURL}
21 | ArchitecturesAllowed=x64
22 | ArchitecturesInstallIn64BitMode=x64
23 | DefaultDirName={autopf}\{#MyAppName}
24 | DisableProgramGroupPage=auto
25 | DefaultGroupName={#MyAppName}
26 | LicenseFile=dist\dafne\LICENSE
27 | ; Uncomment the following line to run in non administrative install mode (install for current user only.)
28 | ;PrivilegesRequired=lowest
29 | PrivilegesRequiredOverridesAllowed=dialog
30 | OutputDir=C:\dafne
31 | OutputBaseFilename=dafne_windows_setup_1.5-alpha
32 | SetupIconFile=..\icons\dafne_icon.ico
33 | Compression=lzma
34 | SolidCompression=yes
35 | WizardStyle=modern
36 |
37 | [Languages]
38 | Name: "english"; MessagesFile: "compiler:Default.isl"
39 |
40 | [Tasks]
41 | Name: "desktopicon"; Description: "{cm:CreateDesktopIcon}"; GroupDescription: "{cm:AdditionalIcons}"; Flags: unchecked
42 |
43 | [Files]
44 | Source: "dist\dafne\{#MyAppExeName}"; DestDir: "{app}"; Flags: ignoreversion
45 | Source: "dist\dafne\*"; DestDir: "{app}"; Flags: ignoreversion recursesubdirs createallsubdirs
46 | ; NOTE: Don't use "Flags: ignoreversion" on any shared system files
47 |
48 | [Icons]
49 | Name: "{group}\{#MyAppName}"; Filename: "{app}\{#MyAppExeName}"
50 | Name: "{autodesktop}\{#MyAppName}"; Filename: "{app}\{#MyAppExeName}"; Tasks: desktopicon
51 |
52 | [Run]
53 | Filename: "{app}\{#MyAppExeName}"; Description: "{cm:LaunchProgram,{#StringChange(MyAppName, '&', '&&')}}"; Flags: nowait postinstall skipifsilent
54 |
55 |
--------------------------------------------------------------------------------
/src/dafne/ui/LogWindowUI.ui:
--------------------------------------------------------------------------------
1 |
2 |
3 | LogWindow
4 |
5 |
6 |
7 | 0
8 | 0
9 | 885
10 | 519
11 |
12 |
13 |
14 | Form
15 |
16 |
17 | -
18 |
19 |
-
20 |
21 |
-
22 |
23 |
24 | Output:
25 |
26 |
27 |
28 | -
29 |
30 |
31 | false
32 |
33 |
34 | QPlainTextEdit::NoWrap
35 |
36 |
37 | true
38 |
39 |
40 |
41 |
42 |
43 | -
44 |
45 |
-
46 |
47 |
48 | Error:
49 |
50 |
51 |
52 | -
53 |
54 |
55 | false
56 |
57 |
58 | QPlainTextEdit::NoWrap
59 |
60 |
61 | true
62 |
63 |
64 |
65 |
66 |
67 |
68 |
69 | -
70 |
71 |
72 | Refresh
73 |
74 |
75 |
76 |
77 |
78 |
79 |
80 |
81 |
--------------------------------------------------------------------------------
/src/dafne/bin/batch_validate.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 | # Copyright (c) 2021 Dafne-Imaging Team
4 | #
5 | # This program is free software: you can redistribute it and/or modify
6 | # it under the terms of the GNU General Public License as published by
7 | # the Free Software Foundation, either version 3 of the License, or
8 | # (at your option) any later version.
9 | #
10 | # This program is distributed in the hope that it will be useful,
11 | # but WITHOUT ANY WARRANTY; without even the implied warranty of
12 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
13 | # GNU General Public License for more details.
14 | #
15 | # You should have received a copy of the GNU General Public License
16 | # along with this program. If not, see .
17 |
18 | import argparse
19 |
20 |
21 | def main():
22 | parser = argparse.ArgumentParser(description='Batch model validation')
23 | parser.add_argument('--classification', type=str, help='Classification to use')
24 | parser.add_argument('--timestamp_start', type=int, help='Timestamp start')
25 | parser.add_argument('--timestamp_end', type=int, help='Timestamp end')
26 | parser.add_argument('--upload_stats', type=bool, help='Upload stats')
27 | parser.add_argument('--save_local', type=bool, help='Save local')
28 | parser.add_argument('--local_filename', type=str, help='Local filename')
29 | parser.add_argument('--roi', type=str, help='Reference ROI file')
30 | parser.add_argument('--masks', type=str, help='Reference Mask dataset')
31 | parser.add_argument('--comment', type=str, help='Comment for logging')
32 | parser.add_argument('dataset', type=str, help='Dataset for validation')
33 |
34 | args = parser.parse_args()
35 | args_to_pass = ['classification', 'timestamp_start', 'timestamp_end', 'upload_stats', 'save_local', 'local_filename']
36 | args_dict = {k: v for k,v in vars(args).items() if v is not None and k in args_to_pass}
37 | dataset = args.dataset
38 | roi = args.roi
39 | masks = args.masks
40 |
41 | from ..utils.BatchValidator import BatchValidator
42 | validator = BatchValidator(**args_dict)
43 | validator.load_directory(dataset)
44 | if roi:
45 | validator.loadROIPickle(roi)
46 | elif masks:
47 | validator.mask_import(masks)
48 |
49 | assert validator.mask_list, 'No masks found'
50 | validator.calculate(args.comment)
51 |
52 |
--------------------------------------------------------------------------------
/test/testILearn_AA.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 |
3 | # Copyright (c) 2021 Dafne-Imaging Team
4 | #
5 | # This program is free software: you can redistribute it and/or modify
6 | # it under the terms of the GNU General Public License as published by
7 | # the Free Software Foundation, either version 3 of the License, or
8 | # (at your option) any later version.
9 | #
10 | # This program is distributed in the hope that it will be useful,
11 | # but WITHOUT ANY WARRANTY; without even the implied warranty of
12 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
13 | # GNU General Public License for more details.
14 | #
15 | # You should have received a copy of the GNU General Public License
16 | # along with this program. If not, see .
17 |
18 | from src.dafne_dl import LocalModelProvider
19 | import numpy as np
20 | import pickle
21 |
22 | GENERATE_PICKLE = True
23 |
24 | model_provider = LocalModelProvider('models')
25 | segmenter = model_provider.load_model('Thigh')
26 |
27 | data_in = np.load('testImages/test_data.npy')
28 | segment_in = np.load('testImages/test_segment.npy',allow_pickle=True)
29 |
30 | n_slices = data_in.shape[2]
31 | resolution = np.array([1.037037, 1.037037])
32 |
33 | LABELS_DICT = {
34 | 1: 'VL',
35 | 2: 'VM',
36 | 3: 'VI',
37 | 4: 'RF',
38 | 5: 'SAR',
39 | 6: 'GRA',
40 | 7: 'AM',
41 | 8: 'SM',
42 | 9: 'ST',
43 | 10: 'BFL',
44 | 11: 'BFS',
45 | 12: 'AL'
46 | }
47 |
48 | if GENERATE_PICKLE:
49 | print("Generating data")
50 | slice_range = range(4,n_slices-5) # range(15,25)
51 | image_list = []
52 | seg_list = []
53 |
54 | for slice in slice_range:
55 | image_list.append(data_in[:,:,slice].squeeze())
56 | seg_dict = {}
57 | for k, v in LABELS_DICT.items():
58 | seg_dict[v] = segment_in[slice][v][:,:] # segment_in[roi_name][:,:,slice]
59 | seg_list.append(seg_dict)
60 |
61 | pickle.dump(seg_list, open('testImages/test_segment.pickle', 'wb'))
62 | pickle.dump(image_list, open('testImages/test_data.pickle', 'wb'))
63 | else:
64 | seg_list = pickle.load(open('testImages/test_segment.pickle', 'rb'))
65 | image_list = pickle.load(open('testImages/test_data.pickle', 'rb'))
66 |
67 | print('Performing incremental learning')
68 | segmenter.incremental_learn({'image_list': image_list, 'resolution': resolution}, seg_list)
69 |
--------------------------------------------------------------------------------
/test/testILearn_split_AA.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 |
3 | # Copyright (c) 2021 Dafne-Imaging Team
4 | #
5 | # This program is free software: you can redistribute it and/or modify
6 | # it under the terms of the GNU General Public License as published by
7 | # the Free Software Foundation, either version 3 of the License, or
8 | # (at your option) any later version.
9 | #
10 | # This program is distributed in the hope that it will be useful,
11 | # but WITHOUT ANY WARRANTY; without even the implied warranty of
12 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
13 | # GNU General Public License for more details.
14 | #
15 | # You should have received a copy of the GNU General Public License
16 | # along with this program. If not, see .
17 |
18 | from src.dafne_dl import LocalModelProvider
19 | import numpy as np
20 | import pickle
21 |
22 | GENERATE_PICKLE = False
23 |
24 | model_provider = LocalModelProvider('models')
25 | segmenter = model_provider.load_model('Thigh-Split')
26 |
27 | data_in = np.load('testImages/test_data.npy')
28 | segment_in = np.load('testImages/test_segment.npy',allow_pickle=True)
29 |
30 | n_slices = data_in.shape[2]
31 | resolution = np.array([1.037037, 1.037037])
32 |
33 | LABELS_DICT = {
34 | 1: 'VL',
35 | 2: 'VM',
36 | 3: 'VI',
37 | 4: 'RF',
38 | 5: 'SAR',
39 | 6: 'GRA',
40 | 7: 'AM',
41 | 8: 'SM',
42 | 9: 'ST',
43 | 10: 'BFL',
44 | 11: 'BFS',
45 | 12: 'AL'
46 | }
47 |
48 | if GENERATE_PICKLE:
49 | print("Generating data")
50 | slice_range = range(4,n_slices-5) # range(15,25)
51 | image_list = []
52 | seg_list = []
53 |
54 | for slice in slice_range:
55 | image_list.append(data_in[:,:,slice].squeeze())
56 | seg_dict = {}
57 | for k, v in LABELS_DICT.items():
58 | seg_dict[v] = segment_in[slice][v][:,:] # segment_in[roi_name][:,:,slice]
59 | seg_list.append(seg_dict)
60 |
61 | pickle.dump(seg_list, open('testImages/test_segment.pickle', 'wb'))
62 | pickle.dump(image_list, open('testImages/test_data.pickle', 'wb'))
63 | else:
64 | seg_list = pickle.load(open('testImages/test_segment.pickle', 'rb'))
65 | image_list = pickle.load(open('testImages/test_data.pickle', 'rb'))
66 |
67 | print('Performing incremental learning')
68 | segmenter.incremental_learn({'image_list': image_list, 'resolution': resolution}, seg_list)
69 |
--------------------------------------------------------------------------------
/src/dafne/ui/ModelBrowser.ui:
--------------------------------------------------------------------------------
1 |
2 |
3 | ModelBrowser
4 |
5 |
6 |
7 | 0
8 | 0
9 | 1024
10 | 768
11 |
12 |
13 |
14 | Dafne Model Browser
15 |
16 |
17 | -
18 |
19 |
20 |
21 | Models
22 |
23 |
24 |
25 |
26 | -
27 |
28 |
29 | true
30 |
31 |
32 | QAbstractItemView::NoEditTriggers
33 |
34 |
35 | true
36 |
37 |
38 | false
39 |
40 |
41 |
42 | Name
43 |
44 |
45 |
46 |
47 | Info
48 |
49 |
50 |
51 |
52 | -
53 |
54 |
55 | Qt::Horizontal
56 |
57 |
58 | QDialogButtonBox::Cancel|QDialogButtonBox::Ok
59 |
60 |
61 |
62 |
63 |
64 |
65 |
66 |
67 | buttonBox
68 | accepted()
69 | ModelBrowser
70 | accept()
71 |
72 |
73 | 248
74 | 254
75 |
76 |
77 | 157
78 | 274
79 |
80 |
81 |
82 |
83 | buttonBox
84 | rejected()
85 | ModelBrowser
86 | reject()
87 |
88 |
89 | 316
90 | 260
91 |
92 |
93 | 286
94 | 274
95 |
96 |
97 |
98 |
99 |
100 |
--------------------------------------------------------------------------------
/src/dafne/ui/ModelBrowserUI.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | # Form implementation generated from reading ui file 'ModelBrowser.ui'
4 | #
5 | # Created by: PyQt5 UI code generator 5.15.7
6 | #
7 | # WARNING: Any manual changes made to this file will be lost when pyuic5 is
8 | # run again. Do not edit this file unless you know what you are doing.
9 |
10 |
11 | from PyQt5 import QtCore, QtGui, QtWidgets
12 |
13 |
14 | class Ui_ModelBrowser(object):
15 | def setupUi(self, ModelBrowser):
16 | ModelBrowser.setObjectName("ModelBrowser")
17 | ModelBrowser.resize(1024, 768)
18 | self.verticalLayout = QtWidgets.QVBoxLayout(ModelBrowser)
19 | self.verticalLayout.setObjectName("verticalLayout")
20 | self.model_tree = QtWidgets.QTreeWidget(ModelBrowser)
21 | self.model_tree.setObjectName("model_tree")
22 | self.verticalLayout.addWidget(self.model_tree)
23 | self.details_table = QtWidgets.QTableWidget(ModelBrowser)
24 | self.details_table.setEnabled(True)
25 | self.details_table.setEditTriggers(QtWidgets.QAbstractItemView.NoEditTriggers)
26 | self.details_table.setObjectName("details_table")
27 | self.details_table.setColumnCount(2)
28 | self.details_table.setRowCount(0)
29 | item = QtWidgets.QTableWidgetItem()
30 | self.details_table.setHorizontalHeaderItem(0, item)
31 | item = QtWidgets.QTableWidgetItem()
32 | self.details_table.setHorizontalHeaderItem(1, item)
33 | self.details_table.horizontalHeader().setStretchLastSection(True)
34 | self.details_table.verticalHeader().setVisible(False)
35 | self.verticalLayout.addWidget(self.details_table)
36 | self.buttonBox = QtWidgets.QDialogButtonBox(ModelBrowser)
37 | self.buttonBox.setOrientation(QtCore.Qt.Horizontal)
38 | self.buttonBox.setStandardButtons(QtWidgets.QDialogButtonBox.Cancel|QtWidgets.QDialogButtonBox.Ok)
39 | self.buttonBox.setObjectName("buttonBox")
40 | self.verticalLayout.addWidget(self.buttonBox)
41 |
42 | self.retranslateUi(ModelBrowser)
43 | self.buttonBox.accepted.connect(ModelBrowser.accept) # type: ignore
44 | self.buttonBox.rejected.connect(ModelBrowser.reject) # type: ignore
45 | QtCore.QMetaObject.connectSlotsByName(ModelBrowser)
46 |
47 | def retranslateUi(self, ModelBrowser):
48 | _translate = QtCore.QCoreApplication.translate
49 | ModelBrowser.setWindowTitle(_translate("ModelBrowser", "Dafne Model Browser"))
50 | self.model_tree.headerItem().setText(0, _translate("ModelBrowser", "Models"))
51 | item = self.details_table.horizontalHeaderItem(0)
52 | item.setText(_translate("ModelBrowser", "Name"))
53 | item = self.details_table.horizontalHeaderItem(1)
54 | item.setText(_translate("ModelBrowser", "Info"))
55 |
--------------------------------------------------------------------------------
/src/dafne/ui/LogWindowUI.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | # Form implementation generated from reading ui file 'LogWindowUI.ui'
4 | #
5 | # Created by: PyQt5 UI code generator 5.15.4
6 | #
7 | # WARNING: Any manual changes made to this file will be lost when pyuic5 is
8 | # run again. Do not edit this file unless you know what you are doing.
9 |
10 |
11 | from PyQt5 import QtCore, QtGui, QtWidgets
12 |
13 |
14 | class Ui_LogWindow(object):
15 | def setupUi(self, LogWindow):
16 | LogWindow.setObjectName("LogWindow")
17 | LogWindow.resize(885, 519)
18 | self.verticalLayout_3 = QtWidgets.QVBoxLayout(LogWindow)
19 | self.verticalLayout_3.setObjectName("verticalLayout_3")
20 | self.horizontalLayout = QtWidgets.QHBoxLayout()
21 | self.horizontalLayout.setObjectName("horizontalLayout")
22 | self.verticalLayout = QtWidgets.QVBoxLayout()
23 | self.verticalLayout.setObjectName("verticalLayout")
24 | self.label = QtWidgets.QLabel(LogWindow)
25 | self.label.setObjectName("label")
26 | self.verticalLayout.addWidget(self.label)
27 | self.output_text = QtWidgets.QPlainTextEdit(LogWindow)
28 | self.output_text.setUndoRedoEnabled(False)
29 | self.output_text.setLineWrapMode(QtWidgets.QPlainTextEdit.NoWrap)
30 | self.output_text.setReadOnly(True)
31 | self.output_text.setObjectName("output_text")
32 | self.verticalLayout.addWidget(self.output_text)
33 | self.horizontalLayout.addLayout(self.verticalLayout)
34 | self.verticalLayout_2 = QtWidgets.QVBoxLayout()
35 | self.verticalLayout_2.setObjectName("verticalLayout_2")
36 | self.label_2 = QtWidgets.QLabel(LogWindow)
37 | self.label_2.setObjectName("label_2")
38 | self.verticalLayout_2.addWidget(self.label_2)
39 | self.error_text = QtWidgets.QPlainTextEdit(LogWindow)
40 | self.error_text.setUndoRedoEnabled(False)
41 | self.error_text.setLineWrapMode(QtWidgets.QPlainTextEdit.NoWrap)
42 | self.error_text.setReadOnly(True)
43 | self.error_text.setObjectName("error_text")
44 | self.verticalLayout_2.addWidget(self.error_text)
45 | self.horizontalLayout.addLayout(self.verticalLayout_2)
46 | self.verticalLayout_3.addLayout(self.horizontalLayout)
47 | self.refresh_btn = QtWidgets.QPushButton(LogWindow)
48 | self.refresh_btn.setObjectName("refresh_btn")
49 | self.verticalLayout_3.addWidget(self.refresh_btn)
50 |
51 | self.retranslateUi(LogWindow)
52 | QtCore.QMetaObject.connectSlotsByName(LogWindow)
53 |
54 | def retranslateUi(self, LogWindow):
55 | _translate = QtCore.QCoreApplication.translate
56 | LogWindow.setWindowTitle(_translate("LogWindow", "Form"))
57 | self.label.setText(_translate("LogWindow", "Output:"))
58 | self.label_2.setText(_translate("LogWindow", "Error:"))
59 | self.refresh_btn.setText(_translate("LogWindow", "Refresh"))
60 |
--------------------------------------------------------------------------------
/install_scripts/dafne_win.spec:
--------------------------------------------------------------------------------
1 | # -*- mode: python ; coding: utf-8 -*-
2 |
3 | block_cipher = None
4 |
5 | a_dafne = Analysis(['..\\dafne'],
6 | pathex=['..\\src'],
7 | binaries=[],
8 | datas=[('..\\LICENSE', '.'), ('..\\src\\dafne\\resources\\*', 'resources\\')],
9 | hiddenimports = ['pydicom',
10 | 'dafne',
11 | 'SimpleITK',
12 | 'tensorflow',
13 | 'skimage',
14 | 'nibabel',
15 | 'dafne_dl',
16 | 'pyvistaqt',
17 | 'pyvista',
18 | 'vtk',
19 | 'torch',
20 | 'torchvision'],
21 | hookspath=['..\\pyinstaller_hooks'],
22 | runtime_hooks=[],
23 | excludes=[],
24 | win_no_prefer_redirects=False,
25 | win_private_assemblies=False,
26 | cipher=block_cipher,
27 | noarchive=False)
28 |
29 | a_calc_tra = Analysis(['..\\calc_transforms.py'],
30 | pathex=['..\\src'],
31 | binaries=[],
32 | datas=[],
33 | hiddenimports=[
34 | 'dafne',
35 | 'pydicom',
36 | 'SimpleITK'],
37 | hookspath=[],
38 | runtime_hooks=[],
39 | excludes=[],
40 | win_no_prefer_redirects=False,
41 | win_private_assemblies=False,
42 | cipher=block_cipher,
43 | noarchive=False)
44 |
45 | MERGE( (a_dafne, 'dafne', 'dafne'), (a_calc_tra, 'calc_transforms', 'calc_transforms') )
46 |
47 | pyz_dafne = PYZ(a_dafne.pure, a_dafne.zipped_data,
48 | cipher=block_cipher)
49 | exe_dafne = EXE(pyz_dafne,
50 | a_dafne.scripts,
51 | [],
52 | exclude_binaries=True,
53 | name='dafne',
54 | debug=False,
55 | bootloader_ignore_signals=False,
56 | strip=False,
57 | upx=True,
58 | icon='..\\icons\\dafne_icon.ico',
59 | console=True)
60 | coll_dafne = COLLECT(exe_dafne,
61 | a_dafne.binaries,
62 | a_dafne.zipfiles,
63 | a_dafne.datas,
64 | strip=False,
65 | upx=True,
66 | upx_exclude=[],
67 | name='dafne')
68 |
69 |
70 | pyz_calc_tra = PYZ(a_calc_tra.pure, a_calc_tra.zipped_data,
71 | cipher=block_cipher)
72 | exe_calc_tra = EXE(pyz_calc_tra,
73 | a_calc_tra.scripts,
74 | [],
75 | exclude_binaries=True,
76 | name='calc_transforms',
77 | debug=False,
78 | bootloader_ignore_signals=False,
79 | strip=False,
80 | upx=True,
81 | icon='..\\icons\\calctransform_ico.ico',
82 | console=True )
83 | coll_calc_tra = COLLECT(exe_calc_tra,
84 | a_calc_tra.binaries,
85 | a_calc_tra.zipfiles,
86 | a_calc_tra.datas,
87 | strip=False,
88 | upx=True,
89 | upx_exclude=[],
90 | name='calc_transforms')
91 |
92 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # avoid sending user configuration
2 | config.pickle
3 |
4 | # avoid sending data
5 | *.dcm
6 | *.model
7 | *.hdf5
8 |
9 | # Byte-compiled / optimized / DLL files
10 | __pycache__/
11 | *.py[cod]
12 | *$py.class
13 |
14 | # C extensions
15 | *.so
16 |
17 | # Distribution / packaging
18 | .Python
19 | build/
20 | develop-eggs/
21 | dist/
22 | downloads/
23 | eggs/
24 | .eggs/
25 | lib/
26 | lib64/
27 | parts/
28 | sdist/
29 | var/
30 | wheels/
31 | share/python-wheels/
32 | *.egg-info/
33 | .installed.cfg
34 | *.egg
35 | MANIFEST
36 |
37 | # PyInstaller
38 | # Usually these files are written by a python script from a template
39 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
40 | *.manifest
41 | #*.spec
42 |
43 | # Installer logs
44 | pip-log.txt
45 | pip-delete-this-directory.txt
46 |
47 | # Unit test / coverage reports
48 | htmlcov/
49 | .tox/
50 | .nox/
51 | .coverage
52 | .coverage.*
53 | .cache
54 | nosetests.xml
55 | coverage.xml
56 | *.cover
57 | *.py,cover
58 | .hypothesis/
59 | .pytest_cache/
60 | cover/
61 |
62 | # Translations
63 | *.mo
64 | *.pot
65 |
66 | # Django stuff:
67 | *.log
68 | local_settings.py
69 | db.sqlite3
70 | db.sqlite3-journal
71 |
72 | # Flask stuff:
73 | instance/
74 | .webassets-cache
75 |
76 | # Scrapy stuff:
77 | .scrapy
78 |
79 | # Sphinx documentation
80 | docs/_build/
81 |
82 | # PyBuilder
83 | .pybuilder/
84 | target/
85 |
86 | # Jupyter Notebook
87 | .ipynb_checkpoints
88 |
89 | # IPython
90 | profile_default/
91 | ipython_config.py
92 |
93 | # pyenv
94 | # For a library or package, you might want to ignore these files since the code is
95 | # intended to run in multiple environments; otherwise, check them in:
96 | # .python-version
97 |
98 | # pipenv
99 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
100 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
101 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
102 | # install all needed dependencies.
103 | #Pipfile.lock
104 |
105 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
106 | __pypackages__/
107 |
108 | # Celery stuff
109 | celerybeat-schedule
110 | celerybeat.pid
111 |
112 | # SageMath parsed files
113 | *.sage.py
114 |
115 | # Environments
116 | .env
117 | .venv
118 | env/
119 | venv/
120 | ENV/
121 | env.bak/
122 | venv.bak/
123 |
124 | # Spyder project settings
125 | .spyderproject
126 | .spyproject
127 |
128 | # Rope project settings
129 | .ropeproject
130 |
131 | # mkdocs documentation
132 | /site
133 |
134 | # mypy
135 | .mypy_cache/
136 | .dmypy.json
137 | dmypy.json
138 |
139 | # Pyre type checker
140 | .pyre/
141 |
142 | # pytype static type analyzer
143 | .pytype/
144 |
145 | # Cython debug symbols
146 | cython_debug/
147 |
148 | .vscode
149 | .DS_Store
150 |
151 | tmp
152 | config_localhost.pickle
153 | config_server.pickle
154 |
--------------------------------------------------------------------------------
/test/validation_ILearn.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # coding: utf-8
3 |
4 | # Copyright (c) 2021 Dafne-Imaging Team
5 | #
6 | # This program is free software: you can redistribute it and/or modify
7 | # it under the terms of the GNU General Public License as published by
8 | # the Free Software Foundation, either version 3 of the License, or
9 | # (at your option) any later version.
10 | #
11 | # This program is distributed in the hope that it will be useful,
12 | # but WITHOUT ANY WARRANTY; without even the implied warranty of
13 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14 | # GNU General Public License for more details.
15 | #
16 | # You should have received a copy of the GNU General Public License
17 | # along with this program. If not, see .
18 |
19 | import numpy as np
20 | import pickle
21 | from generate_thigh_model import coscia_unet as unet
22 | import src.dafne_dl.common.preprocess_train as pretrain
23 |
24 | model=unet()
25 | #model.load_weights('weights/weights_coscia.hdf5') ## old
26 | model.load_weights('Weights_incremental/thigh/weights - 5 - 86965.35.hdf5') ## incremental
27 |
28 | seg_list = pickle.load(open('testImages/test_segment.pickle', 'rb'))
29 | image_list = pickle.load(open('testImages/test_data.pickle', 'rb'))
30 |
31 | LABELS_DICT = {
32 | 1: 'VL',
33 | 2: 'VM',
34 | 3: 'VI',
35 | 4: 'RF',
36 | 5: 'SAR',
37 | 6: 'GRA',
38 | 7: 'AM',
39 | 8: 'SM',
40 | 9: 'ST',
41 | 10: 'BFL',
42 | 11: 'BFS',
43 | 12: 'AL'
44 | }
45 | '''
46 | LABELS_DICT = {
47 | 1: 'SOL',
48 | 2: 'GM',
49 | 3: 'GL',
50 | 4: 'TA',
51 | 5: 'ELD',
52 | 6: 'PE',
53 | }
54 | '''
55 | MODEL_RESOLUTION = np.array([1.037037, 1.037037])
56 | MODEL_SIZE = (432, 432)
57 |
58 | image_list, mask_list = pretrain.common_input_process(LABELS_DICT, MODEL_RESOLUTION, MODEL_SIZE, {'image_list': image_list, 'resolution': MODEL_RESOLUTION}, seg_list)
59 |
60 | ch = mask_list[0].shape[2]
61 | aggregated_masks = []
62 | mask_list_no_overlap = []
63 | for masks in mask_list:
64 | agg, new_masks = pretrain.calc_aggregated_masks_and_remove_overlap(masks)
65 | aggregated_masks.append(agg)
66 | mask_list_no_overlap.append(new_masks)
67 |
68 | for slice_number in range(len(image_list)):
69 | img = image_list[slice_number]
70 | segmentation = model.predict(np.expand_dims(np.stack([img,np.zeros(MODEL_SIZE)],axis=-1),axis=0))
71 | segmentationnum = np.argmax(np.squeeze(segmentation[0,:,:,:ch]), axis=2)
72 | cateseg=np.zeros((432,432,ch),dtype='float32')
73 | for i in range(432):
74 | for j in range(432):
75 | cateseg[i,j,int(segmentationnum[i,j])]=1.0
76 | acc=0
77 | y_pred=cateseg
78 | y_true=mask_list_no_overlap[slice_number]
79 | for j in range(ch): ## Dice
80 | elements_per_class=y_true[:,:,j].sum()
81 | predicted_per_class=y_pred[:,:,j].sum()
82 | intersection=(np.multiply(y_pred[:,:,j],y_true[:,:,j])).sum()
83 | intersection=2.0*intersection
84 | union=elements_per_class+predicted_per_class
85 | acc+=intersection/(union+0.000001)
86 | acc=acc/ch
87 | print(str(slice_number)+'__'+str(acc))
88 |
--------------------------------------------------------------------------------
/test/validation_split_ILearn.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # coding: utf-8
3 |
4 | # Copyright (c) 2021 Dafne-Imaging Team
5 | #
6 | # This program is free software: you can redistribute it and/or modify
7 | # it under the terms of the GNU General Public License as published by
8 | # the Free Software Foundation, either version 3 of the License, or
9 | # (at your option) any later version.
10 | #
11 | # This program is distributed in the hope that it will be useful,
12 | # but WITHOUT ANY WARRANTY; without even the implied warranty of
13 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14 | # GNU General Public License for more details.
15 | #
16 | # You should have received a copy of the GNU General Public License
17 | # along with this program. If not, see .
18 |
19 | import numpy as np
20 | import pickle
21 | from generate_thigh_split_model import coscia_unet as unet
22 | import src.dafne_dl.common.preprocess_train as pretrain
23 |
24 | model=unet()
25 | #model.load_weights('weights/weights_coscia_split.hdf5') ## old
26 | model.load_weights('Weights_incremental_split/thigh/weights - 5 - 47699.60.hdf5') ## incremental
27 |
28 | seg_list = pickle.load(open('testImages/test_segment.pickle', 'rb'))
29 | image_list = pickle.load(open('testImages/test_data.pickle', 'rb'))
30 |
31 | LABELS_DICT = {
32 | 1: 'VL',
33 | 2: 'VM',
34 | 3: 'VI',
35 | 4: 'RF',
36 | 5: 'SAR',
37 | 6: 'GRA',
38 | 7: 'AM',
39 | 8: 'SM',
40 | 9: 'ST',
41 | 10: 'BFL',
42 | 11: 'BFS',
43 | 12: 'AL'
44 | }
45 | '''
46 | LABELS_DICT = {
47 | 1: 'SOL',
48 | 2: 'GM',
49 | 3: 'GL',
50 | 4: 'TA',
51 | 5: 'ELD',
52 | 6: 'PE',
53 | }
54 | '''
55 | MODEL_RESOLUTION = np.array([1.037037, 1.037037])
56 | MODEL_SIZE = (432, 432)
57 | MODEL_SIZE_SPLIT = (250, 250)
58 |
59 | image_list, mask_list = pretrain.common_input_process_split(LABELS_DICT, MODEL_RESOLUTION, MODEL_SIZE, MODEL_SIZE_SPLIT, {'image_list': image_list, 'resolution': MODEL_RESOLUTION}, seg_list)
60 |
61 | ch = mask_list[0].shape[2]
62 | aggregated_masks = []
63 | mask_list_no_overlap = []
64 | for masks in mask_list:
65 | agg, new_masks = pretrain.calc_aggregated_masks_and_remove_overlap(masks)
66 | aggregated_masks.append(agg)
67 | mask_list_no_overlap.append(new_masks)
68 |
69 | for slice_number in range(len(image_list)):
70 | img = image_list[slice_number]
71 | segmentation = model.predict(np.expand_dims(np.stack([img,np.zeros(MODEL_SIZE_SPLIT)],axis=-1),axis=0))
72 | segmentationnum = np.argmax(np.squeeze(segmentation[0,:,:,:ch]), axis=2)
73 | cateseg=np.zeros((MODEL_SIZE_SPLIT[0],MODEL_SIZE_SPLIT[1],ch),dtype='float32')
74 | for i in range(MODEL_SIZE_SPLIT[0]):
75 | for j in range(MODEL_SIZE_SPLIT[1]):
76 | cateseg[i,j,int(segmentationnum[i,j])]=1.0
77 | acc=0
78 | y_pred=cateseg
79 | y_true=mask_list_no_overlap[slice_number]
80 | for j in range(ch): ## Dice
81 | elements_per_class=y_true[:,:,j].sum()
82 | predicted_per_class=y_pred[:,:,j].sum()
83 | intersection=(np.multiply(y_pred[:,:,j],y_true[:,:,j])).sum()
84 | intersection=2.0*intersection
85 | union=elements_per_class+predicted_per_class
86 | acc+=intersection/(union+0.000001)
87 | acc=acc/ch
88 | print(str(slice_number)+'__'+str(acc))
89 |
--------------------------------------------------------------------------------
/install_scripts/create_mac_installer.sh:
--------------------------------------------------------------------------------
1 | #!/bin/zsh
2 | #
3 | # Copyright (c) 2022 Dafne-Imaging Team
4 | #
5 |
6 | # note: checking signature/notarization: spctl -a -vvv
7 |
8 | git pull
9 | git checkout master
10 |
11 | APPNAME=Dafne
12 | VERSION=$(python update_version.py | tail -n 1)
13 | ARCH=$(uname -a | sed -E -n 's/.*(arm64|x86_64)$/\1/p')
14 | DMG_NAME=dafne_mac_${VERSION}_$ARCH.dmg
15 | CODESIGN_IDENTITY="Francesco Santini"
16 | USE_ALTOOL=False
17 |
18 |
19 | NOTARYTOOL() {
20 | if [ -f /Library/Developer/CommandLineTools/usr/bin/notarytool ]
21 | then
22 | /Library/Developer/CommandLineTools/usr/bin/notarytool "$@"
23 | else
24 | xcrun notarytool "$@"
25 | fi
26 | }
27 |
28 | echo $VERSION
29 | pyinstaller dafne_mac.spec --noconfirm
30 | cd dist
31 | echo "Fixing app bundle"
32 | python ../fix_app_bundle_for_mac.py $APPNAME.app
33 | echo "Signing code"
34 | # Sign code outside MacOS
35 | find $APPNAME.app/Contents/Resources -name '*.dylib' | xargs codesign --force -v -s "$CODESIGN_IDENTITY"
36 | find $APPNAME.app/Contents/Resources -name '*.so' | xargs codesign --force -v -s "$CODESIGN_IDENTITY"
37 | # sign the app
38 | codesign --deep --force -v -s "$CODESIGN_IDENTITY" $APPNAME.app
39 | find $APPNAME.app/Contents -path '*bin/*' | xargs codesign --force -o runtime --timestamp --entitlements ../entitlements.plist -v -s "$CODESIGN_IDENTITY"
40 |
41 | # Resign the app with the correct entitlement
42 | codesign --force -o runtime --timestamp --entitlements ../entitlements.plist -v -s "$CODESIGN_IDENTITY" $APPNAME.app
43 |
44 | echo "Creating DMG"
45 | # Create-dmg at some point stopped working because of a newline prepended to the mount point. If this fails, check for this bug.
46 | create-dmg --volname "Dafne" --volicon $APPNAME.app/Contents/Resources/dafne_icon.icns \
47 | --eula $APPNAME.app/Contents/Resources/LICENSE --background ../../icons/mac_installer_bg.png \
48 | --window-size 420 220 --icon-size 64 --icon $APPNAME.app 46 31 \
49 | --app-drop-link 236 90 "$DMG_NAME" $APPNAME.app
50 | codesign -s "$CODESIGN_IDENTITY" "$DMG_NAME"
51 | echo "Notarizing app"
52 |
53 | if [ "$USE_ALTOOL" = "True" ]
54 | then
55 | # make sure that the credentials are stored in keychain with
56 | # xcrun altool --store-password-in-keychain-item "AC_PASSWORD" -u "" -p ""
57 | # Note: password is a "app-specific password" created in the appleID site to bypass 2FA
58 | xcrun altool --notarize-app \
59 | --primary-bundle-id "network.dafne.dafne" --password "@keychain:AC_PASSWORD" \
60 | --file "$DMG_NAME"
61 |
62 | echo "Check the status in 1 hour or so with:"
63 | echo 'xcrun altool --notarization-history 0 -p "@keychain:AC_PASSWORD"'
64 | echo 'If there are any errors check'
65 | echo 'xcrun altool --notarization-info -p "@keychain:AC_PASSWORD"'
66 | echo 'if successful, staple the ticket running'
67 | echo "xcrun stapler staple $DMG_NAME"
68 | else
69 | # alternative with notarytool
70 | # store credentials:
71 | # xcrun notarytool store-credentials "AC_PASSWORD" --apple-id --password --team-id
72 | echo "This can take up to 1 hour"
73 | NOTARYTOOL submit "$DMG_NAME" --wait --keychain-profile "AC_PASSWORD"
74 | # This will wait for the notarization to complete
75 | echo 'If failed, save log file with:'
76 | echo 'xcrun notarytool log --keychain-profile "AC_PASSWORD" notarization.log'
77 | xcrun stapler staple "$DMG_NAME"
78 | fi
79 |
--------------------------------------------------------------------------------
/src/dafne/bin/dafne.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 |
4 | # Copyright (c) 2022 Dafne-Imaging Team
5 | #
6 | # This program is free software: you can redistribute it and/or modify
7 | # it under the terms of the GNU General Public License as published by
8 | # the Free Software Foundation, either version 3 of the License, or
9 | # (at your option) any later version.
10 | #
11 | # This program is distributed in the hope that it will be useful,
12 | # but WITHOUT ANY WARRANTY; without even the implied warranty of
13 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14 | # GNU General Public License for more details.
15 | #
16 | # You should have received a copy of the GNU General Public License
17 | # along with this program. If not, see .
18 |
19 | import os
20 |
21 | from ..ui.WhatsNew import show_news
22 | # Hide tensorflow warnings; set to 1 to see warnings
23 | from ..utils import log
24 |
25 | os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' # or any {'0', '1', '2', '3'}
26 |
27 | from ..ui.MuscleSegmentation import MuscleSegmentation
28 | from ..config import GlobalConfig, load_config
29 |
30 | import matplotlib
31 | matplotlib.use("Qt5Agg")
32 | import argparse
33 | import matplotlib.pyplot as plt
34 |
35 | MODELS_DIR = 'models_old'
36 |
37 |
38 | def main():
39 | parser = argparse.ArgumentParser(description="Muscle segmentation tool.")
40 | parser.add_argument('path', nargs='?', type=str)
41 | parser.add_argument('-c', '--class', dest='dl_class', type=str, help='Specify the deep learning model to use for the dataset')
42 | parser.add_argument('-r', '--register', action='store_true', help='Perform the registration after loading.')
43 | parser.add_argument('-m', '--save-masks', action='store_true', help='Convert saved ROIs to masks.')
44 | parser.add_argument('-d', '--save-dicoms', action='store_true', help='Save ROIs as dicoms in addition to numpy')
45 | parser.add_argument('-q', '--quit', action='store_true', help='Quit after loading the dataset (useful with -r or -q options).')
46 | parser.add_argument('-rm', '--remote-model', action='store_true', help='Force remote model')
47 | parser.add_argument('-lm', '--local-model', action='store_true', help='Force local model')
48 |
49 | args = parser.parse_args()
50 |
51 | load_config()
52 |
53 | if args.remote_model:
54 | GlobalConfig['MODEL_PROVIDER'] = 'Remote'
55 |
56 | if args.local_model:
57 | GlobalConfig['MODEL_PROVIDER'] = 'Local'
58 |
59 | if GlobalConfig['REDIRECT_OUTPUT']:
60 | import sys
61 |
62 | log.log_objects['stdout'] = log.LogStream(GlobalConfig['OUTPUT_LOG_FILE'], sys.stdout if GlobalConfig['ECHO_OUTPUT'] else None)
63 | log.log_objects['stderr'] = log.LogStream(GlobalConfig['ERROR_LOG_FILE'], sys.stderr if GlobalConfig['ECHO_OUTPUT'] else None)
64 |
65 | sys.stdout = log.log_objects['stdout']
66 | sys.stderr = log.log_objects['stderr']
67 |
68 | imFig = MuscleSegmentation()
69 |
70 | dl_class = None
71 |
72 | if args.dl_class:
73 | dl_class = args.dl_class
74 |
75 | if args.path:
76 | imFig.loadDirectory(args.path, dl_class)
77 |
78 | if args.save_dicoms:
79 | imFig.saveDicom = True
80 |
81 | if args.register:
82 | imFig.calcTransforms()
83 |
84 | if args.save_masks:
85 | imFig.saveResults()
86 |
87 | if not args.quit:
88 | plt.show()
89 |
90 |
91 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | [](https://badge.fury.io/py/dafne)
2 | [](https://www.dafne.network/files/documentation.pdf)
3 | [](https://www.dafne.network/documentation/)
4 |
5 | # Dafne
6 | Deep Anatomical Federated Network is a program for the segmentation of medical images. It relies on a server to provide deep learning models to aid the segmentation, and incremental learning is used to improve the performance. See https://www.dafne.network/ for documentation and user information.
7 |
8 | ## Windows binary installation
9 | Please install the Visual Studio Redistributable Package under windows: https://aka.ms/vs/16/release/vc_redist.x64.exe
10 | Then, run the provided installer
11 |
12 | ## Mac binary installation
13 | Install the Dafne App from the downloaded .dmg file as usual. Make sure to download the archive appropriate for your architecture (x86 or arm).
14 |
15 | ## Linux binary installation
16 | The Linux distribution is a self-contained executable file. Simply download it, make it executable, and run it.
17 |
18 | ## pip installation
19 | Dafne can also be installed with pip
20 | `pip install dafne`
21 |
22 | # Citing
23 | If you are writing a scientific paper, and you used Dafne for your data evaluation, please cite the following paper:
24 |
25 | > Santini F, Wasserthal J, Agosti A, et al. *Deep Anatomical Federated Network (Dafne): an open client/server framework for the continuous collaborative improvement of deep-learning-based medical image segmentation*. 2023 doi: [10.48550/arXiv.2302.06352](https://doi.org/10.48550/arXiv.2302.06352).
26 |
27 |
28 | # Notes for developers
29 |
30 | ## dafne
31 |
32 | Run:
33 | `python dafne.py `
34 |
35 |
36 | ## Notes for the DL models
37 |
38 | ### Apply functions
39 | The input of the apply function is:
40 | ```
41 | dict({
42 | 'image': np.array (2D image)
43 | 'resolution': sequence with two elements (image resolution in mm)
44 | 'split_laterality': True/False (indicates whether the ROIs should be split in L/R if applicable)
45 | 'classification': str - The classification tag of the image (optional, to identify model variants)
46 | })
47 | ```
48 |
49 | The output of the classifier is a string.
50 | The output of the segmenters is:
51 | ```
52 | dict({
53 | roi_name_1: np.array (2D mask),
54 | roi_name_2: ...
55 | })
56 | ```
57 |
58 | ### Incremental learn functions
59 | The input of the incremental learn functions are:
60 | ```
61 | training data: dict({
62 | 'resolution': sequence (see above)
63 | 'classification': str (see above)
64 | 'image_list': list([
65 | - np.array (2D image)
66 | - np.array (2D image)
67 | - ...
68 | ])
69 | })
70 |
71 | training outputs: list([
72 | - dict({
73 | roi_name_1: np.array (2D mask)
74 | roi_name_2: ...
75 | })
76 | - dict...
77 | ```
78 |
79 | Every entry in the training outputs list corresponds to an entry in the image_list inside the training data.
80 | So `len(training_data['image_list']) == len(training_outputs)`.
81 |
82 | # Acknowledgments
83 | Input/Output is based on [DOSMA](https://github.com/ad12/DOSMA) - GPLv3 license
84 |
85 | This software includes the [Segment Anything Model (SAM)](https://github.com/facebookresearch/segment-anything) - Apache 2.0 license
86 |
87 | Other packages required for this project are listed in requirements.txt
88 |
--------------------------------------------------------------------------------
/install_scripts/dafne_mac.spec:
--------------------------------------------------------------------------------
1 | # -*- mode: python ; coding: utf-8 -*-
2 |
3 | block_cipher = None
4 | import sys ; sys.setrecursionlimit(sys.getrecursionlimit() * 5)
5 |
6 | a_dafne = Analysis(['../run_dafne.py'],
7 | pathex=['../src'],
8 | binaries=[],
9 | datas=[('../LICENSE', '.'), ('../src/dafne/resources/*', 'resources/')],
10 | hiddenimports = ['dafne',
11 | 'pydicom',
12 | 'SimpleITK',
13 | 'tensorflow',
14 | 'skimage',
15 | 'nibabel',
16 | 'dafne_dl',
17 | 'cmath',
18 | 'ormir-pyvoxel',
19 | 'pyvistaqt',
20 | 'pyvista',
21 | 'vtk',
22 | 'torch',
23 | 'torchvision'],
24 | hookspath=['../pyinstaller_hooks'],
25 | runtime_hooks=[],
26 | excludes=[],
27 | win_no_prefer_redirects=False,
28 | win_private_assemblies=False,
29 | cipher=block_cipher,
30 | noarchive=False)
31 |
32 | a_calc_tra = Analysis(['../calc_transforms.py'],
33 | pathex=['../src'],
34 | binaries=[],
35 | datas=[],
36 | hiddenimports=[
37 | 'dafne',
38 | 'pydicom',
39 | 'SimpleITK',
40 | 'ormir-pyvoxel'],
41 | hookspath=[],
42 | runtime_hooks=[],
43 | excludes=[],
44 | win_no_prefer_redirects=False,
45 | win_private_assemblies=False,
46 | cipher=block_cipher,
47 | noarchive=False)
48 |
49 | MERGE( (a_dafne, 'dafne', 'dafne'), (a_calc_tra, 'calc_transforms', 'calc_transforms') )
50 |
51 | pyz_dafne = PYZ(a_dafne.pure, a_dafne.zipped_data,
52 | cipher=block_cipher)
53 | exe_dafne = EXE(pyz_dafne,
54 | a_dafne.scripts,
55 | [],
56 | exclude_binaries=True,
57 | name='dafne',
58 | debug=False,
59 | icon='../icons/dafne_icon.icns',
60 | bootloader_ignore_signals=False,
61 | strip=False,
62 | upx=True,
63 | console=False )
64 | coll_dafne = COLLECT(exe_dafne,
65 | a_dafne.binaries,
66 | a_dafne.zipfiles,
67 | a_dafne.datas,
68 | strip=False,
69 | upx=True,
70 | upx_exclude=[],
71 | name='dafne')
72 |
73 | pyz_calc_tra = PYZ(a_calc_tra.pure, a_calc_tra.zipped_data,
74 | cipher=block_cipher)
75 | exe_calc_tra = EXE(pyz_calc_tra,
76 | a_calc_tra.scripts,
77 | [],
78 | exclude_binaries=True,
79 | name='calc_transforms',
80 | debug=False,
81 | bootloader_ignore_signals=False,
82 | strip=False,
83 | upx=True,
84 | icon='../icons/calctransform_ico.ico',
85 | console=True )
86 | coll_calc_tra = COLLECT(exe_calc_tra,
87 | a_calc_tra.binaries,
88 | a_calc_tra.zipfiles,
89 | a_calc_tra.datas,
90 | strip=False,
91 | upx=True,
92 | upx_exclude=[],
93 | name='calc_transforms')
94 |
95 | app = BUNDLE(coll_dafne,
96 | name='Dafne.app',
97 | icon='../icons/dafne_icon.icns',
98 | bundle_identifier='network.dafne.dafne',
99 | version='1.8-alpha3')
100 |
--------------------------------------------------------------------------------
/src/dafne/ui/CalcTransformsUI.ui:
--------------------------------------------------------------------------------
1 |
2 |
3 | CalcTransformsUI
4 |
5 |
6 |
7 | 0
8 | 0
9 | 412
10 | 218
11 |
12 |
13 |
14 | Form
15 |
16 |
17 | -
18 |
19 |
-
20 |
21 |
22 |
23 | 0
24 | 0
25 |
26 |
27 |
28 | Location:
29 |
30 |
31 |
32 | -
33 |
34 |
35 | false
36 |
37 |
38 |
39 | -
40 |
41 |
42 | Choose...
43 |
44 |
45 |
46 |
47 |
48 | -
49 |
50 |
51 | false
52 |
53 |
54 | 0
55 |
56 |
57 |
58 | -
59 |
60 |
61 | Orientation
62 |
63 |
64 |
-
65 |
66 |
67 | Original
68 |
69 |
70 | true
71 |
72 |
73 |
74 | -
75 |
76 |
77 | Axial
78 |
79 |
80 |
81 | -
82 |
83 |
84 | Sagittal
85 |
86 |
87 |
88 | -
89 |
90 |
91 | Coronal
92 |
93 |
94 |
95 |
96 |
97 |
98 | -
99 |
100 |
101 | Qt::Vertical
102 |
103 |
104 | QSizePolicy::Expanding
105 |
106 |
107 |
108 | 20
109 | 45
110 |
111 |
112 |
113 |
114 | -
115 |
116 |
117 | false
118 |
119 |
120 | Calculate Transforms
121 |
122 |
123 |
124 |
125 |
126 |
127 |
128 |
129 |
--------------------------------------------------------------------------------
/src/dafne/MedSAM/utils/pre_grey_rgb.py:
--------------------------------------------------------------------------------
1 | #%% import packages
2 | import numpy as np
3 | import os
4 | join = os.path.join
5 | from skimage import io, transform
6 | from tqdm import tqdm
7 |
8 | # convert 2D data to npy files, including images and corresponding masks
9 | modality = 'dd' # e.g., 'Dermoscopy
10 | anatomy = 'dd' # e.g., 'SkinCancer'
11 | img_name_suffix = '.png'
12 | gt_name_suffix = '.png'
13 | prefix = modality + '_' + anatomy + '_'
14 | save_suffix = '.npy'
15 | image_size = 1024
16 | img_path = 'path to /images' # path to the images
17 | gt_path = 'path to/labels' # path to the corresponding annotations
18 | npy_path = 'path to/data/npy/' + prefix[:-1] # save npy path e.g., MedSAM/data/npy/; don't miss the `/`
19 | os.makedirs(join(npy_path, "gts"), exist_ok=True)
20 | os.makedirs(join(npy_path, "imgs"), exist_ok=True)
21 | names = sorted(os.listdir(gt_path))
22 | print(f'ori \# files {len(names)=}')
23 |
24 | # set label ids that are excluded
25 | remove_label_ids = []
26 | tumor_id = None # only set this when there are multiple tumors in one image; convert semantic masks to instance masks
27 | label_id_offset = 0
28 | do_intensity_cutoff = False # True for grey images
29 | #%% save preprocessed images and masks as npz files
30 | for name in tqdm(names):
31 | image_name = name.split(gt_name_suffix)[0] + img_name_suffix
32 | gt_name = name
33 | npy_save_name = prefix + gt_name.split(gt_name_suffix)[0]+save_suffix
34 | gt_data_ori = np.uint8(io.imread(join(gt_path, gt_name)))
35 | # remove label ids
36 | for remove_label_id in remove_label_ids:
37 | gt_data_ori[gt_data_ori==remove_label_id] = 0
38 | # label tumor masks as instances and remove from gt_data_ori
39 | if tumor_id is not None:
40 | tumor_bw = np.uint8(gt_data_ori==tumor_id)
41 | gt_data_ori[tumor_bw>0] = 0
42 | # label tumor masks as instances
43 | tumor_inst, tumor_n = cc3d.connected_components(tumor_bw, connectivity=26, return_N=True)
44 | # put the tumor instances back to gt_data_ori
45 | gt_data_ori[tumor_inst>0] = tumor_inst[tumor_inst>0] + label_id_offset + 1
46 |
47 | # crop the ground truth with non-zero slices
48 | image_data = io.imread(join(img_path, image_name))
49 | if np.max(image_data) > 255.0:
50 | image_data = np.uint8((image_data-image_data.min()) / (np.max(image_data)-np.min(image_data))*255.0)
51 | if len(image_data.shape) == 2:
52 | image_data = np.repeat(np.expand_dims(image_data, -1), 3, -1)
53 | assert len(image_data.shape) == 3, 'image data is not three channels: img shape:' + str(image_data.shape) + image_name
54 | # convert three channel to one channel
55 | if image_data.shape[-1] > 3:
56 | image_data = image_data[:,:,:3]
57 | # image preprocess start
58 | if do_intensity_cutoff:
59 | lower_bound, upper_bound = np.percentile(image_data[image_data>0], 0.5), np.percentile(image_data[image_data>0], 99.5)
60 | image_data_pre = np.clip(image_data, lower_bound, upper_bound)
61 | image_data_pre = (image_data_pre - np.min(image_data_pre))/(np.max(image_data_pre)-np.min(image_data_pre))*255.0
62 | image_data_pre[image_data==0] = 0
63 | image_data_pre = np.uint8(image_data_pre)
64 | else:
65 | # print('no intensity cutoff')
66 | image_data_pre = image_data.copy()
67 | np.savez_compressed(join(npy_path, prefix + gt_name.split(gt_name_suffix)[0]+'.npz'), imgs=image_data_pre, gts=gt_data_ori)
68 | resize_img = transform.resize(image_data_pre, (image_size, image_size), order=3, mode='constant', preserve_range=True, anti_aliasing=True)
69 | resize_img01 = resize_img/255.0
70 | resize_gt = transform.resize(gt_data_ori, (image_size, image_size), order=0, mode='constant', preserve_range=True, anti_aliasing=False)
71 | # save resize img and gt as npy
72 | np.save(join(npy_path, "imgs", npy_save_name), resize_img01)
73 | np.save(join(npy_path, "gts", npy_save_name), resize_gt.astype(np.uint8))
74 |
75 |
--------------------------------------------------------------------------------
/src/dafne/ui/ContourPainter.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2021 Dafne-Imaging Team
2 | #
3 | # This program is free software: you can redistribute it and/or modify
4 | # it under the terms of the GNU General Public License as published by
5 | # the Free Software Foundation, either version 3 of the License, or
6 | # (at your option) any later version.
7 | #
8 | # This program is distributed in the hope that it will be useful,
9 | # but WITHOUT ANY WARRANTY; without even the implied warranty of
10 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11 | # GNU General Public License for more details.
12 | #
13 | # You should have received a copy of the GNU General Public License
14 | # along with this program. If not, see .
15 |
16 | from matplotlib.patches import Circle, Polygon
17 | from ..utils.pySplineInterp import SplineInterpROIClass
18 |
19 | MIN_KNOT_RADIUS = 0.5 # if radius is smaller than this, knots are not painted
20 |
21 |
22 | class ContourPainter:
23 | """
24 | Class to paint a series of ROIs on a pyplot axes object
25 | """
26 | def __init__(self, color, knot_radius):
27 | self._knots = []
28 | self._curves = []
29 | self.rois = []
30 | self.color = color
31 | self.knot_radius = knot_radius
32 | self.painted = False
33 |
34 | def set_color(self, color):
35 | self.color = color
36 | self.recalculate_patches()
37 |
38 | def set_radius(self, radius):
39 | self.radius = radius
40 | self.recalculate_patches()
41 |
42 | def clear_patches(self, axes = None):
43 | if not self.painted: return
44 | self.painted = False
45 | if axes:
46 | # axes.patches = [] # Error with new matplotlib!
47 | try:
48 | for patch in axes.patches:
49 | patch.remove()
50 | except Exception as e:
51 | print("Error removing patches", e)
52 | return
53 | for knot in self._knots:
54 | try:
55 | knot.set_visible(False)
56 | except:
57 | pass
58 | try:
59 | knot.remove()
60 | except:
61 | pass
62 | for curve in self._curves:
63 | try:
64 | curve.set_visible(False)
65 | except:
66 | pass
67 | try:
68 | curve.remove()
69 | except:
70 | pass
71 |
72 | def recalculate_patches(self):
73 | self.clear_patches()
74 | self._knots = []
75 | self._curves = []
76 | for roi in self.rois:
77 | if self.knot_radius >= MIN_KNOT_RADIUS:
78 | for knot in roi.knots:
79 | self._knots.append(Circle(knot,
80 | self.knot_radius,
81 | facecolor='none',
82 | edgecolor=self.color,
83 | linewidth=1.0))
84 | try:
85 | self._curves.append(Polygon(roi.getCurve(),
86 | facecolor = 'none',
87 | edgecolor = self.color,
88 | zorder=1))
89 | except:
90 | pass
91 |
92 | def add_roi(self, roi: SplineInterpROIClass):
93 | self.rois.append(roi)
94 | self.recalculate_patches()
95 |
96 | def clear_rois(self, axes = None):
97 | self.clear_patches(axes)
98 | self._knots = []
99 | self._curves = []
100 | self.rois = []
101 |
102 | def draw(self, axes, clear_first=False):
103 | # print("Calling Contourpainter draw")
104 | if clear_first:
105 | self.clear_patches()
106 | for knot in self._knots:
107 | #if not self.painted:
108 | axes.add_patch(knot)
109 | axes.draw_artist(knot)
110 | self.painted = True
111 | for curve in self._curves:
112 | #if not self.painted:
113 | axes.add_patch(curve)
114 | axes.draw_artist(curve)
115 | self.painted = True
116 | # print("Painted?", self.painted)
117 |
118 |
119 |
120 |
--------------------------------------------------------------------------------
/src/dafne/MedSAM/segment_anything/utils/transforms.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # Copyright (c) Meta Platforms, Inc. and affiliates.
3 | # All rights reserved.
4 |
5 | # This source code is licensed under the license found in the
6 | # LICENSE file in the root directory of this source tree.
7 |
8 | import numpy as np
9 | import torch
10 | from torch.nn import functional as F
11 | from torchvision.transforms.functional import resize, to_pil_image # type: ignore
12 |
13 | from copy import deepcopy
14 | from typing import Tuple
15 |
16 |
17 | class ResizeLongestSide:
18 | """
19 | Resizes images to the longest side 'target_length', as well as provides
20 | methods for resizing coordinates and boxes. Provides methods for
21 | transforming both numpy array and batched torch tensors.
22 | """
23 |
24 | def __init__(self, target_length: int) -> None:
25 | self.target_length = target_length
26 |
27 | def apply_image(self, image: np.ndarray) -> np.ndarray:
28 | """
29 | Expects a numpy array with shape HxWxC in uint8 format.
30 | """
31 | target_size = self.get_preprocess_shape(
32 | image.shape[0], image.shape[1], self.target_length
33 | )
34 | return np.array(resize(to_pil_image(image), target_size))
35 |
36 | def apply_coords(
37 | self, coords: np.ndarray, original_size: Tuple[int, ...]
38 | ) -> np.ndarray:
39 | """
40 | Expects a numpy array of length 2 in the final dimension. Requires the
41 | original image size in (H, W) format.
42 | """
43 | old_h, old_w = original_size
44 | new_h, new_w = self.get_preprocess_shape(old_h, old_w, self.target_length)
45 | new_coords = np.empty_like(coords)
46 | new_coords[..., 0] = coords[..., 0] * (new_w / old_w)
47 | new_coords[..., 1] = coords[..., 1] * (new_h / old_h)
48 | return new_coords
49 |
50 | def apply_boxes(
51 | self, boxes: np.ndarray, original_size: Tuple[int, ...]
52 | ) -> np.ndarray:
53 | """
54 | Expects a numpy array shape Bx4. Requires the original image size
55 | in (H, W) format.
56 | """
57 | boxes = self.apply_coords(boxes.reshape(-1, 2, 2), original_size)
58 | return boxes.reshape(-1, 4)
59 |
60 | def apply_image_torch(self, image: torch.Tensor) -> torch.Tensor:
61 | """
62 | Expects batched images with shape BxCxHxW and float format. This
63 | transformation may not exactly match apply_image. apply_image is
64 | the transformation expected by the model.
65 | """
66 | # Expects an image in BCHW format. May not exactly match apply_image.
67 | target_size = self.get_preprocess_shape(
68 | image.shape[2], image.shape[3], self.target_length
69 | )
70 | return F.interpolate(
71 | image, target_size, mode="bilinear", align_corners=False, antialias=True
72 | )
73 |
74 | def apply_coords_torch(
75 | self, coords: torch.Tensor, original_size: Tuple[int, ...]
76 | ) -> torch.Tensor:
77 | """
78 | Expects a torch tensor with length 2 in the last dimension. Requires the
79 | original image size in (H, W) format.
80 | """
81 | old_h, old_w = original_size
82 | new_h, new_w = self.get_preprocess_shape(
83 | original_size[0], original_size[1], self.target_length
84 | )
85 | coords = deepcopy(coords).to(torch.float)
86 | coords[..., 0] = coords[..., 0] * (new_w / old_w)
87 | coords[..., 1] = coords[..., 1] * (new_h / old_h)
88 | return coords
89 |
90 | def apply_boxes_torch(
91 | self, boxes: torch.Tensor, original_size: Tuple[int, ...]
92 | ) -> torch.Tensor:
93 | """
94 | Expects a torch tensor with shape Bx4. Requires the original image
95 | size in (H, W) format.
96 | """
97 | boxes = self.apply_coords_torch(boxes.reshape(-1, 2, 2), original_size)
98 | return boxes.reshape(-1, 4)
99 |
100 | @staticmethod
101 | def get_preprocess_shape(
102 | oldh: int, oldw: int, long_side_length: int
103 | ) -> Tuple[int, int]:
104 | """
105 | Compute the output size given input size and target long side length.
106 | """
107 | scale = long_side_length * 1.0 / max(oldh, oldw)
108 | newh, neww = oldh * scale, oldw * scale
109 | neww = int(neww + 0.5)
110 | newh = int(newh + 0.5)
111 | return (newh, neww)
112 |
--------------------------------------------------------------------------------
/src/dafne/MedSAM/utils/format_convert.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | import os
3 | join = os.path.join
4 | import random
5 | import numpy as np
6 | from skimage import io
7 | import SimpleITK as sitk
8 |
9 |
10 | def dcm2nii(dcm_path, nii_path):
11 | """
12 | Convert dicom files to nii files
13 | """
14 | reader = sitk.ImageSeriesReader()
15 | dicom_names = reader.GetGDCMSeriesFileNames(dcm_path)
16 | reader.SetFileNames(dicom_names)
17 | image = reader.Execute()
18 | sitk.WriteImage(image, nii_path)
19 |
20 | def mhd2nii(mhd_path, nii_path):
21 | """
22 | Convert mhd files to nii files
23 | """
24 | image = sitk.ReadImage(mhd_path)
25 | sitk.WriteImage(image, nii_path)
26 |
27 | def nii2nii(nii_path, nii_gz_path):
28 | """
29 | Convert nii files to nii.gz files, which can reduce the file size
30 | """
31 | image = sitk.ReadImage(nii_path)
32 | sitk.WriteImage(image, nii_gz_path)
33 |
34 | def nrrd2nii(nrrd_path, nii_path):
35 | """
36 | Convert nrrd files to nii files
37 | """
38 | image = sitk.ReadImage(nrrd_path)
39 | sitk.WriteImage(image, nii_path)
40 |
41 | def jpg2png(jpg_path, png_path):
42 | """
43 | Convert jpg files to png files
44 | """
45 | image = io.imread(jpg_path)
46 | io.imsave(png_path, image)
47 |
48 | def patchfy(img, mask, outpath, basename):
49 | """
50 | Patchfy the image and mask into 1024x1024 patches
51 | """
52 | image_patch_dir = join(outpath, "images")
53 | mask_patch_dir = join(outpath, "labels")
54 | os.makedirs(image_patch_dir, exist_ok=True)
55 | os.makedirs(mask_patch_dir, exist_ok=True)
56 | assert img.shape[:2] == mask.shape
57 | patch_height = 1024
58 | patch_width = 1024
59 |
60 | img_height, img_width = img.shape[:2]
61 | mask_height, mask_width = mask.shape
62 |
63 | if img_height % patch_height != 0:
64 | img = np.pad(img, ((0, patch_height - img_height % patch_height), (0, 0), (0, 0)), mode="constant")
65 | if img_width % patch_width != 0:
66 | img = np.pad(img, ((0, 0), (0, patch_width - img_width % patch_width), (0, 0)), mode="constant")
67 | if mask_height % patch_height != 0:
68 | mask = np.pad(mask, ((0, patch_height - mask_height % patch_height), (0, 0)), mode="constant")
69 | if mask_width % patch_width != 0:
70 | mask = np.pad(mask, ((0, 0), (0, patch_width - mask_width % patch_width)), mode="constant")
71 |
72 | assert img.shape[:2] == mask.shape
73 | assert img.shape[0] % patch_height == 0
74 | assert img.shape[1] % patch_width == 0
75 | assert mask.shape[0] % patch_height == 0
76 | assert mask.shape[1] % patch_width == 0
77 |
78 | height_steps = (img_height // patch_height) if img_height % patch_height == 0 else (img_height // patch_height + 1)
79 | width_steps = (img_width // patch_width) if img_width % patch_width == 0 else (img_width // patch_width + 1)
80 |
81 | for i in range(height_steps):
82 | for j in range(width_steps):
83 | img_patch = img[i * patch_height:(i + 1) * patch_height, j * patch_width:(j + 1) * patch_width, :]
84 | mask_patch = mask[i * patch_height:(i + 1) * patch_height, j * patch_width:(j + 1) * patch_width]
85 | assert img_patch.shape[:2] == mask_patch.shape
86 | assert img_patch.shape[0] == patch_height
87 | assert img_patch.shape[1] == patch_width
88 | print(f"img_patch.shape: {img_patch.shape}, mask_patch.shape: {mask_patch.shape}")
89 | img_patch_path = join(image_patch_dir, f"{basename}_{i}_{j}.png")
90 | mask_patch_path = join(mask_patch_dir, f"{basename}_{i}_{j}.png")
91 | io.imsave(img_patch_path, img_patch)
92 | io.imsave(mask_patch_path, mask_patch)
93 |
94 |
95 | def rle_decode(mask_rle, img_shape):
96 | """
97 | #functions to convert encoding to mask and mask to encoding
98 | mask_rle: run-length as string formated (start length)
99 | shape: (height,width) of array to return
100 | Returns numpy array, 1 - mask, 0 - background
101 | """
102 | seq = mask_rle.split()
103 | starts = np.array(list(map(int, seq[0::2])))
104 | lengths = np.array(list(map(int, seq[1::2])))
105 | assert len(starts) == len(lengths)
106 | ends = starts + lengths
107 | img = np.zeros((np.product(img_shape),), dtype=np.uint8)
108 | for begin, end in zip(starts, ends):
109 | img[begin:end] = 255
110 | # https://stackoverflow.com/a/46574906/4521646
111 | img.shape = img_shape
112 | return img.T
--------------------------------------------------------------------------------
/src/dafne/MedSAM/utils/split.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | import os
3 | join = os.path.join
4 | import random
5 |
6 | path_nii = '' # please complete path; two subfolders: images and labels
7 | path_video = None # or specify the path
8 | path_2d = None # or specify the path
9 |
10 | #%% split 3D nii data
11 | if path_nii is not None:
12 | img_path = join(path_nii, 'images')
13 | gt_path = join(path_nii, 'labels')
14 | gt_names = sorted(os.listdir(gt_path))
15 | img_suffix = '_0000.nii.gz'
16 | gt_suffix = '.nii.gz'
17 | # split 20% data for validation and testing
18 | validation_path = join(path_nii, 'validation')
19 | os.makedirs(join(validation_path, 'images'), exist_ok=True)
20 | os.makedirs(join(validation_path, 'labels'), exist_ok=True)
21 | testing_path = join(path_nii, 'testing')
22 | os.makedirs(join(testing_path, 'images'), exist_ok=True)
23 | os.makedirs(join(testing_path, 'labels'), exist_ok=True)
24 | candidates = random.sample(gt_names, int(len(gt_names)*0.2))
25 | # split half of test names for validation
26 | validation_names = random.sample(candidates, int(len(candidates)*0.5))
27 | test_names = [name for name in candidates if name not in validation_names]
28 | # move validation and testing data to corresponding folders
29 | for name in validation_names:
30 | img_name = name.split(gt_suffix)[0] + img_suffix
31 | os.rename(join(img_path, img_name), join(validation_path, 'images', img_name))
32 | os.rename(join(gt_path, name), join(validation_path, 'labels', name))
33 | for name in test_names:
34 | img_name = name.split(gt_suffix)[0] + img_suffix
35 | os.rename(join(img_path, img_name), join(testing_path, 'images', img_name))
36 | os.rename(join(gt_path, name), join(testing_path, 'labels', name))
37 |
38 |
39 | ##% split 2D images
40 | if path_2d is not None:
41 | img_path = join(path_2d, 'images')
42 | gt_path = join(path_2d, 'labels')
43 | gt_names = sorted(os.listdir(gt_path))
44 | img_suffix = '.png'
45 | gt_suffix = '.png'
46 | # split 20% data for validation and testing
47 | validation_path = join(path_2d, 'validation')
48 | os.makedirs(join(validation_path, 'images'), exist_ok=True)
49 | os.makedirs(join(validation_path, 'labels'), exist_ok=True)
50 | testing_path = join(path_2d, 'testing')
51 | os.makedirs(join(testing_path, 'images'), exist_ok=True)
52 | os.makedirs(join(testing_path, 'labels'), exist_ok=True)
53 | candidates = random.sample(gt_names, int(len(gt_names)*0.2))
54 | # split half of test names for validation
55 | validation_names = random.sample(candidates, int(len(candidates)*0.5))
56 | test_names = [name for name in candidates if name not in validation_names]
57 | # move validation and testing data to corresponding folders
58 | for name in validation_names:
59 | img_name = name.split(gt_suffix)[0] + img_suffix
60 | os.rename(join(img_path, img_name), join(validation_path, 'images', img_name))
61 | os.rename(join(gt_path, name), join(validation_path, 'labels', name))
62 |
63 | for name in test_names:
64 | img_name = name.split(gt_suffix)[0] + img_suffix
65 | os.rename(join(img_path, img_name), join(testing_path, 'images', img_name))
66 | os.rename(join(gt_path, name), join(testing_path, 'labels', name))
67 |
68 | #%% split video data
69 | if path_video is not None:
70 | img_path = join(path_video, 'images')
71 | gt_path = join(path_video, 'labels')
72 | gt_folders = sorted(os.listdir(gt_path))
73 | # split 20% videos for validation and testing
74 | validation_path = join(path_video, 'validation')
75 | os.makedirs(join(validation_path, 'images'), exist_ok=True)
76 | os.makedirs(join(validation_path, 'labels'), exist_ok=True)
77 | testing_path = join(path_video, 'testing')
78 | os.makedirs(join(testing_path, 'images'), exist_ok=True)
79 | os.makedirs(join(testing_path, 'labels'), exist_ok=True)
80 | candidates = random.sample(gt_folders, int(len(gt_folders)*0.2))
81 | # split half of test names for validation
82 | validation_names = random.sample(candidates, int(len(candidates)*0.5))
83 | test_names = [name for name in candidates if name not in validation_names]
84 | # move validation and testing data to corresponding folders
85 | for name in validation_names:
86 | os.rename(join(img_path, name), join(validation_path, 'images', name))
87 | os.rename(join(gt_path, name), join(validation_path, 'labels', name))
88 | for name in test_names:
89 | os.rename(join(img_path, name), join(testing_path, 'images', name))
90 | os.rename(join(gt_path, name), join(testing_path, 'labels', name))
91 |
--------------------------------------------------------------------------------
/src/dafne/ui/CalcTransformsUI.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | # Form implementation generated from reading ui file 'CalcTransformsUI.ui'
4 | #
5 | # Created by: PyQt5 UI code generator 5.15.7
6 | #
7 | # WARNING: Any manual changes made to this file will be lost when pyuic5 is
8 | # run again. Do not edit this file unless you know what you are doing.
9 |
10 |
11 | from PyQt5 import QtCore, QtGui, QtWidgets
12 |
13 |
14 | class Ui_CalcTransformsUI(object):
15 | def setupUi(self, CalcTransformsUI):
16 | CalcTransformsUI.setObjectName("CalcTransformsUI")
17 | CalcTransformsUI.resize(412, 218)
18 | self.verticalLayout = QtWidgets.QVBoxLayout(CalcTransformsUI)
19 | self.verticalLayout.setObjectName("verticalLayout")
20 | self.horizontalLayout = QtWidgets.QHBoxLayout()
21 | self.horizontalLayout.setObjectName("horizontalLayout")
22 | self.label = QtWidgets.QLabel(CalcTransformsUI)
23 | sizePolicy = QtWidgets.QSizePolicy(QtWidgets.QSizePolicy.Preferred, QtWidgets.QSizePolicy.Fixed)
24 | sizePolicy.setHorizontalStretch(0)
25 | sizePolicy.setVerticalStretch(0)
26 | sizePolicy.setHeightForWidth(self.label.sizePolicy().hasHeightForWidth())
27 | self.label.setSizePolicy(sizePolicy)
28 | self.label.setObjectName("label")
29 | self.horizontalLayout.addWidget(self.label)
30 | self.location_Text = QtWidgets.QLineEdit(CalcTransformsUI)
31 | self.location_Text.setEnabled(False)
32 | self.location_Text.setObjectName("location_Text")
33 | self.horizontalLayout.addWidget(self.location_Text)
34 | self.choose_Button = QtWidgets.QPushButton(CalcTransformsUI)
35 | self.choose_Button.setObjectName("choose_Button")
36 | self.horizontalLayout.addWidget(self.choose_Button)
37 | self.verticalLayout.addLayout(self.horizontalLayout)
38 | self.progressBar = QtWidgets.QProgressBar(CalcTransformsUI)
39 | self.progressBar.setEnabled(False)
40 | self.progressBar.setProperty("value", 0)
41 | self.progressBar.setObjectName("progressBar")
42 | self.verticalLayout.addWidget(self.progressBar)
43 | self.orientationBox = QtWidgets.QGroupBox(CalcTransformsUI)
44 | self.orientationBox.setObjectName("orientationBox")
45 | self.horizontalLayout_2 = QtWidgets.QHBoxLayout(self.orientationBox)
46 | self.horizontalLayout_2.setObjectName("horizontalLayout_2")
47 | self.original_radio = QtWidgets.QRadioButton(self.orientationBox)
48 | self.original_radio.setChecked(True)
49 | self.original_radio.setObjectName("original_radio")
50 | self.horizontalLayout_2.addWidget(self.original_radio)
51 | self.axial_radio = QtWidgets.QRadioButton(self.orientationBox)
52 | self.axial_radio.setObjectName("axial_radio")
53 | self.horizontalLayout_2.addWidget(self.axial_radio)
54 | self.sagittal_radio = QtWidgets.QRadioButton(self.orientationBox)
55 | self.sagittal_radio.setObjectName("sagittal_radio")
56 | self.horizontalLayout_2.addWidget(self.sagittal_radio)
57 | self.coronal_radio = QtWidgets.QRadioButton(self.orientationBox)
58 | self.coronal_radio.setObjectName("coronal_radio")
59 | self.horizontalLayout_2.addWidget(self.coronal_radio)
60 | self.verticalLayout.addWidget(self.orientationBox)
61 | spacerItem = QtWidgets.QSpacerItem(20, 45, QtWidgets.QSizePolicy.Minimum, QtWidgets.QSizePolicy.Expanding)
62 | self.verticalLayout.addItem(spacerItem)
63 | self.calculate_button = QtWidgets.QPushButton(CalcTransformsUI)
64 | self.calculate_button.setEnabled(False)
65 | self.calculate_button.setObjectName("calculate_button")
66 | self.verticalLayout.addWidget(self.calculate_button)
67 |
68 | self.retranslateUi(CalcTransformsUI)
69 | QtCore.QMetaObject.connectSlotsByName(CalcTransformsUI)
70 |
71 | def retranslateUi(self, CalcTransformsUI):
72 | _translate = QtCore.QCoreApplication.translate
73 | CalcTransformsUI.setWindowTitle(_translate("CalcTransformsUI", "Form"))
74 | self.label.setText(_translate("CalcTransformsUI", "Location:"))
75 | self.choose_Button.setText(_translate("CalcTransformsUI", "Choose..."))
76 | self.orientationBox.setTitle(_translate("CalcTransformsUI", "Orientation"))
77 | self.original_radio.setText(_translate("CalcTransformsUI", "Original"))
78 | self.axial_radio.setText(_translate("CalcTransformsUI", "Axial"))
79 | self.sagittal_radio.setText(_translate("CalcTransformsUI", "Sagittal"))
80 | self.coronal_radio.setText(_translate("CalcTransformsUI", "Coronal"))
81 | self.calculate_button.setText(_translate("CalcTransformsUI", "Calculate Transforms"))
82 |
--------------------------------------------------------------------------------
/install_scripts/fix_app_bundle_for_mac.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # Code from https://github.com/pyinstaller/pyinstaller/wiki/Recipe-OSX-Code-Signing-Qt
3 | #
4 |
5 | import os
6 | import shutil
7 | import sys
8 | from pathlib import Path
9 | from typing import Generator, List, Optional
10 |
11 | from macholib.MachO import MachO
12 |
13 |
14 | def create_symlink(folder: Path) -> None:
15 | """Create the appropriate symlink in the MacOS folder
16 | pointing to the Resources folder.
17 | """
18 | sibbling = Path(str(folder).replace("MacOS", ""))
19 |
20 | # PyQt5/Qt/qml/QtQml/Models.2
21 | root = str(sibbling).partition("Contents")[2].lstrip("/")
22 | # ../../../../
23 | backward = "../" * (root.count("/") + 1)
24 | # ../../../../Resources/PyQt5/Qt/qml/QtQml/Models.2
25 | good_path = f"{backward}Resources/{root}"
26 |
27 | folder.symlink_to(good_path)
28 |
29 |
30 | def fix_dll(dll: Path) -> None:
31 | """Fix the DLL lookup paths to use relative ones for Qt dependencies.
32 | Inspiration: PyInstaller/depend/dylib.py:mac_set_relative_dylib_deps()
33 | Currently one header is pointing to (we are in the Resources folder):
34 | @loader_path/../../../../QtCore (it is referencing to the old MacOS folder)
35 | It will be converted to:
36 | @loader_path/../../../../../../MacOS/QtCore
37 | """
38 |
39 | def match_func(pth: str) -> Optional[str]:
40 | """Callback function for MachO.rewriteLoadCommands() that is
41 | called on every lookup path setted in the DLL headers.
42 | By returning None for system libraries, it changes nothing.
43 | Else we return a relative path pointing to the good file
44 | in the MacOS folder.
45 | """
46 | basename = os.path.basename(pth)
47 | if not basename.startswith("Qt"):
48 | return None
49 | return f"@loader_path{good_path}/{basename}"
50 |
51 | # Resources/PyQt5/Qt/qml/QtQuick/Controls.2/Fusion
52 | root = str(dll.parent).partition("Contents")[2][1:]
53 | # /../../../../../../..
54 | backward = "/.." * (root.count("/") + 1)
55 | # /../../../../../../../MacOS
56 | good_path = f"{backward}/MacOS"
57 |
58 | # Rewrite Mach headers with corrected @loader_path
59 | dll = MachO(dll)
60 | dll.rewriteLoadCommands(match_func)
61 | with open(dll.filename, "rb+") as f:
62 | for header in dll.headers:
63 | f.seek(0)
64 | dll.write(f)
65 | f.seek(0, 2)
66 | f.flush()
67 |
68 |
69 | def find_problematic_folders(folder: Path) -> Generator[Path, None, None]:
70 | """Recursively yields problematic folders (containing a dot in their name)."""
71 | for path in folder.iterdir():
72 | if not path.is_dir() or path.is_symlink():
73 | # Skip simlinks as they are allowed (even with a dot)
74 | continue
75 | if "." in path.name:
76 | yield path
77 | else:
78 | yield from find_problematic_folders(path)
79 |
80 |
81 | def move_contents_to_resources(folder: Path) -> Generator[Path, None, None]:
82 | """Recursively move any non symlink file from a problematic folder
83 | to the sibbling one in Resources.
84 | """
85 | for path in folder.iterdir():
86 | if path.is_symlink():
87 | continue
88 | if path.name == "qml":
89 | yield from move_contents_to_resources(path)
90 | else:
91 | sibbling = Path(str(path).replace("MacOS", "Resources"))
92 | sibbling.parent.mkdir(parents=True, exist_ok=True)
93 | shutil.move(path, sibbling)
94 | yield sibbling
95 |
96 |
97 | def main(args: List[str]) -> int:
98 | """
99 | Fix the application to allow codesign (NXDRIVE-1301).
100 | Take one or more .app as arguments: "Nuxeo Drive.app".
101 | To overall process will:
102 | - move problematic folders from MacOS to Resources
103 | - fix the DLLs lookup paths
104 | - create the appropriate symbolic link
105 | """
106 | for app in args:
107 | name = os.path.basename(app)
108 | print(f">>> [{name}] Fixing Qt folder names")
109 | path = Path(app) / "Contents" / "MacOS"
110 | for folder in find_problematic_folders(path):
111 | for file in move_contents_to_resources(folder):
112 | try:
113 | fix_dll(file)
114 | except (ValueError, IsADirectoryError):
115 | continue
116 | shutil.rmtree(folder)
117 | create_symlink(folder)
118 | print(f" !! Fixed {folder}")
119 | print(f">>> [{name}] Application fixed.")
120 |
121 |
122 | if __name__ == "__main__":
123 | sys.exit(main(sys.argv[1:]))
--------------------------------------------------------------------------------
/src/dafne/MedSAM/MedSAM_Inference.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # %% load environment
3 | import numpy as np
4 | import matplotlib.pyplot as plt
5 | import os
6 |
7 | join = os.path.join
8 | import torch
9 | from segment_anything import sam_model_registry
10 | from skimage import io, transform
11 | import torch.nn.functional as F
12 | import argparse
13 |
14 |
15 | # visualization functions
16 | # source: https://github.com/facebookresearch/segment-anything/blob/main/notebooks/predictor_example.ipynb
17 | # change color to avoid red and green
18 | def show_mask(mask, ax, random_color=False):
19 | if random_color:
20 | color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
21 | else:
22 | color = np.array([251 / 255, 252 / 255, 30 / 255, 0.6])
23 | h, w = mask.shape[-2:]
24 | mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
25 | ax.imshow(mask_image)
26 |
27 |
28 | def show_box(box, ax):
29 | x0, y0 = box[0], box[1]
30 | w, h = box[2] - box[0], box[3] - box[1]
31 | ax.add_patch(
32 | plt.Rectangle((x0, y0), w, h, edgecolor="blue", facecolor=(0, 0, 0, 0), lw=2)
33 | )
34 |
35 |
36 | @torch.no_grad()
37 | def medsam_inference(medsam_model, img_embed, box_1024, H, W):
38 | box_torch = torch.as_tensor(box_1024, dtype=torch.float, device=img_embed.device)
39 | if len(box_torch.shape) == 2:
40 | box_torch = box_torch[:, None, :] # (B, 1, 4)
41 |
42 | sparse_embeddings, dense_embeddings = medsam_model.prompt_encoder(
43 | points=None,
44 | boxes=box_torch,
45 | masks=None,
46 | )
47 | low_res_logits, _ = medsam_model.mask_decoder(
48 | image_embeddings=img_embed, # (B, 256, 64, 64)
49 | image_pe=medsam_model.prompt_encoder.get_dense_pe(), # (1, 256, 64, 64)
50 | sparse_prompt_embeddings=sparse_embeddings, # (B, 2, 256)
51 | dense_prompt_embeddings=dense_embeddings, # (B, 256, 64, 64)
52 | multimask_output=False,
53 | )
54 |
55 | low_res_pred = torch.sigmoid(low_res_logits) # (1, 1, 256, 256)
56 |
57 | low_res_pred = F.interpolate(
58 | low_res_pred,
59 | size=(H, W),
60 | mode="bilinear",
61 | align_corners=False,
62 | ) # (1, 1, gt.shape)
63 | low_res_pred = low_res_pred.squeeze().cpu().numpy() # (256, 256)
64 | medsam_seg = (low_res_pred > 0.5).astype(np.uint8)
65 | return medsam_seg
66 |
67 |
68 | # %% load model and image
69 | parser = argparse.ArgumentParser(
70 | description="run inference on testing set based on MedSAM"
71 | )
72 | parser.add_argument(
73 | "-i",
74 | "--data_path",
75 | type=str,
76 | default="assets/img_demo.png",
77 | help="path to the data folder",
78 | )
79 | parser.add_argument(
80 | "-o",
81 | "--seg_path",
82 | type=str,
83 | default="assets/",
84 | help="path to the segmentation folder",
85 | )
86 | parser.add_argument(
87 | "--box",
88 | type=list,
89 | default=[95, 255, 190, 350],
90 | help="bounding box of the segmentation target",
91 | )
92 | parser.add_argument("--device", type=str, default="cuda:0", help="device")
93 | parser.add_argument(
94 | "-chk",
95 | "--checkpoint",
96 | type=str,
97 | default="work_dir/MedSAM/medsam_vit_b.pth",
98 | help="path to the trained model",
99 | )
100 | args = parser.parse_args()
101 |
102 | device = args.device
103 | medsam_model = sam_model_registry["vit_b"](checkpoint=args.checkpoint)
104 | medsam_model = medsam_model.to(device)
105 | medsam_model.eval()
106 |
107 | img_np = io.imread(args.data_path)
108 | if len(img_np.shape) == 2:
109 | img_3c = np.repeat(img_np[:, :, None], 3, axis=-1)
110 | else:
111 | img_3c = img_np
112 | H, W, _ = img_3c.shape
113 | # %% image preprocessing
114 | img_1024 = transform.resize(
115 | img_3c, (1024, 1024), order=3, preserve_range=True, anti_aliasing=True
116 | ).astype(np.uint8)
117 | img_1024 = (img_1024 - img_1024.min()) / np.clip(
118 | img_1024.max() - img_1024.min(), a_min=1e-8, a_max=None
119 | ) # normalize to [0, 1], (H, W, 3)
120 | # convert the shape to (3, H, W)
121 | img_1024_tensor = (
122 | torch.tensor(img_1024).float().permute(2, 0, 1).unsqueeze(0).to(device)
123 | )
124 |
125 | box_np = np.array([args.box])
126 | # transfer box_np t0 1024x1024 scale
127 | box_1024 = box_np / np.array([W, H, W, H]) * 1024
128 | with torch.no_grad():
129 | image_embedding = medsam_model.image_encoder(img_1024_tensor) # (1, 256, 64, 64)
130 |
131 | medsam_seg = medsam_inference(medsam_model, image_embedding, box_1024, H, W)
132 | io.imsave(
133 | join(args.seg_path, "seg_" + os.path.basename(args.data_path)),
134 | medsam_seg,
135 | check_contrast=False,
136 | )
137 |
138 | # %% visualize results
139 | fig, ax = plt.subplots(1, 2, figsize=(10, 5))
140 | ax[0].imshow(img_3c)
141 | show_box(box_np[0], ax[0])
142 | ax[0].set_title("Input Image and Bounding Box")
143 | ax[1].imshow(img_3c)
144 | show_mask(medsam_seg, ax[1])
145 | show_box(box_np[0], ax[1])
146 | ax[1].set_title("MedSAM Segmentation")
147 | plt.show()
148 |
--------------------------------------------------------------------------------
/src/dafne/MedSAM/segment_anything/build_sam.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # Copyright (c) Meta Platforms, Inc. and affiliates.
3 | # All rights reserved.
4 |
5 | # This source code is licensed under the license found in the
6 | # LICENSE file in the root directory of this source tree.
7 | from functools import partial
8 | from pathlib import Path
9 | import urllib.request
10 | import torch
11 |
12 | from .modeling import (
13 | ImageEncoderViT,
14 | MaskDecoder,
15 | PromptEncoder,
16 | Sam,
17 | TwoWayTransformer,
18 | )
19 |
20 |
21 | def build_sam_vit_h(checkpoint=None):
22 | return _build_sam(
23 | encoder_embed_dim=1280,
24 | encoder_depth=32,
25 | encoder_num_heads=16,
26 | encoder_global_attn_indexes=[7, 15, 23, 31],
27 | checkpoint=checkpoint,
28 | )
29 |
30 |
31 | build_sam = build_sam_vit_h
32 |
33 |
34 | def build_sam_vit_l(checkpoint=None):
35 | return _build_sam(
36 | encoder_embed_dim=1024,
37 | encoder_depth=24,
38 | encoder_num_heads=16,
39 | encoder_global_attn_indexes=[5, 11, 17, 23],
40 | checkpoint=checkpoint,
41 | )
42 |
43 |
44 | def build_sam_vit_b(checkpoint=None):
45 | return _build_sam(
46 | encoder_embed_dim=768,
47 | encoder_depth=12,
48 | encoder_num_heads=12,
49 | encoder_global_attn_indexes=[2, 5, 8, 11],
50 | checkpoint=checkpoint,
51 | )
52 |
53 |
54 | sam_model_registry = {
55 | "default": build_sam_vit_h,
56 | "vit_h": build_sam_vit_h,
57 | "vit_l": build_sam_vit_l,
58 | "vit_b": build_sam_vit_b,
59 | }
60 |
61 |
62 | def _build_sam(
63 | encoder_embed_dim,
64 | encoder_depth,
65 | encoder_num_heads,
66 | encoder_global_attn_indexes,
67 | checkpoint=None,
68 | ):
69 | prompt_embed_dim = 256
70 | image_size = 1024
71 | vit_patch_size = 16
72 | image_embedding_size = image_size // vit_patch_size
73 | sam = Sam(
74 | image_encoder=ImageEncoderViT(
75 | depth=encoder_depth,
76 | embed_dim=encoder_embed_dim,
77 | img_size=image_size,
78 | mlp_ratio=4,
79 | norm_layer=partial(torch.nn.LayerNorm, eps=1e-6),
80 | num_heads=encoder_num_heads,
81 | patch_size=vit_patch_size,
82 | qkv_bias=True,
83 | use_rel_pos=True,
84 | global_attn_indexes=encoder_global_attn_indexes,
85 | window_size=14,
86 | out_chans=prompt_embed_dim,
87 | ),
88 | prompt_encoder=PromptEncoder(
89 | embed_dim=prompt_embed_dim,
90 | image_embedding_size=(image_embedding_size, image_embedding_size),
91 | input_image_size=(image_size, image_size),
92 | mask_in_chans=16,
93 | ),
94 | mask_decoder=MaskDecoder(
95 | num_multimask_outputs=3,
96 | transformer=TwoWayTransformer(
97 | depth=2,
98 | embedding_dim=prompt_embed_dim,
99 | mlp_dim=2048,
100 | num_heads=8,
101 | ),
102 | transformer_dim=prompt_embed_dim,
103 | iou_head_depth=3,
104 | iou_head_hidden_dim=256,
105 | ),
106 | pixel_mean=[123.675, 116.28, 103.53],
107 | pixel_std=[58.395, 57.12, 57.375],
108 | )
109 | sam.eval()
110 | checkpoint = Path(checkpoint)
111 | if checkpoint.name == "sam_vit_b_01ec64.pth" and not checkpoint.exists():
112 | cmd = input("Download sam_vit_b_01ec64.pth from facebook AI? [y]/n: ")
113 | if len(cmd) == 0 or cmd.lower() == "y":
114 | checkpoint.parent.mkdir(parents=True, exist_ok=True)
115 | print("Downloading SAM ViT-B checkpoint...")
116 | urllib.request.urlretrieve(
117 | "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth",
118 | checkpoint,
119 | )
120 | print(checkpoint.name, " is downloaded!")
121 | elif checkpoint.name == "sam_vit_h_4b8939.pth" and not checkpoint.exists():
122 | cmd = input("Download sam_vit_h_4b8939.pth from facebook AI? [y]/n: ")
123 | if len(cmd) == 0 or cmd.lower() == "y":
124 | checkpoint.parent.mkdir(parents=True, exist_ok=True)
125 | print("Downloading SAM ViT-H checkpoint...")
126 | urllib.request.urlretrieve(
127 | "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth",
128 | checkpoint,
129 | )
130 | print(checkpoint.name, " is downloaded!")
131 | elif checkpoint.name == "sam_vit_l_0b3195.pth" and not checkpoint.exists():
132 | cmd = input("Download sam_vit_l_0b3195.pth from facebook AI? [y]/n: ")
133 | if len(cmd) == 0 or cmd.lower() == "y":
134 | checkpoint.parent.mkdir(parents=True, exist_ok=True)
135 | print("Downloading SAM ViT-L checkpoint...")
136 | urllib.request.urlretrieve(
137 | "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_l_0b3195.pth",
138 | checkpoint,
139 | )
140 | print(checkpoint.name, " is downloaded!")
141 |
142 | if checkpoint is not None:
143 | with open(checkpoint, "rb") as f:
144 | state_dict = torch.load(f, map_location=torch.device('cpu'))
145 | sam.load_state_dict(state_dict)
146 | return sam
147 |
--------------------------------------------------------------------------------
/src/dafne/MedSAM/utils/README.md:
--------------------------------------------------------------------------------
1 | # Introduction
2 | This folder contains the code for data organization, splitting, preprocessing, and checkpoint converting.
3 |
4 |
5 | ## Data Organization
6 | Since the orginal data formats and folder structures vary greatly across different dataset, we need to organize them as unified structures, allowing to use the same functions for data preprocessing.
7 | The expected folder structures are as follows:
8 |
9 | 3D nii data
10 | ```
11 | ----dataset_name
12 | --------images
13 | ------------xxxx_0000.nii.gz
14 | ------------xxxx_0000.nii.gz
15 | --------labels
16 | ------------xxxx.nii.gz
17 | ------------xxxx.nii.gz
18 | ```
19 | Note: you can also use different suffix for images and labels. Please change them in the following preprocessing scripts as well.
20 |
21 | 2D data
22 | ```
23 | ----dataset_name
24 | --------images
25 | ------------xxxx.jpg/png
26 | ------------xxxx.jpg/png
27 | --------labels
28 | ------------xxxx.png
29 | ------------xxxx.jpg/png
30 | ```
31 |
32 | video data
33 | ```
34 | ----dataset_name
35 | --------images
36 | ------------video1
37 | ----------------xxxx.png
38 | ------------video2
39 | ----------------xxxx.png
40 | --------labels
41 | ------------video1
42 | ----------------xxxx.png
43 | ------------video2
44 | ----------------xxxx.png
45 | ```
46 |
47 | Unfortunately, it is impossible to have one script to finish all the data organization. We manually organized with commonly used data format converting functions, including `dcm2nii`, `mhd2nii`, `nii2nii`, `nrrd2nii`, `jpg2png`, `tif2png`, `rle_decode`. These functions are available at `format_convert.py`
48 |
49 | ## Data Splitting
50 | For common 2D images (e.g., skin cancer demoscopy, chest X-Ray), they can be directly separated into 80%/10%/10% for training, parameter tuning, and internal validation, respectively. For 3D images (e.g., all the MRI/CT scans) and video data, they should be split in the case/video level rather than 2D slice/frame level. For 2D whole-slide images, the splitting is in the whole-slide level. Since they cannot be directly sent to the model because of the high resolution, we divided them into patches with a fixed size of `1024x1024` after data splitting.
51 |
52 | After finishing the data organization, the data splitting can be easily done by running
53 | ```bash
54 | python split.py
55 | ```
56 | Please set the proper data path in the script. The expected folder structures (e.g., 3D images) are
57 |
58 | ```
59 | ----dataset_name
60 | --------images
61 | ------------xxxx_0000.nii.gz
62 | ------------xxxx_0000.nii.gz
63 | --------labels
64 | ------------xxxx.nii.gz
65 | ------------xxxx.nii.gz
66 | --------validation
67 | ------------images
68 | ----------------xxxx_0000.nii.gz
69 | ----------------xxxx_0000.nii.gz
70 | ------------labels
71 | ----------------xxxx.nii.gz
72 | ----------------xxxx.nii.gz
73 | --------testing
74 | ------------images
75 | ----------------xxxx_0000.nii.gz
76 | ----------------xxxx_0000.nii.gz
77 | ------------labels
78 | ----------------xxxx.nii.gz
79 | ----------------xxxx.nii.gz
80 | ```
81 |
82 | ## Data Preprocessing and Ensembling
83 |
84 | All the images will be preprocessed as `npy` files. There are two main reasons for choosing this format. First, it allows fast data loading (main reason). We learned this point from [nnU-Net](https://github.com/MIC-DKFZ/nnUNet). Second, numpy file is a universal data interface to unify all the different data formats. For the convenience of debugging and inference, we also saved the original images and labels as `npz` files. Spacing information is also saved for CT and MR images.
85 |
86 | The following steps are applied to all images
87 | - max-min normalization
88 | - resample image size to 1024x2014
89 | - save the pre-processed images and labels as npy files
90 |
91 | Different modalities also have their own additional pre-process steps based on the data features.
92 |
93 | For CT images, we fist adjust the window level and width following the [common practice](https://radiopaedia.org/articles/windowing-ct).
94 | - Soft tissue window level (40) and width (400)
95 | - Chest window level (-600) and width (1500)
96 | - Brain window level (40) and width (80)
97 |
98 | For MR and ultrasound, mammography, and Optical Coherence Tomography (OCT) images (i.e., ultrasound), we apply intensity cut-off with 0.5 and 99.5 percentiles of the foreground voxels. Regarding RGB images (e.g., endoscopy, dermoscopy, fundus, and pathology images), if they are already within the expected intensity range of [0, 255], their intensities remained unchanged. However, if they fell outside this range, max-min normalization was applited to rescale the intensity values to [0, 255].
99 |
100 | Preprocess for CT/MR images:
101 | ```bash
102 | python pre_CT_MR.py
103 | ```
104 |
105 | Preprocess for grey and RGB images:
106 | ```bash
107 | python pre_grey_rgb.py
108 | ```
109 |
110 | Note: Please set the corresponding folder path and molidaty information. We provided an example in the script.
111 |
112 | Data ensembling of different training datasets is very simple. Since all the training data are converted into `npy` files during preprocessing, you just need to merge them into one folder.
113 |
114 |
115 | ## Checkpoint Converting
116 | If the model is trained with multiple GPUs, please use the script `ckpt_convert.py` to convert the format since users only use one GPU for model inference in real practice.
117 |
118 | Set the path to `sam_ckpt_path`, `medsam_ckpt_path`, and `save_path` and run
119 |
120 | ```bash
121 | python ckpt_convert.py
122 | ```
123 |
124 |
--------------------------------------------------------------------------------
/icons/mac_installer_bg.svg:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
129 |
--------------------------------------------------------------------------------
/src/dafne/ui/WhatsNew.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2022 Dafne-Imaging Team
2 | from datetime import datetime
3 | import xml.etree.ElementTree as ET
4 | from urllib.parse import urlparse
5 |
6 | import requests
7 | from PyQt5.QtCore import Qt, QObject, pyqtSignal
8 | from PyQt5.QtWidgets import QDialog, QVBoxLayout, QPushButton, QLabel, QSizePolicy
9 |
10 | from ..config import GlobalConfig, save_config
11 | from ..utils.ThreadHelpers import separate_thread_decorator
12 |
13 | MAX_NEWS_ITEMS = 3
14 | BLOG_INDEX = '/blog/index/'
15 | BLOG_ADDRESS = '/blog/'
16 |
17 |
18 | class WhatsNewDialog(QDialog):
19 | def __init__(self, news_list, index_address, *args, **kwargs):
20 | QDialog.__init__(self, *args, **kwargs)
21 | my_layout = QVBoxLayout(self)
22 | self.setLayout(my_layout)
23 | self.setWindowTitle(f"Dafne News")
24 | self.setWindowModality(Qt.ApplicationModal)
25 |
26 | for news in news_list[:MAX_NEWS_ITEMS]:
27 | title_label = QLabel(f'')
28 | title_label.setOpenExternalLinks(True)
29 | title_label.setTextInteractionFlags(Qt.LinksAccessibleByMouse)
30 | title_label.sizePolicy().setVerticalStretch(0)
31 | title_label.setWordWrap(True)
32 | my_layout.addWidget(title_label)
33 | date_label = QLabel(f'{news["date"]}')
34 | date_label.sizePolicy().setVerticalStretch(0)
35 | my_layout.addWidget(date_label)
36 | body_label = QLabel(news["excerpt"])
37 | body_label.setWordWrap(True)
38 | size_policy = QSizePolicy(QSizePolicy.Preferred, QSizePolicy.Preferred)
39 | size_policy.setHorizontalStretch(0)
40 | size_policy.setVerticalStretch(1)
41 | size_policy.setHeightForWidth(body_label.sizePolicy().hasHeightForWidth())
42 | body_label.setAlignment(Qt.AlignLeading | Qt.AlignLeft | Qt.AlignTop)
43 | body_label.setSizePolicy(size_policy)
44 | my_layout.addWidget(body_label)
45 |
46 | more_news_label = QLabel(f'All news...')
47 | more_news_label.setOpenExternalLinks(True)
48 | more_news_label.setTextInteractionFlags(Qt.LinksAccessibleByMouse)
49 | more_news_label.sizePolicy().setVerticalStretch(0)
50 | my_layout.addWidget(more_news_label)
51 |
52 | n_news = min(MAX_NEWS_ITEMS, len(news_list))
53 |
54 | btn = QPushButton("Close")
55 | btn.clicked.connect(self.close)
56 | my_layout.addWidget(btn)
57 | self.resize(300, 110 * n_news + 60)
58 | self.show()
59 |
60 |
61 | def xml_timestamp_to_datetime(timestamp):
62 | return datetime.strptime(timestamp, '%Y-%m-%dT%H:%M:%S%z')
63 |
64 |
65 | def datetime_to_xml_timestamp(dt):
66 | return dt.strftime('%Y-%m-%dT%H:%M:%S%z')
67 |
68 |
69 | def check_for_updates():
70 | last_news_time = xml_timestamp_to_datetime(GlobalConfig['LAST_NEWS'])
71 | # last_news_time = xml_timestamp_to_datetime('2010-11-10T00:00:00+00:00')
72 | try:
73 | r = requests.get(GlobalConfig['NEWS_URL'], timeout=(1, None))
74 | except requests.exceptions.ConnectionError:
75 | return [], []
76 | except requests.exceptions.Timeout:
77 | return [], []
78 |
79 | if r.status_code != 200:
80 | return [], []
81 | try:
82 | feed = ET.fromstring(r.text)
83 | except ET.ParseError:
84 | print("Error parsing news feed")
85 | return [], []
86 |
87 | parsed_uri = urlparse(GlobalConfig['NEWS_URL'])
88 | base_url = f'{parsed_uri.scheme}://{parsed_uri.netloc}'
89 |
90 | xml_ns = {'atom': 'http://www.w3.org/2005/Atom'}
91 |
92 | news_list = []
93 | newest_time = last_news_time
94 | for entry in feed.findall('atom:entry', xml_ns):
95 | link = entry.find('atom:link', xml_ns).attrib['href']
96 | # skip the index
97 | if link == BLOG_INDEX:
98 | continue
99 |
100 | news_time = xml_timestamp_to_datetime(entry.find('atom:updated', xml_ns).text)
101 | if news_time > last_news_time:
102 | news_list.append({'date': news_time.strftime('%Y-%m-%d'),
103 | 'link': base_url + link,
104 | 'title': entry.find('atom:title', xml_ns).text,
105 | 'excerpt': entry.find('atom:summary', xml_ns).text})
106 | if news_time > newest_time:
107 | newest_time = news_time
108 |
109 | GlobalConfig['LAST_NEWS'] = datetime_to_xml_timestamp(newest_time)
110 | news_list.sort(key=lambda x: x['date'], reverse=True)
111 | save_config()
112 | return news_list, base_url + BLOG_ADDRESS
113 |
114 |
115 | class NewsChecker(QObject):
116 | news_ready = pyqtSignal(list, str)
117 |
118 | def __init__(self, *args, **kwargs):
119 | QObject.__init__(self, *args, **kwargs)
120 |
121 | @separate_thread_decorator
122 | def check_news(self):
123 | news_list, index_address = check_for_updates()
124 | if news_list:
125 | self.news_ready.emit(news_list, index_address)
126 |
127 |
128 | def show_news():
129 | news_list, index_address = check_for_updates()
130 | if news_list:
131 | d = WhatsNewDialog(news_list, index_address)
132 | d.exec()
133 |
134 |
135 | def main():
136 | import sys
137 | from PyQt5.QtWidgets import QApplication
138 | app = QApplication(sys.argv)
139 | app.setQuitOnLastWindowClosed(True)
140 | show_news()
141 |
--------------------------------------------------------------------------------
/src/dafne/ui/BatchCalcTransforms.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 |
4 | # Copyright (c) 2021 Dafne-Imaging Team
5 | #
6 | # This program is free software: you can redistribute it and/or modify
7 | # it under the terms of the GNU General Public License as published by
8 | # the Free Software Foundation, either version 3 of the License, or
9 | # (at your option) any later version.
10 | #
11 | # This program is distributed in the hope that it will be useful,
12 | # but WITHOUT ANY WARRANTY; without even the implied warranty of
13 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14 | # GNU General Public License for more details.
15 | #
16 | # You should have received a copy of the GNU General Public License
17 | # along with this program. If not, see .
18 |
19 | from PyQt5.QtWidgets import QWidget, QMainWindow, QFileDialog, QMessageBox, QApplication
20 | from PyQt5.QtCore import pyqtSlot, pyqtSignal
21 | from .CalcTransformsUI import Ui_CalcTransformsUI
22 | import os
23 | import numpy as np
24 | from ..utils.RegistrationManager import RegistrationManager
25 | from dicomUtils.misc import dosma_volume_from_path, get_nifti_orientation
26 | from ..utils.ThreadHelpers import separate_thread_decorator
27 | from ..config import GlobalConfig
28 | import sys
29 |
30 |
31 | class CalcTransformWindow(QWidget, Ui_CalcTransformsUI):
32 | update_progress = pyqtSignal(int)
33 | success = pyqtSignal()
34 |
35 | def __init__(self):
36 | QWidget.__init__(self)
37 | self.setupUi(self)
38 | self.setWindowTitle("Dicom transform calculator")
39 | self.registrationManager = None
40 | self.update_progress.connect(self.set_progress)
41 | self.choose_Button.clicked.connect(self.load_data)
42 | self.calculate_button.clicked.connect(self.calculate)
43 | self.success.connect(self.show_success_box)
44 | self.data = None
45 | self.basepath = ''
46 | self.basename = ''
47 |
48 | @pyqtSlot()
49 | def load_data(self):
50 | if GlobalConfig['ENABLE_NIFTI']:
51 | filter = 'Image files (*.dcm *.ima *.nii *.nii.gz);; Dicom files (*.dcm *.ima);;Nifti files (*.nii *.nii.gz);;All files (*.*)'
52 | else:
53 | filter = 'Dicom files (*.dcm *.ima);;All files (*.*)'
54 |
55 | dataFile, _ = QFileDialog.getOpenFileName(self, caption='Select dataset to import',
56 | filter=filter)
57 |
58 | path = os.path.abspath(dataFile)
59 | print(path)
60 | _, ext = os.path.splitext(path)
61 | dataset_name = os.path.basename(path)
62 |
63 | containing_dir = os.path.dirname(path)
64 |
65 | if ext.lower() not in ['.nii', '.gz']:
66 | path = containing_dir
67 |
68 |
69 | medical_volume = None
70 | basename = ''
71 | try:
72 | medical_volume, affine_valid, title, self.basepath, self.basename = dosma_volume_from_path(path, self, sort=GlobalConfig['DICOM_SORT'])
73 | self.data = medical_volume.volume
74 | except:
75 | pass
76 | print("Basepath", self.basepath)
77 |
78 | if self.data is None:
79 | self.progressBar.setValue(0)
80 | self.progressBar.setEnabled(False)
81 | self.calculate_button.setEnabled(False)
82 | QMessageBox.warning(self, 'Warning', 'Invalid dataset!')
83 | return
84 |
85 | self.data = medical_volume
86 |
87 | self.progressBar.setEnabled(True)
88 | self.registrationManager = None
89 | self.location_Text.setText(containing_dir if not basename else basename)
90 | self.calculate_button.setEnabled(True)
91 |
92 | @pyqtSlot(int)
93 | def set_progress(self, value):
94 | self.progressBar.setValue(value)
95 |
96 | @pyqtSlot()
97 | def show_success_box(self):
98 | QMessageBox.information(self, 'Done', 'Done!')
99 |
100 | @pyqtSlot()
101 | @separate_thread_decorator
102 | def calculate(self):
103 | self.choose_Button.setEnabled(False)
104 | self.calculate_button.setEnabled(False)
105 |
106 | if self.axial_radio.isChecked():
107 | self.data = self.data.reformat(get_nifti_orientation('axial'))
108 | elif self.sagittal_radio.isChecked():
109 | self.data = self.data.reformat(get_nifti_orientation('sagittal'))
110 | elif self.coronal_radio.isChecked():
111 | self.data = self.data.reformat(get_nifti_orientation('coronal'))
112 |
113 |
114 | data = list(np.transpose(self.data.volume, [2, 0, 1]))
115 | self.progressBar.setMaximum(len(data))
116 |
117 | self.registrationManager = RegistrationManager(data, None, os.getcwd(),
118 | GlobalConfig['TEMP_DIR'])
119 | self.registrationManager.set_standard_transforms_name(self.basepath, self.basename)
120 |
121 | self.registrationManager.calc_transforms(lambda value: self.update_progress.emit(value))
122 | self.choose_Button.setEnabled(True)
123 | self.calculate_button.setEnabled(False)
124 | self.update_progress.emit(0)
125 | self.success.emit()
126 |
127 |
128 | def run():
129 | app = QApplication(sys.argv)
130 | window = QMainWindow()
131 | widget = CalcTransformWindow()
132 | window.setCentralWidget(widget)
133 | window.setWindowTitle("Dicom transform calculator")
134 | window.show()
135 | sys.exit(app.exec_())
136 |
--------------------------------------------------------------------------------
/src/dafne/ui/BrushPatches.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2021 Dafne-Imaging Team
2 | #
3 | # This program is free software: you can redistribute it and/or modify
4 | # it under the terms of the GNU General Public License as published by
5 | # the Free Software Foundation, either version 3 of the License, or
6 | # (at your option) any later version.
7 | #
8 | # This program is distributed in the hope that it will be useful,
9 | # but WITHOUT ANY WARRANTY; without even the implied warranty of
10 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11 | # GNU General Public License for more details.
12 | #
13 | # You should have received a copy of the GNU General Public License
14 | # along with this program. If not, see .
15 |
16 | from matplotlib.patches import Polygon, Rectangle
17 | import numpy as np
18 | import math
19 | from scipy.ndimage import shift
20 |
21 |
22 | class SquareBrush(Rectangle):
23 |
24 | def __init__(self, *args, **kwargs):
25 | Rectangle.__init__(self, *args, **kwargs)
26 |
27 | def to_mask(self, shape):
28 | """
29 | convert the brush to a binary mask of the size "shape"
30 | """
31 | mask = np.zeros(shape, dtype=np.uint8)
32 |
33 | xy = self.get_xy()
34 | h = self.get_height()
35 | w = self.get_width()
36 | # x and y are inverted
37 | x0 = int(np.round(xy[1] + 0.5))
38 | y0 = int(np.round(xy[0] + 0.5))
39 | x1 = int(x0 + np.round(h))
40 | y1 = int(y0 + np.round(w))
41 |
42 | mask[x0:x1, y0:y1] = 1
43 | return mask
44 |
45 |
46 | class PixelatedCircleBrush(Polygon):
47 |
48 | def __init__(self, center, radius, **kwargs):
49 | self.point_array = None
50 | self.kwargs = kwargs
51 | self.center = None
52 | self.radius = None
53 | self.base_mask = None
54 | Polygon.__init__(self, np.array([[0,0],[1,1]]), **kwargs)
55 | self.center = np.array(center).ravel() # make sure it's a row vector
56 | self.set_radius(radius)
57 |
58 | def get_center(self):
59 | return self.center
60 |
61 | def get_radius(self):
62 | return self.radius
63 |
64 | def set_center(self, center):
65 | self.center = np.array(center).ravel() # make sure it's a row vector
66 | self._recalculate_xy()
67 |
68 | def set_radius(self, radius):
69 | if radius != self.radius:
70 | self.radius = radius
71 | self._recalculate_vertices()
72 | self._recalculate_mask()
73 | self._recalculate_xy()
74 |
75 | def to_mask(self, shape):
76 | mask = np.zeros(shape, dtype=np.uint8)
77 | mask[0:self.base_mask.shape[0], 0:self.base_mask.shape[1]] = self.base_mask
78 | mask = shift(mask, (self.center[1] - self.radius, self.center[0] - self.radius), order=0, prefilter=False)
79 | return mask
80 |
81 | def _recalculate_xy(self):
82 | xy = self.point_array + self.center
83 | self.set_xy(xy)
84 |
85 | def _recalculate_vertices(self):
86 | if self.radius == 0:
87 | self.point_array = np.array([[0,0],[1,0],[1,1],[0,1]]) - 0.5
88 | return
89 |
90 | radius = self.radius + 0.5
91 |
92 | # midpoint circle algorithm
93 | x = radius
94 | y = 0
95 | P = 1 - radius
96 |
97 | octant_point_array = [(x, y-0.5)]
98 |
99 |
100 | while x > y:
101 |
102 | y += 1
103 |
104 | # Mid-point inside or on the perimeter
105 | if P <= 0:
106 | P = P + 2 * y + 1
107 |
108 | # Mid-point outside the perimeter
109 | else:
110 | octant_point_array.append((x, y-0.5))
111 | x -= 1
112 | octant_point_array.append((x, y-0.5))
113 | P = P + 2 * y - 2 * x + 1
114 |
115 | if x < y:
116 | break
117 |
118 | # assemble the octants
119 | quarter_point_array = octant_point_array[:]
120 | quarter_point_array.extend([(y,x) for x,y in octant_point_array[::-1]])
121 | point_array = quarter_point_array[:]
122 | point_array.extend([(-x,y) for x,y in quarter_point_array[::-1]])
123 | point_array.extend([(-x,-y) for x,y in quarter_point_array])
124 | point_array.extend(([(x,-y) for x,y in quarter_point_array[::-1]]))
125 |
126 | self.point_array = np.array(point_array)
127 |
128 | def _recalculate_mask(self):
129 | if self.radius == 0:
130 | self.base_mask = np.ones((1, 1), dtype=np.uint8)
131 | return
132 |
133 | radius = self.radius
134 |
135 | # midpoint circle algorithm
136 | x = radius
137 | y = 0
138 | P = 1 - radius
139 |
140 | self.base_mask = np.zeros((self.radius * 2 + 1, self.radius * 2 + 1), dtype=np.uint8)
141 |
142 | def fill_mask_line(x, y):
143 | r = radius
144 | self.base_mask[int(r - x): int(r + x + 1), int(r + y)] = 1
145 | self.base_mask[int(r - x):int(r + x + 1), int(r - y)] = 1
146 |
147 | fill_mask_line(x, y)
148 | fill_mask_line(y, x)
149 |
150 | while x > y:
151 | y += 1
152 |
153 | # Mid-point inside or on the perimeter
154 | if P <= 0:
155 | P = P + 2 * y + 1
156 |
157 | # Mid-point outside the perimeter
158 | else:
159 | x -= 1
160 | P = P + 2 * y - 2 * x + 1
161 |
162 | fill_mask_line(x, y)
163 | fill_mask_line(y, x)
164 |
165 | if (x < y):
166 | break
--------------------------------------------------------------------------------
/src/dafne/utils/mask_utils.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2021 Dafne-Imaging Team
2 | #
3 | # This program is free software: you can redistribute it and/or modify
4 | # it under the terms of the GNU General Public License as published by
5 | # the Free Software Foundation, either version 3 of the License, or
6 | # (at your option) any later version.
7 | #
8 | # This program is distributed in the hope that it will be useful,
9 | # but WITHOUT ANY WARRANTY; without even the implied warranty of
10 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11 | # GNU General Public License for more details.
12 | #
13 | # You should have received a copy of the GNU General Public License
14 | # along with this program. If not, see .
15 |
16 | import os
17 | import numpy as np
18 | from voxel import DicomWriter, NiftiWriter, MedicalVolume
19 | from matplotlib import pyplot as plt
20 | from scipy import ndimage
21 |
22 |
23 | def save_dicom_masks(base_path: str, mask_dict: dict, affine, dicom_headers: list):
24 | dicom_writer = DicomWriter(num_workers=0)
25 | for name, mask in mask_dict.items():
26 | n = name.strip()
27 | if n == '': n = '_'
28 | dicom_path = os.path.join(base_path, n)
29 | try:
30 | os.makedirs(dicom_path)
31 | except OSError:
32 | pass
33 | medical_volume = MedicalVolume(mask.astype(np.uint16), affine, dicom_headers)
34 | dicom_writer.save(medical_volume, dicom_path, fname_fmt='image%04d.dcm')
35 |
36 |
37 | def save_npy_masks(base_path, mask_dict):
38 | for name, mask in mask_dict.items():
39 | npy_name = os.path.join(base_path, name + '.npy')
40 | np.save(npy_name, mask)
41 |
42 |
43 | def save_npz_masks(filename, mask_dict):
44 | np.savez_compressed(filename, **mask_dict)
45 |
46 |
47 | def save_nifti_masks(base_path, mask_dict, affine):
48 | nifti_writer = NiftiWriter()
49 | for name, mask in mask_dict.items():
50 | nifti_name = os.path.join(base_path, name + '.nii.gz')
51 | medical_volume = MedicalVolume(mask.astype(np.uint16), affine)
52 | nifti_writer.save(medical_volume, nifti_name)
53 |
54 |
55 | def make_accumulated_mask(mask_dict):
56 | accumulated_mask = None
57 | current_value = 1
58 | name_list = []
59 | for name, mask in mask_dict.items():
60 | name_list.append(name)
61 | if accumulated_mask is None:
62 | accumulated_mask = 1*(mask>0)
63 | else:
64 | accumulated_mask += current_value * (mask>0)
65 | current_value += 1
66 | return accumulated_mask, name_list
67 |
68 |
69 | def write_legend(filename, name_list):
70 | with open(filename, 'w') as f:
71 | f.write('Value,Label\n')
72 | for index, name in enumerate(name_list):
73 | f.write(f'{index+1},{name}\n')
74 |
75 |
76 | def write_itksnap_legend(filename, name_list):
77 | PREAMBLE = """################################################
78 | # ITK-SnAP Label Description File
79 | # File format:
80 | # IDX -R- -G- -B- -A-- VIS MSH LABEL
81 | # Fields:
82 | # IDX: Zero-based index
83 | # -R-: Red color component (0..255)
84 | # -G-: Green color component (0..255)
85 | # -B-: Blue color component (0..255)
86 | # -A-: Label transparency (0.00 .. 1.00)
87 | # VIS: Label visibility (0 or 1)
88 | # MSH: Label mesh visibility (0 or 1)
89 | # LABEL: Label description
90 | ################################################\n"""
91 |
92 | line_format = '{id:>5d}{red:>6d}{green:>5d}{blue:>5d}{alpha:>9.2g}{vis:>3d}{mesh:>3d} "{label}"\n'
93 |
94 | cmap = plt.get_cmap('hsv')
95 | nLabels = len(name_list)
96 | nColors = cmap.N
97 |
98 | step = int(nColors/nLabels)
99 |
100 | with open(filename, 'w') as f:
101 | f.write(PREAMBLE)
102 | f.write(line_format.format(id=0, red=0, green=0, blue=0, alpha=0, vis=0, mesh=0, label='Clear Label'))
103 | for index, name in enumerate(name_list):
104 | color = cmap(index*step)
105 | f.write(line_format.format(id=index+1, red=int(color[0]*255), green=int(color[1]*255), blue=int(color[2]*255),
106 | alpha=1, vis=1, mesh=1, label=name))
107 |
108 |
109 | def save_single_nifti(filename, mask_dict, affine):
110 | nifti_writer = NiftiWriter()
111 | accumulated_mask, name_list = make_accumulated_mask(mask_dict)
112 | legend_name = filename + '.csv'
113 | snap_legend_name = filename + '_itk-snap.txt'
114 | medical_volume = MedicalVolume(accumulated_mask.astype(np.uint16), affine)
115 | nifti_writer.save(medical_volume, filename)
116 | write_legend(legend_name, name_list)
117 | write_itksnap_legend(snap_legend_name, name_list)
118 |
119 |
120 | def save_single_dicom_dataset(base_path, mask_dict, affine, dicom_headers: list):
121 | dicom_writer = DicomWriter(num_workers=0)
122 | accumulated_mask, name_list = make_accumulated_mask(mask_dict)
123 | medical_volume = MedicalVolume(accumulated_mask.astype(np.uint16), affine, dicom_headers)
124 | try:
125 | os.makedirs(base_path)
126 | except OSError:
127 | pass
128 | dicom_writer.save(medical_volume, base_path, fname_fmt='image%04d.dcm')
129 | legend_name = os.path.join(base_path, 'legend.csv')
130 | snap_legend_name = os.path.join(base_path, 'legend_itk-snap.txt')
131 | write_legend(legend_name, name_list)
132 | write_itksnap_legend(snap_legend_name, name_list)
133 |
134 |
135 | def distance_mask(mask):
136 | mask = mask.astype(np.uint8)
137 | internal_distance = ndimage.distance_transform_edt(mask)
138 | external_distance = ndimage.distance_transform_edt(1-mask)
139 | return internal_distance - external_distance
--------------------------------------------------------------------------------
/src/dafne/ui/ValidateUI.ui:
--------------------------------------------------------------------------------
1 |
2 |
3 | ValidateUI
4 |
5 |
6 |
7 | 0
8 | 0
9 | 724
10 | 348
11 |
12 |
13 |
14 | Form
15 |
16 |
17 | -
18 |
19 |
20 | Configure...
21 |
22 |
23 |
24 | -
25 |
26 |
-
27 |
28 |
29 |
30 | 0
31 | 0
32 |
33 |
34 |
35 | Data Location:
36 |
37 |
38 |
39 | -
40 |
41 |
42 | false
43 |
44 |
45 |
46 | -
47 |
48 |
49 | Choose...
50 |
51 |
52 |
53 |
54 |
55 | -
56 |
57 |
-
58 |
59 |
60 |
61 | 0
62 | 0
63 |
64 |
65 |
66 | Masks Location:
67 |
68 |
69 |
70 | -
71 |
72 |
73 | false
74 |
75 |
76 |
77 | -
78 |
79 |
80 | false
81 |
82 |
83 | Choose folder...
84 |
85 |
86 |
87 | -
88 |
89 |
90 | false
91 |
92 |
93 | Choose ROI...
94 |
95 |
96 |
97 |
98 |
99 | -
100 |
101 |
-
102 |
103 |
104 | Current eval:
105 |
106 |
107 |
108 | -
109 |
110 |
111 | false
112 |
113 |
114 |
115 | 1
116 | 0
117 |
118 |
119 |
120 | 0
121 |
122 |
123 |
124 |
125 |
126 | -
127 |
128 |
-
129 |
130 |
131 | Overall progress:
132 |
133 |
134 |
135 | -
136 |
137 |
138 | false
139 |
140 |
141 |
142 | 1
143 | 0
144 |
145 |
146 |
147 | 0
148 |
149 |
150 |
151 |
152 |
153 | -
154 |
155 |
156 |
157 |
158 |
159 |
160 | -
161 |
162 |
163 | Qt::Vertical
164 |
165 |
166 | QSizePolicy::Expanding
167 |
168 |
169 |
170 | 20
171 | 45
172 |
173 |
174 |
175 |
176 | -
177 |
178 |
179 | false
180 |
181 |
182 | Evaluate
183 |
184 |
185 |
186 |
187 |
188 |
189 |
190 |
191 |
--------------------------------------------------------------------------------
/src/dafne/utils/polyToMask.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | # Copyright (c) 2021 Dafne-Imaging Team
4 | #
5 | # This program is free software: you can redistribute it and/or modify
6 | # it under the terms of the GNU General Public License as published by
7 | # the Free Software Foundation, either version 3 of the License, or
8 | # (at your option) any later version.
9 | #
10 | # This program is distributed in the hope that it will be useful,
11 | # but WITHOUT ANY WARRANTY; without even the implied warranty of
12 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
13 | # GNU General Public License for more details.
14 | #
15 | # You should have received a copy of the GNU General Public License
16 | # along with this program. If not, see .
17 |
18 | import numpy as np
19 | from collections import deque
20 |
21 | # recursively flood a mask in place
22 | def flood(seedX, seedY, mask):
23 | # check if we are on countour or out of bounds
24 | sz = mask.shape
25 | q = deque()
26 | if (mask[seedX][seedY] == 1):
27 | return
28 | q.append((seedX, seedY))
29 |
30 | # function to determine if a point is out of bound
31 | isOutOfBound = lambda p: p[0] < 0 or p[1] < 0 or p[0] >= sz[0] or p[1] >= sz[1]
32 | while q: # iterate until empty
33 | currentNode = q.popleft()
34 | if isOutOfBound(currentNode) or mask[currentNode[0]][currentNode[1]] == 1:
35 | continue
36 | # travel right (east)
37 | for e in range(currentNode[0], sz[0]):
38 | # exit if we reached a countour
39 | if mask[e][currentNode[1]] == 1:
40 | break
41 | # change color
42 | mask[e][currentNode[1]] = 1
43 | # add north and south to the queue
44 | q.append((e, currentNode[1]-1))
45 | q.append((e, currentNode[1]+1))
46 | # travel left (west)
47 | for w in range(currentNode[0]-1, -1, -1):
48 | # exit if we reached a countour
49 | if mask[w][currentNode[1]] == 1:
50 | break
51 | # change color
52 | mask[w][currentNode[1]] = 1
53 | # add north and south to the queue
54 | q.append((w, currentNode[1]-1))
55 | q.append((w, currentNode[1]+1))
56 |
57 |
58 | # old recursive implementation that crashed python :(
59 | #
60 | # if seedX < 0 or seedX >= sz[0] or seedY < 0 or seedY >= sz[1] or mask[seedX][seedY] == 1:
61 | # return
62 | # mask[seedX][seedY] = 1
63 | # # recursively flood
64 | # flood(seedX+1, seedY, mask)
65 | # flood(seedX-1, seedY, mask)
66 | # flood(seedX, seedY+1, mask)
67 | # flood(seedX, seedY-1, mask)
68 |
69 | def intround(x):
70 | return int(round(x))
71 |
72 | #converts this spline to mask of a defined size. Note! At the moment this will not work properly if the contour touches the edges!
73 | def polyToMask(points, size):
74 | size = (size[1], size[0]) # x is rows and y is columns
75 | mask = np.zeros((size[0]+1, size[1])) # create a mask that is 1 larger than needed
76 | for i in range(0, len(points)):
77 | curPoint = points[i]
78 | try:
79 | nextPoint = points[i+1]
80 | except IndexError:
81 | nextPoint = points[0] # close the polygon
82 |
83 | #print curPoint, nextPoint
84 | if (curPoint[0] == nextPoint[0]) and (curPoint[1] == nextPoint[1]):
85 | continue
86 | if curPoint[0] < 0: curPoint[0] = 0
87 | if curPoint[1] < 0: curPoint[1] = 0
88 | if nextPoint[0] < 0: nextPoint[0] = 0
89 | if nextPoint[1] < 0: nextPoint[1] = 0
90 |
91 | if curPoint[0] > size[0]-1: curPoint[0] = size[0]-1
92 | if curPoint[1] > size[1]-1: curPoint[1] = size[1]-1
93 | if nextPoint[0] > size[0]-1: nextPoint[0] = size[0]-1
94 | if nextPoint[1] > size[1]-1: nextPoint[1] = size[1]-1
95 |
96 | mask[intround(curPoint[0])][intround(curPoint[1])] = 1 # set initial point to 1
97 | # special case for vertical line
98 | if nextPoint[0] == curPoint[0]:
99 | # order start and end
100 | if curPoint[1] < nextPoint[1]:
101 | startY = curPoint[1]
102 | endY = nextPoint[1]
103 | else:
104 | startY = nextPoint[1]
105 | endY = curPoint[1]
106 | for y in range(intround(startY), intround(endY+1)): # how stupid is this?
107 | mask[intround(curPoint[0])][intround(y)] = 1
108 | else:
109 | # not a vertical line
110 | slope = (nextPoint[1]-curPoint[1])/(nextPoint[0]-curPoint[0])
111 | if abs(slope) < 1:
112 | #travel along x because line is "flat"
113 | if curPoint[0] < nextPoint[0]:
114 | startX = curPoint[0]
115 | endX = nextPoint[0]
116 | startY = curPoint[1]
117 | else:
118 | startX = nextPoint[0]
119 | endX = curPoint[0]
120 | startY = nextPoint[1]
121 | for x in range(intround(startX), intround(endX+1)):
122 | nextY = startY + (x-startX)*slope
123 | mask[intround(x)][intround(nextY)] = 1
124 | else:
125 | # travel along y because line is "steep"
126 | if curPoint[1] < nextPoint[1]:
127 | startY = curPoint[1]
128 | endY = nextPoint[1]
129 | startX = curPoint[0]
130 | else:
131 | startY = nextPoint[1]
132 | endY = curPoint[1]
133 | startX = nextPoint[0]
134 | for y in range(intround(startY), intround(endY+1)):
135 | nextX = startX + (y-startY)/slope
136 | mask[intround(nextX)][intround(y)] = 1
137 |
138 | contour = np.copy(mask) # save this for later
139 | flood(mask.shape[0]-1, 0, mask) # flood the outside of the contour, by starting on a point that is for sure outside
140 | # now the mask is actually the inverted mask and the contour
141 | # invert the mask and add the contour back
142 | mask = np.logical_or(contour, np.logical_not(mask))
143 | #mask = np.logical_not(mask) # do not add the contour! Change this for production, this is only for the abstract
144 | return np.transpose(mask[0:len(mask)-1]) # return the mask with the original desired size - transposed because x is rows and y is columns
145 |
--------------------------------------------------------------------------------
/src/dafne/MedSAM/segment_anything/utils/onnx.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # Copyright (c) Meta Platforms, Inc. and affiliates.
3 | # All rights reserved.
4 |
5 | # This source code is licensed under the license found in the
6 | # LICENSE file in the root directory of this source tree.
7 |
8 | import torch
9 | import torch.nn as nn
10 | from torch.nn import functional as F
11 |
12 | from typing import Tuple
13 |
14 | from ..modeling import Sam
15 | from .amg import calculate_stability_score
16 |
17 |
18 | class SamOnnxModel(nn.Module):
19 | """
20 | This model should not be called directly, but is used in ONNX export.
21 | It combines the prompt encoder, mask decoder, and mask postprocessing of Sam,
22 | with some functions modified to enable model tracing. Also supports extra
23 | options controlling what information. See the ONNX export script for details.
24 | """
25 |
26 | def __init__(
27 | self,
28 | model: Sam,
29 | return_single_mask: bool,
30 | use_stability_score: bool = False,
31 | return_extra_metrics: bool = False,
32 | ) -> None:
33 | super().__init__()
34 | self.mask_decoder = model.mask_decoder
35 | self.model = model
36 | self.img_size = model.image_encoder.img_size
37 | self.return_single_mask = return_single_mask
38 | self.use_stability_score = use_stability_score
39 | self.stability_score_offset = 1.0
40 | self.return_extra_metrics = return_extra_metrics
41 |
42 | @staticmethod
43 | def resize_longest_image_size(
44 | input_image_size: torch.Tensor, longest_side: int
45 | ) -> torch.Tensor:
46 | input_image_size = input_image_size.to(torch.float32)
47 | scale = longest_side / torch.max(input_image_size)
48 | transformed_size = scale * input_image_size
49 | transformed_size = torch.floor(transformed_size + 0.5).to(torch.int64)
50 | return transformed_size
51 |
52 | def _embed_points(
53 | self, point_coords: torch.Tensor, point_labels: torch.Tensor
54 | ) -> torch.Tensor:
55 | point_coords = point_coords + 0.5
56 | point_coords = point_coords / self.img_size
57 | point_embedding = self.model.prompt_encoder.pe_layer._pe_encoding(point_coords)
58 | point_labels = point_labels.unsqueeze(-1).expand_as(point_embedding)
59 |
60 | point_embedding = point_embedding * (point_labels != -1)
61 | point_embedding = (
62 | point_embedding
63 | + self.model.prompt_encoder.not_a_point_embed.weight * (point_labels == -1)
64 | )
65 |
66 | for i in range(self.model.prompt_encoder.num_point_embeddings):
67 | point_embedding = (
68 | point_embedding
69 | + self.model.prompt_encoder.point_embeddings[i].weight
70 | * (point_labels == i)
71 | )
72 |
73 | return point_embedding
74 |
75 | def _embed_masks(
76 | self, input_mask: torch.Tensor, has_mask_input: torch.Tensor
77 | ) -> torch.Tensor:
78 | mask_embedding = has_mask_input * self.model.prompt_encoder.mask_downscaling(
79 | input_mask
80 | )
81 | mask_embedding = mask_embedding + (
82 | 1 - has_mask_input
83 | ) * self.model.prompt_encoder.no_mask_embed.weight.reshape(1, -1, 1, 1)
84 | return mask_embedding
85 |
86 | def mask_postprocessing(
87 | self, masks: torch.Tensor, orig_im_size: torch.Tensor
88 | ) -> torch.Tensor:
89 | masks = F.interpolate(
90 | masks,
91 | size=(self.img_size, self.img_size),
92 | mode="bilinear",
93 | align_corners=False,
94 | )
95 |
96 | prepadded_size = self.resize_longest_image_size(orig_im_size, self.img_size).to(
97 | torch.int64
98 | )
99 | masks = masks[..., : prepadded_size[0], : prepadded_size[1]] # type: ignore
100 |
101 | orig_im_size = orig_im_size.to(torch.int64)
102 | h, w = orig_im_size[0], orig_im_size[1]
103 | masks = F.interpolate(masks, size=(h, w), mode="bilinear", align_corners=False)
104 | return masks
105 |
106 | def select_masks(
107 | self, masks: torch.Tensor, iou_preds: torch.Tensor, num_points: int
108 | ) -> Tuple[torch.Tensor, torch.Tensor]:
109 | # Determine if we should return the multiclick mask or not from the number of points.
110 | # The reweighting is used to avoid control flow.
111 | score_reweight = torch.tensor(
112 | [[1000] + [0] * (self.model.mask_decoder.num_mask_tokens - 1)]
113 | ).to(iou_preds.device)
114 | score = iou_preds + (num_points - 2.5) * score_reweight
115 | best_idx = torch.argmax(score, dim=1)
116 | masks = masks[torch.arange(masks.shape[0]), best_idx, :, :].unsqueeze(1)
117 | iou_preds = iou_preds[torch.arange(masks.shape[0]), best_idx].unsqueeze(1)
118 |
119 | return masks, iou_preds
120 |
121 | @torch.no_grad()
122 | def forward(
123 | self,
124 | image_embeddings: torch.Tensor,
125 | point_coords: torch.Tensor,
126 | point_labels: torch.Tensor,
127 | mask_input: torch.Tensor,
128 | has_mask_input: torch.Tensor,
129 | orig_im_size: torch.Tensor,
130 | ):
131 | sparse_embedding = self._embed_points(point_coords, point_labels)
132 | dense_embedding = self._embed_masks(mask_input, has_mask_input)
133 |
134 | masks, scores = self.model.mask_decoder.predict_masks(
135 | image_embeddings=image_embeddings,
136 | image_pe=self.model.prompt_encoder.get_dense_pe(),
137 | sparse_prompt_embeddings=sparse_embedding,
138 | dense_prompt_embeddings=dense_embedding,
139 | )
140 |
141 | if self.use_stability_score:
142 | scores = calculate_stability_score(
143 | masks, self.model.mask_threshold, self.stability_score_offset
144 | )
145 |
146 | if self.return_single_mask:
147 | masks, scores = self.select_masks(masks, scores, point_coords.shape[1])
148 |
149 | upscaled_masks = self.mask_postprocessing(masks, orig_im_size)
150 |
151 | if self.return_extra_metrics:
152 | stability_scores = calculate_stability_score(
153 | upscaled_masks, self.model.mask_threshold, self.stability_score_offset
154 | )
155 | areas = (upscaled_masks > self.model.mask_threshold).sum(-1).sum(-1)
156 | return upscaled_masks, scores, stability_scores, areas, masks
157 |
158 | return upscaled_masks, scores, masks
159 |
--------------------------------------------------------------------------------
/src/dafne/ui/ValidateUI.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | # Form implementation generated from reading ui file 'ValidateUI.ui'
4 | #
5 | # Created by: PyQt5 UI code generator 5.15.6
6 | #
7 | # WARNING: Any manual changes made to this file will be lost when pyuic5 is
8 | # run again. Do not edit this file unless you know what you are doing.
9 |
10 |
11 | from PyQt5 import QtCore, QtGui, QtWidgets
12 |
13 |
14 | class Ui_ValidateUI(object):
15 | def setupUi(self, ValidateUI):
16 | ValidateUI.setObjectName("ValidateUI")
17 | ValidateUI.resize(724, 348)
18 | self.verticalLayout = QtWidgets.QVBoxLayout(ValidateUI)
19 | self.verticalLayout.setObjectName("verticalLayout")
20 | self.configure_button = QtWidgets.QPushButton(ValidateUI)
21 | self.configure_button.setObjectName("configure_button")
22 | self.verticalLayout.addWidget(self.configure_button)
23 | self.horizontalLayout = QtWidgets.QHBoxLayout()
24 | self.horizontalLayout.setObjectName("horizontalLayout")
25 | self.label = QtWidgets.QLabel(ValidateUI)
26 | sizePolicy = QtWidgets.QSizePolicy(QtWidgets.QSizePolicy.Preferred, QtWidgets.QSizePolicy.Fixed)
27 | sizePolicy.setHorizontalStretch(0)
28 | sizePolicy.setVerticalStretch(0)
29 | sizePolicy.setHeightForWidth(self.label.sizePolicy().hasHeightForWidth())
30 | self.label.setSizePolicy(sizePolicy)
31 | self.label.setObjectName("label")
32 | self.horizontalLayout.addWidget(self.label)
33 | self.data_location_Text = QtWidgets.QLineEdit(ValidateUI)
34 | self.data_location_Text.setEnabled(False)
35 | self.data_location_Text.setObjectName("data_location_Text")
36 | self.horizontalLayout.addWidget(self.data_location_Text)
37 | self.data_choose_Button = QtWidgets.QPushButton(ValidateUI)
38 | self.data_choose_Button.setObjectName("data_choose_Button")
39 | self.horizontalLayout.addWidget(self.data_choose_Button)
40 | self.verticalLayout.addLayout(self.horizontalLayout)
41 | self.horizontalLayout_2 = QtWidgets.QHBoxLayout()
42 | self.horizontalLayout_2.setObjectName("horizontalLayout_2")
43 | self.label_2 = QtWidgets.QLabel(ValidateUI)
44 | sizePolicy = QtWidgets.QSizePolicy(QtWidgets.QSizePolicy.Preferred, QtWidgets.QSizePolicy.Fixed)
45 | sizePolicy.setHorizontalStretch(0)
46 | sizePolicy.setVerticalStretch(0)
47 | sizePolicy.setHeightForWidth(self.label_2.sizePolicy().hasHeightForWidth())
48 | self.label_2.setSizePolicy(sizePolicy)
49 | self.label_2.setObjectName("label_2")
50 | self.horizontalLayout_2.addWidget(self.label_2)
51 | self.mask_location_Text = QtWidgets.QLineEdit(ValidateUI)
52 | self.mask_location_Text.setEnabled(False)
53 | self.mask_location_Text.setObjectName("mask_location_Text")
54 | self.horizontalLayout_2.addWidget(self.mask_location_Text)
55 | self.mask_choose_Button = QtWidgets.QPushButton(ValidateUI)
56 | self.mask_choose_Button.setEnabled(False)
57 | self.mask_choose_Button.setObjectName("mask_choose_Button")
58 | self.horizontalLayout_2.addWidget(self.mask_choose_Button)
59 | self.roi_choose_Button = QtWidgets.QPushButton(ValidateUI)
60 | self.roi_choose_Button.setEnabled(False)
61 | self.roi_choose_Button.setObjectName("roi_choose_Button")
62 | self.horizontalLayout_2.addWidget(self.roi_choose_Button)
63 | self.verticalLayout.addLayout(self.horizontalLayout_2)
64 | self.horizontalLayout_3 = QtWidgets.QHBoxLayout()
65 | self.horizontalLayout_3.setObjectName("horizontalLayout_3")
66 | self.currentProgress_Label = QtWidgets.QLabel(ValidateUI)
67 | self.currentProgress_Label.setObjectName("currentProgress_Label")
68 | self.horizontalLayout_3.addWidget(self.currentProgress_Label)
69 | self.slice_progressBar = QtWidgets.QProgressBar(ValidateUI)
70 | self.slice_progressBar.setEnabled(False)
71 | sizePolicy = QtWidgets.QSizePolicy(QtWidgets.QSizePolicy.Expanding, QtWidgets.QSizePolicy.Fixed)
72 | sizePolicy.setHorizontalStretch(1)
73 | sizePolicy.setVerticalStretch(0)
74 | sizePolicy.setHeightForWidth(self.slice_progressBar.sizePolicy().hasHeightForWidth())
75 | self.slice_progressBar.setSizePolicy(sizePolicy)
76 | self.slice_progressBar.setProperty("value", 0)
77 | self.slice_progressBar.setObjectName("slice_progressBar")
78 | self.horizontalLayout_3.addWidget(self.slice_progressBar)
79 | self.verticalLayout.addLayout(self.horizontalLayout_3)
80 | self.horizontalLayout_4 = QtWidgets.QHBoxLayout()
81 | self.horizontalLayout_4.setObjectName("horizontalLayout_4")
82 | self.label_4 = QtWidgets.QLabel(ValidateUI)
83 | self.label_4.setObjectName("label_4")
84 | self.horizontalLayout_4.addWidget(self.label_4)
85 | self.overall_progressBar = QtWidgets.QProgressBar(ValidateUI)
86 | self.overall_progressBar.setEnabled(False)
87 | sizePolicy = QtWidgets.QSizePolicy(QtWidgets.QSizePolicy.Expanding, QtWidgets.QSizePolicy.Fixed)
88 | sizePolicy.setHorizontalStretch(1)
89 | sizePolicy.setVerticalStretch(0)
90 | sizePolicy.setHeightForWidth(self.overall_progressBar.sizePolicy().hasHeightForWidth())
91 | self.overall_progressBar.setSizePolicy(sizePolicy)
92 | self.overall_progressBar.setProperty("value", 0)
93 | self.overall_progressBar.setObjectName("overall_progressBar")
94 | self.horizontalLayout_4.addWidget(self.overall_progressBar)
95 | self.verticalLayout.addLayout(self.horizontalLayout_4)
96 | self.status_Label = QtWidgets.QLabel(ValidateUI)
97 | self.status_Label.setText("")
98 | self.status_Label.setObjectName("status_Label")
99 | self.verticalLayout.addWidget(self.status_Label)
100 | spacerItem = QtWidgets.QSpacerItem(20, 45, QtWidgets.QSizePolicy.Minimum, QtWidgets.QSizePolicy.Expanding)
101 | self.verticalLayout.addItem(spacerItem)
102 | self.evaluate_Button = QtWidgets.QPushButton(ValidateUI)
103 | self.evaluate_Button.setEnabled(False)
104 | self.evaluate_Button.setObjectName("evaluate_Button")
105 | self.verticalLayout.addWidget(self.evaluate_Button)
106 |
107 | self.retranslateUi(ValidateUI)
108 | QtCore.QMetaObject.connectSlotsByName(ValidateUI)
109 |
110 | def retranslateUi(self, ValidateUI):
111 | _translate = QtCore.QCoreApplication.translate
112 | ValidateUI.setWindowTitle(_translate("ValidateUI", "Form"))
113 | self.configure_button.setText(_translate("ValidateUI", "Configure..."))
114 | self.label.setText(_translate("ValidateUI", "Data Location:"))
115 | self.data_choose_Button.setText(_translate("ValidateUI", "Choose..."))
116 | self.label_2.setText(_translate("ValidateUI", "Masks Location:"))
117 | self.mask_choose_Button.setText(_translate("ValidateUI", "Choose folder..."))
118 | self.roi_choose_Button.setText(_translate("ValidateUI", "Choose ROI..."))
119 | self.currentProgress_Label.setText(_translate("ValidateUI", "Current eval:"))
120 | self.label_4.setText(_translate("ValidateUI", "Overall progress:"))
121 | self.evaluate_Button.setText(_translate("ValidateUI", "Evaluate"))
122 |
--------------------------------------------------------------------------------
/src/dafne/ui/Viewer3D.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2023 Dafne-Imaging Team
2 | # Part of this code are based on "wezel": https://github.com/QIB-Sheffield/wezel/
3 |
4 | import os
5 | import numpy as np
6 | from PyQt5.QtCore import pyqtSlot, pyqtSignal, Qt
7 | from PyQt5.QtWidgets import QWidget, QVBoxLayout, QApplication, QPushButton, QHBoxLayout, QLabel, QSlider
8 | from ..config.config import GlobalConfig
9 |
10 | os.environ["QT_API"] = "pyqt5"
11 | import pyvista as pv
12 | from pyvistaqt import QtInteractor
13 |
14 | WIDTH = 400
15 | HEIGHT = 400
16 |
17 | OPACITY_ARRAY = np.array([0, 0.2, 0.3, 0.6, 0.2, 0])
18 | COLORMAP = 'bone'
19 |
20 | class Viewer3D(QWidget):
21 |
22 | hide_signal = pyqtSignal()
23 |
24 | def __init__(self):
25 | super().__init__()
26 | self.plotter = QtInteractor(self)
27 | self.plotter.background_color = 'black'
28 | self.plotter.add_camera_orientation_widget()
29 | self.spacing = (1.0, 1.0, 1.0)
30 | self.data = None
31 | self.anatomy = None
32 | self.actor_roi = None
33 | self.actor_anatomy = None
34 | self.anatomy_opacity = 0.2
35 | # Layout
36 | layout = QVBoxLayout()
37 | layout.setContentsMargins(0, 0, 0, 0)
38 | layout.setSpacing(0)
39 | layout.addWidget(self.plotter)
40 |
41 | fit_button = QPushButton("Fit to scene")
42 | layout.addWidget(fit_button)
43 | fit_button.clicked.connect(self.plotter.reset_camera)
44 |
45 | opacity_widget = QWidget()
46 | opacity_widget_layout = QHBoxLayout()
47 | opacity_widget.setLayout(opacity_widget_layout)
48 | opacity_widget_layout.addWidget(QLabel('Anatomy opacity:'))
49 | self.anat_opacity_slider = QSlider(Qt.Horizontal)
50 | self.anat_opacity_slider.setMinimum(0)
51 | self.anat_opacity_slider.setMaximum(100)
52 | self.anat_opacity_slider.setValue(20)
53 | self.anat_opacity_slider.valueChanged.connect(self.set_global_anat_opacity)
54 | opacity_widget_layout.addWidget(self.anat_opacity_slider)
55 | layout.addWidget(opacity_widget)
56 |
57 |
58 | self.setLayout(layout)
59 | self.setWindowTitle("3D Viewer")
60 | screen_width = QApplication.desktop().screenGeometry().width()
61 | self.setGeometry(screen_width - WIDTH, 0, WIDTH, HEIGHT)
62 | self.real_close_flag = False
63 |
64 | @pyqtSlot(list, np.ndarray)
65 | def set_spacing_and_anatomy(self, spacing, anatomy):
66 | self.spacing = spacing
67 | self.anatomy = anatomy.astype(np.uint16)
68 | self.visualize_anatomy()
69 |
70 | @pyqtSlot(int)
71 | def set_global_anat_opacity(self, value):
72 | self.anatomy_opacity = float(value)/100
73 | if self.actor_anatomy is None:
74 | return
75 |
76 | lut = pv.LookupTable(cmap=COLORMAP)
77 | lut.apply_opacity(OPACITY_ARRAY * self.anatomy_opacity)
78 | self.actor_anatomy.prop.apply_lookup_table(lut)
79 | self.plotter.render()
80 |
81 | def visualize_anatomy(self):
82 | self.plotter.remove_actor(self.actor_anatomy, render=False)
83 | if self.anatomy is None or self.spacing is None or self.anatomy_opacity == 0:
84 | self.plotter.render()
85 | return
86 | vol = pv.ImageData(dimensions=np.array(self.anatomy.shape)+1, spacing=self.spacing)
87 | vol.cell_data['values'] = self.anatomy.flatten(order='F')
88 |
89 | opacity = (OPACITY_ARRAY* self.anatomy_opacity)
90 |
91 | self.actor_anatomy = self.plotter.add_volume(vol,
92 | scalars='values',
93 | clim=[self.anatomy.min(), self.anatomy.max()],
94 | opacity=opacity,
95 | cmap=COLORMAP,
96 | show_scalar_bar=False)
97 | self.plotter.show()
98 |
99 | @pyqtSlot(list, np.ndarray)
100 | def set_spacing_and_data(self, spacing, data):
101 | """
102 | Set the data and spacing.
103 | """
104 | self.spacing = spacing
105 | self.data = data
106 | self.update_data()
107 |
108 | @pyqtSlot(list)
109 | def set_spacing(self, spacing):
110 | """
111 | Set the affine transformation matrix.
112 | """
113 | self.spacing = spacing
114 |
115 | @pyqtSlot(np.ndarray)
116 | def set_affine(self, affine):
117 | """
118 | Set the affine transformation matrix.
119 | """
120 | column_spacing = np.linalg.norm(affine[:3, 0])
121 | row_spacing = np.linalg.norm(affine[:3, 1])
122 | slice_spacing = np.linalg.norm(affine[:3, 2])
123 | self.spacing = (column_spacing, row_spacing, slice_spacing) # mm
124 | self.data = None
125 |
126 | def update_data(self):
127 | if not self.isVisible():
128 | return
129 |
130 | camera_position = self.plotter.camera_position
131 | #self.plotter.clear()
132 | self.plotter.remove_actor(self.actor_roi, reset_camera=False, render=False)
133 | if self.data is None or self.spacing is None or not np.any(self.data):
134 | print("No data to plot")
135 | self.plotter.render()
136 | return
137 |
138 |
139 | #grid = pv.UniformGrid(dimensions=self.data.shape, spacing=self.spacing)
140 | grid = pv.ImageData(dimensions=self.data.shape, spacing=self.spacing)
141 | surf = grid.contour([0.5], self.data.flatten(order="F"), method='marching_cubes')
142 | color = GlobalConfig['ROI_COLOR']
143 | color = [color[0], color[1], color[2]]
144 | self.actor_roi = self.plotter.add_mesh(surf,
145 | color=color,
146 | opacity=0.8,
147 | show_edges=False,
148 | smooth_shading=True,
149 | specular=0.5,
150 | show_scalar_bar=False,
151 | render=False
152 | )
153 |
154 | #restore camera position if it's not the default, which is too narrow
155 | if np.max(np.abs(camera_position[0])) > 1:
156 | self.plotter.camera_position = camera_position
157 | self.plotter.render()
158 |
159 | def real_close(self):
160 | self.real_close_flag = True
161 | self.close()
162 |
163 | def closeEvent(self, event):
164 | if self.real_close_flag: # if the window is closed by the user
165 | event.accept()
166 | self.hide_signal.emit()
167 | event.ignore()
168 | self.hide()
169 |
170 | @pyqtSlot(np.ndarray)
171 | def set_data(self, data):
172 | """
173 | Set the data to be plotted.
174 | """
175 | self.data = data
176 | self.update_data()
177 |
178 | @pyqtSlot(int, np.ndarray)
179 | def set_slice(self, slice_number, slice_data):
180 | """
181 | Set the slice to be plotted.
182 | """
183 | if self.data is None:
184 | return
185 | self.data[:, :, slice_number] = slice_data
186 | self.update_data()
187 |
--------------------------------------------------------------------------------
/icons/dafne_icon.svg:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
126 |
--------------------------------------------------------------------------------
/src/dafne/MedSAM/segment_anything/modeling/mask_decoder.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # Copyright (c) Meta Platforms, Inc. and affiliates.
3 | # All rights reserved.
4 |
5 | # This source code is licensed under the license found in the
6 | # LICENSE file in the root directory of this source tree.
7 |
8 | import torch
9 | from torch import nn
10 | from torch.nn import functional as F
11 |
12 | from typing import List, Tuple, Type
13 |
14 | from .common import LayerNorm2d
15 |
16 |
17 | class MaskDecoder(nn.Module):
18 | def __init__(
19 | self,
20 | *,
21 | transformer_dim: int,
22 | transformer: nn.Module,
23 | num_multimask_outputs: int = 3,
24 | activation: Type[nn.Module] = nn.GELU,
25 | iou_head_depth: int = 3,
26 | iou_head_hidden_dim: int = 256,
27 | ) -> None:
28 | """
29 | Predicts masks given an image and prompt embeddings, using a
30 | transformer architecture.
31 |
32 | Arguments:
33 | transformer_dim (int): the channel dimension of the transformer
34 | transformer (nn.Module): the transformer used to predict masks
35 | num_multimask_outputs (int): the number of masks to predict
36 | when disambiguating masks
37 | activation (nn.Module): the type of activation to use when
38 | upscaling masks
39 | iou_head_depth (int): the depth of the MLP used to predict
40 | mask quality
41 | iou_head_hidden_dim (int): the hidden dimension of the MLP
42 | used to predict mask quality
43 | """
44 | super().__init__()
45 | self.transformer_dim = transformer_dim
46 | self.transformer = transformer
47 |
48 | self.num_multimask_outputs = num_multimask_outputs
49 |
50 | self.iou_token = nn.Embedding(1, transformer_dim)
51 | self.num_mask_tokens = num_multimask_outputs + 1
52 | self.mask_tokens = nn.Embedding(self.num_mask_tokens, transformer_dim)
53 |
54 | self.output_upscaling = nn.Sequential(
55 | nn.ConvTranspose2d(
56 | transformer_dim, transformer_dim // 4, kernel_size=2, stride=2
57 | ),
58 | LayerNorm2d(transformer_dim // 4),
59 | activation(),
60 | nn.ConvTranspose2d(
61 | transformer_dim // 4, transformer_dim // 8, kernel_size=2, stride=2
62 | ),
63 | activation(),
64 | )
65 | self.output_hypernetworks_mlps = nn.ModuleList(
66 | [
67 | MLP(transformer_dim, transformer_dim, transformer_dim // 8, 3)
68 | for i in range(self.num_mask_tokens)
69 | ]
70 | )
71 |
72 | self.iou_prediction_head = MLP(
73 | transformer_dim, iou_head_hidden_dim, self.num_mask_tokens, iou_head_depth
74 | )
75 |
76 | def forward(
77 | self,
78 | image_embeddings: torch.Tensor,
79 | image_pe: torch.Tensor,
80 | sparse_prompt_embeddings: torch.Tensor,
81 | dense_prompt_embeddings: torch.Tensor,
82 | multimask_output: bool,
83 | ) -> Tuple[torch.Tensor, torch.Tensor]:
84 | """
85 | Predict masks given image and prompt embeddings.
86 |
87 | Arguments:
88 | image_embeddings (torch.Tensor): the embeddings from the image encoder
89 | image_pe (torch.Tensor): positional encoding with the shape of image_embeddings
90 | sparse_prompt_embeddings (torch.Tensor): the embeddings of the points and boxes
91 | dense_prompt_embeddings (torch.Tensor): the embeddings of the mask inputs
92 | multimask_output (bool): Whether to return multiple masks or a single
93 | mask.
94 |
95 | Returns:
96 | torch.Tensor: batched predicted masks
97 | torch.Tensor: batched predictions of mask quality
98 | """
99 | masks, iou_pred = self.predict_masks(
100 | image_embeddings=image_embeddings,
101 | image_pe=image_pe,
102 | sparse_prompt_embeddings=sparse_prompt_embeddings,
103 | dense_prompt_embeddings=dense_prompt_embeddings,
104 | )
105 |
106 | # Select the correct mask or masks for output
107 | if multimask_output:
108 | mask_slice = slice(1, None)
109 | else:
110 | mask_slice = slice(0, 1)
111 | masks = masks[:, mask_slice, :, :]
112 | iou_pred = iou_pred[:, mask_slice]
113 |
114 | # Prepare output
115 | return masks, iou_pred
116 |
117 | def predict_masks(
118 | self,
119 | image_embeddings: torch.Tensor,
120 | image_pe: torch.Tensor,
121 | sparse_prompt_embeddings: torch.Tensor,
122 | dense_prompt_embeddings: torch.Tensor,
123 | ) -> Tuple[torch.Tensor, torch.Tensor]:
124 | """Predicts masks. See 'forward' for more details."""
125 | # Concatenate output tokens
126 | output_tokens = torch.cat(
127 | [self.iou_token.weight, self.mask_tokens.weight], dim=0
128 | )
129 | output_tokens = output_tokens.unsqueeze(0).expand(
130 | sparse_prompt_embeddings.size(0), -1, -1
131 | )
132 | tokens = torch.cat((output_tokens, sparse_prompt_embeddings), dim=1)
133 |
134 | # Expand per-image data in batch direction to be per-mask
135 | if image_embeddings.shape[0] != tokens.shape[0]:
136 | src = torch.repeat_interleave(image_embeddings, tokens.shape[0], dim=0)
137 | else:
138 | src = image_embeddings
139 | src = src + dense_prompt_embeddings
140 | pos_src = torch.repeat_interleave(image_pe, tokens.shape[0], dim=0)
141 | b, c, h, w = src.shape
142 |
143 | # Run the transformer
144 | hs, src = self.transformer(src, pos_src, tokens)
145 | iou_token_out = hs[:, 0, :]
146 | mask_tokens_out = hs[:, 1 : (1 + self.num_mask_tokens), :]
147 |
148 | # Upscale mask embeddings and predict masks using the mask tokens
149 | src = src.transpose(1, 2).view(b, c, h, w)
150 | upscaled_embedding = self.output_upscaling(src)
151 | hyper_in_list: List[torch.Tensor] = []
152 | for i in range(self.num_mask_tokens):
153 | hyper_in_list.append(
154 | self.output_hypernetworks_mlps[i](mask_tokens_out[:, i, :])
155 | )
156 | hyper_in = torch.stack(hyper_in_list, dim=1)
157 | b, c, h, w = upscaled_embedding.shape
158 | masks = (hyper_in @ upscaled_embedding.view(b, c, h * w)).view(b, -1, h, w)
159 |
160 | # Generate mask quality predictions
161 | iou_pred = self.iou_prediction_head(iou_token_out)
162 |
163 | return masks, iou_pred
164 |
165 |
166 | # Lightly adapted from
167 | # https://github.com/facebookresearch/MaskFormer/blob/main/mask_former/modeling/transformer/transformer_predictor.py # noqa
168 | class MLP(nn.Module):
169 | def __init__(
170 | self,
171 | input_dim: int,
172 | hidden_dim: int,
173 | output_dim: int,
174 | num_layers: int,
175 | sigmoid_output: bool = False,
176 | ) -> None:
177 | super().__init__()
178 | self.num_layers = num_layers
179 | h = [hidden_dim] * (num_layers - 1)
180 | self.layers = nn.ModuleList(
181 | nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim])
182 | )
183 | self.sigmoid_output = sigmoid_output
184 |
185 | def forward(self, x):
186 | for i, layer in enumerate(self.layers):
187 | x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
188 | if self.sigmoid_output:
189 | x = F.sigmoid(x)
190 | return x
191 |
--------------------------------------------------------------------------------
/src/dafne/MedSAM/utils/pre_CT_MR.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # %% import packages
3 | # pip install connected-components-3d
4 | import numpy as np
5 |
6 | # import nibabel as nib
7 | import SimpleITK as sitk
8 | import os
9 |
10 | join = os.path.join
11 | from skimage import transform
12 | from tqdm import tqdm
13 | import cc3d
14 |
15 | # convert nii image to npz files, including original image and corresponding masks
16 | modality = "CT"
17 | anatomy = "Abd" # anantomy + dataset name
18 | img_name_suffix = "_0000.nii.gz"
19 | gt_name_suffix = ".nii.gz"
20 | prefix = modality + "_" + anatomy + "_"
21 |
22 | nii_path = "data/FLARE22Train/images" # path to the nii images
23 | gt_path = "data/FLARE22Train/labels" # path to the ground truth
24 | npy_path = "data/npy/" + prefix[:-1]
25 | os.makedirs(join(npy_path, "gts"), exist_ok=True)
26 | os.makedirs(join(npy_path, "imgs"), exist_ok=True)
27 |
28 | image_size = 1024
29 | voxel_num_thre2d = 100
30 | voxel_num_thre3d = 1000
31 |
32 | names = sorted(os.listdir(gt_path))
33 | print(f"ori \# files {len(names)=}")
34 | names = [
35 | name
36 | for name in names
37 | if os.path.exists(join(nii_path, name.split(gt_name_suffix)[0] + img_name_suffix))
38 | ]
39 | print(f"after sanity check \# files {len(names)=}")
40 |
41 | # set label ids that are excluded
42 | remove_label_ids = [
43 | 12
44 | ] # remove deodenum since it is scattered in the image, which is hard to specify with the bounding box
45 | tumor_id = None # only set this when there are multiple tumors; convert semantic masks to instance masks
46 | # set window level and width
47 | # https://radiopaedia.org/articles/windowing-ct
48 | WINDOW_LEVEL = 40 # only for CT images
49 | WINDOW_WIDTH = 400 # only for CT images
50 |
51 | # %% save preprocessed images and masks as npz files
52 | for name in tqdm(names[:40]): # use the remaining 10 cases for validation
53 | image_name = name.split(gt_name_suffix)[0] + img_name_suffix
54 | gt_name = name
55 | gt_sitk = sitk.ReadImage(join(gt_path, gt_name))
56 | gt_data_ori = np.uint8(sitk.GetArrayFromImage(gt_sitk))
57 | # remove label ids
58 | for remove_label_id in remove_label_ids:
59 | gt_data_ori[gt_data_ori == remove_label_id] = 0
60 | # label tumor masks as instances and remove from gt_data_ori
61 | if tumor_id is not None:
62 | tumor_bw = np.uint8(gt_data_ori == tumor_id)
63 | gt_data_ori[tumor_bw > 0] = 0
64 | # label tumor masks as instances
65 | tumor_inst, tumor_n = cc3d.connected_components(
66 | tumor_bw, connectivity=26, return_N=True
67 | )
68 | # put the tumor instances back to gt_data_ori
69 | gt_data_ori[tumor_inst > 0] = (
70 | tumor_inst[tumor_inst > 0] + np.max(gt_data_ori) + 1
71 | )
72 |
73 | # exclude the objects with less than 1000 pixels in 3D
74 | gt_data_ori = cc3d.dust(
75 | gt_data_ori, threshold=voxel_num_thre3d, connectivity=26, in_place=True
76 | )
77 | # remove small objects with less than 100 pixels in 2D slices
78 |
79 | for slice_i in range(gt_data_ori.shape[0]):
80 | gt_i = gt_data_ori[slice_i, :, :]
81 | # remove small objects with less than 100 pixels
82 | # reason: fro such small objects, the main challenge is detection rather than segmentation
83 | gt_data_ori[slice_i, :, :] = cc3d.dust(
84 | gt_i, threshold=voxel_num_thre2d, connectivity=8, in_place=True
85 | )
86 | # find non-zero slices
87 | z_index, _, _ = np.where(gt_data_ori > 0)
88 | z_index = np.unique(z_index)
89 |
90 | if len(z_index) > 0:
91 | # crop the ground truth with non-zero slices
92 | gt_roi = gt_data_ori[z_index, :, :]
93 | # load image and preprocess
94 | img_sitk = sitk.ReadImage(join(nii_path, image_name))
95 | image_data = sitk.GetArrayFromImage(img_sitk)
96 | # nii preprocess start
97 | if modality == "CT":
98 | lower_bound = WINDOW_LEVEL - WINDOW_WIDTH / 2
99 | upper_bound = WINDOW_LEVEL + WINDOW_WIDTH / 2
100 | image_data_pre = np.clip(image_data, lower_bound, upper_bound)
101 | image_data_pre = (
102 | (image_data_pre - np.min(image_data_pre))
103 | / (np.max(image_data_pre) - np.min(image_data_pre))
104 | * 255.0
105 | )
106 | else:
107 | lower_bound, upper_bound = np.percentile(
108 | image_data[image_data > 0], 0.5
109 | ), np.percentile(image_data[image_data > 0], 99.5)
110 | image_data_pre = np.clip(image_data, lower_bound, upper_bound)
111 | image_data_pre = (
112 | (image_data_pre - np.min(image_data_pre))
113 | / (np.max(image_data_pre) - np.min(image_data_pre))
114 | * 255.0
115 | )
116 | image_data_pre[image_data == 0] = 0
117 |
118 | image_data_pre = np.uint8(image_data_pre)
119 | img_roi = image_data_pre[z_index, :, :]
120 | np.savez_compressed(join(npy_path, prefix + gt_name.split(gt_name_suffix)[0]+'.npz'), imgs=img_roi, gts=gt_roi, spacing=img_sitk.GetSpacing())
121 | # save the image and ground truth as nii files for sanity check;
122 | # they can be removed
123 | img_roi_sitk = sitk.GetImageFromArray(img_roi)
124 | gt_roi_sitk = sitk.GetImageFromArray(gt_roi)
125 | sitk.WriteImage(
126 | img_roi_sitk,
127 | join(npy_path, prefix + gt_name.split(gt_name_suffix)[0] + "_img.nii.gz"),
128 | )
129 | sitk.WriteImage(
130 | gt_roi_sitk,
131 | join(npy_path, prefix + gt_name.split(gt_name_suffix)[0] + "_gt.nii.gz"),
132 | )
133 | # save the each CT image as npy file
134 | for i in range(img_roi.shape[0]):
135 | img_i = img_roi[i, :, :]
136 | img_3c = np.repeat(img_i[:, :, None], 3, axis=-1)
137 | resize_img_skimg = transform.resize(
138 | img_3c,
139 | (image_size, image_size),
140 | order=3,
141 | preserve_range=True,
142 | mode="constant",
143 | anti_aliasing=True,
144 | )
145 | resize_img_skimg_01 = (resize_img_skimg - resize_img_skimg.min()) / np.clip(
146 | resize_img_skimg.max() - resize_img_skimg.min(), a_min=1e-8, a_max=None
147 | ) # normalize to [0, 1], (H, W, 3)
148 | gt_i = gt_roi[i, :, :]
149 | resize_gt_skimg = transform.resize(
150 | gt_i,
151 | (image_size, image_size),
152 | order=0,
153 | preserve_range=True,
154 | mode="constant",
155 | anti_aliasing=False,
156 | )
157 | resize_gt_skimg = np.uint8(resize_gt_skimg)
158 | assert resize_img_skimg_01.shape[:2] == resize_gt_skimg.shape
159 | np.save(
160 | join(
161 | npy_path,
162 | "imgs",
163 | prefix
164 | + gt_name.split(gt_name_suffix)[0]
165 | + "-"
166 | + str(i).zfill(3)
167 | + ".npy",
168 | ),
169 | resize_img_skimg_01,
170 | )
171 | np.save(
172 | join(
173 | npy_path,
174 | "gts",
175 | prefix
176 | + gt_name.split(gt_name_suffix)[0]
177 | + "-"
178 | + str(i).zfill(3)
179 | + ".npy",
180 | ),
181 | resize_gt_skimg,
182 | )
183 |
--------------------------------------------------------------------------------