├── .github ├── ISSUE_TEMPLATE │ ├── bug_report.md │ ├── feature_request.md │ └── question.md └── workflows │ ├── linter.yml │ ├── python-publish.yml │ ├── tests.yml │ └── tests_mr.yml ├── .gitignore ├── .pre-commit-config.yaml ├── LICENSE ├── README.md ├── TPTBox ├── __init__.py ├── core │ ├── __init__.py │ ├── bids_constants.py │ ├── bids_files.py │ ├── compat.py │ ├── dicom │ │ ├── __init__.py │ │ ├── dicom2nii_utils.py │ │ ├── dicom_extract.py │ │ ├── dicom_header_to_keys.py │ │ ├── fix_brocken.py │ │ └── nii2dicom.py │ ├── internal │ │ ├── __init__.py │ │ ├── ants_load.py │ │ ├── deep_learning_utils.py │ │ └── nii_help.py │ ├── nii_poi_abstract.py │ ├── nii_wrapper.py │ ├── nii_wrapper_math.py │ ├── np_utils.py │ ├── poi.py │ ├── poi_fun │ │ ├── __init__.py │ │ ├── _help.py │ │ ├── pixel_based_point_finder.py │ │ ├── poi_abstract.py │ │ ├── poi_global.py │ │ ├── ray_casting.py │ │ ├── save_load.py │ │ ├── strategies.py │ │ ├── vertebra_direction.py │ │ └── vertebra_pois_non_centroids.py │ ├── sitk_utils.py │ └── vert_constants.py ├── logger │ ├── __init__.py │ ├── log_constants.py │ └── log_file.py ├── mesh3D │ ├── __init__.py │ ├── mesh.py │ ├── mesh_colors.py │ └── snapshot3D.py ├── registration │ ├── __init__.py │ ├── deepali │ │ ├── __init__.py │ │ ├── _hooks.py │ │ ├── _utils.py │ │ ├── deepali_model.py │ │ ├── deepali_trainer.py │ │ └── spine_rigid_elements_reg.py │ ├── deformable │ │ ├── __init__.py │ │ ├── _deepali │ │ │ ├── __init__.py │ │ │ ├── deform_reg_pair.py │ │ │ ├── deformable_config.yaml │ │ │ ├── engine.py │ │ │ ├── hooks.py │ │ │ ├── metrics.py │ │ │ ├── optim.py │ │ │ └── registration_losses.py │ │ ├── _grid_search_vert.py │ │ ├── deformable_reg.py │ │ ├── deformable_reg_old.py │ │ ├── grid_search.py │ │ └── settings.json │ ├── ridged_intensity │ │ ├── __init__.py │ │ ├── affine_deepali.py │ │ └── register.py │ ├── ridged_points │ │ ├── __init__.py │ │ └── point_registration.py │ └── script_ax2sag.py ├── segmentation │ ├── TotalVibeSeg │ │ ├── __init__.py │ │ ├── auto_download.py │ │ ├── inference_nnunet.py │ │ └── totalvibeseg.py │ ├── __init__.py │ ├── nnUnet_utils │ │ ├── __init__.py │ │ ├── data_iterators.py │ │ ├── default_preprocessor.py │ │ ├── export_prediction.py │ │ ├── get_network_from_plans.py │ │ ├── inference_api.py │ │ ├── plans_handler.py │ │ ├── predictor.py │ │ └── sliding_window_prediction.py │ ├── oar_segmentator │ │ ├── __init__.py │ │ ├── map_to_binary.py │ │ └── run.py │ └── spineps.py ├── spine │ ├── __init__.py │ ├── snapshot2D │ │ ├── __init__.py │ │ ├── snapshot_modular.py │ │ └── snapshot_templates.py │ └── spinestats │ │ ├── __init__.py │ │ ├── angles.py │ │ ├── distances.py │ │ ├── ivd_pois.py │ │ └── make_endplate.py ├── stitching │ ├── README.md │ ├── __init__.py │ ├── stitching.jpg │ ├── stitching.py │ └── stitching_tools.py └── tests │ ├── __init__.py │ ├── sample_ct │ ├── sub-ct_label-22_ct.nii.gz │ ├── sub-ct_seg-subreg_label-22_msk.nii.gz │ └── sub-ct_seg-vert_label-22_msk.nii.gz │ ├── sample_mri │ ├── sub-mri_label-6_T2w.nii.gz │ ├── sub-mri_seg-subreg_label-6_msk.nii.gz │ └── sub-mri_seg-vert_label-6_msk.nii.gz │ ├── speedtests │ ├── __init__.py │ ├── speedtest.py │ ├── speedtest_cc3d.py │ ├── speedtest_cc3d_crop.py │ ├── speedtest_connected_components.py │ ├── speedtest_connected_components_labelwise.py │ ├── speedtest_connected_components_simple.py │ ├── speedtest_count_nonzero.py │ ├── speedtest_crop.py │ ├── speedtest_dilate.py │ ├── speedtest_extract_label.py │ ├── speedtest_extract_label_loop.py │ ├── speedtest_extract_label_nii.py │ ├── speedtest_fillholes.py │ ├── speedtest_filter_connected_components.py │ ├── speedtest_isempty.py │ ├── speedtest_maplabels.py │ ├── speedtest_npunique.py │ └── speedtest_uncrop.py │ ├── test_cc3d.py │ └── test_utils.py ├── examples ├── dicom_select │ └── __init__.py ├── nako │ ├── README.md │ ├── __init__.py │ ├── dicom2nii_bids.py │ ├── stitching_T2w.py │ └── stitching_vibe.py └── registration │ ├── IVD_transfer │ ├── __init__.py │ └── transfare_spine_seg.py │ └── atlas_poi_transfer_leg │ ├── __init__.py │ ├── atlas_poi_transfer.py │ ├── atlas_poi_transfer_leg_ct.py │ ├── example.ipynb │ └── example.py ├── pyproject.toml ├── seg_all.ipynb ├── tutorials ├── internal │ └── tutorial_continuous_testing.ipynb ├── tutorial_BIDS_files.ipynb ├── tutorial_Dataset_processing.ipynb ├── tutorial_Nifty.ipynb ├── tutorial_POI.ipynb ├── tutorial_logger.ipynb ├── tutorial_pointregistation.ipynb └── tutorial_snapshot2D.ipynb └── unit_tests ├── __init__.py ├── test_bids_dataset_parallel.py ├── test_bids_file.py ├── test_bids_file_print.py ├── test_centroids.py ├── test_centroids_save.py ├── test_compat.py ├── test_nii.py ├── test_nii_wrapper_auto.py ├── test_nputils.py ├── test_poi.py ├── test_poi_autogen.py ├── test_poi_global.py ├── test_reg_seg.py ├── test_slicing.py ├── test_stiching.py ├── test_testsamples.py └── test_vertconstants.py /.github/ISSUE_TEMPLATE/bug_report.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Bug report 3 | about: Create a report to help us find the error and fix it 4 | title: '' 5 | labels: '' 6 | assignees: '' 7 | 8 | --- 9 | 10 | **Describe the bug** 11 | A clear description of what the bug is. 12 | 13 | **How to Reproduce** 14 | Minimal steps to reproduce the behavior (or send code snippets): 15 | 1. Go to '...' 16 | 2. Install '....' 17 | 3. Run commands '....' 18 | 19 | **Expected behavior** 20 | A clear description of what you expected to happen or should happen. 21 | 22 | **Screenshots** 23 | If possible, add screenshots to help explain your problem. 24 | 25 | **Environment** 26 | 27 | ### operating system and version? 28 | e.g. Ubuntu 23.10 LTS 29 | 30 | ### Python environment and version? 31 | e.g. Conda environment with Python 3.10. Check your Python version with: 32 | ```sh 33 | python --version 34 | ``` 35 | 36 | **Additional context** 37 | Add any other context about the problem here. 38 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/feature_request.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Feature request 3 | about: Suggest an idea for this project 4 | title: '' 5 | labels: '' 6 | assignees: '' 7 | 8 | --- 9 | 10 | **Description** 11 | Describe what feature and/or behavior you would like to have. E.g. I want ... 12 | 13 | **Is your feature request related to a problem? Please describe.** 14 | A clear description of what the problem is. Ex. I'm always frustrated when [...] 15 | 16 | **Describe alternatives you've considered** 17 | A clear and concise description of any alternative solutions or features you've considered. You can also cite other packages or relate to code. 18 | 19 | **Additional context** 20 | Add any other context or screenshots about the feature request here. 21 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/question.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Question 3 | about: Please ask your question, make sure to read the FAQ before 4 | title: '' 5 | labels: '' 6 | assignees: '' 7 | 8 | --- 9 | 10 | **Your question** 11 | A clear and concise description of what the problem is. Ex. I'm always frustrated when [...] 12 | -------------------------------------------------------------------------------- /.github/workflows/linter.yml: -------------------------------------------------------------------------------- 1 | name: ruff linter 2 | 3 | on: 4 | push: 5 | branches: [ main ] 6 | pull_request: 7 | branches: [ main ] 8 | 9 | jobs: 10 | ruff: 11 | runs-on: ubuntu-latest 12 | steps: 13 | - uses: actions/checkout@v3 14 | - uses: actions/setup-python@v3 15 | - run: pip install ruff 16 | - run: ruff check . #--fix 17 | #- uses: chartboost/ruff-action@v1 18 | # with: 19 | # fix_args: --fix 20 | # with: 21 | # args: --check . 22 | - uses: stefanzweifel/git-auto-commit-action@v4 23 | with: 24 | commit_message: 'style fixes by ruff' 25 | -------------------------------------------------------------------------------- /.github/workflows/python-publish.yml: -------------------------------------------------------------------------------- 1 | name: release_to_pypi 2 | 3 | on: 4 | release: 5 | types: [created] 6 | 7 | permissions: 8 | contents: read 9 | 10 | jobs: 11 | deploy: 12 | name: Publish to test PyPI 13 | runs-on: ubuntu-latest 14 | 15 | steps: 16 | - uses: actions/checkout@v3 17 | - name: Set up Python 18 | uses: actions/setup-python@v4 19 | with: 20 | python-version: '3.10' 21 | - name: Install dependencies 22 | run: | 23 | python -m pip install --upgrade pip 24 | pip install build 25 | pip install twine 26 | - name: Build package 27 | run: python -m build 28 | #| # 29 | #poetry version $(git describe --tags --abbrev=0) 30 | #poetry build 31 | - name: Upload to PyPI 32 | env: 33 | TWINE_USERNAME: __token__ 34 | TWINE_PASSWORD: ${{ secrets.PYPI_API_TOKEN }} 35 | run: | 36 | twine upload dist/*.whl 37 | #- name: Publish package 38 | # #uses: pypa/gh-action-pypi-publish@27b31702a0e7fc50959f5ad993c78deac1bdfc29 39 | # uses: pypa/gh-action-pypi-publish@release/v1 40 | # with: 41 | # verbose: true 42 | # #user: Hendrik_Code 43 | # password: ${{ secrets.PYPI_API_TOKEN }} 44 | -------------------------------------------------------------------------------- /.github/workflows/tests.yml: -------------------------------------------------------------------------------- 1 | # This workflow will install Python dependencies, run tests and lint with a variety of Python versions 2 | # For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-python 3 | 4 | name: tests 5 | 6 | on: 7 | push: 8 | branches: [ "main" ] 9 | 10 | jobs: 11 | build: 12 | runs-on: ${{ matrix.os }} 13 | strategy: 14 | fail-fast: false 15 | matrix: 16 | os: [ubuntu-latest, windows-latest] 17 | python-version: ["3.9", "3.10", "3.11", "3.12"] 18 | 19 | steps: 20 | - uses: actions/checkout@v4 21 | - name: Set up Python ${{ matrix.python-version }} 22 | uses: actions/setup-python@v5 23 | with: 24 | python-version: ${{ matrix.python-version }} 25 | - name: Configure python 26 | run: | 27 | python -m pip install --upgrade pip 28 | python -m pip install poetry 29 | python -m pip install flake8 pytest 30 | - name: Install dependancies 31 | run: | 32 | python -m poetry install 33 | - name: Lint with flake8 34 | run: | 35 | # stop the build if there are Python syntax errors or undefined names 36 | flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics 37 | # exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide 38 | flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics 39 | - name: Test with pytest and create coverage report 40 | run: | 41 | python -m poetry run coverage run --source=TPTBox -m pytest 42 | python -m poetry run coverage xml 43 | - name: Upload coverage results to Codecov (Only on merge to main) 44 | # Only upload to Codecov after a merge to the main branch 45 | if: github.ref == 'refs/heads/main' && github.event_name == 'push' 46 | uses: codecov/codecov-action@v4 47 | with: 48 | token: ${{ secrets.CODECOV_TOKEN }} 49 | -------------------------------------------------------------------------------- /.github/workflows/tests_mr.yml: -------------------------------------------------------------------------------- 1 | # This workflow will install Python dependencies, run tests and lint with a variety of Python versions 2 | # For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-python 3 | 4 | name: tests 5 | 6 | on: 7 | pull_request: 8 | branches: [ "main" ] 9 | 10 | jobs: 11 | build: 12 | runs-on: ${{ matrix.os }} 13 | strategy: 14 | fail-fast: false 15 | matrix: 16 | os: [ubuntu-latest] 17 | python-version: ["3.10"] 18 | 19 | steps: 20 | - uses: actions/checkout@v4 21 | - name: Set up Python ${{ matrix.python-version }} 22 | uses: actions/setup-python@v5 23 | with: 24 | python-version: ${{ matrix.python-version }} 25 | - name: Install dependencies 26 | run: | 27 | python -m pip install --upgrade pip 28 | python -m pip install flake8 pytest 29 | pip install -e . 30 | - name: Lint with flake8 31 | run: | 32 | # stop the build if there are Python syntax errors or undefined names 33 | flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics 34 | # exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide 35 | flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics 36 | - name: Test with pytest 37 | run: | 38 | pytest 39 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | env/ 12 | build/ 13 | develop-eggs/ 14 | dist/ 15 | downloads/ 16 | eggs/ 17 | .eggs/ 18 | lib/ 19 | lib64/ 20 | parts/ 21 | papers/ 22 | sdist/ 23 | var/ 24 | wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | *_tmp.* 29 | MNIST 30 | lightning_logs 31 | version_0 32 | # PyInstaller 33 | # Usually these files are written by a python script from a template 34 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 35 | *.manifest 36 | *.spec 37 | 38 | # Installer logs 39 | pip-log.txt 40 | pip-delete-this-directory.txt 41 | 42 | # Unit test / coverage reports 43 | htmlcov/ 44 | .tox/ 45 | .coverage 46 | .coverage.* 47 | .cache 48 | nosetests.xml 49 | coverage.xml 50 | *.cover 51 | .hypothesis/ 52 | 53 | # Translations 54 | *.mo 55 | *.pot 56 | 57 | # Django stuff: 58 | *.log 59 | local_settings.py 60 | 61 | # Flask stuff: 62 | instance/ 63 | .webassets-cache 64 | 65 | # Scrapy stuff: 66 | .scrapy 67 | 68 | # Sphinx documentation 69 | docs/_build/ 70 | 71 | # PyBuilder 72 | target/ 73 | 74 | # Jupyter Notebook 75 | .ipynb_checkpoints 76 | 77 | # pyenv 78 | .python-version 79 | 80 | # celery beat schedule file 81 | celerybeat-schedule 82 | 83 | # SageMath parsed files 84 | *.sage.py 85 | 86 | # dotenv 87 | .env 88 | 89 | # virtualenv 90 | .venv 91 | venv/ 92 | ENV/ 93 | 94 | # Spyder project settings 95 | .spyderproject 96 | .spyproject 97 | 98 | # Rope project settings 99 | .ropeproject 100 | 101 | # mkdocs documentation 102 | /site 103 | 104 | # mypy 105 | .mypy_cache/ 106 | 107 | # project directories 108 | data 109 | data/* 110 | datasets/* 111 | output/A/ 112 | output/B/ 113 | output*/ 114 | unet/out*/ 115 | *.gz 116 | *.nii 117 | *.png 118 | *.csv 119 | 120 | # model checkpoints 121 | *.pth 122 | *.pckl 123 | *.tfrecord 124 | */settings.json 125 | unet/*_seg/ 126 | *.jpg 127 | 128 | *.txt 129 | weights/ 130 | *.gif 131 | *.png 132 | *.npy 133 | 134 | lightning_logs 135 | lightning_logs* 136 | fid 137 | log_* 138 | logs_* 139 | 140 | */test_data/* 141 | 142 | *test_Kati.py 143 | /tmp/ 144 | */mokeups/* 145 | Conventional_Registration_Tutorial.ipynb 146 | airlab 147 | /TPTBox/registation/mokeups/* 148 | *.pyc 149 | test_*.ipynb 150 | .*DS_Store 151 | *tmp*.ipynb 152 | *.json 153 | tmp_*.py 154 | *.cache 155 | *tmp*.py 156 | tmp 157 | tmp/* 158 | !/TPTBox/tests/sample_mri/* 159 | !/TPTBox/tests/sample_ct/* 160 | poetry.lock 161 | *.dcm 162 | *.pkl 163 | tutorials/tutorial_data_processing/* 164 | tutorials/*PixelPandemonium/* 165 | tutorials/dataset-PixelPandemonium/* 166 | *.html 167 | _*.py 168 | dicom_select -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | # See https://pre-commit.com for more information 2 | # See https://pre-commit.com/hooks.html for more hooks 3 | repos: 4 | - repo: https://github.com/pre-commit/pre-commit-hooks 5 | rev: v3.2.0 6 | hooks: 7 | - id: trailing-whitespace 8 | - id: end-of-file-fixer 9 | - id: check-yaml 10 | - id: check-added-large-files 11 | 12 | - repo: https://github.com/astral-sh/ruff-pre-commit 13 | # Ruff version. 14 | rev: v0.11.0 15 | hooks: 16 | # Run the linter. 17 | - id: ruff 18 | types_or: [ python, pyi ] 19 | args: [ --fix ] 20 | # Run the formatter. 21 | - id: ruff-format 22 | types_or: [ python, pyi ] 23 | -------------------------------------------------------------------------------- /TPTBox/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | # POI 4 | import sys 5 | from pathlib import Path 6 | 7 | sys.path.append(str(Path(__file__).parent)) 8 | # packages 9 | from TPTBox import core # noqa: I001 10 | from TPTBox.core import bids_files, np_utils 11 | 12 | # BIDS 13 | from TPTBox.core.bids_files import BIDS_FILE, BIDS_Family, BIDS_Global_info, Searchquery, Subject_Container 14 | 15 | # NII 16 | from TPTBox.core.nii_wrapper import ( 17 | NII, 18 | Image_Reference, 19 | Interpolateable_Image_Reference, 20 | to_nii, 21 | to_nii_interpolateable, 22 | to_nii_optional, 23 | to_nii_seg, 24 | Has_Grid, 25 | ) 26 | from TPTBox.core.poi import AX_CODES, POI, POI_Reference, calc_centroids, calc_poi_from_subreg_vert 27 | from TPTBox.core.poi import calc_poi_from_two_segs 28 | from TPTBox.core.poi import calc_poi_from_two_segs as calc_poi_labeled_buffered 29 | from TPTBox.core.poi_fun.poi_global import POI_Global 30 | from TPTBox.core.vert_constants import ZOOMS, Location, Vertebra_Instance, v_idx2name, v_idx_order, v_name2idx 31 | 32 | # Logger 33 | from TPTBox.logger import Log_Type, Logger, Logger_Interface, Print_Logger, String_Logger 34 | from TPTBox.logger.log_file import No_Logger 35 | 36 | Centroids = POI 37 | 38 | __all__ = [ 39 | "AX_CODES", 40 | "BIDS_FILE", 41 | "NII", 42 | "POI", 43 | "ZOOMS", 44 | "BIDS_Family", 45 | "BIDS_Global_info", 46 | "Image_Reference", 47 | "Interpolateable_Image_Reference", 48 | "Location", 49 | "Log_Type", 50 | "POI_Global", 51 | "POI_Reference", 52 | "Print_Logger", 53 | "Searchquery", 54 | "Subject_Container", 55 | "Vertebra_Instance", 56 | "bids_files", 57 | "calc_centroids", 58 | "calc_poi_from_subreg_vert", 59 | "calc_poi_from_two_segs", 60 | "core", 61 | "load_poi", 62 | "np_utils", 63 | "to_nii", 64 | "v_idx2name", 65 | "v_idx_order", 66 | "v_name2idx", 67 | ] 68 | -------------------------------------------------------------------------------- /TPTBox/core/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | # POI 4 | # packages 5 | from . import bids_files, np_utils, sitk_utils 6 | 7 | # BIDS 8 | from .bids_files import BIDS_FILE, BIDS_Family, BIDS_Global_info, Searchquery, Subject_Container 9 | 10 | # NII 11 | from .nii_wrapper import NII, Image_Reference, Interpolateable_Image_Reference, to_nii, to_nii_interpolateable, to_nii_optional, to_nii_seg 12 | from .poi import AX_CODES, POI, POI_Reference, calc_centroids, calc_poi_from_subreg_vert, calc_poi_from_two_segs 13 | from .poi_fun.poi_global import POI_Global 14 | from .vert_constants import ZOOMS, Location, v_idx2name, v_idx_order, v_name2idx 15 | 16 | __all__ = ["bids_files", "np_utils", "sitk_utils"] 17 | -------------------------------------------------------------------------------- /TPTBox/core/compat.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | 4 | def zip_strict(*iterables): 5 | """ 6 | A strict version of zip that raises a ValueError if the input iterables have different lengths. 7 | 8 | Converts each iterable to a list to check lengths. This assumes all iterables are finite. 9 | 10 | Args: 11 | *iterables: Finite iterables to be zipped together. 12 | 13 | Returns: 14 | An iterator of tuples, where the i-th tuple contains the i-th element from each iterable. 15 | 16 | Raises: 17 | ValueError: If the input iterables have different lengths. 18 | """ 19 | lists = [list(it) for it in iterables] 20 | lengths = [len(lst) for lst in lists] 21 | if len(set(lengths)) != 1: 22 | raise ValueError(f"Length mismatch: {lengths}") 23 | return zip(*lists) 24 | -------------------------------------------------------------------------------- /TPTBox/core/dicom/__init__.py: -------------------------------------------------------------------------------- 1 | from TPTBox.core.dicom.dicom_extract import extract_dicom_folder 2 | -------------------------------------------------------------------------------- /TPTBox/core/dicom/fix_brocken.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import pickle 4 | from pathlib import Path 5 | 6 | from tqdm import tqdm 7 | 8 | from TPTBox import BIDS_FILE, NII, BIDS_Global_info, Print_Logger 9 | 10 | ### 11 | # _find_all_broken(), opens all files and checks if they can be read. (Some T2Haste are 2D an are considered brocken by our software) 12 | # _test_and_replace() read the file generated by find_all_brocken and exports them into a new dataset, to be copied later back. 13 | # source_folder = Path("_Nachlieferung_20_25/") 14 | source_folder = Path("TODO") 15 | source_folders = { 16 | "mevibe": { 17 | "rule": "part", 18 | "eco1-pip1": "ME_vibe_fatquant_pre_Eco1_PIP1", 19 | "eco4-pop1": "ME_vibe_fatquant_pre_Eco4_POP1", 20 | "fat-fraction": "ME_vibe_fatquant_pre_Output_FP", 21 | "water-fraction": "ME_vibe_fatquant_pre_Output_WP", 22 | "eco3-in1": "ME_vibe_fatquant_pre_Eco3_IN1", 23 | "fat": "ME_vibe_fatquant_pre_Output_F", 24 | "water": "ME_vibe_fatquant_pre_Output_W", 25 | "eco2-opp2": "ME_vibe_fatquant_pre_Eco2_OPP2", 26 | "eco5-arb1": "ME_vibe_fatquant_pre_Eco5_ARB1", 27 | "r2s": "ME_vibe_fatquant_pre_Output_R2s_Eff", 28 | "eco0-opp1": "ME_vibe_fatquant_pre_Eco0_OPP1", 29 | }, 30 | "T2haste": {"rule": "acq", "ax": "T2_HASTE_TRA_COMPOSED"}, 31 | "vibe": { 32 | "rule": "part", 33 | "fat": "3D_GRE_TRA_F", 34 | "inphase": "3D_GRE_TRA_in", 35 | "outphase": "3D_GRE_TRA_opp", 36 | "water": "3D_GRE_TRA_W", 37 | }, 38 | "T2w": { 39 | "rule": "chunk", 40 | "LWS": "III_T2_TSE_SAG_LWS", 41 | "BWS": "II_T2_TSE_SAG_BWS", 42 | "HWS": "I_T2_TSE_SAG_HWS", 43 | }, 44 | "pd": {"rule": "acq", "iso": "PD_FS_SPC_COR"}, 45 | } 46 | 47 | 48 | def test_nii(path: Path | str | BIDS_FILE): 49 | if isinstance(path, str): 50 | path = Path(path) 51 | if path.exists(): 52 | try: 53 | NII.load(path, True).copy().unique() 54 | except Exception: 55 | return False 56 | return True 57 | 58 | 59 | def _find_all_broken(path: str = "TODO/dataset-nako/", parents=None): 60 | brocken = [] 61 | subj_id = 0 62 | with open("broken.pkl", "rb") as w: 63 | brocken = pickle.load(w) 64 | subj_id = int(brocken[-1].get("sub")) 65 | print(f"{len(brocken)=}; continue at {subj_id=}") 66 | 67 | if parents is None: 68 | parents = ["rawdata"] 69 | bgi = BIDS_Global_info([path], parents=parents) 70 | for s, subj in tqdm(bgi.enumerate_subjects(sort=True)): 71 | if int(s) <= subj_id: 72 | continue 73 | q = subj.new_query(flatten=True) 74 | q.filter_filetype("nii.gz") 75 | for f in q.loop_list(): 76 | if not test_nii(f): 77 | print("BROKEN:", f) 78 | brocken.append(f) 79 | with open("broken.pkl", "wb") as w: 80 | pickle.dump(brocken, w) 81 | with open("broken2.pkl", "wb") as w: 82 | pickle.dump(brocken, w) 83 | 84 | 85 | source_folder_encrypted = Path("/media/veracrypt1/NAKO-732_MRT/") 86 | source_folder_encrypted_alternative = Path("Nachlieferung") 87 | 88 | 89 | def _test_and_replace(out_folder="~/dataset-nako"): 90 | from TPTBox.core.dicom.dicom_extract import extract_dicom_folder 91 | 92 | with open("TODO", "rb") as w: 93 | brocken = pickle.load(w) 94 | print(len(brocken)) 95 | for bf in tqdm(brocken): 96 | bf: BIDS_FILE 97 | 98 | # Localize the zip 99 | subj = bf.get("sub") 100 | mod = bf.format 101 | sub_key = bf.get(source_folders[mod]["rule"]) 102 | # print(source_folders[mod].keys(), sub_key) 103 | assert sub_key is not None, mod 104 | 105 | # IF MEVIBE Export all 106 | if mod == "mevibe": 107 | out_files = {} 108 | for i in source_folders["mevibe"].values(): 109 | f = source_folder_encrypted / i 110 | f2 = Path("/NON/NON/ON") 111 | try: 112 | f2 = next((f).glob(f"{subj}*.zip")) 113 | except StopIteration: 114 | try: 115 | f2 = next((f).glob(f"{subj}*.zip")) 116 | except StopIteration: 117 | continue 118 | if f2.exists(): 119 | out_files.update(extract_dicom_folder(f2, Path(out_folder), make_subject_chunks=3, verbose=False)) 120 | else: 121 | print(f, f.exists()) 122 | else: 123 | try: 124 | zip_file = next((source_folder_encrypted / source_folders[mod][sub_key]).glob(f"{subj}*.zip")) 125 | except StopIteration: 126 | try: 127 | zip_file = next((source_folder_encrypted_alternative / source_folders[mod][sub_key]).glob(f"{subj}*.zip")) 128 | except StopIteration: 129 | print((source_folder_encrypted / source_folders[mod][sub_key]) / (f"{subj}*.zip")) 130 | continue 131 | ## Call the extraction 132 | out_files = extract_dicom_folder(zip_file, Path(out_folder), make_subject_chunks=3, verbose=False) 133 | ## -- Testing --- 134 | # Save over brocken... 135 | for o in out_files.values(): 136 | if o is not None and not test_nii(o): 137 | Print_Logger().on_fail("Still Broken ", out_files) 138 | 139 | 140 | if __name__ == "__main__": 141 | # test_and_replace() 142 | # find_all_broken() 143 | pass 144 | -------------------------------------------------------------------------------- /TPTBox/core/dicom/nii2dicom.py: -------------------------------------------------------------------------------- 1 | # Original Source: https://github.com/amine0110/nifti2dicom/blob/main/nifti2dicom.py 2 | # Added that you can add json, we extract back into the dicom header 3 | from __future__ import annotations 4 | 5 | import json 6 | import os 7 | import time 8 | from glob import glob 9 | from pathlib import Path 10 | 11 | import SimpleITK as sitk # noqa: N813 12 | from pydicom.tag import Tag 13 | 14 | from TPTBox import Print_Logger 15 | 16 | 17 | def writeSlices(series_tag_values: dict, new_img: sitk.Image, i, out_dir: str | Path, name="slice"): 18 | image_slice: sitk.Image = new_img[:, :, i] 19 | writer = sitk.ImageFileWriter() 20 | writer.KeepOriginalImageUIDOn() 21 | 22 | # Tags shared by the series. 23 | for k, v in series_tag_values.items(): 24 | image_slice.SetMetaData(str(k), str(v)) 25 | 26 | # Slice specific tags. 27 | image_slice.SetMetaData("0008|0012", time.strftime("%Y%m%d")) # Instance Creation Date 28 | image_slice.SetMetaData("0008|0013", time.strftime("%H%M%S")) # Instance Creation Time 29 | 30 | # Check if modality is specified in the metadata, otherwise set a default value 31 | modality = series_tag_values.get("0008|0060", "MR") # Defaulting to MR (Magnetic Resonance) 32 | image_slice.SetMetaData("0008|0060", modality) # Modality 33 | 34 | # (0020, 0032) image position patient determines the 3D spacing between slices. 35 | image_slice.SetMetaData("0020|0032", "\\".join(map(str, new_img.TransformIndexToPhysicalPoint((0, 0, i))))) # Image Position (Patient) 36 | image_slice.SetMetaData("0020|0013", str(i)) # Instance Number 37 | 38 | # Write to the output directory and add the extension dcm, to force writing in DICOM format. 39 | writer.SetFileName(str(Path(out_dir, f"{name}{str(i).zfill(4)}.dcm"))) 40 | writer.Execute(image_slice) 41 | 42 | 43 | def nifti2dicom_1file( 44 | in_nii: str | Path, out_dir: str | Path, no_json_ok=False, secondary=False, json_path: None | str | Path = None, out_name="slice" 45 | ): 46 | """ 47 | This function converts one NIfTI file into a DICOM series. 48 | 49 | Parameters: 50 | - in_nii: Path to the NIfTI file 51 | - out_dir: Path to the output directory 52 | - no_json_ok: Whether to proceed without a JSON metadata file 53 | - secondary: Whether the images are derived/secondary 54 | - json_path: Optional path to the JSON metadata file 55 | """ 56 | out_dir = Path(out_dir) 57 | out_dir.mkdir(exist_ok=True) 58 | 59 | # Get JSON metadata 60 | json_path = Path(str(in_nii).split(".")[0] + ".json") if json_path is None else Path(json_path) 61 | if json_path.exists(): 62 | with open(json_path) as j: 63 | meta: dict = json.load(j) 64 | series_tag_values = {} 65 | for k, v in meta.items(): 66 | try: 67 | tag = Tag(k) 68 | a = tag.json_key 69 | a = a[:4] + "|" + a[4:] 70 | series_tag_values[a] = v 71 | except ValueError: 72 | Print_Logger().on_fail(k, "cannot be converted to DICOM tag") 73 | else: 74 | if not no_json_ok: 75 | raise FileNotFoundError(json_path) 76 | series_tag_values = {} 77 | 78 | new_img = sitk.ReadImage(in_nii) 79 | series_tag_values["0008|0031"] = time.strftime("%H%M%S") # Modification Time 80 | series_tag_values["0008|0031"] = time.strftime("%Y%m%d") # Modification Date 81 | direction = new_img.GetDirection() 82 | series_tag_values["0020|0037"] = "\\".join( 83 | map(str, (direction[0], direction[3], direction[6], direction[1], direction[4], direction[7])) 84 | ) # Image Orientation (Patient) 85 | 86 | if secondary: 87 | series_tag_values["0008|0008"] = "DERIVED\\SECONDARY" 88 | 89 | # Generate unique Series and Study Instance UIDs if not already present 90 | if "0020|000e" not in series_tag_values: 91 | series_tag_values["0020|000e"] = ( 92 | f"1.2.826.0.1.3680043.2.1125.{series_tag_values['0008|0031']}.1{series_tag_values['0008|0031']}" # Series Instance UID 93 | ) 94 | series_tag_values["0020|000d"] = ( 95 | f"1.2.826.0.1.3680043.2.1125.1{series_tag_values['0008|0031']}.1{series_tag_values['0008|0031']}" # Study Instance UID 96 | ) 97 | # series_tag_values["0008|103e"] = "Created-Pycad" 98 | else: 99 | series_tag_values["0020|000e"] = ( 100 | f"{series_tag_values['0020|000e'][:27]}{series_tag_values['0008|0031']}.1{series_tag_values['0008|0031']}" # Series Instance UID 101 | ) 102 | series_tag_values["0020|000d"] = ( 103 | f"{series_tag_values['0020|000e'][:28]}{series_tag_values['0008|0031']}.1{series_tag_values['0008|0031']}" # Study Instance UID 104 | ) 105 | # series_tag_values["0008|103e"] = "Created-Pycad" 106 | 107 | for i in range(new_img.GetDepth()): 108 | writeSlices(series_tag_values, new_img, i, out_dir, name=out_name) 109 | 110 | 111 | def nifti2dicom_mfiles(nifti_dir, out_dir=""): 112 | """ 113 | This function converts multiple NIfTI files into DICOM series. 114 | 115 | Parameters: 116 | - nifti_dir: Path to the directory containing NIfTI files 117 | - out_dir: Path to the output directory 118 | 119 | Each NIfTI file's folder will be created automatically, so no need to create an empty folder for each patient. 120 | """ 121 | images = glob(nifti_dir + "/*.nii.gz") # noqa: PTH207 122 | 123 | for image in images: 124 | o_path = Path(out_dir, os.path.basename(image)[:-7]) # noqa: PTH119 125 | os.makedirs(o_path, exist_ok=True) # noqa: PTH103 126 | 127 | nifti2dicom_1file(image, o_path) 128 | 129 | 130 | if __name__ == "__main__": 131 | nifti2dicom_1file("sub-spinegan0004_ses-20210617_sequ-6_part-fat_dixon.nii.gz", "out_test") 132 | -------------------------------------------------------------------------------- /TPTBox/core/internal/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from .ants_load import ants_to_nifti, nifti_to_ants 4 | -------------------------------------------------------------------------------- /TPTBox/core/internal/ants_load.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from typing import TYPE_CHECKING 4 | 5 | import ants 6 | import numpy as np 7 | 8 | if TYPE_CHECKING: 9 | from nibabel.nifti1 import Nifti1Image 10 | 11 | 12 | def nifti_to_ants(nib_image: Nifti1Image): 13 | """ 14 | Convert a Nifti image to an ANTsPy image. 15 | 16 | Parameters 17 | ---------- 18 | nib_image : Nifti1Image 19 | The Nifti image to be converted. 20 | 21 | Returns 22 | ------- 23 | ants_image : ants.ANTsImage 24 | The converted ANTs image. 25 | """ 26 | ndim = nib_image.ndim 27 | 28 | if ndim < 3: 29 | raise NotImplementedError("Conversion is only implemented for 3D or higher images.") 30 | q_form = nib_image.get_qform() 31 | spacing = nib_image.header["pixdim"][1 : ndim + 1] 32 | 33 | origin = np.zeros(ndim) 34 | origin[:3] = np.dot(np.diag([-1, -1, 1]), q_form[:3, 3]) 35 | 36 | direction = np.eye(ndim) 37 | direction[:3, :3] = np.dot(np.diag([-1, -1, 1]), q_form[:3, :3]) / spacing[:3] 38 | 39 | ants_img = ants.from_numpy( 40 | data=nib_image.get_fdata(), 41 | origin=origin.tolist(), 42 | spacing=spacing.tolist(), 43 | direction=direction, 44 | ) 45 | "add nibabel conversion (lacey import to prevent forced dependency)" 46 | 47 | return ants_img 48 | 49 | 50 | def get_ras_affine_from_ants(ants_img) -> np.ndarray: 51 | """ 52 | Convert ANTs image affine to RAS coordinate system. 53 | Source: https://github.com/fepegar/torchio/blob/main/src/torchio/data/io.py 54 | Parameters 55 | ---------- 56 | ants_img : ants.ANTsImage 57 | The ANTs image whose affine is to be converted. 58 | 59 | Returns 60 | ------- 61 | affine : np.ndarray 62 | The affine matrix in RAS coordinates. 63 | """ 64 | spacing = np.array(ants_img.spacing) 65 | direction_lps = np.array(ants_img.direction) 66 | origin_lps = np.array(ants_img.origin) 67 | direction_length = direction_lps.shape[0] * direction_lps.shape[1] 68 | if direction_length == 9: 69 | rotation_lps = direction_lps.reshape(3, 3) 70 | elif direction_length == 4: # 2D case (1, W, H, 1) 71 | rotation_lps_2d = direction_lps.reshape(2, 2) 72 | rotation_lps = np.eye(3) 73 | rotation_lps[:2, :2] = rotation_lps_2d 74 | spacing = np.append(spacing, 1) 75 | origin_lps = np.append(origin_lps, 0) 76 | elif direction_length == 16: # Fix potential bad NIfTI 77 | rotation_lps = direction_lps.reshape(4, 4)[:3, :3] 78 | spacing = spacing[:-1] 79 | origin_lps = origin_lps[:-1] 80 | else: 81 | raise NotImplementedError(f"Unexpected direction length = {direction_length}.") 82 | 83 | rotation_ras = np.dot(np.diag([-1, -1, 1]), rotation_lps) 84 | rotation_ras_zoom = rotation_ras * spacing 85 | translation_ras = np.dot(np.diag([-1, -1, 1]), origin_lps) 86 | 87 | affine = np.eye(4) 88 | affine[:3, :3] = rotation_ras_zoom 89 | affine[:3, 3] = translation_ras 90 | 91 | return affine 92 | 93 | 94 | def ants_to_nifti(img, header=None) -> Nifti1Image: 95 | """ 96 | Convert an ANTs image to a Nifti image. 97 | 98 | Parameters 99 | ---------- 100 | img : ants.ANTsImage 101 | The ANTs image to be converted. 102 | header : Nifti1Header, optional 103 | Optional header to use for the Nifti image. 104 | 105 | Returns 106 | ------- 107 | img : Nifti1Image 108 | The converted Nifti image. 109 | """ 110 | from nibabel.nifti1 import Nifti1Image 111 | 112 | affine = get_ras_affine_from_ants(img) 113 | arr = img.numpy() 114 | 115 | if header is not None: 116 | header.set_data_dtype(arr.dtype) 117 | 118 | return Nifti1Image(arr, affine, header) 119 | 120 | 121 | # Legacy names for backwards compatibility 122 | from_nibabel = nifti_to_ants 123 | to_nibabel = ants_to_nifti 124 | 125 | if __name__ == "__main__": 126 | import nibabel as nib 127 | 128 | fn = ants.get_ants_data("mni") 129 | ants_img = ants.image_read(fn) 130 | nii_mni: Nifti1Image = nib.load(fn) 131 | ants_mni = to_nibabel(ants_img) 132 | assert (ants_mni.get_qform() == nii_mni.get_qform()).all() 133 | assert (ants_mni.affine == nii_mni.affine).all() 134 | temp = from_nibabel(nii_mni) 135 | 136 | assert ants.image_physical_space_consistency(ants_img, temp) 137 | 138 | fn = ants.get_data("ch2") 139 | ants_mni = ants.image_read(fn) 140 | nii_mni = nib.load(fn) 141 | ants_mni = to_nibabel(ants_mni) 142 | assert (ants_mni.get_qform() == nii_mni.get_qform()).all() 143 | 144 | nii_org = nib.load(fn) 145 | ants_org = ants.image_read(fn) 146 | temp = ants_org 147 | for _ in range(10): 148 | temp = to_nibabel(ants_org) 149 | assert (temp.get_qform() == nii_org.get_qform()).all() 150 | assert (ants_mni.affine == nii_mni.affine).all() 151 | temp = from_nibabel(temp) 152 | assert ants.image_physical_space_consistency(ants_org, temp) 153 | for _ in range(10): 154 | temp = from_nibabel(nii_org) 155 | assert ants.image_physical_space_consistency(ants_org, temp) 156 | temp = to_nibabel(temp) 157 | 158 | assert (temp.get_qform() == nii_org.get_qform()).all() 159 | assert (ants_mni.affine == nii_mni.affine).all() 160 | -------------------------------------------------------------------------------- /TPTBox/core/internal/deep_learning_utils.py: -------------------------------------------------------------------------------- 1 | from typing import Literal 2 | 3 | import torch 4 | 5 | from TPTBox.core.vert_constants import never_called 6 | 7 | DEVICES = Literal["cpu", "cuda", "mps"] 8 | 9 | 10 | def get_device(ddevice: DEVICES, gpu_id: int): 11 | if ddevice == "cpu": 12 | # import multiprocessing 13 | 14 | # try: 15 | # torch.set_num_threads(multiprocessing.cpu_count()) 16 | # except Exception: 17 | # pass 18 | device = torch.device("cpu") 19 | elif ddevice == "cuda": 20 | device = torch.device(type="cuda", index=gpu_id) 21 | elif ddevice == "mps": 22 | device = torch.device("mps") 23 | else: 24 | never_called(ddevice) 25 | return device 26 | -------------------------------------------------------------------------------- /TPTBox/core/poi_fun/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Hendrik-code/TPTBox/c6cd01460d09df747fd6baa77680a8f575049c7c/TPTBox/core/poi_fun/__init__.py -------------------------------------------------------------------------------- /TPTBox/core/poi_fun/_help.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from functools import wraps 4 | from time import time 5 | 6 | import numpy as np 7 | 8 | from TPTBox import POI, Location, Logger_Interface, Print_Logger 9 | from TPTBox.core.compat import zip_strict 10 | from TPTBox.core.nii_wrapper import NII 11 | from TPTBox.core.vert_constants import Vertebra_Instance 12 | 13 | _log = Print_Logger() 14 | sacrum_w_o_arcus = (Vertebra_Instance.COCC.value, Vertebra_Instance.S6.value, Vertebra_Instance.S5.value, Vertebra_Instance.S4.value) 15 | sacrum_w_o_direction = (Vertebra_Instance.COCC.value,) 16 | 17 | 18 | def to_local_np(loc: Location, bb: tuple[slice, slice, slice] | None, poi: POI, label, log: Logger_Interface, verbose=True): 19 | if (label, loc.value) in poi: 20 | if bb is None: 21 | return np.asarray(poi[label, loc.value]) 22 | return np.asarray([a - b.start for a, b in zip_strict(poi[label, loc.value], bb)]) 23 | if verbose: 24 | log.on_fail(f"region={label},subregion={loc.value} is missing") 25 | # raise KeyError(f"region={label},subregion={loc.value} is missing.") 26 | return None 27 | 28 | 29 | def paint_into_NII(poi: POI, a: NII, l=None, idxs=None, rays: None | list[tuple[Location, Location]] = None): 30 | from TPTBox.core.poi_fun.ray_casting import add_ray_to_img 31 | 32 | if l is None: 33 | l = [Location.Vertebra_Disc_Inferior, Location.Vertebra_Disc_Superior] 34 | if rays is None: 35 | rays = [(Location.Vertebra_Disc, Location.Vertebra_Disc_Inferior), (Location.Vertebra_Disc, Location.Vertebra_Disc_Superior)] 36 | if idxs is None: 37 | idxs = poi.keys_region(sort=True) 38 | 39 | assert a is not None 40 | spline = a.copy() * 0 41 | spline.rescale_() 42 | poi_r = poi.rescale() 43 | for loc in l: 44 | for idx in idxs: 45 | if (idx, loc) not in poi_r: 46 | continue 47 | x, y, z = poi_r[idx, loc] 48 | spline[round(x), round(y), round(z)] = loc.value 49 | spline.dilate_msk_(2) 50 | spline.resample_from_to_(a) 51 | a[spline != 0] = spline[spline != 0] 52 | for start, goal in rays: 53 | for idx in idxs: 54 | try: 55 | assert a is not None 56 | direction = np.array(poi[idx, goal]) - np.array(poi[idx, start]) 57 | if abs(direction.sum().item()) < 0.000000000001: 58 | print("skip", idx, goal, "-", start) 59 | continue 60 | a = add_ray_to_img(poi[idx, start], direction, a, True, value=199, dilate=2) # type: ignore 61 | except KeyError: 62 | pass 63 | return a 64 | 65 | 66 | def timing(f): 67 | @wraps(f) 68 | def wrap(*args, **kw): 69 | ts = time() 70 | result = f(*args, **kw) 71 | te = time() 72 | _log.on_neutral(f"func:{f.__name__!r} took: {te - ts:2.4f} sec") 73 | return result 74 | 75 | return wrap 76 | 77 | 78 | def make_spine_plot(pois: POI, body_spline, vert_nii: NII, filenames): 79 | from matplotlib import pyplot as plt 80 | 81 | pois = pois.reorient() 82 | vert_nii = vert_nii.reorient().rescale(pois.zoom) 83 | body_center_list = list(np.array(pois.values())) 84 | # fitting a curve to the poi and getting it's first derivative 85 | plt.figure(figsize=(10, 10)) 86 | plt.imshow( 87 | np.swapaxes(np.max(vert_nii.get_array(), axis=vert_nii.get_axis(direction="R")), 0, 1), 88 | cmap=plt.cm.gray, # type: ignore 89 | ) 90 | plt.plot(np.asarray(body_center_list)[:, 0], np.asarray(body_center_list)[:, 1]) 91 | plt.plot(np.asarray(body_spline[:, 0]), np.asarray(body_spline[:, 1]), "-") 92 | plt.savefig(filenames) 93 | -------------------------------------------------------------------------------- /TPTBox/logger/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from .log_constants import Log_Type 4 | from .log_file import Logger, Logger_Interface, Reflection_Logger, String_Logger 5 | from .log_file import No_Logger as Print_Logger 6 | -------------------------------------------------------------------------------- /TPTBox/logger/log_constants.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import datetime 4 | import time 5 | from enum import Enum, auto 6 | from time import struct_time 7 | 8 | 9 | class Log_Type(Enum): 10 | """The different types of Logs supported""" 11 | 12 | TEXT = auto() 13 | NEUTRAL = auto() 14 | SAVE = auto() 15 | WARNING = auto() 16 | WARNING_THROW = auto() 17 | LOG = auto() 18 | OK = auto() 19 | FAIL = auto() 20 | Yellow = auto() 21 | STRANGE = auto() 22 | UNDERLINE = auto() 23 | ITALICS = auto() 24 | BOLD = auto() 25 | DOCKER = auto() 26 | TOTALSEG = auto() 27 | STAGE = auto() 28 | 29 | 30 | class bcolors: 31 | """Terminal color symbols""" 32 | 33 | # Front Colors 34 | BLACK = "\033[30m" 35 | PINK = "\033[95m" 36 | BLUE = "\033[94m" 37 | CYAN = "\033[96m" 38 | GREEN = "\033[92m" 39 | YELLOW = "\033[93m" 40 | RED = "\033[91m" 41 | Yellow2 = "\033[33m" # "\033[33m" <-- Yellow 42 | GRAY = "\033[37m" 43 | LightGray = "\033[37m" 44 | DarkGray = "\033[90m" 45 | LightRed = "\033[91m" 46 | LightGreen = "\033[92m" 47 | LightYellow = "\033[93m" 48 | LightBlue = "\033[94m" 49 | LightMagenta = "\033[95m" 50 | LightCyan = "\033[96m" 51 | # Modes 52 | BOLD = "\033[1m" 53 | UNDERLINE = "\033[4m" 54 | DISABLE = "\033[02m" 55 | STRIKETHROUGH = "\033[09m" 56 | REVERSE = "\033[07m" 57 | ITALICS = "\033[3m" 58 | # Background Colors 59 | BG_BLACK = "\033[40m" 60 | BG_RED = "\033[41m" 61 | BG_GREEN = "\033[42m" 62 | BG_ORANGE = "\033[43m" 63 | BG_BLUE = "\033[44m" 64 | BG_PURPLE = "\033[45m" 65 | BG_CYAN = "\033[46m" 66 | BG_GRAY = "\033[47m" 67 | # End of line (cleans color) 68 | ENDC = "\033[0m" 69 | 70 | 71 | # Defines for each Log_Type the corresponding color to be used as well as its prefix 72 | # TODO make bcolors enum more fancy 73 | type2bcolors: dict[Log_Type, tuple[str, str]] = { 74 | Log_Type.TEXT: (bcolors.ENDC, "[*]"), 75 | Log_Type.NEUTRAL: (bcolors.ENDC, "[ ]"), 76 | Log_Type.SAVE: (bcolors.CYAN, "[*]"), 77 | Log_Type.WARNING: (bcolors.YELLOW, "[?]"), 78 | Log_Type.WARNING_THROW: (bcolors.YELLOW, "[?]"), 79 | Log_Type.LOG: (bcolors.BLUE, "[#]"), 80 | Log_Type.OK: (bcolors.GREEN, "[+]"), 81 | Log_Type.FAIL: (bcolors.RED, "[!]"), 82 | Log_Type.Yellow: (bcolors.Yellow2, "[*]"), 83 | Log_Type.STRANGE: (bcolors.PINK, "[-]"), 84 | Log_Type.UNDERLINE: (bcolors.UNDERLINE, "[_]"), 85 | Log_Type.ITALICS: (bcolors.ITALICS, "[ ]"), 86 | Log_Type.BOLD: (bcolors.BOLD, "[*]"), 87 | Log_Type.DOCKER: (bcolors.ITALICS, "[Docker]"), 88 | Log_Type.TOTALSEG: (bcolors.ITALICS, "[TOTALSEG]"), 89 | Log_Type.STAGE: (bcolors.BG_BLUE, "[*]"), 90 | } 91 | 92 | 93 | def datatype_to_string(text, log_type: Log_Type): 94 | """Processes given text into a readable string 95 | 96 | Args: 97 | text (str): _description_ 98 | log_type (Log_Type): _description_ 99 | 100 | Returns: 101 | _type_: _description_ 102 | """ 103 | if isinstance(text, dict): 104 | return _dict_to_string(text, log_type) 105 | return str(text) 106 | 107 | 108 | def _dict_to_string(u_dict: dict, ltype: Log_Type = Log_Type.TEXT): 109 | """Converts a dictionary into a readable string 110 | 111 | Args: 112 | u_dict (dict): dictionary to be logged 113 | ltype (Log_Type, optional): Log_Type. Defaults to Log_Type.TEXT. 114 | 115 | Returns: 116 | _type_: string version of the dictionary 117 | """ 118 | text = "" 119 | text += "{" 120 | for key, value in u_dict.items(): 121 | if isinstance(key, str): 122 | key = f"'{key}'" # noqa: PLW2901 123 | if isinstance(value, str): 124 | value = f"'{value}'" # noqa: PLW2901 125 | text += " " + color_log_text(Log_Type.UNDERLINE, str(key), end=ltype) + ": " + str(value) + "; " 126 | text += "}" 127 | return text 128 | 129 | 130 | def get_formatted_time(): 131 | return format_time_short(get_time()) 132 | 133 | 134 | def get_time() -> struct_time: 135 | t = time.localtime() 136 | return t 137 | 138 | 139 | def _format_time(t: struct_time): 140 | return time.asctime(t) 141 | 142 | 143 | def format_time_short(t: struct_time) -> str: 144 | return ( 145 | "date-" 146 | + str(t.tm_year) 147 | + "-" 148 | + str(t.tm_mon) 149 | + "-" 150 | + str(t.tm_mday) 151 | + "_time-" 152 | + str(t.tm_hour) 153 | + "-" 154 | + str(t.tm_min) 155 | + "-" 156 | + str(t.tm_sec) 157 | ) 158 | 159 | 160 | def _convert_seconds(seconds: float): 161 | return str(datetime.timedelta(seconds=seconds)) + " h:mm:ss" 162 | 163 | 164 | def color_log_text(ltype: Log_Type, text: str, end: Log_Type = Log_Type.TEXT): 165 | """Colors text(str) based on given Log_Type 166 | 167 | Args: 168 | ltype (Log_Type): Log_Type (defines the color being used) 169 | text (str): Text to be colored 170 | end (Log_Type, optional): What color should come after this text. Defaults to Log_Type.TEXT. (no color) 171 | 172 | Returns: 173 | _type_: _description_ 174 | """ 175 | return _color_text(color_char=type2bcolors[ltype][0], text=text, end=type2bcolors[end][0]) 176 | 177 | 178 | def _color_text(text: str, color_char, end=bcolors.ENDC): 179 | return f"{color_char}{text}{bcolors.ENDC}{end}" 180 | 181 | 182 | def _clean_all_color_from_text(text: str): 183 | import re 184 | 185 | ansi_escape = re.compile(r"\x1B(?:[@-Z\\-_]|\[[0-?]*[ -/]*[@-~])") 186 | text = ansi_escape.sub("", text) 187 | return text 188 | 189 | 190 | if __name__ == "__main__": 191 | text = "Hello World" 192 | colored_text = color_log_text(Log_Type.OK, text) 193 | uncolored_text = _clean_all_color_from_text(colored_text) 194 | print(colored_text) 195 | print(uncolored_text) 196 | -------------------------------------------------------------------------------- /TPTBox/mesh3D/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Hendrik-code/TPTBox/c6cd01460d09df747fd6baa77680a8f575049c7c/TPTBox/mesh3D/__init__.py -------------------------------------------------------------------------------- /TPTBox/mesh3D/mesh_colors.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import random 4 | 5 | import numpy as np 6 | 7 | 8 | class RGB_Color: 9 | def __init__(self, rgb: tuple[int, int, int]): 10 | assert isinstance(rgb, tuple) and [isinstance(i, int) for i in rgb], "did not receive a tuple of 3 ints" 11 | self.rgb = np.array(rgb) 12 | 13 | @classmethod 14 | def init_separate(cls, r: int, g: int, b: int): 15 | return cls((r, g, b)) 16 | 17 | @classmethod 18 | def init_list(cls, rgb: list[int] | np.ndarray): 19 | assert len(rgb) == 3, "rgb requires exactly three integers" 20 | if isinstance(rgb, np.ndarray): 21 | assert rgb.dtype == int, "rgb numpy array not of type int!" 22 | return cls(tuple(rgb)) 23 | 24 | def __repr__(self) -> str: 25 | return str(self) 26 | 27 | def __str__(self) -> str: 28 | return "RGB_Color-" + str(self.rgb) 29 | 30 | def __call__(self, normed: bool = False): 31 | if normed: 32 | return self.rgb / 255.0 33 | return self.rgb 34 | 35 | def __getitem__(self, item): 36 | return self.rgb[item] / 255.0 37 | 38 | 39 | class Mesh_Color_List: 40 | # General Colors 41 | BEIGE = RGB_Color.init_list([255, 250, 200]) 42 | MAROON = RGB_Color.init_list([128, 0, 0]) 43 | YELLOW = RGB_Color.init_list([255, 255, 25]) 44 | ORANGE = RGB_Color.init_list([245, 130, 48]) 45 | BLUE = RGB_Color.init_list([30, 144, 255]) 46 | BLACK = RGB_Color.init_list([0, 0, 0]) 47 | WHITE = RGB_Color.init_list([255, 255, 255]) 48 | GREEN = RGB_Color.init_list([50, 250, 65]) 49 | MAGENTA = RGB_Color.init_list([240, 50, 250]) 50 | SPRINGGREEN = RGB_Color.init_list([0, 255, 128]) 51 | CYAN = RGB_Color.init_list([70, 240, 240]) 52 | PINK = RGB_Color.init_list([255, 105, 180]) 53 | BROWN = RGB_Color.init_list([160, 100, 30]) 54 | DARKGRAY = RGB_Color.init_list([95, 93, 68]) 55 | GRAY = RGB_Color.init_list([143, 140, 110]) 56 | NAVY = RGB_Color.init_list([0, 0, 128]) 57 | LIME = RGB_Color.init_list([210, 245, 60]) 58 | 59 | ITK_1 = RGB_Color.init_list([255, 0, 0]) 60 | ITK_2 = RGB_Color.init_list([0, 255, 0]) 61 | ITK_3 = RGB_Color.init_list([0, 0, 255]) 62 | ITK_4 = RGB_Color.init_list([255, 255, 0]) 63 | ITK_5 = RGB_Color.init_list([0, 255, 255]) 64 | ITK_6 = RGB_Color.init_list([255, 0, 255]) 65 | ITK_7 = RGB_Color.init_list([255, 239, 213]) 66 | ITK_8 = RGB_Color.init_list([0, 0, 205]) 67 | ITK_9 = RGB_Color.init_list([205, 133, 63]) 68 | ITK_10 = RGB_Color.init_list([210, 180, 140]) 69 | ITK_11 = RGB_Color.init_list([102, 205, 170]) 70 | ITK_12 = RGB_Color.init_list([0, 0, 128]) 71 | ITK_13 = RGB_Color.init_list([0, 139, 139]) 72 | ITK_14 = RGB_Color.init_list([46, 139, 87]) 73 | ITK_15 = RGB_Color.init_list([255, 228, 225]) 74 | ITK_16 = RGB_Color.init_list([106, 90, 205]) 75 | ITK_17 = RGB_Color.init_list([221, 160, 221]) 76 | ITK_18 = RGB_Color.init_list([233, 150, 122]) 77 | ITK_19 = RGB_Color.init_list([165, 42, 42]) 78 | 79 | ITK_20 = RGB_Color.init_list([255, 250, 250]) 80 | ITK_21 = RGB_Color.init_list([147, 112, 219]) 81 | ITK_22 = RGB_Color.init_list([218, 112, 214]) 82 | ITK_23 = RGB_Color.init_list([75, 0, 130]) 83 | ITK_24 = RGB_Color.init_list([255, 182, 193]) 84 | ITK_25 = RGB_Color.init_list([60, 179, 113]) 85 | ITK_26 = RGB_Color.init_list([255, 235, 205]) 86 | ITK_27 = RGB_Color.init_list([255, 228, 196]) 87 | ITK_28 = RGB_Color.init_list([218, 165, 32]) 88 | ITK_29 = RGB_Color.init_list([0, 128, 128]) 89 | ITK_30 = RGB_Color.init_list([188, 143, 143]) 90 | ITK_31 = RGB_Color.init_list([255, 105, 180]) 91 | ITK_32 = RGB_Color.init_list([255, 218, 185]) 92 | ITK_33 = RGB_Color.init_list([222, 184, 135]) 93 | ITK_34 = RGB_Color.init_list([127, 255, 0]) 94 | ITK_35 = RGB_Color.init_list([139, 69, 19]) 95 | ITK_36 = RGB_Color.init_list([124, 252, 0]) 96 | ITK_37 = RGB_Color.init_list([255, 255, 224]) 97 | ITK_38 = RGB_Color.init_list([70, 130, 180]) 98 | ITK_39 = RGB_Color.init_list([0, 100, 0]) 99 | ITK_40 = RGB_Color.init_list([238, 130, 238]) 100 | ## Subregions 101 | ITK_41 = RGB_Color.init_list([238, 232, 170]) 102 | ITK_42 = RGB_Color.init_list([240, 255, 240]) 103 | ITK_43 = RGB_Color.init_list([245, 222, 179]) 104 | ITK_44 = RGB_Color.init_list([184, 134, 11]) 105 | ITK_45 = RGB_Color.init_list([32, 178, 170]) 106 | ITK_46 = RGB_Color.init_list([255, 20, 147]) 107 | ITK_47 = RGB_Color.init_list([25, 25, 112]) 108 | ITK_48 = RGB_Color.init_list([112, 128, 144]) 109 | ITK_49 = RGB_Color.init_list([34, 139, 34]) 110 | ITK_50 = RGB_Color.init_list([248, 248, 255]) 111 | ITK_51 = RGB_Color.init_list([245, 255, 250]) 112 | ITK_52 = RGB_Color.init_list([255, 160, 122]) 113 | ITK_53 = RGB_Color.init_list([144, 238, 144]) 114 | ITK_54 = RGB_Color.init_list([173, 255, 47]) 115 | ITK_55 = RGB_Color.init_list([65, 105, 225]) 116 | ITK_56 = RGB_Color.init_list([255, 99, 71]) 117 | ITK_57 = RGB_Color.init_list([250, 240, 230]) 118 | ITK_58 = RGB_Color.init_list([128, 0, 0]) 119 | ITK_59 = RGB_Color.init_list([50, 205, 50]) 120 | ITK_60 = RGB_Color.init_list([244, 164, 96]) 121 | ITK_61 = RGB_Color.init_list([255, 255, 240]) 122 | ITK_62 = RGB_Color.init_list([123, 104, 238]) 123 | ITK_63 = RGB_Color.init_list([255, 165, 0]) 124 | ITK_64 = RGB_Color.init_list([173, 216, 230]) 125 | ITK_65 = RGB_Color.init_list([255, 192, 203]) 126 | ITK_66 = RGB_Color.init_list([127, 255, 212]) 127 | ITK_67 = RGB_Color.init_list([255, 140, 0]) 128 | ITK_68 = RGB_Color.init_list([143, 188, 143]) 129 | ITK_69 = RGB_Color.init_list([220, 20, 60]) 130 | ITK_70 = RGB_Color.init_list([253, 245, 230]) 131 | ITK_71 = RGB_Color.init_list([255, 250, 240]) 132 | ITK_72 = RGB_Color.init_list([0, 206, 209]) 133 | 134 | ITK_100 = RGB_Color.init_list([176, 224, 230]) 135 | 136 | 137 | _color_dict = {v: getattr(Mesh_Color_List, v) for v in vars(Mesh_Color_List) if not callable(v) and not v.startswith("__")} 138 | 139 | _color_mapping_by_label: dict[int, RGB_Color] = { 140 | i: _color_dict.get( 141 | f"ITK_{i}", 142 | RGB_Color.init_list([random.randint(20, 245), random.randint(20, 245), random.randint(20, 245)]), 143 | ) 144 | for i in range(1, 150) 145 | } 146 | 147 | _color_map_in_row = np.array([v.rgb for v in _color_mapping_by_label.values()]) 148 | 149 | 150 | def get_color_by_label(label: int): 151 | if label not in _color_mapping_by_label: 152 | return _color_mapping_by_label[label % 50 + 1] 153 | return _color_mapping_by_label[label] 154 | -------------------------------------------------------------------------------- /TPTBox/registration/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | try: 4 | from .ridged_points.point_registration import Point_Registration, ridged_points_from_poi, ridged_points_from_subreg_vert 5 | 6 | except Exception: 7 | pass 8 | try: 9 | from .deepali.spine_rigid_elements_reg import Rigid_Elements_Registration 10 | 11 | except Exception: 12 | pass 13 | 14 | try: 15 | from .deepali.deepali_model import General_Registration 16 | 17 | except Exception: 18 | pass 19 | -------------------------------------------------------------------------------- /TPTBox/registration/deepali/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | try: 4 | from .spine_rigid_elements_reg import Rigid_Elements_Registration 5 | 6 | except Exception: 7 | pass 8 | 9 | try: 10 | from .deepali_model import General_Registration 11 | 12 | except Exception: 13 | pass 14 | -------------------------------------------------------------------------------- /TPTBox/registration/deformable/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from .deformable_reg import Deformable_Registration 4 | -------------------------------------------------------------------------------- /TPTBox/registration/deformable/_deepali/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Hendrik-code/TPTBox/c6cd01460d09df747fd6baa77680a8f575049c7c/TPTBox/registration/deformable/_deepali/__init__.py -------------------------------------------------------------------------------- /TPTBox/registration/deformable/_deepali/deformable_config.yaml: -------------------------------------------------------------------------------- 1 | energy: 2 | be: 3 | - weight: 0.001 4 | - name: BSplineBending 5 | - stride: 1 6 | seg: 7 | name: LNCC 8 | model: 9 | name: SVFFD 10 | stride: 11 | - 8 12 | - 8 13 | - 16 14 | transpose: false 15 | optim: 16 | lr: 0.001 17 | max_steps: 1000 18 | min_delta: -0.0001 19 | name: Adam 20 | pyramid: 21 | dims: [x,y,z] 22 | levels: 3 23 | spacing: [1.40625,1.40625,3.0] 24 | #spacing: [2.232 , 2.232, 3.0] 25 | spacing: 26 | [1.40625,1.40625,3.0] 27 | -------------------------------------------------------------------------------- /TPTBox/registration/deformable/_deepali/engine.py: -------------------------------------------------------------------------------- 1 | r"""Engine for iterative optimization-based registration.""" 2 | 3 | # TODO: Use ignite.engine instead of custom RegistrationEngine. 4 | 5 | from __future__ import annotations 6 | 7 | import math 8 | from collections import OrderedDict 9 | from collections.abc import Callable 10 | from timeit import default_timer as timer 11 | 12 | import torch 13 | from torch import Tensor 14 | from torch.optim import Optimizer 15 | from torch.utils.hooks import RemovableHandle 16 | 17 | from TPTBox.registration.deformable._deepali.optim import slope_of_least_squares_fit 18 | from TPTBox.registration.deformable._deepali.registration_losses import RegistrationLoss, RegistrationResult 19 | 20 | PROFILING = False 21 | 22 | 23 | class RegistrationEngine: 24 | r"""Minimize registration loss until convergence.""" 25 | 26 | def __init__( 27 | self, 28 | loss: RegistrationLoss, 29 | optimizer: Optimizer, 30 | max_steps: int = 500, 31 | min_delta: float = 1e-6, 32 | min_value: float = float("nan"), 33 | max_history: int = 10, 34 | ): 35 | r"""Initialize registration loop.""" 36 | self.loss = loss 37 | self.optimizer = optimizer 38 | self.num_steps = 0 39 | self.max_steps = max_steps 40 | self.min_delta = min_delta 41 | self.min_value = min_value 42 | self.max_history = max(2, max_history) 43 | self.loss_values = [] 44 | self._eval_hooks = OrderedDict() 45 | self._step_hooks = OrderedDict() 46 | 47 | @property 48 | def loss_value(self) -> float: 49 | if not self.loss_values: 50 | return float("inf") 51 | return self.loss_values[-1] 52 | 53 | def step(self) -> float: 54 | r"""Perform one registration step. 55 | 56 | Returns: 57 | Loss value prior to taking gradient step. 58 | 59 | """ 60 | num_evals = 0 61 | 62 | def closure() -> float: 63 | self.optimizer.zero_grad() 64 | t_start = timer() 65 | result = self.loss.eval() 66 | if PROFILING: 67 | print(f"Forward pass in {timer() - t_start:.3f}s") 68 | loss = result["loss"] 69 | assert isinstance(loss, Tensor) 70 | t_start = timer() 71 | loss.backward() 72 | if PROFILING: 73 | print(f"Backward pass in {timer() - t_start:.3f}s") 74 | nonlocal num_evals 75 | num_evals += 1 76 | with torch.no_grad(): 77 | for hook in self._eval_hooks.values(): 78 | hook(self, self.num_steps, num_evals, result) 79 | return float(loss) 80 | 81 | loss_value = self.optimizer.step(closure) 82 | assert loss_value is not None 83 | 84 | with torch.no_grad(): 85 | for hook in self._step_hooks.values(): 86 | hook(self, self.num_steps, num_evals, loss_value) 87 | 88 | return loss_value 89 | 90 | def run(self) -> float: 91 | r"""Perform registration steps until convergence. 92 | 93 | Returns: 94 | Loss value prior to taking last gradient step. 95 | 96 | """ 97 | self.loss_values = [] 98 | self.num_steps = 0 99 | while self.num_steps < self.max_steps and not self.converged(): 100 | value = self.step() 101 | self.num_steps += 1 102 | if math.isnan(value): 103 | raise RuntimeError(f"NaN value in registration loss at gradient step {self.num_steps}") 104 | if math.isinf(value): 105 | raise RuntimeError(f"Inf value in registration loss at gradient step {self.num_steps}") 106 | self.loss_values.append(value) 107 | if len(self.loss_values) > self.max_history: 108 | self.loss_values.pop(0) 109 | return self.loss_value 110 | 111 | def converged(self) -> bool: 112 | r"""Check convergence criteria.""" 113 | values = self.loss_values 114 | if not values: 115 | return False 116 | value = values[-1] 117 | epsilon = abs(self.min_delta * value) if self.min_delta < 0 else self.min_delta 118 | slope = slope_of_least_squares_fit(values) 119 | if abs(slope) < epsilon: 120 | return True 121 | return value < self.min_value 122 | 123 | def register_eval_hook(self, hook: Callable[[RegistrationEngine, int, int, RegistrationResult], None]) -> RemovableHandle: 124 | r"""Registers a evaluation hook. 125 | 126 | The hook will be called every time after the registration loss has been evaluated 127 | during a single step of the optimizer, and the backward pass was performed, but 128 | before the model parameters are updated by taking a gradient step. 129 | 130 | hook(self, num_steps: int, num_evals: int, result: RegistrationResult) -> None 131 | 132 | Returns: 133 | A handle that can be used to remove the added hook by calling ``handle.remove()`` 134 | 135 | """ 136 | handle = RemovableHandle(self._eval_hooks) 137 | self._eval_hooks[handle.id] = hook 138 | return handle 139 | 140 | def register_step_hook(self, hook: Callable[[RegistrationEngine, int, int, float], None]) -> RemovableHandle: 141 | r"""Registers a gradient step hook. 142 | 143 | The hook will be called every time after a gradient step of the optimizer. 144 | 145 | hook(self, num_steps: int, num_evals: int, loss: float) -> None 146 | 147 | Returns: 148 | A handle that can be used to remove the added hook by calling ``handle.remove()`` 149 | 150 | """ 151 | handle = RemovableHandle(self._step_hooks) 152 | self._step_hooks[handle.id] = hook 153 | return handle 154 | -------------------------------------------------------------------------------- /TPTBox/registration/deformable/_deepali/hooks.py: -------------------------------------------------------------------------------- 1 | r"""Hooks for iterative optimization-based registration engine.""" 2 | 3 | from collections.abc import Callable 4 | 5 | import torch.nn.functional as torch_functional 6 | from deepali.core import functional as deepali_functional 7 | from deepali.core.kernels import gaussian1d 8 | from deepali.spatial import is_linear_transform 9 | 10 | from TPTBox.registration.deformable._deepali.engine import RegistrationEngine, RegistrationResult 11 | 12 | RegistrationEvalHook = Callable[[RegistrationEngine, int, int, RegistrationResult], None] 13 | RegistrationStepHook = Callable[[RegistrationEngine, int, int, float], None] 14 | 15 | 16 | def noop(_reg: RegistrationEngine, *_args, **_kwargs) -> None: 17 | r"""Dummy no-op loss evaluation hook.""" 18 | 19 | 20 | def normalize_linear_grad(reg: RegistrationEngine, *_args, **_kwargs) -> None: 21 | r"""Loss evaluation hook for normalization of linear transformation gradient after backward pass.""" 22 | denom = None 23 | for group in reg.optimizer.param_groups: 24 | for p in (p for p in group["params"] if p.grad is not None): 25 | max_abs_grad = p.grad.abs().max() 26 | if denom is None or denom < max_abs_grad: 27 | denom = max_abs_grad 28 | if denom is None: 29 | return 30 | for group in reg.optimizer.param_groups: 31 | for p in (p for p in group["params"] if p.grad is not None): 32 | p.grad /= denom 33 | 34 | 35 | def normalize_nonrigid_grad(reg: RegistrationEngine, *_args, **_kwargs) -> None: 36 | r"""Loss evaluation hook for normalization of non-rigid transformation gradient after backward pass.""" 37 | for group in reg.optimizer.param_groups: 38 | for p in (p for p in group["params"] if p.grad is not None): 39 | torch_functional.normalize(p.grad, p=2, dim=1, out=p.grad) 40 | 41 | 42 | def normalize_grad_hook(transform) -> RegistrationEvalHook: 43 | r"""Loss evaluation hook for normalization of transformation gradient after backward pass.""" 44 | if is_linear_transform(transform): 45 | return normalize_linear_grad 46 | return normalize_nonrigid_grad 47 | 48 | 49 | def _smooth_nonrigid_grad(reg: RegistrationEngine, sigma: float = 1) -> None: 50 | r"""Loss evaluation hook for Gaussian smoothing of non-rigid transformation gradient after backward pass.""" 51 | if sigma <= 0: 52 | return 53 | kernel = gaussian1d(sigma) 54 | for group in reg.optimizer.param_groups: 55 | for p in (p for p in group["params"] if p.grad is not None): 56 | p.grad.copy_(deepali_functional.conv(p.grad, kernel)) 57 | 58 | 59 | def smooth_grad_hook(transform, sigma: float) -> RegistrationEvalHook: 60 | r"""Loss evaluation hook for Gaussian smoothing of non-rigid gradient after backward pass.""" 61 | if is_linear_transform(transform): 62 | return noop 63 | 64 | def fn(reg: RegistrationEngine, *_args, **_kwargs): 65 | return _smooth_nonrigid_grad(reg, sigma=sigma) 66 | 67 | return fn 68 | -------------------------------------------------------------------------------- /TPTBox/registration/deformable/_deepali/metrics.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import math 4 | 5 | import numpy as np 6 | import SimpleITK as sitk # noqa: N813 7 | import torch 8 | from torch import Tensor 9 | from torch.nn.modules.loss import _Loss 10 | 11 | 12 | class NMILOSS(_Loss): 13 | """Normalized mutual information metric. 14 | 15 | As presented in the work by `De Vos 2020: `_ 16 | 17 | """ 18 | 19 | def __init__( 20 | self, 21 | intensity_range: tuple[float, float] | None = None, 22 | nbins: int = 32, 23 | sigma: float = 0.1, 24 | use_mask: bool = False, 25 | ): 26 | super().__init__() 27 | self.intensity_range = intensity_range 28 | self.nbins = nbins 29 | self.sigma = sigma 30 | if use_mask: 31 | self.forward = self.masked_metric 32 | else: 33 | self.forward = self.metric 34 | 35 | def metric(self, fixed: Tensor, warped: Tensor) -> Tensor: 36 | with torch.no_grad(): 37 | if self.intensity_range: 38 | fixed_range = self.intensity_range 39 | warped_range = self.intensity_range 40 | else: 41 | fixed_range = fixed.min(), fixed.max() 42 | warped_range = warped.min(), warped.max() 43 | 44 | bins_fixed = torch.linspace( 45 | fixed_range[0], 46 | fixed_range[1], 47 | self.nbins, 48 | dtype=fixed.dtype, 49 | device=fixed.device, 50 | ) 51 | bins_warped = torch.linspace( 52 | warped_range[0], 53 | warped_range[1], 54 | self.nbins, 55 | dtype=fixed.dtype, 56 | device=fixed.device, 57 | ) 58 | 59 | return -nmi_gauss(fixed, warped, bins_fixed, bins_warped, sigma=self.sigma).mean() 60 | 61 | 62 | def nmi_gauss(x1, x2, x1_bins, x2_bins, sigma=1e-3, e=1e-10): 63 | assert x1.shape == x2.shape, "Inputs are not of similar shape" 64 | 65 | def gaussian_window(x, bins, sigma): 66 | assert x.ndim == 2, "Input tensor should be 2-dimensional." 67 | return torch.exp(-((x[:, None, :] - bins[None, :, None]) ** 2) / (2 * sigma**2)) / (math.sqrt(2 * math.pi) * sigma) 68 | 69 | x1_windowed = gaussian_window(x1.flatten(1), x1_bins, sigma) 70 | x2_windowed = gaussian_window(x2.flatten(1), x2_bins, sigma) 71 | p_xy = torch.bmm(x1_windowed, x2_windowed.transpose(1, 2)) 72 | p_xy = p_xy + e # deal with numerical instability 73 | 74 | p_xy = p_xy / p_xy.sum((1, 2))[:, None, None] 75 | 76 | p_x = p_xy.sum(1) 77 | p_y = p_xy.sum(2) 78 | 79 | i = (p_xy * torch.log(p_xy / (p_x[:, None] * p_y[:, :, None]))).sum((1, 2)) 80 | 81 | marg_ent_0 = (p_x * torch.log(p_x)).sum(1) 82 | marg_ent_1 = (p_y * torch.log(p_y)).sum(1) 83 | 84 | normalized = -1 * 2 * i / (marg_ent_0 + marg_ent_1) # harmonic mean 85 | 86 | return normalized 87 | 88 | 89 | def calculate_dice(mask1, mask2, label_class=0): 90 | """ 91 | from https://github.com/voxelmorph/ 92 | Dice score of a specified class between two label masks. 93 | (classes are encoded but by label class number not one-hot ) 94 | 95 | Args: 96 | mask1: (numpy.array, shape (N, 1, *sizes)) segmentation mask 1 97 | mask2: (numpy.array, shape (N, 1, *sizes)) segmentation mask 2 98 | label_class: (int or float) 99 | 100 | Returns: 101 | volume_dice 102 | """ 103 | mask1_pos = (mask1 == label_class).astype(np.float32) 104 | mask2_pos = (mask2 == label_class).astype(np.float32) 105 | 106 | assert mask1.ndim == mask2.ndim 107 | axes = tuple(range(2, mask1.ndim)) 108 | pos1and2 = np.sum(mask1_pos * mask2_pos, axis=axes) 109 | pos1 = np.sum(mask1_pos, axis=axes) 110 | pos2 = np.sum(mask2_pos, axis=axes) 111 | return np.mean(2 * pos1and2 / (pos1 + pos2 + 1e-7)) 112 | 113 | 114 | def dice(a, b, label): 115 | return calculate_dice(a, b, label_class=label) 116 | 117 | 118 | def calculate_jacobian_metrics(disp): 119 | """ 120 | Calculate Jacobian related regularity metrics. 121 | from https://github.com/voxelmorph/ 122 | 123 | Args: 124 | disp: (numpy.ndarray, shape (N, ndim, *sizes) Displacement field 125 | 126 | Returns: 127 | folding_ratio: (scalar) Folding ratio (ratio of Jacobian determinant < 0 points) 128 | mag_grad_jac_det: (scalar) Mean magnitude of the spatial gradient of Jacobian determinant 129 | """ 130 | folding_ratio = [] 131 | mag_grad_jac_det = [] 132 | for n in range(disp.shape[0]): 133 | disp_n = np.moveaxis(disp[n, ...], 0, -1) # (*sizes, ndim) 134 | jac_det_n = calculate_jacobian_det(disp_n) 135 | folding_ratio += [(jac_det_n < 0).sum() / np.prod(jac_det_n.shape)] 136 | mag_grad_jac_det += [np.abs(np.gradient(jac_det_n)).mean()] 137 | return np.mean(folding_ratio), np.mean(mag_grad_jac_det) 138 | 139 | 140 | def calculate_jacobian_det(disp): 141 | """ 142 | Calculate Jacobian determinant of displacement field of one image/volume (2D/3D) 143 | from https://github.com/voxelmorph/ 144 | 145 | Args: 146 | disp: (numpy.ndarray, shape (*sizes, ndim)) Displacement field 147 | 148 | Returns: 149 | jac_det: (numpy.adarray, shape (*sizes) Point-wise Jacobian determinant 150 | """ 151 | disp_img = sitk.GetImageFromArray(disp, isVector=True) 152 | jac_det_img = sitk.DisplacementFieldJacobianDeterminant(disp_img) 153 | jac_det = sitk.GetArrayFromImage(jac_det_img) 154 | return jac_det 155 | -------------------------------------------------------------------------------- /TPTBox/registration/deformable/_deepali/optim.py: -------------------------------------------------------------------------------- 1 | from collections.abc import Sequence 2 | 3 | import torch.optim 4 | from torch.nn import Module 5 | from torch.optim import Optimizer 6 | 7 | 8 | def new_optimizer(name: str, model: Module, **kwargs) -> Optimizer: 9 | r"""Initialize new optimizer for parameters of given model. 10 | 11 | Args: 12 | name: Name of optimizer. 13 | model: Module whose parameters are to be optimized. 14 | kwargs: Keyword arguments for named optimizer. 15 | 16 | Returns: 17 | New optimizer instance. 18 | 19 | """ 20 | cls = getattr(torch.optim, name, None) 21 | if cls is None: 22 | raise ValueError(f"Unknown optimizer: {name}") 23 | if not issubclass(cls, Optimizer): 24 | raise TypeError(f"Requested type '{name}' is not a subclass of torch.optim.Optimizer") 25 | if "learning_rate" in kwargs: 26 | if "lr" in kwargs: 27 | raise ValueError("new_optimizer() 'lr' and 'learning_rate' are mutually exclusive") 28 | kwargs["lr"] = kwargs.pop("learning_rate") 29 | return cls(model.parameters(), **kwargs) 30 | 31 | 32 | def slope_of_least_squares_fit(values: Sequence[float]) -> float: 33 | r"""Compute slope of least squares fit of line to last n objective function values 34 | 35 | See also: 36 | - https://www.che.udel.edu/pdf/FittingData.pdf 37 | - https://en.wikipedia.org/wiki/1_%2B_2_%2B_3_%2B_4_%2B_%E2%8B%AF 38 | - https://proofwiki.org/wiki/Sum_of_Sequence_of_Squares 39 | 40 | """ 41 | n = len(values) 42 | if n < 2: 43 | return float("nan") 44 | if n == 2: 45 | return values[1] - values[0] 46 | # sum_x1 divided by n as a slight modified to reduce no. of operations, 47 | # i.e., the other terms are divided by n as well by dropping one factor n 48 | sum_x1 = (n + 1) / 2 49 | sum_x2 = n * (n + 1) * (2 * n + 1) / 6 50 | sum_y1 = sum(values) 51 | sum_xy = sum(((x + 1) * y for x, y in enumerate(values))) 52 | return (sum_xy - sum_x1 * sum_y1) / (sum_x2 - n * sum_x1 * sum_x1) 53 | -------------------------------------------------------------------------------- /TPTBox/registration/deformable/_grid_search_vert.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | import numpy as np 4 | 5 | from TPTBox import POI, Image_Reference, calc_centroids, calc_poi_from_two_segs 6 | from TPTBox.core.nii_wrapper import NII, to_nii 7 | from TPTBox.core.poi_fun.ray_casting import calculate_pca_normal_np 8 | from TPTBox.registration.deformable.deformable_reg import Deformable_Registration 9 | from TPTBox.registration.ridged_points import ridged_points_from_poi 10 | 11 | setting = { 12 | "loss": {"config": {"be": {"stride": 1, "name": "BSplineBending"}, "seg": {"name": "MSE"}}, "weights": {"be": 0.001, "seg": 1}}, 13 | "model": {"name": "SVFFD", "args": {"stride": [4, 4, 4], "transpose": False}, "init": None}, 14 | "optim": {"name": "Adam", "args": {"lr": 0.1}, "loop": {"max_steps": 1000, "min_delta": -0.001}}, 15 | "pyramid": { 16 | "levels": 4, # 4 for 0.8 res instead of 3 17 | "coarsest_level": 3, # 3 for 0.8 res instead of 2 18 | "finest_level": 0, 19 | "finest_spacing": None, # Auto set by the nifty 20 | "min_size": 16, 21 | "pyramid_dims": ["x", "y", "z"], 22 | }, 23 | } 24 | 25 | 26 | def main_vert_test(): 27 | p = Path("/DATA/NAS/datasets_processed/CT_spine/dataset-verse20/derivatives/sub-verse510/") 28 | sub = p / "sub-verse510_dir-ax_seg-subreg_msk.nii.gz" 29 | vert = p / "sub-verse510_dir-ax_seg-vert_msk.nii.gz" 30 | out = Path("/DATA/NAS/ongoing_projects/robert/test/seg_transplant") 31 | # Load and extract two vertebras 32 | vert = to_nii(vert, True) 33 | sub = to_nii(sub, True) # .resample_from_to(vert) 34 | L1 = vert.extract_label(21) # noqa: N806 35 | L2 = vert.extract_label(22) # noqa: N806 36 | c1 = L1.compute_crop(dist=10) 37 | L1.apply_crop_(c1) 38 | c2 = L2.compute_crop(dist=10) 39 | L2.apply_crop_(c2) 40 | L1.save(out / "L1.nii.gz") 41 | L2.save(out / "L2.nii.gz") 42 | sub1 = sub.apply_crop(c1) 43 | sub2 = sub.apply_crop(c2) 44 | # Compute Points 45 | poi1 = calc_poi_from_two_segs(L1, sub1, out / "L1_cdt.json") 46 | poi2 = calc_poi_from_two_segs(L2, sub2, out / "L2_cdt.json") 47 | # Point registration 48 | reg = ridged_points_from_poi(poi1, poi2) 49 | L2_preg_sub = reg.transform_nii(sub2 * L2) # noqa: N806 50 | L2_preg_sub.save(out / "L2_preg.nii.gz") 51 | L2_preg = reg.transform_nii(L2) # noqa: N806 52 | # Deformable Registration 53 | reg_deform = Deformable_Registration(L1, L2_preg, config=setting) 54 | reg_deform.transform_nii(L2_preg_sub).save(out / "L2_reg_large_no_be.nii.gz") 55 | 56 | 57 | def get_femurs(img: Image_Reference, seg_id=13): 58 | """Returns left (2) and right (1) 59 | 60 | Args: 61 | img (Image_Reference): _description_ 62 | 63 | Returns: 64 | _type_: _description_ 65 | """ 66 | nii = to_nii(img, True) 67 | # Extract Femurs 68 | cc = nii.extract_label(seg_id) 69 | # CC 70 | cc = cc.get_connected_components() 71 | print(f"Warning more than two cc {cc.max()=}") if cc.max() >= 3 else None 72 | cc[cc > 2] = 0 73 | # Compute Poi of two largest CC (id = 1,2, cause sorted) 74 | femur_poi = calc_centroids(cc, second_stage=seg_id) 75 | # Extract left femur 76 | dim = femur_poi.get_axis("R") 77 | mirror = femur_poi.orientation[dim] == "R" 78 | left_id = 0 if femur_poi[1, seg_id][dim] > femur_poi[2, seg_id][dim] else 1 79 | if mirror: 80 | left_id = 1 - left_id 81 | left_id += 1 82 | get_additonal_point(femur_poi, cc, seg_id) 83 | if left_id == 2: 84 | return cc, femur_poi 85 | else: 86 | return cc.map_labels_({1: 2, 2: 1}, verbose=False), femur_poi.map_labels_(label_map_region={1: 2, 2: 1}, verbose=False) 87 | 88 | 89 | def get_additonal_point(poi: POI, cc: NII, seg_id): 90 | for i in range(1, 3): 91 | try: 92 | normal_vector = calculate_pca_normal_np(cc.extract_label(i).get_array(), pca_component=0, zoom=poi.zoom) 93 | except ValueError: 94 | return 95 | dim = poi.get_axis("S") 96 | mirror = poi.orientation[dim] != "S" 97 | 98 | # check if it is pointing in the same direction 99 | flip = -1 if (normal_vector[dim] < 0 and mirror) or (normal_vector[dim] > 0 and not mirror) else 1 100 | 101 | direction_point = np.array(poi[i, seg_id]) + normal_vector * 20 * flip 102 | poi[1, seg_id + 100] = direction_point 103 | direction_point = np.array(poi[i, seg_id]) + normal_vector * 20 * -flip 104 | poi[1, seg_id + 200] = direction_point 105 | 106 | 107 | if __name__ == "__main__": 108 | ### FEMUR ### 109 | p = Path("/DATA/NAS/ongoing_projects/robert/test/seg_transplant/bone/") 110 | if not (p / "femur_0001.nii.gz").exists() or True: 111 | femur_0, poi_0 = get_femurs(p / "bone_0000.nii.gz") 112 | femur_0.save(p / "femur_0000.nii.gz") 113 | poi_0.save(p / "femur_0000.json") 114 | femur_1, poi_1 = get_femurs(p / "bone_0001.nii.gz") 115 | femur_1.save(p / "femur_0001.nii.gz") 116 | poi_1.save(p / "femur_0001.json") 117 | femur_0_: NII = to_nii(p / "femur_0000.nii.gz", True).extract_label(1) # TODO automatic Left/right selection 118 | poi_0 = POI.load(p / "femur_0000.json") 119 | femur_1 = to_nii(p / "femur_0001.nii.gz", True).extract_label(1) 120 | poi_1 = POI.load(p / "femur_0001.json").extract_region(1) 121 | # Point registration 122 | reg = ridged_points_from_poi(poi_0, poi_1) 123 | femur_1_reg = reg.transform_nii(femur_1) 124 | femur_1_reg.save(p / "femur_1_point_reg.nii.gz") 125 | subreg = reg.transform_nii(to_nii(p / "femur_1_subreg.nii.gz", True)) 126 | # Crop to speed up and decreas gpu memory consumption. 127 | c0 = femur_0_.compute_crop(dist=10) 128 | c1 = femur_1_reg.compute_crop(dist=10) 129 | ex_slice = [slice(min(a.start, b.start), max(a.stop, b.stop)) for a, b in zip(c1, c0, strict=False)] 130 | femur_0 = femur_0_.apply_crop(ex_slice) 131 | femur_1_reg.apply_crop_(ex_slice) 132 | 133 | # Deformable registration 134 | reg_deform = Deformable_Registration(femur_0, femur_1_reg, config=setting, verbose=99) 135 | reg_deform.transform_nii(femur_1_reg).resample_from_to(femur_0_).save(p / "femur_0001_reg.nii.gz") 136 | reg_deform.transform_nii(subreg).resample_from_to(femur_0_).save(p / "subreg_reg.nii.gz") 137 | 138 | # femur_1.save(p / "test_f1.nii.gz") 139 | # femur_0.save(p / "test_f0.nii.gz") 140 | -------------------------------------------------------------------------------- /TPTBox/registration/deformable/settings.json: -------------------------------------------------------------------------------- 1 | { 2 | "loss": 3 | { 4 | "config": { 5 | "be": { 6 | "stride": 1, 7 | "name": "BSplineBending" 8 | }, 9 | "seg": { 10 | "name": "LNCC" 11 | } 12 | }, 13 | "weights": { 14 | "be": 0.001, 15 | "seg": 1 16 | } 17 | }, 18 | "model":{ 19 | "name": "SVFFD", 20 | "args": {"stride": [8, 8, 16], "transpose": false}, 21 | "init": null 22 | }, 23 | "optim":{ 24 | "name": "Adam", 25 | "args":{"lr": 0.001}, 26 | "loop": {"max_steps": 1000, "min_delta": -0.0001} 27 | }, 28 | "pyramid":{ 29 | "levels":3, 30 | "coarsest_level":2, 31 | "finest_level":0, 32 | "finest_spacing":null, 33 | "min_size":16, 34 | "pyramid_dims":["x", "y", "z"] 35 | } 36 | } 37 | -------------------------------------------------------------------------------- /TPTBox/registration/ridged_intensity/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Hendrik-code/TPTBox/c6cd01460d09df747fd6baa77680a8f575049c7c/TPTBox/registration/ridged_intensity/__init__.py -------------------------------------------------------------------------------- /TPTBox/registration/ridged_intensity/register.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import os 4 | import sys 5 | from pathlib import Path 6 | from typing import Literal 7 | 8 | import nibabel as nib 9 | 10 | from TPTBox import Print_Logger 11 | 12 | try: 13 | import nipy.algorithms.registration as nipy_reg 14 | import numpy as np 15 | from nipy.algorithms.registration.affine import Affine 16 | except ModuleNotFoundError: 17 | err = Print_Logger() 18 | err.on_fail("This subscript needs nipy as an additonal package") 19 | err.on_fail("Please install: pip install nipy") 20 | raise 21 | from TPTBox import AX_CODES, NII 22 | 23 | Similarity_Measures = Literal["slr", "mi", "pmi", "dpmi", "cc", "cr", "crl1"] 24 | Affine_Transforms = Literal["affine", "affine2d", "similarity", "similarity2d", "rigid", "rigid2d"] 25 | 26 | 27 | class HiddenPrints: 28 | def __enter__(self): 29 | self._original_stdout = sys.stdout 30 | sys.stdout = open(os.devnull, "w") # noqa: SIM115 31 | 32 | def __exit__(self, exc_type, exc_val, exc_tb): 33 | sys.stdout.close() 34 | sys.stdout = self._original_stdout 35 | 36 | 37 | def registrate_ants(moving: NII, fixed: NII, type_of_transform="DenseRigid", verbose=False, **qargs): 38 | import ants 39 | 40 | mytx = ants.registration(fixed=fixed.to_ants(), moving=moving.to_ants(), type_of_transform=type_of_transform, verbose=verbose, **qargs) 41 | 42 | warped_moving = mytx["warpedmovout"] 43 | print(mytx) 44 | return NII(ants.to_nibabel(warped_moving)), mytx["fwdtransforms"] 45 | 46 | 47 | def registrate_nipy( 48 | moving: NII, 49 | fixed: NII, 50 | similarity: Similarity_Measures = "cc", 51 | optimizer: Affine_Transforms = "rigid", 52 | other_moving: list[NII] | None = None, 53 | ): 54 | if other_moving is None: 55 | other_moving = [] 56 | hist_reg = nipy_reg.HistogramRegistration(fixed.nii, moving.nii, similarity=similarity) 57 | with HiddenPrints(): 58 | transform: Affine = hist_reg.optimize(optimizer, iterations=100) 59 | aligned_img = apply_registration_nipy(moving, fixed, transform) 60 | out_arr = [apply_registration_nipy(i, fixed, transform) for i in other_moving] 61 | for out, other in zip(out_arr, other_moving): 62 | out.seg = other.seg 63 | return aligned_img, transform, out_arr 64 | 65 | 66 | def only_change_affine(nii: NII, transform: Affine): 67 | aff = nii.affine 68 | t_affine = transform.as_affine() 69 | t_affine = np.dot(t_affine, aff) 70 | return NII(nib.nifti1.Nifti1Image(nii.get_array(), t_affine), nii.seg) 71 | 72 | 73 | def apply_registration_nipy(moving: NII, fixed: NII, transform: Affine): 74 | aligned_img = nipy_reg.resample(moving.nii, transform, fixed.nii, interp_order=0 if moving.seg else 3) 75 | aligned_img = fixed.set_array(aligned_img.get_data()) 76 | aligned_img.seg = moving.seg 77 | return aligned_img 78 | 79 | 80 | def register_native_res( 81 | moving: NII, 82 | fixed: NII, 83 | similarity: Similarity_Measures = "cc", 84 | optimizer: Affine_Transforms = "rigid", 85 | other_moving: list[NII] | None = None, 86 | ) -> tuple[NII, NII, Affine, list[NII]]: 87 | """register an image to an other, with its native resolution of moving. Uses Global coordinates. 88 | 89 | Args: 90 | moving (NII): _description_ 91 | fixed (NII): _description_ 92 | similarity (Similarity_Measures, optional): _description_. Defaults to "cc". 93 | optimizer (Affine_Transforms, optional): _description_. Defaults to "rigid". 94 | 95 | Returns: 96 | (NII,NII): _description_ 97 | """ 98 | if other_moving is None: 99 | other_moving = [] 100 | fixed_m_res = fixed.copy() 101 | fixed_m_res.resample_from_to_(moving) 102 | aligned_img, transform, out_arr = registrate_nipy(moving, fixed_m_res, similarity, optimizer, other_moving) 103 | return aligned_img, fixed_m_res, transform, out_arr 104 | 105 | 106 | def crop_shared_(a: NII, b: NII): 107 | crop = a.compute_crop() 108 | crop = b.compute_crop(other_crop=crop) 109 | print(crop) 110 | a.apply_crop_(crop) 111 | b.apply_crop_(crop) 112 | return crop 113 | 114 | 115 | if __name__ == "__main__": 116 | p = "/media/data/new_NAKO/NAKO/MRT/rawdata/105/sub-105013/" 117 | moving = NII.load(Path(p, "t1dixon", "sub-105013_acq-ax_rec-in_chunk-2_t1dixon.nii.gz"), False) 118 | fixed = NII.load(Path(p, "T2w", "sub-105013_acq-sag_chunk-LWS_sequ-31_T2w.nii.gz"), False) 119 | fixed.resample_from_to_(moving) 120 | # fixed.save("fixed_rep.nii.gz") 121 | aligned_img = registrate_nipy(moving, fixed) 122 | -------------------------------------------------------------------------------- /TPTBox/registration/ridged_points/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | try: 4 | from .point_registration import Point_Registration, ridged_points_from_poi, ridged_points_from_subreg_vert 5 | 6 | except Exception: 7 | pass 8 | -------------------------------------------------------------------------------- /TPTBox/segmentation/TotalVibeSeg/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from TPTBox.segmentation.TotalVibeSeg.totalvibeseg import extract_vertebra_bodies_from_totalVibe, run_totalvibeseg, total_vibe_map 4 | -------------------------------------------------------------------------------- /TPTBox/segmentation/TotalVibeSeg/auto_download.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Hartmut Häntze 2 | # Edited by Robert Graf 2024 3 | 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | from __future__ import annotations 16 | 17 | import json 18 | import logging 19 | import os 20 | import urllib.request 21 | import zipfile 22 | from pathlib import Path 23 | from typing import Any 24 | 25 | from tqdm import tqdm 26 | 27 | logger = logging.getLogger(__name__) 28 | WEIGHTS_URL_ = "https://github.com/robert-graf/TotalVibeSegmentator/releases/download/v1.0.0/" 29 | env_name = "TOTALVIBE_WEIGHTS_PATH" 30 | 31 | 32 | def get_weights_dir(idx, model_path: Path | None = None) -> Path: 33 | if env_name in os.environ: 34 | weights_dir: Path = Path(os.environ[env_name]) 35 | elif model_path is not None and model_path.exists(): 36 | weights_dir = model_path 37 | else: 38 | assert Path(__file__).parent.name == "TotalVibeSeg", Path(__file__).parent 39 | 40 | weights_dir = Path(__file__).parent.parent / "nnUNet/nnUNet_results" 41 | 42 | weights_dir.parent.mkdir(exist_ok=True) 43 | weights_dir.mkdir(exist_ok=True) 44 | weights_dir = weights_dir / f"Dataset{idx:03}" 45 | 46 | return weights_dir 47 | 48 | 49 | def read_config(idx) -> dict[str, float]: 50 | weights_dir = get_weights_dir(idx) 51 | ds_path = weights_dir / "dataset.json" 52 | if ds_path.exists(): 53 | with open(ds_path) as f: 54 | config_info: dict[str, float] = json.load(f) 55 | return config_info 56 | else: 57 | return {"dataset_release": 0.0} 58 | 59 | 60 | def _download_weights(idx=85, addendum="", first=True) -> None: 61 | weights_dir = get_weights_dir(idx) 62 | weights_url = WEIGHTS_URL_ + f"{idx:03}{addendum}.zip" 63 | _download(weights_url, weights_dir, text="pretrained weights") 64 | if first: 65 | addendum_download(idx) 66 | 67 | 68 | def _download(weights_url, weights_dir, text="", is_zip=True) -> None: 69 | try: 70 | # Retrieve file size 71 | with urllib.request.urlopen(str(weights_url)) as response: 72 | file_size = int(response.info().get("Content-Length", -1)) 73 | except Exception: 74 | print("Download attempt failed:", weights_url) 75 | return 76 | print(f"Downloading {text}...") 77 | 78 | with tqdm(total=file_size, unit="B", unit_scale=True, unit_divisor=1024, desc=Path(weights_url).name) as pbar: 79 | 80 | def update_progress(block_num: int, block_size: int, total_size: int) -> None: 81 | if pbar.total != total_size: 82 | pbar.total = total_size 83 | pbar.update(block_num * block_size - pbar.n) 84 | 85 | zip_path = weights_dir.parent / Path(weights_url).name 86 | # Download the file 87 | urllib.request.urlretrieve(str(weights_url), zip_path, reporthook=update_progress) 88 | if is_zip: 89 | print(f"Extracting {text}...") 90 | with zipfile.ZipFile(zip_path, "r") as zip_ref: 91 | zip_ref.extractall(weights_dir) 92 | os.remove(zip_path) # noqa: PTH107 93 | 94 | 95 | def addendum_download(idx): 96 | weights_dir = get_weights_dir(idx) 97 | next_zip = weights_dir / "other_downloads.json" 98 | if next_zip.exists(): 99 | with open(next_zip) as f: 100 | add = json.load(f) 101 | [_download_weights(idx, addendum=a, first=False) for a in add] 102 | next_zip.unlink() 103 | 104 | 105 | def download_weights(idx, model_path: Path | None = None) -> Path: 106 | weights_dir = get_weights_dir(idx, model_path) 107 | 108 | # Check if weights are downloaded 109 | if not weights_dir.exists(): 110 | _download_weights(idx) 111 | addendum_download(idx) 112 | return weights_dir 113 | -------------------------------------------------------------------------------- /TPTBox/segmentation/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from TPTBox.segmentation.spineps import run_spineps_all, run_spineps_single 4 | from TPTBox.segmentation.TotalVibeSeg import run_totalvibeseg 5 | -------------------------------------------------------------------------------- /TPTBox/segmentation/nnUnet_utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Hendrik-code/TPTBox/c6cd01460d09df747fd6baa77680a8f575049c7c/TPTBox/segmentation/nnUnet_utils/__init__.py -------------------------------------------------------------------------------- /TPTBox/segmentation/nnUnet_utils/data_iterators.py: -------------------------------------------------------------------------------- 1 | # Adapted from https://github.com/MIC-DKFZ/nnUNet 2 | # Isensee, F., Jaeger, P. F., Kohl, S. A., Petersen, J., & Maier-Hein, K. H. (2021). nnU-Net: a self-configuring 3 | # method for deep learning-based biomedical image segmentation. Nature methods, 18(2), 203-211. 4 | from __future__ import annotations 5 | 6 | import numpy as np 7 | import torch 8 | from batchgenerators.dataloading.data_loader import DataLoader 9 | 10 | from TPTBox.segmentation.nnUnet_utils.default_preprocessor import DefaultPreprocessor 11 | from TPTBox.segmentation.nnUnet_utils.plans_handler import ConfigurationManager, PlansManager 12 | 13 | 14 | class PreprocessAdapterFromNpy(DataLoader): 15 | def __init__( 16 | self, 17 | list_of_images: list[np.ndarray], 18 | list_of_segs_from_prev_stage: list[np.ndarray] | None, 19 | list_of_image_properties: list[dict], 20 | truncated_of_names: list[str] | None, 21 | plans_manager: PlansManager, 22 | dataset_json: dict, 23 | configuration_manager: ConfigurationManager, 24 | num_threads_in_multithreaded: int = 1, 25 | verbose: bool = False, 26 | ): 27 | preprocessor = DefaultPreprocessor(verbose=verbose) 28 | self.preprocessor, self.plans_manager, self.configuration_manager, self.dataset_json, self.truncated_of_names = ( 29 | preprocessor, 30 | plans_manager, 31 | configuration_manager, 32 | dataset_json, 33 | truncated_of_names, 34 | ) 35 | self.label_manager = plans_manager.get_label_manager(dataset_json) 36 | 37 | if list_of_segs_from_prev_stage is None: 38 | list_of_segs_from_prev_stage = [None] * len(list_of_images) # type: ignore 39 | if truncated_of_names is None: 40 | truncated_of_names = [None] * len(list_of_images) # type: ignore 41 | 42 | super().__init__( 43 | list(zip(list_of_images, list_of_segs_from_prev_stage, list_of_image_properties, truncated_of_names)), # type: ignore 44 | 1, 45 | num_threads_in_multithreaded, 46 | seed_for_shuffle=1, 47 | return_incomplete=True, 48 | shuffle=False, 49 | infinite=False, 50 | sampling_probabilities=None, 51 | ) 52 | 53 | self.indices = list(range(len(list_of_images))) 54 | 55 | def generate_train_batch(self): 56 | idx = self.get_indices()[0] 57 | image = self._data[idx][0] 58 | seg_prev_stage = self._data[idx][1] 59 | props = self._data[idx][2] 60 | of_name = self._data[idx][3] 61 | # if we have a segmentation from the previous stage we have to process it together with the images so that we 62 | # can crop it appropriately (if needed). Otherwise it would just be resized to the shape of the data after 63 | # preprocessing and then there might be misalignments 64 | data, seg = self.preprocessor.run_case_npy( 65 | image, seg_prev_stage, props, self.plans_manager, self.configuration_manager, self.dataset_json 66 | ) 67 | if seg_prev_stage is not None: 68 | seg_onehot = convert_labelmap_to_one_hot(seg[0], self.label_manager.foreground_labels, data.dtype) 69 | data = np.vstack((data, seg_onehot)) 70 | 71 | data = torch.from_numpy(data) 72 | 73 | return {"data": data, "data_properites": props, "ofile": of_name} 74 | 75 | 76 | def convert_labelmap_to_one_hot( 77 | segmentation: np.ndarray | torch.Tensor, all_labels: list | torch.Tensor | np.ndarray | tuple, output_dtype=None 78 | ) -> np.ndarray | torch.Tensor: 79 | """ 80 | if output_dtype is None then we use np.uint8/torch.uint8 81 | if input is torch.Tensor then output will be on the same device 82 | 83 | np.ndarray is faster than torch.Tensor 84 | 85 | if segmentation is torch.Tensor, this function will be faster if it is LongTensor. If it is something else we have 86 | to cast which takes time. 87 | 88 | IMPORTANT: This function only works properly if your labels are consecutive integers, so something like 0, 1, 2, 3, ... 89 | DO NOT use it with 0, 32, 123, 255, ... or whatever (fix your labels, yo) 90 | """ 91 | if isinstance(segmentation, torch.Tensor): 92 | result = torch.zeros( 93 | (len(all_labels), *segmentation.shape), 94 | dtype=output_dtype if output_dtype is not None else torch.uint8, 95 | device=segmentation.device, 96 | ) 97 | # variant 1, 2x faster than 2 98 | result.scatter_(0, segmentation[None].long(), 1) # why does this have to be long!? 99 | # variant 2, slower than 1 100 | # for i, l in enumerate(all_labels): 101 | # result[i] = segmentation == l 102 | else: 103 | result = np.zeros((len(all_labels), *segmentation.shape), dtype=output_dtype if output_dtype is not None else np.uint8) 104 | # variant 1, fastest in my testing 105 | for i, l in enumerate(all_labels): 106 | result[i] = segmentation == l 107 | # variant 2. Takes about twice as long so nah 108 | # result = np.eye(len(all_labels))[segmentation].transpose((3, 0, 1, 2)) 109 | return result 110 | -------------------------------------------------------------------------------- /TPTBox/segmentation/nnUnet_utils/get_network_from_plans.py: -------------------------------------------------------------------------------- 1 | # Adapted from https://github.com/MIC-DKFZ/nnUNet 2 | # Isensee, F., Jaeger, P. F., Kohl, S. A., Petersen, J., & Maier-Hein, K. H. (2021). nnU-Net: a self-configuring 3 | # method for deep learning-based biomedical image segmentation. Nature methods, 18(2), 203-211. 4 | from __future__ import annotations 5 | 6 | from dynamic_network_architectures.architectures.unet import PlainConvUNet, ResidualEncoderUNet 7 | from dynamic_network_architectures.building_blocks.helper import convert_dim_to_conv_op, get_matching_instancenorm 8 | from dynamic_network_architectures.initialization.weight_init import init_last_bn_before_add_to_0 9 | from nnunetv2.utilities.network_initialization import InitWeights_He 10 | from nnunetv2.utilities.plans_handling.plans_handler import ConfigurationManager, PlansManager 11 | from torch import nn 12 | 13 | 14 | def get_network_from_plans( 15 | plans_manager: PlansManager, 16 | dataset_json: dict, 17 | configuration_manager: ConfigurationManager, 18 | num_input_channels: int, 19 | num_output_channels: int | None = None, 20 | deep_supervision: bool = True, 21 | ): 22 | """ 23 | we may have to change this in the future to accommodate other plans -> network mappings 24 | 25 | num_input_channels can differ depending on whether we do cascade. Its best to make this info available in the 26 | trainer rather than inferring it again from the plans here. 27 | """ 28 | if "architecture" in configuration_manager.configuration: 29 | from nnunetv2.utilities.get_network_from_plans import get_network_from_plans as plans 30 | 31 | class_name = configuration_manager.configuration["architecture"]["network_class_name"] 32 | kwargs = configuration_manager.configuration["architecture"]["arch_kwargs"] 33 | _kw_requires_import = configuration_manager.configuration["architecture"]["_kw_requires_import"] 34 | return plans(class_name, kwargs, _kw_requires_import, num_input_channels, num_output_channels, deep_supervision=deep_supervision) # type: ignore 35 | num_stages = len(configuration_manager.conv_kernel_sizes) 36 | 37 | dim = len(configuration_manager.conv_kernel_sizes[0]) 38 | conv_op = convert_dim_to_conv_op(dim) 39 | 40 | label_manager = plans_manager.get_label_manager(dataset_json) 41 | 42 | segmentation_network_class_name = configuration_manager.UNet_class_name 43 | mapping = {"PlainConvUNet": PlainConvUNet, "ResidualEncoderUNet": ResidualEncoderUNet} 44 | kwargs = { 45 | "PlainConvUNet": { 46 | "conv_bias": True, 47 | "norm_op": get_matching_instancenorm(conv_op), 48 | "norm_op_kwargs": {"eps": 1e-5, "affine": True}, 49 | "dropout_op": None, 50 | "dropout_op_kwargs": None, 51 | "nonlin": nn.LeakyReLU, 52 | "nonlin_kwargs": {"inplace": True}, 53 | }, 54 | "ResidualEncoderUNet": { 55 | "conv_bias": True, 56 | "norm_op": get_matching_instancenorm(conv_op), 57 | "norm_op_kwargs": {"eps": 1e-5, "affine": True}, 58 | "dropout_op": None, 59 | "dropout_op_kwargs": None, 60 | "nonlin": nn.LeakyReLU, 61 | "nonlin_kwargs": {"inplace": True}, 62 | }, 63 | } 64 | assert segmentation_network_class_name in mapping.keys(), ( 65 | "The network architecture specified by the plans file " 66 | "is non-standard (maybe your own?). Yo'll have to dive " 67 | "into either this " 68 | "function (get_network_from_plans) or " 69 | "the init of your nnUNetModule to accomodate that." 70 | ) 71 | network_class = mapping[segmentation_network_class_name] 72 | 73 | conv_or_blocks_per_stage = { 74 | "n_conv_per_stage" 75 | if network_class != ResidualEncoderUNet 76 | else "n_blocks_per_stage": configuration_manager.n_conv_per_stage_encoder, 77 | "n_conv_per_stage_decoder": configuration_manager.n_conv_per_stage_decoder, 78 | } 79 | # network class name!! 80 | model = network_class( 81 | input_channels=num_input_channels, 82 | n_stages=num_stages, 83 | features_per_stage=[ 84 | min(configuration_manager.UNet_base_num_features * 2**i, configuration_manager.unet_max_num_features) for i in range(num_stages) 85 | ], 86 | conv_op=conv_op, 87 | kernel_sizes=configuration_manager.conv_kernel_sizes, 88 | strides=configuration_manager.pool_op_kernel_sizes, 89 | num_classes=label_manager.num_segmentation_heads, 90 | deep_supervision=deep_supervision, 91 | **conv_or_blocks_per_stage, 92 | **kwargs[segmentation_network_class_name], 93 | ) 94 | model.apply(InitWeights_He(1e-2)) 95 | if network_class == ResidualEncoderUNet: 96 | model.apply(init_last_bn_before_add_to_0) 97 | return model 98 | -------------------------------------------------------------------------------- /TPTBox/segmentation/nnUnet_utils/sliding_window_prediction.py: -------------------------------------------------------------------------------- 1 | # Adapted from https://github.com/MIC-DKFZ/nnUNet 2 | # Isensee, F., Jaeger, P. F., Kohl, S. A., Petersen, J., & Maier-Hein, K. H. (2021). nnU-Net: a self-configuring 3 | # method for deep learning-based biomedical image segmentation. Nature methods, 18(2), 203-211. 4 | from __future__ import annotations 5 | 6 | from functools import lru_cache 7 | 8 | import numpy as np 9 | import torch 10 | from acvl_utils.cropping_and_padding.padding import pad_nd_image 11 | from scipy.ndimage import gaussian_filter 12 | 13 | 14 | @lru_cache(maxsize=2) 15 | def compute_gaussian( 16 | tile_size: tuple[int, ...] | list[int], 17 | sigma_scale: float = 1.0 / 8, 18 | value_scaling_factor: float = 1, 19 | dtype=torch.float16, 20 | device=torch.device("cuda", 0), # noqa: B008 21 | ) -> torch.Tensor: 22 | tmp = np.zeros(tile_size) 23 | center_coords = [i // 2 for i in tile_size] 24 | sigmas = [i * sigma_scale for i in tile_size] 25 | tmp[tuple(center_coords)] = 1 26 | gaussian_importance_map = gaussian_filter(tmp, sigmas, 0, mode="constant", cval=0) 27 | 28 | gaussian_importance_map = torch.from_numpy(gaussian_importance_map).type(dtype).to(device) 29 | 30 | gaussian_importance_map = gaussian_importance_map / torch.max(gaussian_importance_map) * value_scaling_factor 31 | gaussian_importance_map = gaussian_importance_map.type(dtype) 32 | 33 | # gaussian_importance_map cannot be 0, otherwise we may end up with nans! 34 | gaussian_importance_map[gaussian_importance_map == 0] = torch.min(gaussian_importance_map[gaussian_importance_map != 0]) 35 | 36 | return gaussian_importance_map 37 | 38 | 39 | def compute_steps_for_sliding_window(image_size: tuple[int, ...], tile_size: tuple[int, ...], tile_step_size: float) -> list[list[int]]: 40 | assert [i >= j for i, j in zip(image_size, tile_size)], "image size must be as large or larger than patch_size" 41 | assert 0 < tile_step_size <= 1, "step_size must be larger than 0 and smaller or equal to 1" 42 | 43 | # our step width is patch_size*step_size at most, but can be narrower. For example if we have image size of 44 | # 110, patch size of 64 and step_size of 0.5, then we want to make 3 steps starting at coordinate 0, 23, 46 45 | target_step_sizes_in_voxels = [i * tile_step_size for i in tile_size] 46 | 47 | num_steps = [int(np.ceil((i - k) / j)) + 1 for i, j, k in zip(image_size, target_step_sizes_in_voxels, tile_size)] 48 | 49 | steps = [] 50 | for dim in range(len(tile_size)): 51 | # the highest step value for this dimension is 52 | max_step_value = image_size[dim] - tile_size[dim] 53 | # does not matter because there is only one step at 0 54 | actual_step_size = max_step_value / (num_steps[dim] - 1) if num_steps[dim] > 1 else 99999999999 55 | steps_here = [int(np.round(actual_step_size * i)) for i in range(num_steps[dim])] 56 | 57 | steps.append(steps_here) 58 | 59 | return steps 60 | -------------------------------------------------------------------------------- /TPTBox/segmentation/oar_segmentator/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Hendrik-code/TPTBox/c6cd01460d09df747fd6baa77680a8f575049c7c/TPTBox/segmentation/oar_segmentator/__init__.py -------------------------------------------------------------------------------- /TPTBox/segmentation/oar_segmentator/run.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import time 4 | from collections.abc import Sequence 5 | from concurrent.futures import ThreadPoolExecutor, as_completed 6 | from pathlib import Path 7 | 8 | import GPUtil 9 | from tqdm import tqdm 10 | 11 | from TPTBox import BIDS_FILE, NII, POI, to_nii 12 | from TPTBox.segmentation.nnUnet_utils.inference_api import load_inf_model, run_inference 13 | from TPTBox.segmentation.oar_segmentator.map_to_binary import class_map, class_map_9_parts, except_labels_combine, map_taskid_to_partname 14 | 15 | class_map_inv = {v: k for k, v in class_map.items()} 16 | 17 | 18 | def save_resampled_segmentation(seg_nii: NII, in_file: BIDS_FILE, parent, org: NII | POI, idx): 19 | """Helper function to resample and save NIfTI file.""" 20 | out_path = in_file.get_changed_path("nii.gz", "msk", parent=parent, info={"seg": f"oar-{idx}"}, non_strict_mode=True) 21 | seg_nii.resample_from_to(org, verbose=False, mode="nearest").save(out_path) 22 | 23 | 24 | def run_oar_segmentor( 25 | ct_path: Path | str | BIDS_FILE, 26 | dataset: Path | str | None = None, 27 | oar_path="/home/fercho/code/oar_segmentator/models/nnunet/results/nnUNet/3d_fullres/", 28 | parent="derivatives", 29 | gpu=None, 30 | override=False, 31 | ): 32 | ## Hard coded info ## 33 | zoom = 1.5 34 | orientation = ("R", "A", "S") 35 | #### 36 | if isinstance(ct_path, BIDS_FILE): 37 | in_file = ct_path 38 | else: 39 | if dataset is None: 40 | dataset = Path(ct_path).parent 41 | in_file = BIDS_FILE(ct_path, dataset) 42 | out_path_combined = in_file.get_changed_path("nii.gz", "msk", parent=parent, info={"seg": "oar-combined"}, non_strict_mode=True) 43 | if out_path_combined.exists() and not override: 44 | print("skip", out_path_combined.name, " ", end="\r") 45 | return 46 | org = to_nii(in_file) 47 | print("resample ") 48 | input_nii = org.rescale((zoom, zoom, zoom), mode="nearest").reorient(orientation) 49 | org = (org.shape, org.affine, org.zoom) 50 | segs: dict[int, NII] = {} 51 | futures = [] 52 | # Create ThreadPoolExecutor for parallel saving 53 | print("start") 54 | with ThreadPoolExecutor(max_workers=4) as executor: 55 | for idx in tqdm(range(251, 260), desc="Predict oar segmentation"): 56 | # Suppress stdout and stderr for run_inference 57 | nnunet_path = next(next(iter(Path(oar_path).glob(f"*{idx}*"))).glob("*__nnUNetPlans*")) 58 | nnunet = load_inf_model(nnunet_path, allow_non_final=True, use_folds=(0,), gpu=gpu) 59 | seg_nii, _, _ = run_inference(input_nii, nnunet, logits=False) 60 | segs[idx] = seg_nii 61 | # Submit the save task to the thread pool 62 | futures.append(executor.submit(save_resampled_segmentation, seg_nii, in_file, parent, org, idx)) 63 | # Wait for all save tasks to complete 64 | for future in as_completed(futures): 65 | future.result() # Ensure any exceptions in threads are raised 66 | seg_combined = segs[251] * 0 67 | for tid in range(251, 260): 68 | seg = segs[tid] 69 | for jdx, class_name in class_map_9_parts[map_taskid_to_partname[tid]].items(): 70 | if any(class_name in s for s in except_labels_combine): 71 | continue 72 | seg_combined[seg == jdx] = class_map_inv[class_name] 73 | seg_combined.resample_from_to(org, verbose=False, mode="nearest").save(out_path_combined) 74 | 75 | 76 | def check_gpu_memory(gpu_id, threshold=50): 77 | """Check the GPU memory utilization and return True if usage exceeds threshold.""" 78 | gpus = GPUtil.getGPUs() 79 | for gpu in gpus: 80 | if gpu.id == gpu_id: 81 | return gpu.memoryUtil * 100 > threshold 82 | return False 83 | 84 | 85 | def run_oar_segmentor_in_parallel(dataset, parents: Sequence[str] = ("rawdata",), gpu_id=3, threshold=50, max_workers=16, override=False): 86 | """Run the OAR segmentation in parallel and pause when GPU memory usage exceeds the threshold.""" 87 | from TPTBox import BIDS_Global_info 88 | 89 | bgi = BIDS_Global_info([dataset], parents=parents) 90 | 91 | futures = [] 92 | 93 | # ThreadPoolExecutor for parallel execution 94 | with ThreadPoolExecutor(max_workers=max_workers) as executor: 95 | for _name, subject in bgi.enumerate_subjects(): 96 | q = subject.new_query(flatten=True) 97 | q.filter_filetype("nii.gz") 98 | q.filter_format("ct") 99 | 100 | for i in q.loop_list(): 101 | # Check GPU memory usage and pause if above threshold 102 | while check_gpu_memory(gpu_id, threshold): 103 | print(f"GPU memory usage exceeded {threshold}%. Pausing submission...") 104 | time.sleep(10) # Pause for 10 seconds before checking again 105 | 106 | # Submit run_oar_segmentor task to the executor 107 | futures.append(executor.submit(run_oar_segmentor, i, gpu=gpu_id, override=override)) 108 | 109 | # Wait for all tasks to complete 110 | for future in as_completed(futures): 111 | try: 112 | future.result() # This will raise any exceptions encountered 113 | except Exception as e: 114 | print(f"Error in execution: {e}") 115 | 116 | 117 | if __name__ == "__main__": 118 | # Example usage 119 | bgi = "/DATA/NAS/datasets_processed/CT_spine/dataset-shockroom-without-fx/" 120 | 121 | run_oar_segmentor_in_parallel(bgi, ("rawdata_fixed",), gpu_id=0, threshold=50, max_workers=16, override=False) 122 | -------------------------------------------------------------------------------- /TPTBox/spine/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Hendrik-code/TPTBox/c6cd01460d09df747fd6baa77680a8f575049c7c/TPTBox/spine/__init__.py -------------------------------------------------------------------------------- /TPTBox/spine/snapshot2D/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from .snapshot_modular import Image_Modes, Snapshot_Frame, Visualization_Type, colors_itk, create_snapshot 4 | from .snapshot_templates import ct_mri_snapshot, mip_shot, mri_snapshot, poi_snapshot, snapshot, spline_shot, vibe_snapshot 5 | -------------------------------------------------------------------------------- /TPTBox/spine/spinestats/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from .angles import ( 4 | compute_lordosis_and_kyphosis, 5 | compute_max_cobb_angle, 6 | compute_max_cobb_angle_multi, 7 | plot_cobb_and_lordosis_and_kyphosis, 8 | plot_cobb_angle, 9 | plot_compute_lordosis_and_kyphosis, 10 | ) 11 | from .ivd_pois import calculate_IVD_POI, calculate_pca_normal_np, compute_fake_ivd 12 | from .make_endplate import endplate_extraction 13 | -------------------------------------------------------------------------------- /TPTBox/spine/spinestats/distances.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from TPTBox import POI, Image_Reference, Location, calc_poi_from_subreg_vert 4 | 5 | 6 | def _compute_distance( 7 | poi: POI, 8 | l1: Location, 9 | l2: Location, 10 | key: str, 11 | vert: Image_Reference | None = None, 12 | subreg: Image_Reference | None = None, 13 | all_pois_computed=False, 14 | recompute=False, 15 | ): 16 | """Compute the IVD height with a single point. Returns poi-object with computed points poi heights in 17 | poi.info["ivd_heights_center_mm"] 18 | all_pois_computed requires [Location.Vertebra_Direction_Inferior, Location.Vertebra_Disc_Superior] to be computed 19 | recompute skips if poi.info["ivd_heights_center_mm"] exist 20 | Args: 21 | vert (Image_Reference): _description_ 22 | subreg (Image_Reference): _description_ 23 | poi (POI): _description_ 24 | """ 25 | if key in poi.info and not recompute: 26 | return poi 27 | if not all_pois_computed: 28 | if vert is None or (subreg is None and l1.value not in poi.keys_region() and l2.value in poi.keys_region()): 29 | raise ValueError(f"{vert=} and {subreg=} must be set or precomputed all pois; {all_pois_computed=} -- {poi.keys_region()=}") 30 | else: 31 | all_pois_computed = True 32 | if not all_pois_computed: 33 | poi = calc_poi_from_subreg_vert(vert, subreg, extend_to=poi, subreg_id=[l1, l2]) 34 | poi.info[key] = poi.calculate_distances_poi_two_locations(l1, l2, keep_zoom=False) 35 | return poi 36 | 37 | 38 | distances_funs: dict[str, tuple[Location, Location]] = { 39 | "ivd_heights_center_mm": ( 40 | Location.Vertebra_Disc_Inferior, 41 | Location.Vertebra_Disc_Superior, 42 | ), 43 | "vertebra_heights_center_mm": ( 44 | Location.Additional_Vertebral_Body_Middle_Superior_Median, 45 | Location.Additional_Vertebral_Body_Middle_Inferior_Median, 46 | ), 47 | "vertebra_width_LR_center_mm": ( 48 | Location.Muscle_Inserts_Vertebral_Body_Right, 49 | Location.Muscle_Inserts_Vertebral_Body_Left, 50 | ), 51 | "vertebra_width_AP_center_mm": ( 52 | Location.Additional_Vertebral_Body_Posterior_Central_Median, 53 | Location.Additional_Vertebral_Body_Anterior_Central_Median, 54 | ), 55 | } 56 | 57 | 58 | def compute_all_distances( 59 | poi: POI, 60 | vert: Image_Reference | None = None, 61 | subreg: Image_Reference | None = None, 62 | all_pois_computed=False, 63 | recompute=False, 64 | ): 65 | for key, (l1, l2) in distances_funs.items(): 66 | poi = _compute_distance(poi, l1, l2, key, vert, subreg, all_pois_computed, recompute) 67 | return poi 68 | -------------------------------------------------------------------------------- /TPTBox/stitching/README.md: -------------------------------------------------------------------------------- 1 | # Torso Processing ToolBox (TPTBox) - Stitching 2 | 3 | This tool can merge multiple Nifti images if they are already aligned in global space. You can check this by opening them in ITKSnap with "open additional image." 4 | 5 | ![Example of a stitching](stitching.jpg?raw=true "Example of a stitching") 6 | 7 | 8 | ### Standalone 9 | This script can be run directly from the console. Copy 'stiching.py' and install the necessary package. 10 | 11 | ``` 12 | stitching.py 13 | [-h] print the help message 14 | [-i IMAGES [IMAGES ...]] a list of input image paths 15 | [-o OUTPUT] The output image path 16 | [-v] verbose - if set, there will be more printouts. 17 | [-min_value MIN_VALUE] New pixels not present will get this value. Recommended 0 for MRI and for CT -1024 or the known min-value. 18 | [-seg] This flag is required if you merge segmentation Niftis. 19 | Switches: 20 | [-no_bias] If set: Do not use n4_bias_field_correction. It speeds up the process, but n4_bias_field_correction helps in roughly aligning the histogram. 21 | [-bias_crop] crop empty spaces by the bias field mask. 22 | [-crop] crop empty space away 23 | [-sr] Store the ramp and stitching of the images in a 4d nii.gz 24 | Optional: 25 | [-hists] Use histogram matching to put the images in the roughly same histogram. The previous image is used when hist_n is not set. 26 | [-hist_n HISTOGRAM_NAME] path to an image that should be used for histogram matching 27 | [-ramp_e RAMP_EDGE_MIN_VALUE] The ramp is only considering values above this minimum value 28 | [-ms MIN_SPACING] Set the minimum Spacing (in mm) 29 | [-dtype DTYPE] Force a dtype 30 | ``` 31 | 32 | Example: 33 | 34 | Given the image a.nii.gz,b.nii.gz,c.nii.gz and the segmentations a_msk.nii.gz,b_msk.nii.gz,c_msk.nii.gz. The images can be merged with: 35 | 36 | ```bash 37 | stitching.py -i a.nii.gz b.nii.gz c.nii.gz -o out.nii.gz 38 | stitching.py -i a_msk.nii.gz b_msk.nii.gz c_msk.nii.gz -o out_msk.nii.gz -seg 39 | ``` 40 | 41 | ### Install as a package 42 | 43 | Install on Python 3.10 or higher 44 | ```bash 45 | pip install TPTBox 46 | ``` 47 | 48 | ```python 49 | from TPTBox import NII 50 | from TPTBox.stitching import stitching 51 | out_nii,_ = stitching([NII.load("a.nii.gz",seg=False), NII.load("b.nii.gz",seg=False), NII.load("c.nii.gz",seg=False)], out="out.nii.gz") 52 | 53 | ``` 54 | 55 | or 56 | 57 | 58 | ```python 59 | from TPTBox.stitching import stitching_raw 60 | stitching_raw(["a.nii.gz", "b.nii.gz", "c.nii.gz"], "out.nii.gz", is_segmentation=False) 61 | ``` 62 | 63 | 64 | ### Cite 65 | ``` 66 | Graf, R., Platzek, PS., Riedel, E.O. et al. Generating synthetic high-resolution spinal STIR and T1w images from T2w FSE and low-resolution axial Dixon. Eur Radiol (2024). https://doi.org/10.1007/s00330-024-11047-1 67 | 68 | ``` 69 | 70 | ``` 71 | @article{graf2024generating, 72 | title={Generating synthetic high-resolution spinal STIR and T1w images from T2w FSE and low-resolution axial Dixon}, 73 | author={Graf, Robert and Platzek, Paul-S{\"o}ren and Riedel, Evamaria Olga and Kim, Su Hwan and Lenhart, Nicolas and Ramsch{\"u}tz, Constanze and Paprottka, Karolin Johanna and Kertels, Olivia Ruriko and M{\"o}ller, Hendrik Kristian and Atad, Matan and others}, 74 | journal={European Radiology}, 75 | pages={1--11}, 76 | year={2024}, 77 | publisher={Springer} 78 | } 79 | 80 | ``` -------------------------------------------------------------------------------- /TPTBox/stitching/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from TPTBox.stitching.stitching_tools import GNC_stitch_T2w, stitching, stitching_raw 4 | -------------------------------------------------------------------------------- /TPTBox/stitching/stitching.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Hendrik-code/TPTBox/c6cd01460d09df747fd6baa77680a8f575049c7c/TPTBox/stitching/stitching.jpg -------------------------------------------------------------------------------- /TPTBox/stitching/stitching_tools.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from pathlib import Path 4 | 5 | import numpy as np 6 | 7 | from TPTBox import BIDS_FILE, NII, Image_Reference, No_Logger, to_nii 8 | from TPTBox.stitching.stitching import main as stitching_raw 9 | 10 | logger = No_Logger() 11 | 12 | 13 | def stitching( 14 | bids_files: list[BIDS_FILE | NII | str | Path], 15 | out: BIDS_FILE | str | Path, 16 | is_seg=False, 17 | is_ct: bool = False, 18 | verbose_stitching=False, 19 | bias_field=False, 20 | kick_out_fully_integrated_images=True, 21 | verbose=True, 22 | dtype: type = float, 23 | match_histogram=False, 24 | store_ramp=False, 25 | ): 26 | out = str(out.file["nii.gz"]) if isinstance(out, BIDS_FILE) else str(out) 27 | files = [to_nii(bf).nii for bf in bids_files] 28 | logger.print("stitching", out, verbose=verbose) 29 | return stitching_raw( 30 | files, 31 | out, 32 | match_histogram=match_histogram, 33 | store_ramp=store_ramp, 34 | verbose=verbose_stitching, 35 | min_value=-1024 if is_ct else 0, 36 | bias_field=bias_field, 37 | kick_out_fully_integrated_images=kick_out_fully_integrated_images, 38 | is_segmentation=is_seg, 39 | dtype=dtype, 40 | ) 41 | 42 | 43 | def _crop_borders(nii: NII, chunk_info: str, cut: dict[str, tuple[slice, slice, slice]]) -> NII: 44 | if chunk_info not in cut: 45 | logger.print("chunk_info must be in [HWS, BWS, LWS]") 46 | ori = nii.orientation 47 | return nii.reorient_().apply_crop_(cut[chunk_info]).reorient_(ori) 48 | 49 | 50 | def GNC_stitch_T2w( 51 | HWS: Image_Reference, # noqa: N803 52 | BWS: Image_Reference, # noqa: N803 53 | LWS: Image_Reference, # noqa: N803 54 | n4_after_stitch: bool = False, 55 | # cut={ 56 | # "HWS": (slice(None), slice(0, 400), slice(None)), 57 | # "BWS": (slice(None), slice(80, 400), slice(None)), 58 | # "LWS": (slice(None), slice(48, 448), slice(None)), 59 | # }, 60 | ): 61 | """Preprocessing steps where n4 each chunk, then stitch, then n4 62 | Args: 63 | HWS (NII | str | Path): Cervical region 64 | BWS (NII | str | Path): Thoracic region 65 | LWS (NII | str | Path): Lumbar region 66 | n4_after_stitch (bool): where to do n4 correction after stitching again 67 | Returns: 68 | NII: Stitched and n4 corrected nifty 69 | """ 70 | chunks = {"HWS": {}, "BWS": {}, "LWS": {}} 71 | chunks["HWS"]["nii"] = NII.load(HWS, seg=False).reorient_() 72 | chunks["BWS"]["nii"] = NII.load(BWS, seg=False).reorient_() 73 | chunks["LWS"]["nii"] = NII.load(LWS, seg=False).reorient_() 74 | # for k in chunks.keys(): 75 | # # chunks[k]["n4"] = _crop_borders(n4_bias(chunks[k]["nii"], spline_param=200)[0], k, cut) 76 | # # chunks[k]["n4"].apply_crop_slice_(cut[k]) 77 | # chunks_m = {k: chunks[k]["n4"] for k in chunks.keys()} 78 | # chunks_a = list([l.nii for l in chunks_m.values()]) 79 | chunks_a = [a["nii"].nii for a in chunks.values()] 80 | # Stitch three chunks together 81 | stitched, _ = stitching_raw( 82 | chunks_a, 83 | output=None, 84 | match_histogram=False, 85 | store_ramp=False, 86 | verbose=False, 87 | bias_field=False, 88 | save=False, 89 | ) 90 | stitched_nii = to_nii(stitched) 91 | if n4_after_stitch: 92 | stitched_nii, _ = n4_bias(stitched_nii) 93 | slices = ( 94 | _center_frontal(stitched_nii.shape[0]), 95 | slice(None), 96 | slice(None), 97 | ) 98 | stitched_nii.apply_crop_(slices) 99 | return stitched_nii.set_dtype_(np.uint16) 100 | 101 | 102 | def n4_bias( 103 | nii: NII, 104 | threshold: int = 70, 105 | spline_param: int = 200, 106 | dtype2nii: bool = False, 107 | norm: int = -1, 108 | ): 109 | from ants.utils.convert_nibabel import from_nibabel 110 | 111 | # print("n4 bias", nii.dtype) 112 | mask = nii.get_array() 113 | mask[mask < threshold] = 0 114 | mask[mask != 0] = 1 115 | mask_nii = nii.set_array(mask) 116 | mask_nii.seg = True 117 | mask_nii.dilate_msk_(mm=3, verbose=False) 118 | n4: NII = nii.n4_bias_field_correction(mask=from_nibabel(mask_nii.nii), spline_param=spline_param) 119 | if norm != -1: 120 | n4 *= norm / n4.max() 121 | if dtype2nii: 122 | n4.set_dtype_(nii.dtype) 123 | return n4, mask_nii 124 | 125 | 126 | def _center_frontal(size): 127 | """Calculating the location of the frontalplain +- 256""" 128 | # The multiplicator-factor was empirically evaluated in an excel sheet with 11 random MRIs 129 | lower_bound = size * 0.2 130 | upper_bound = size * 0.75 131 | distance = int((upper_bound - lower_bound) / 2) 132 | middle = int(lower_bound + distance) 133 | # rather push it little back over the top in dorsal direction rather then not putting it into account 134 | small_value = middle - 122 135 | big_value = middle + 134 136 | return slice(small_value, big_value) 137 | -------------------------------------------------------------------------------- /TPTBox/tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Hendrik-code/TPTBox/c6cd01460d09df747fd6baa77680a8f575049c7c/TPTBox/tests/__init__.py -------------------------------------------------------------------------------- /TPTBox/tests/sample_ct/sub-ct_label-22_ct.nii.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Hendrik-code/TPTBox/c6cd01460d09df747fd6baa77680a8f575049c7c/TPTBox/tests/sample_ct/sub-ct_label-22_ct.nii.gz -------------------------------------------------------------------------------- /TPTBox/tests/sample_ct/sub-ct_seg-subreg_label-22_msk.nii.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Hendrik-code/TPTBox/c6cd01460d09df747fd6baa77680a8f575049c7c/TPTBox/tests/sample_ct/sub-ct_seg-subreg_label-22_msk.nii.gz -------------------------------------------------------------------------------- /TPTBox/tests/sample_ct/sub-ct_seg-vert_label-22_msk.nii.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Hendrik-code/TPTBox/c6cd01460d09df747fd6baa77680a8f575049c7c/TPTBox/tests/sample_ct/sub-ct_seg-vert_label-22_msk.nii.gz -------------------------------------------------------------------------------- /TPTBox/tests/sample_mri/sub-mri_label-6_T2w.nii.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Hendrik-code/TPTBox/c6cd01460d09df747fd6baa77680a8f575049c7c/TPTBox/tests/sample_mri/sub-mri_label-6_T2w.nii.gz -------------------------------------------------------------------------------- /TPTBox/tests/sample_mri/sub-mri_seg-subreg_label-6_msk.nii.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Hendrik-code/TPTBox/c6cd01460d09df747fd6baa77680a8f575049c7c/TPTBox/tests/sample_mri/sub-mri_seg-subreg_label-6_msk.nii.gz -------------------------------------------------------------------------------- /TPTBox/tests/sample_mri/sub-mri_seg-vert_label-6_msk.nii.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Hendrik-code/TPTBox/c6cd01460d09df747fd6baa77680a8f575049c7c/TPTBox/tests/sample_mri/sub-mri_seg-vert_label-6_msk.nii.gz -------------------------------------------------------------------------------- /TPTBox/tests/speedtests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Hendrik-code/TPTBox/c6cd01460d09df747fd6baa77680a8f575049c7c/TPTBox/tests/speedtests/__init__.py -------------------------------------------------------------------------------- /TPTBox/tests/speedtests/speedtest.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import random 4 | import timeit 5 | from collections.abc import Callable 6 | from copy import deepcopy 7 | from time import perf_counter 8 | 9 | import numpy as np 10 | from tqdm import tqdm 11 | 12 | 13 | def speed_test_input( 14 | inp, 15 | functions: list[Callable], 16 | assert_equal_function: Callable | None = None, 17 | print_output: bool = False, 18 | *args, 19 | **kwargs, 20 | ): 21 | time_measures = {} 22 | outs = {} 23 | random.shuffle(functions) 24 | for f in functions: 25 | input_copy = deepcopy(inp) 26 | out = f(*input_copy, *args, **kwargs) if isinstance(input_copy, (tuple, list)) else f(input_copy, *args, **kwargs) 27 | if print_output: 28 | print(f.__name__, out) 29 | 30 | # def f_test(): 31 | # f(*input_copy, *args, **kwargs) if isinstance(input_copy, (tuple, list)) else f(input_copy, *args, **kwargs) 32 | 33 | time = timeit.timeit( 34 | "f(*input_copy, *args, **kwargs) if isinstance(input_copy, (tuple, list)) else f(input_copy, *args, **kwargs)", 35 | # "f_test()", 36 | # setup="def f_test(): f(*input_copy, *args, **kwargs) if isinstance(input_copy, (tuple, list)) else f(input_copy, *args, **kwargs)", 37 | number=5, 38 | globals=locals(), 39 | ) 40 | 41 | # time = perf_counter() - start 42 | outs[f.__name__] = out 43 | time_measures[f.__name__] = time 44 | 45 | if assert_equal_function is not None and len(functions) > 1: 46 | for o in outs.values(): 47 | assertion = assert_equal_function(o, outs[functions[0].__name__]) 48 | assert assertion, f"speed_test: nonequal results given the assert_equal_function, got {o, outs[functions[0].__name__]}" 49 | return time_measures 50 | 51 | 52 | def speed_test( 53 | get_input_func: Callable, 54 | functions: list[Callable], 55 | repeats: int = 20, 56 | assert_equal_function: Callable | None = None, 57 | *args, 58 | **kwargs, 59 | ): 60 | time_sums: dict[str, list[float]] = {"input_function": []} 61 | # print first iteration 62 | print() 63 | print("Print first speed test") 64 | 65 | for repeat_idx in tqdm(range(repeats)): 66 | start = perf_counter() 67 | inp = get_input_func() 68 | time = perf_counter() - start 69 | time_sums["input_function"].append(time) 70 | time_measures = speed_test_input( 71 | inp, 72 | *args, 73 | functions=functions, 74 | assert_equal_function=assert_equal_function, 75 | print_output=repeat_idx == 0, 76 | **kwargs, 77 | ) 78 | for k, v in time_measures.items(): 79 | if k not in time_sums: 80 | time_sums[k] = [] 81 | time_sums[k].append(v) 82 | 83 | times_sorted = dict(sorted(time_sums.items(), key=lambda x: sum(x[1]) / repeats)) 84 | for idx, (k, v) in enumerate(times_sorted.items()): 85 | print(idx + 1, ".\t", round(sum(v) / repeats, ndigits=6), "+-", round(np.std(v), ndigits=6), "\t", k) 86 | # for k, v in time_sums.items(): 87 | # print(k, "\t", round(sum(v) / repeats, ndigits=6), "+-", round(np.std(v), ndigits=6)) 88 | -------------------------------------------------------------------------------- /TPTBox/tests/speedtests/speedtest_cc3d.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | if __name__ == "__main__": 4 | # speed test dilation 5 | import random 6 | 7 | import numpy as np 8 | from cc3d import statistics as cc3dstats 9 | from scipy.ndimage import center_of_mass 10 | 11 | from TPTBox.core.np_utils import ( 12 | _to_labels, 13 | np_bbox_binary, 14 | np_bounding_boxes, 15 | np_center_of_mass, 16 | np_map_labels, 17 | np_unique, 18 | np_unique_withoutzero, 19 | np_volume, 20 | ) 21 | from TPTBox.tests.speedtests.speedtest import speed_test 22 | from TPTBox.tests.test_utils import get_nii 23 | 24 | def get_nii_array(): 25 | num_points = random.randint(1, 30) 26 | nii, points, orientation, sizes = get_nii(x=(140, 140, 150), num_point=num_points) 27 | # nii.map_labels_({1: -1}, verbose=False) 28 | arr = nii.get_seg_array().astype(np.uint8) 29 | # arr[arr == 1] = -1 30 | arr_r = arr.copy() 31 | return arr_r 32 | 33 | def cc3d_com(arr: np.ndarray): 34 | return np_center_of_mass(arr) 35 | 36 | def center_of_mass_one(arr: np.ndarray): 37 | coms = center_of_mass(arr) 38 | return coms 39 | 40 | def center_of_mass_(arr: np.ndarray): 41 | cc_label_set = np_unique(arr) 42 | coms = {} 43 | for c in cc_label_set: 44 | if c == 0: 45 | continue 46 | com = center_of_mass(arr == c) 47 | coms[c] = com 48 | return coms 49 | 50 | def bbox_(arr: np.ndarray): 51 | cc_label_set = np_unique(arr) 52 | coms = {} 53 | for c in cc_label_set: 54 | if c == 0: 55 | continue 56 | com = np_bbox_binary(arr == c) 57 | coms[c] = com 58 | return coms 59 | 60 | speed_test( 61 | repeats=50, 62 | get_input_func=get_nii_array, 63 | functions=[cc3d_com, center_of_mass_], 64 | assert_equal_function=lambda x, y: np.all([x[i][0] == y[i][0] for i in x.keys()]), # noqa: ARG005 65 | # np.all([x[i] == y[i] for i in range(len(x))]) 66 | ) 67 | # print(time_measures) 68 | -------------------------------------------------------------------------------- /TPTBox/tests/speedtests/speedtest_cc3d_crop.py: -------------------------------------------------------------------------------- 1 | if __name__ == "__main__": 2 | # speed test dilation 3 | import random 4 | 5 | import numpy as np 6 | from scipy.ndimage import center_of_mass 7 | 8 | from TPTBox.core.nii_wrapper import NII 9 | from TPTBox.core.np_utils import ( 10 | _to_labels, 11 | cc3dstatistics, 12 | np_bbox_binary, 13 | np_extract_label, 14 | ) 15 | from TPTBox.tests.speedtests.speedtest import speed_test 16 | from TPTBox.tests.test_utils import get_nii 17 | 18 | def get_nii_array(): 19 | num_points = random.randint(5, 10) 20 | nii, points, orientation, sizes = get_nii(x=(300, 300, 300), num_point=num_points) 21 | # nii.map_labels_({1: -1}, verbose=False) 22 | arr = nii.get_seg_array().astype(np.uint) 23 | # arr[arr == 1] = -1 24 | # arr_r = arr.copy() 25 | return arr 26 | 27 | def normal(arr): 28 | return cc3dstatistics(arr, use_crop=False) 29 | 30 | def crop(arr): 31 | crop = np_bbox_binary(arr) 32 | arr = arr[crop] 33 | return cc3dstatistics(arr, use_crop=False) 34 | 35 | speed_test( 36 | repeats=50, 37 | get_input_func=get_nii_array, 38 | functions=[normal, crop], 39 | assert_equal_function=lambda x, y: True, # noqa: ARG005 40 | # np.all([x[i] == y[i] for i in range(len(x))]) 41 | ) 42 | # print(time_measures) 43 | -------------------------------------------------------------------------------- /TPTBox/tests/speedtests/speedtest_connected_components.py: -------------------------------------------------------------------------------- 1 | if __name__ == "__main__": 2 | # speed test dilation 3 | import random 4 | 5 | import numpy as np 6 | from cc3d import statistics as cc3dstats 7 | from scipy.ndimage import center_of_mass 8 | 9 | from TPTBox.core.nii_wrapper import NII 10 | from TPTBox.core.np_utils import ( 11 | _to_labels, 12 | connected_components, 13 | np_bbox_binary, 14 | np_calc_overlapping_labels, 15 | np_connected_components, 16 | np_extract_label, 17 | np_unique, 18 | np_unique_withoutzero, 19 | ) 20 | from TPTBox.tests.speedtests.speedtest import speed_test 21 | from TPTBox.tests.test_utils import get_nii 22 | 23 | def get_nii_array(): 24 | num_points = random.randint(50, 51) 25 | nii, points, orientation, sizes = get_nii(x=(100, 100, 100), num_point=num_points) 26 | # nii.map_labels_({1: -1}, verbose=False) 27 | arr = nii.get_seg_array().astype(np.uint8) 28 | # arr[arr == 1] = -1 29 | # arr_r = arr.copy() 30 | return arr 31 | 32 | def np_naive_cc(arr: np.ndarray): 33 | return np_connected_components(arr)[0][1] 34 | 35 | def np_naive_cc_extract(arr: np.ndarray): 36 | return np_connected_components(arr, use_extract2=True)[0][1] 37 | 38 | def np_naive_cc_gcrop(arr: np.ndarray): 39 | crop = np_bbox_binary(arr) 40 | arrc = arr[crop] 41 | connectivity = 3 42 | connectivity = min((connectivity + 1) * 2, 8) if arr.ndim == 2 else 6 if connectivity == 1 else 18 if connectivity == 2 else 26 43 | 44 | labels: list[int] = np_unique(arrc) 45 | zarr = np.zeros((len(labels), *arr.shape), dtype=arr.dtype) 46 | 47 | subreg_cc = {} 48 | subreg_cc_n = {} 49 | for idx, subreg in enumerate(labels): # type:ignore 50 | img_subreg = np_extract_label(arrc, subreg, inplace=False) 51 | labels_out, n = connected_components(img_subreg, connectivity=connectivity, return_N=True) 52 | arrn = zarr[idx] 53 | arrn[crop] = labels_out 54 | subreg_cc[subreg] = arrn 55 | subreg_cc_n[subreg] = n 56 | return subreg_cc[1] # , subreg_cc_n 57 | 58 | def np_crop_cc(arr: np.ndarray): 59 | # crop 60 | connectivity = 3 61 | connectivity = min((connectivity + 1) * 2, 8) if arr.ndim == 2 else 6 if connectivity == 1 else 18 if connectivity == 2 else 26 62 | 63 | labels: list[int] = np_unique(arr) 64 | 65 | subreg_cc = {} 66 | subreg_cc_n = {} 67 | for subreg in labels: # type:ignore 68 | img_subreg = np_extract_label(arr, subreg, inplace=False) 69 | crop = np_bbox_binary(img_subreg) 70 | img_subregc = img_subreg[crop] 71 | labels_out, n = connected_components(img_subregc, connectivity=connectivity, return_N=True) 72 | img_subreg[crop] = labels_out 73 | 74 | subreg_cc[subreg] = img_subreg 75 | subreg_cc_n[subreg] = n 76 | return subreg_cc[1] # , subreg_cc_n 77 | 78 | # def np_cc_once(arr: np.ndarray): 79 | # # call cc once, then relabel 80 | # connectivity = 3 81 | # connectivity = min((connectivity + 1) * 2, 8) if arr.ndim == 2 else 6 if connectivity == 1 else 18 if connectivity == 2 else 26 82 | # 83 | # labels: list[int] = np_unique(arr) 84 | # 85 | # subreg_cc = {} 86 | # subreg_cc_n = {} 87 | # crop = np_bbox_binary(arr) 88 | # arrc = arr[crop] 89 | # zarr = np.zeros((len(labels), *arr.shape), dtype=arr.dtype) 90 | # 91 | # labels_out = connected_components(arrc, connectivity=connectivity, return_N=False) 92 | # for sidx, subreg in enumerate(labels): # type:ignore 93 | # arrcc[crop][np.logical_and()] 94 | # # arr[s == subreg] 95 | # # img_subreg = np_extract_label(arrc, subreg, inplace=False) 96 | # # lcrop = np_bbox_binary(img_subreg) 97 | # img_subregc = img_subreg[lcrop] 98 | # img_subreg[lcrop] = labels_out[lcrop] * img_subregc 99 | # 100 | # arrcc = zarr[sidx] 101 | # arrcc[crop] = img_subreg 102 | # subreg_cc[subreg] = arrcc 103 | # subreg_cc_n[subreg] = len(np_unique_withoutzero(img_subreg[lcrop])) 104 | # return subreg_cc[1] # , subreg_cc_n 105 | 106 | def np_cc_once_lcrop(arr: np.ndarray): 107 | # call cc once, then relabel 108 | connectivity = 3 109 | connectivity = min((connectivity + 1) * 2, 8) if arr.ndim == 2 else 6 if connectivity == 1 else 18 if connectivity == 2 else 26 110 | 111 | labels: list[int] = np_unique(arr) 112 | 113 | subreg_cc = {} 114 | subreg_cc_n = {} 115 | # crop = np_bbox_binary(arr) 116 | arrc = arr # [crop] 117 | 118 | labels_out = connected_components(arrc, connectivity=connectivity, return_N=False) 119 | for subreg in labels: # type:ignore 120 | img_subreg = np_extract_label(arrc, subreg, inplace=False) 121 | lcrop = np_bbox_binary(img_subreg) 122 | img_subregc = img_subreg[lcrop] 123 | img_subreg[lcrop] = labels_out[lcrop] * img_subregc 124 | 125 | # arrcc = np.zeros(arr.shape, dtype=arr.dtype) 126 | # arrcc[crop] = img_subreg 127 | arrcc = img_subreg 128 | subreg_cc[subreg] = arrcc 129 | subreg_cc_n[subreg] = len(np_unique_withoutzero(img_subreg[lcrop])) 130 | return subreg_cc[1] # , subreg_cc_n 131 | 132 | speed_test( 133 | repeats=50, 134 | get_input_func=get_nii_array, 135 | functions=[np_naive_cc, np_naive_cc_extract], 136 | assert_equal_function=lambda x, y: np.count_nonzero(x) == np.count_nonzero(y), 137 | # np.all([x[i] == y[i] for i in range(x.shape[0])]), # noqa: ARG005 138 | # np.all([x[i] == y[i] for i in range(len(x))]) 139 | ) 140 | # print(time_measures) 141 | -------------------------------------------------------------------------------- /TPTBox/tests/speedtests/speedtest_connected_components_labelwise.py: -------------------------------------------------------------------------------- 1 | if __name__ == "__main__": 2 | # speed test dilation 3 | import random 4 | 5 | import numpy as np 6 | from cc3d import statistics as cc3dstats 7 | from scipy.ndimage import center_of_mass 8 | 9 | from TPTBox.core.nii_wrapper import NII 10 | from TPTBox.core.np_utils import ( 11 | _connected_components, 12 | _to_labels, 13 | np_bbox_binary, 14 | np_calc_overlapping_labels, 15 | np_connected_components, 16 | np_connected_components_per_label, 17 | np_connected_components_per_label2, 18 | np_extract_label, 19 | np_unique, 20 | np_unique_withoutzero, 21 | ) 22 | from TPTBox.tests.speedtests.speedtest import speed_test 23 | from TPTBox.tests.test_utils import get_nii 24 | 25 | def get_nii_array(): 26 | num_points = random.randint(10, 31) 27 | nii, points, orientation, sizes = get_nii(x=(300, 300, 300), num_point=num_points) 28 | # nii.map_labels_({1: -1}, verbose=False) 29 | arr = nii.get_seg_array().astype(np.uint8) 30 | # arr[arr == 1] = -1 31 | # arr_r = arr.copy() 32 | return arr 33 | 34 | def np_cc_labelwise1(arr: np.ndarray): 35 | return np_connected_components_per_label(arr)[0][1] 36 | 37 | def np_cc_labelwise2(arr: np.ndarray): 38 | return np_connected_components_per_label2(arr)[0][1] 39 | 40 | def np_cc_labelwise2crop(arr: np.ndarray): 41 | return np_connected_components_per_label2(arr, use_crop=True)[0][1] 42 | 43 | # def np_cc_once(arr: np.ndarray): 44 | # # call cc once, then relabel 45 | # connectivity = 3 46 | # connectivity = min((connectivity + 1) * 2, 8) if arr.ndim == 2 else 6 if connectivity == 1 else 18 if connectivity == 2 else 26 47 | # 48 | # labels: list[int] = np_unique(arr) 49 | # 50 | # subreg_cc = {} 51 | # subreg_cc_n = {} 52 | # crop = np_bbox_binary(arr) 53 | # arrc = arr[crop] 54 | # zarr = np.zeros((len(labels), *arr.shape), dtype=arr.dtype) 55 | # 56 | # labels_out = connected_components(arrc, connectivity=connectivity, return_N=False) 57 | # for sidx, subreg in enumerate(labels): # type:ignore 58 | # arrcc[crop][np.logical_and()] 59 | # # arr[s == subreg] 60 | # # img_subreg = np_extract_label(arrc, subreg, inplace=False) 61 | # # lcrop = np_bbox_binary(img_subreg) 62 | # img_subregc = img_subreg[lcrop] 63 | # img_subreg[lcrop] = labels_out[lcrop] * img_subregc 64 | # 65 | # arrcc = zarr[sidx] 66 | # arrcc[crop] = img_subreg 67 | # subreg_cc[subreg] = arrcc 68 | # subreg_cc_n[subreg] = len(np_unique_withoutzero(img_subreg[lcrop])) 69 | # return subreg_cc[1] # , subreg_cc_n 70 | 71 | speed_test( 72 | repeats=50, 73 | get_input_func=get_nii_array, 74 | functions=[ 75 | np_cc_labelwise1, 76 | np_cc_labelwise2, 77 | ], 78 | assert_equal_function=lambda x, y: True, # np.array_equal(x, y), # noqa: ARG005 79 | # np.all([x[i] == y[i] for i in range(x.shape[0])]), # noqa: ARG005 80 | # np.all([x[i] == y[i] for i in range(len(x))]) 81 | ) 82 | # print(time_measures) 83 | -------------------------------------------------------------------------------- /TPTBox/tests/speedtests/speedtest_connected_components_simple.py: -------------------------------------------------------------------------------- 1 | if __name__ == "__main__": 2 | # speed test dilation 3 | import random 4 | 5 | import numpy as np 6 | from cc3d import statistics as cc3dstats 7 | from scipy.ndimage import center_of_mass 8 | 9 | from TPTBox.core.nii_wrapper import NII 10 | from TPTBox.core.np_utils import ( 11 | _connected_components, 12 | _to_labels, 13 | np_bbox_binary, 14 | np_calc_overlapping_labels, 15 | np_connected_components, 16 | np_extract_label, 17 | np_unique, 18 | np_unique_withoutzero, 19 | ) 20 | from TPTBox.tests.speedtests.speedtest import speed_test 21 | from TPTBox.tests.test_utils import get_nii 22 | 23 | def get_nii_array(): 24 | num_points = random.randint(50, 51) 25 | nii, points, orientation, sizes = get_nii(x=(200, 200, 200), num_point=num_points) 26 | # nii.map_labels_({1: -1}, verbose=False) 27 | arr = nii.get_seg_array().astype(np.uint8) 28 | # arr[arr == 1] = -1 29 | # arr_r = arr.copy() 30 | return arr, 3, True 31 | 32 | def np_naive_cc(arr: np.ndarray): 33 | return np_connected_components(arr)[0][1] 34 | 35 | def np_cc_once_N(arr: np.ndarray, connectivity: int = 3, include_zero: bool = True): 36 | connectivity = min((connectivity + 1) * 2, 8) if arr.ndim == 2 else 6 if connectivity == 1 else 18 if connectivity == 2 else 26 37 | if include_zero: 38 | arr[arr == 0] = arr.max() + 1 39 | labels_out, n = _connected_components(arr, connectivity=connectivity, return_N=True) 40 | return labels_out 41 | 42 | def np_cc_once_N_false(arr: np.ndarray, connectivity: int = 3, include_zero: bool = True): # noqa: ARG001 43 | connectivity = min((connectivity + 1) * 2, 8) if arr.ndim == 2 else 6 if connectivity == 1 else 18 if connectivity == 2 else 26 44 | labels_out, n = _connected_components(arr, connectivity=connectivity, return_N=True) 45 | return labels_out 46 | 47 | def np_cc_once(arr: np.ndarray, connectivity: int = 3, include_zero: bool = True): 48 | connectivity = min((connectivity + 1) * 2, 8) if arr.ndim == 2 else 6 if connectivity == 1 else 18 if connectivity == 2 else 26 49 | if include_zero: 50 | arr[arr == 0] = arr.max() + 1 51 | labels_out = _connected_components(arr, connectivity=connectivity, return_N=False) 52 | # N = np_unique(labels_out) 53 | return labels_out 54 | 55 | # def np_cc_once(arr: np.ndarray): 56 | # # call cc once, then relabel 57 | # connectivity = 3 58 | # connectivity = min((connectivity + 1) * 2, 8) if arr.ndim == 2 else 6 if connectivity == 1 else 18 if connectivity == 2 else 26 59 | # 60 | # labels: list[int] = np_unique(arr) 61 | # 62 | # subreg_cc = {} 63 | # subreg_cc_n = {} 64 | # crop = np_bbox_binary(arr) 65 | # arrc = arr[crop] 66 | # zarr = np.zeros((len(labels), *arr.shape), dtype=arr.dtype) 67 | # 68 | # labels_out = connected_components(arrc, connectivity=connectivity, return_N=False) 69 | # for sidx, subreg in enumerate(labels): # type:ignore 70 | # arrcc[crop][np.logical_and()] 71 | # # arr[s == subreg] 72 | # # img_subreg = np_extract_label(arrc, subreg, inplace=False) 73 | # # lcrop = np_bbox_binary(img_subreg) 74 | # img_subregc = img_subreg[lcrop] 75 | # img_subreg[lcrop] = labels_out[lcrop] * img_subregc 76 | # 77 | # arrcc = zarr[sidx] 78 | # arrcc[crop] = img_subreg 79 | # subreg_cc[subreg] = arrcc 80 | # subreg_cc_n[subreg] = len(np_unique_withoutzero(img_subreg[lcrop])) 81 | # return subreg_cc[1] # , subreg_cc_n 82 | 83 | speed_test( 84 | repeats=50, 85 | get_input_func=get_nii_array, 86 | functions=[ 87 | np_cc_once_N, 88 | np_cc_once_N_false, 89 | ], 90 | assert_equal_function=lambda x, y: True, # np.all([x[i] == y[i] for i in range(x.shape[0])]), # noqa: ARG005 91 | # np.all([x[i] == y[i] for i in range(x.shape[0])]), # noqa: ARG005 92 | # np.all([x[i] == y[i] for i in range(len(x))]) 93 | ) 94 | # print(time_measures) 95 | -------------------------------------------------------------------------------- /TPTBox/tests/speedtests/speedtest_count_nonzero.py: -------------------------------------------------------------------------------- 1 | if __name__ == "__main__": 2 | # speed test dilation 3 | import random 4 | 5 | import numpy as np 6 | from cc3d import statistics as cc3dstats 7 | from scipy.ndimage import center_of_mass 8 | 9 | from TPTBox.core.nii_wrapper import NII 10 | from TPTBox.core.np_utils import _to_labels, np_bbox_binary, np_count_nonzero, np_volume 11 | from TPTBox.tests.speedtests.speedtest import speed_test 12 | from TPTBox.tests.test_utils import get_nii 13 | 14 | def get_nii_array(): 15 | num_points = random.randint(1, 10) 16 | nii, points, orientation, sizes = get_nii(x=(350, 350, 350), num_point=num_points) 17 | # nii.map_labels_({1: -1}, verbose=False) 18 | arr = nii.get_seg_array().astype(np.uint) 19 | # arr[arr == 1] = -1 20 | # arr_r = arr.copy() 21 | return arr 22 | 23 | def np_naive_count(arr: np.ndarray): 24 | return np.count_nonzero(arr) 25 | 26 | def np_count(arr: np.ndarray): 27 | return sum(np_volume(arr).values()) 28 | 29 | def np_countgreater(arr: np.ndarray): 30 | return (arr > 0).sum() 31 | 32 | def np_countcrop(arr: np.ndarray): 33 | crop = np_bbox_binary(arr) 34 | arrc = arr[crop] 35 | return np.count_nonzero(arrc) 36 | 37 | speed_test( 38 | repeats=50, 39 | get_input_func=get_nii_array, 40 | functions=[np_naive_count, np_count, np_countgreater, np_countcrop], 41 | assert_equal_function=lambda x, y: x == y, # noqa: ARG005 42 | # np.all([x[i] == y[i] for i in range(len(x))]) 43 | ) 44 | # print(time_measures) 45 | -------------------------------------------------------------------------------- /TPTBox/tests/speedtests/speedtest_crop.py: -------------------------------------------------------------------------------- 1 | if __name__ == "__main__": 2 | # speed test dilation 3 | import random 4 | 5 | import numpy as np 6 | from cc3d import statistics as cc3dstats 7 | from scipy.ndimage import center_of_mass 8 | 9 | from TPTBox.core.nii_wrapper import NII 10 | from TPTBox.core.np_utils import ( 11 | _to_labels, 12 | np_bbox_binary, 13 | np_dilate_msk, 14 | ) 15 | from TPTBox.tests.speedtests.speedtest import speed_test 16 | from TPTBox.tests.test_utils import get_nii 17 | 18 | def get_nii_array(): 19 | num_points = random.randint(1, 5) 20 | nii, points, orientation, sizes = get_nii(x=(150, 150, 150), num_point=num_points) 21 | # nii.map_labels_({1: -1}, verbose=False) 22 | # arr = nii.get_seg_array().astype(int) 23 | # arr[arr == 1] = -1 24 | # arr_r = arr.copy() 25 | return nii 26 | 27 | def nii_compute_crop(nii: NII): 28 | return nii.compute_crop(dist=1, minimum=2) 29 | 30 | def np_compute_crop(nii: NII): 31 | return np_bbox_binary(nii.get_seg_array() > 2, px_dist=1) 32 | 33 | speed_test( 34 | repeats=50, 35 | get_input_func=get_nii_array, 36 | functions=[nii_compute_crop, np_compute_crop], 37 | assert_equal_function=lambda x, y: np.all([x[i] == y[i] for i in range(len(x))]), 38 | ) 39 | # print(time_measures) 40 | -------------------------------------------------------------------------------- /TPTBox/tests/speedtests/speedtest_dilate.py: -------------------------------------------------------------------------------- 1 | if __name__ == "__main__": 2 | # speed test dilation 3 | import random 4 | 5 | import numpy as np 6 | from cc3d import statistics as cc3dstats 7 | from scipy.ndimage import center_of_mass 8 | 9 | from TPTBox.core.nii_wrapper import NII 10 | from TPTBox.core.np_utils import ( 11 | _to_labels, 12 | np_dilate_msk, 13 | np_erode_msk, 14 | ) 15 | from TPTBox.tests.speedtests.speedtest import speed_test 16 | from TPTBox.tests.test_utils import get_nii 17 | 18 | def get_nii_array(): 19 | num_points = random.randint(1, 10) 20 | nii, points, orientation, sizes = get_nii(x=(100, 100, 100), num_point=num_points) 21 | # nii.map_labels_({1: -1}, verbose=False) 22 | # arr = nii.get_seg_array().astype(int) 23 | # arr[arr == 1] = -1 24 | # arr_r = arr.copy() 25 | return nii 26 | 27 | def nii_dilate_withoutcrop(nii: NII): 28 | return nii.dilate_msk_(n_pixel=1, use_crop=False).get_seg_array() 29 | 30 | def nii_dilate_withcrop(nii: NII): 31 | return nii.dilate_msk_(n_pixel=1, use_crop=True).get_seg_array() 32 | 33 | def np_dilate_withoutcrop(nii: NII): 34 | return np_dilate_msk(nii.get_seg_array(), n_pixel=1, use_crop=False) 35 | 36 | def np_dilate_withcrop(nii: NII): 37 | return np_dilate_msk(nii.get_seg_array(), n_pixel=1, use_crop=True) 38 | 39 | def np_dilate_withLcrop(nii: NII): 40 | return np_dilate_msk(nii.get_seg_array(), n_pixel=1, use_crop=False, use_local_crop=True) 41 | 42 | def np_dilate_withBcrop(nii: NII): 43 | return np_dilate_msk(nii.get_seg_array(), n_pixel=1, use_crop=True, use_local_crop=True) 44 | 45 | #### ERODE 46 | 47 | def nii_erode_withoutcrop(nii: NII): 48 | return nii.erode_msk_(n_pixel=1, use_crop=False).get_seg_array() 49 | 50 | def nii_erode_withcrop(nii: NII): 51 | return nii.erode_msk_(n_pixel=2, use_crop=True).get_seg_array() 52 | 53 | def np_erode_withoutcrop(nii: NII): 54 | return np_erode_msk(nii.get_seg_array(), n_pixel=1, use_crop=False) 55 | 56 | def np_erode_withcrop(nii: NII): 57 | return np_erode_msk(nii.get_seg_array(), n_pixel=1, use_crop=True) 58 | 59 | speed_test( 60 | repeats=50, 61 | get_input_func=get_nii_array, 62 | functions=[ 63 | np_erode_withoutcrop, 64 | np_erode_withcrop, 65 | ], 66 | assert_equal_function=lambda x, y: np.all([x[i] == y[i] for i in range(x.shape[0])]), # noqa: ARG005 67 | # np.all([x[i] == y[i] for i in range(len(x))]) 68 | ) 69 | # print(time_measures) 70 | -------------------------------------------------------------------------------- /TPTBox/tests/speedtests/speedtest_extract_label.py: -------------------------------------------------------------------------------- 1 | if __name__ == "__main__": 2 | # speed test dilation 3 | import random 4 | 5 | import numpy as np 6 | from scipy.ndimage import center_of_mass 7 | 8 | from TPTBox.core.nii_wrapper import NII 9 | from TPTBox.core.np_utils import ( 10 | _to_labels, 11 | np_extract_label, 12 | ) 13 | from TPTBox.tests.speedtests.speedtest import speed_test 14 | from TPTBox.tests.test_utils import get_nii 15 | 16 | def get_nii_array(): 17 | num_points = random.randint(5, 10) 18 | nii, points, orientation, sizes = get_nii(x=(500, 500, 500), num_point=num_points) 19 | # nii.map_labels_({1: -1}, verbose=False) 20 | arr = nii.get_seg_array().astype(np.uint8) 21 | # arr[arr == 1] = -1 22 | # arr_r = arr.copy() 23 | return arr 24 | 25 | # def nii_extract(nii: NII): 26 | # return nii.extract_label([1, 2, 3, 4, 5]).get_seg_array() 27 | 28 | extract_one_label = 2 29 | extract_label = [2, 3, 4, 5] 30 | 31 | def dummy(arr_bin: np.ndarray): 32 | return arr_bin + arr_bin 33 | 34 | # EXTRACT ONE LABEL 35 | 36 | def np_extract_one(arr: np.ndarray): 37 | return np_extract_label(arr, extract_one_label) 38 | 39 | def np_extract_one_nii(arr: np.ndarray): 40 | arr = arr.copy() 41 | arr[arr != extract_one_label] = 0 42 | arr[arr == extract_one_label] = 1 43 | return arr 44 | 45 | def np_extract_one_equal(arr: np.ndarray): 46 | return arr == extract_one_label 47 | 48 | # EXTRACT LIST OF LABELS 49 | 50 | def np_extractlist(arr: np.ndarray): 51 | return np_extract_label(arr, extract_label, inplace=False) 52 | 53 | def np_extractlist_isin(arr: np.ndarray): 54 | arr = arr.copy() 55 | arr_msk = np.isin(arr, extract_label) 56 | arr[arr_msk] = 1 57 | arr[~arr_msk] = 0 58 | return arr 59 | 60 | def np_extractlist_for(arr: np.ndarray): 61 | arr = arr.copy() 62 | if 1 not in extract_label: 63 | arr[arr == 1] = 0 64 | for idx in extract_label: 65 | arr[arr == idx] = 1 66 | arr[arr != 1] = 0 67 | return arr 68 | 69 | speed_test( 70 | repeats=50, 71 | get_input_func=get_nii_array, 72 | functions=[np_extractlist, np_extractlist_isin, np_extractlist_for], 73 | assert_equal_function=lambda x, y: np.array_equal(x, y), # noqa: ARG005 74 | # np.all([x[i] == y[i] for i in range(len(x))]) 75 | ) 76 | # print(time_measures) 77 | -------------------------------------------------------------------------------- /TPTBox/tests/speedtests/speedtest_extract_label_loop.py: -------------------------------------------------------------------------------- 1 | if __name__ == "__main__": 2 | # speed test dilation 3 | import random 4 | 5 | import numpy as np 6 | from scipy.ndimage import center_of_mass 7 | 8 | from TPTBox.core.nii_wrapper import NII 9 | from TPTBox.core.np_utils import ( 10 | _to_labels, 11 | np_extract_label, 12 | ) 13 | from TPTBox.tests.speedtests.speedtest import speed_test 14 | from TPTBox.tests.test_utils import get_nii 15 | 16 | def get_nii_array(): 17 | num_points = random.randint(5, 10) 18 | nii, points, orientation, sizes = get_nii(x=(300, 300, 300), num_point=num_points) 19 | # nii.map_labels_({1: -1}, verbose=False) 20 | arr = nii.get_seg_array().astype(np.uint8) 21 | # arr[arr == 1] = -1 22 | # arr_r = arr.copy() 23 | return arr 24 | 25 | # def nii_extract(nii: NII): 26 | # return nii.extract_label([1, 2, 3, 4, 5]).get_seg_array() 27 | 28 | extract_label_one = 2 29 | extract_label = [2, 3, 4, 5] 30 | 31 | def dummy(arr_bin: np.ndarray): 32 | return arr_bin + arr_bin 33 | 34 | def np_extractloop(arr: np.ndarray): 35 | if 1 not in extract_label: 36 | arr[arr == 1] = 0 37 | for idx in extract_label: 38 | arr[arr == idx] = 1 39 | arr[arr != 1] = 0 40 | return arr 41 | 42 | def np_extractloop_indexing(arr: np.ndarray): 43 | arrl = arr.copy() 44 | for l in extract_label: 45 | arr += dummy(arrl == l) 46 | return arr 47 | 48 | def np_extractloop_indexing2(arr: np.ndarray): 49 | arrl = arr.copy() 50 | for l in extract_label: 51 | arrl[arrl != l] = 0 52 | arrl[arrl == l] = 1 53 | arr += dummy(arrl) 54 | return arr 55 | 56 | def np_extractloop_e(arr: np.ndarray): 57 | arrl = arr.copy() 58 | for l in extract_label: 59 | arr += dummy(np_extract_label(arrl, l)) 60 | return arr 61 | 62 | speed_test( 63 | repeats=50, 64 | get_input_func=get_nii_array, 65 | functions=[np_extractloop, np_extractloop_indexing, np_extractloop_indexing2, np_extractloop_e], 66 | assert_equal_function=lambda x, y: np.array_equal(x, y), # noqa: ARG005 67 | # np.all([x[i] == y[i] for i in range(len(x))]) 68 | ) 69 | # print(time_measures) 70 | -------------------------------------------------------------------------------- /TPTBox/tests/speedtests/speedtest_extract_label_nii.py: -------------------------------------------------------------------------------- 1 | if __name__ == "__main__": 2 | # speed test dilation 3 | import random 4 | 5 | import numpy as np 6 | from scipy.ndimage import center_of_mass 7 | 8 | from TPTBox.core.nii_wrapper import NII 9 | from TPTBox.core.np_utils import ( 10 | _to_labels, 11 | np_extract_label, 12 | ) 13 | from TPTBox.tests.speedtests.speedtest import speed_test 14 | from TPTBox.tests.test_utils import get_nii 15 | 16 | def get_nii_array(): 17 | num_points = random.randint(5, 10) 18 | nii, points, orientation, sizes = get_nii(x=(300, 300, 300), num_point=num_points) 19 | # nii.map_labels_({1: -1}, verbose=False) 20 | # arr = nii.get_seg_array().astype(int) 21 | # arr[arr == 1] = -1 22 | # arr_r = arr.copy() 23 | return nii 24 | 25 | extract_label = [2, 3, 4, 5] 26 | 27 | def nii_extract(nii: NII): 28 | return nii.extract_label(extract_label) 29 | 30 | def nii_extract2(nii: NII): 31 | return nii.set_array(np_extract_label(nii.get_seg_array(), extract_label, inplace=False)) 32 | 33 | def nii_extract3(nii: NII): 34 | return nii.set_array(np_extract_label(nii.get_seg_array(), extract_label, inplace=True)) 35 | 36 | speed_test( 37 | repeats=100, 38 | get_input_func=get_nii_array, 39 | functions=[ 40 | nii_extract, 41 | nii_extract2, 42 | nii_extract3, 43 | ], 44 | # functions=[extractloop_e, extractloop_indexing], 45 | assert_equal_function=lambda x, y: np.array_equal(x, y), # noqa: ARG005 46 | # np.all([x[i] == y[i] for i in range(len(x))]) 47 | ) 48 | # print(time_measures) 49 | -------------------------------------------------------------------------------- /TPTBox/tests/speedtests/speedtest_fillholes.py: -------------------------------------------------------------------------------- 1 | if __name__ == "__main__": 2 | # speed test dilation 3 | import random 4 | 5 | import numpy as np 6 | from cc3d import statistics as cc3dstats 7 | from scipy.ndimage import center_of_mass 8 | 9 | from TPTBox.core.nii_wrapper import NII 10 | from TPTBox.core.np_utils import _fill, _to_labels, np_bbox_binary, np_connected_components, np_extract_label, np_fill_holes 11 | from TPTBox.tests.speedtests.speedtest import speed_test 12 | from TPTBox.tests.test_utils import get_nii 13 | 14 | def get_nii_array(): 15 | num_points = random.randint(1, 10) 16 | nii, points, orientation, sizes = get_nii(x=(150, 150, 150), num_point=num_points) 17 | # nii.map_labels_({1: -1}, verbose=False) 18 | arr = nii.get_seg_array().astype(np.uint) 19 | # arr[arr == 1] = -1 20 | # arr_r = arr.copy() 21 | return nii, arr 22 | 23 | def nii_fill_holes(nii: NII, arr): # noqa: ARG001 24 | return nii.fill_holes(use_crop=True).get_seg_array() 25 | 26 | def np_nfill_holes(nii: NII, arr): # noqa: ARG001 27 | return np_fill_holes(arr, use_crop=False) 28 | 29 | def np_nfill_holes_crop(nii, arr: np.ndarray): # noqa: ARG001 30 | return np_fill_holes(arr, use_crop=True) 31 | 32 | def np_fill_holes_extract2(nii, arr: np.ndarray): # noqa: ARG001 33 | slice_wise_dim = None 34 | assert 2 <= arr.ndim <= 3 35 | assert arr.ndim == 3 or slice_wise_dim is None, "slice_wise_dim set but array is 3D" 36 | labels: list[int] = _to_labels(arr, None) 37 | 38 | gcrop = np_bbox_binary(arr, px_dist=1) 39 | arrc = arr[gcrop] 40 | 41 | for l in labels: # type:ignore 42 | arr_l = np_extract_label(arrc, l) 43 | crop = np_bbox_binary(arr_l, px_dist=1) 44 | arr_lc = arr_l[crop] 45 | 46 | if slice_wise_dim is None: 47 | filled = _fill(arr_lc).astype(arr.dtype) 48 | else: 49 | assert 0 <= slice_wise_dim <= arr.ndim - 1, f"slice_wise_dim needs to be in range [0, {arr.ndim - 1}]" 50 | filled = np.swapaxes(arr_lc, 0, slice_wise_dim) 51 | filled = np.stack([_fill(x) for x in filled]) 52 | filled = np.swapaxes(filled, 0, slice_wise_dim) 53 | filled[filled != 0] = l 54 | arrc[crop][arrc[crop] == 0] = filled[arrc[crop] == 0] 55 | 56 | arr[gcrop] = arrc 57 | return arr 58 | 59 | def np_fill_holes_extract(nii, arr: np.ndarray): # noqa: ARG001 60 | slice_wise_dim = None 61 | assert 2 <= arr.ndim <= 3 62 | assert arr.ndim == 3 or slice_wise_dim is None, "slice_wise_dim set but array is 3D" 63 | labels: list[int] = _to_labels(arr, None) 64 | 65 | gcrop = np_bbox_binary(arr, px_dist=1) 66 | arrc = arr[gcrop] 67 | 68 | for l in labels: # type:ignore 69 | arr_l = arrc == l 70 | crop = np_bbox_binary(arr_l, px_dist=1) 71 | arr_lc = arr_l[crop] 72 | 73 | if slice_wise_dim is None: 74 | filled = _fill(arr_lc).astype(arr.dtype) 75 | else: 76 | assert 0 <= slice_wise_dim <= arr.ndim - 1, f"slice_wise_dim needs to be in range [0, {arr.ndim - 1}]" 77 | filled = np.swapaxes(arr_lc, 0, slice_wise_dim) 78 | filled = np.stack([_fill(x) for x in filled]) 79 | filled = np.swapaxes(filled, 0, slice_wise_dim) 80 | filled[filled != 0] = l 81 | 82 | arrc[crop][arrc[crop] == 0] = filled[arrc[crop] == 0] 83 | 84 | arr[gcrop] = arrc 85 | return arr 86 | 87 | speed_test( 88 | repeats=50, 89 | get_input_func=get_nii_array, 90 | functions=[np_nfill_holes, np_fill_holes_extract, np_nfill_holes_crop], 91 | assert_equal_function=lambda x, y: np.array_equal(x, y), # noqa: ARG005 92 | # np.all([x[i] == y[i] for i in range(len(x))]) 93 | ) 94 | # print(time_measures) 95 | -------------------------------------------------------------------------------- /TPTBox/tests/speedtests/speedtest_filter_connected_components.py: -------------------------------------------------------------------------------- 1 | if __name__ == "__main__": 2 | # speed test dilation 3 | import random 4 | 5 | import numpy as np 6 | from cc3d import statistics as cc3dstats 7 | from scipy.ndimage import center_of_mass 8 | 9 | from TPTBox.core.nii_wrapper import NII 10 | from TPTBox.core.np_utils import ( 11 | _connected_components, 12 | _to_labels, 13 | np_bbox_binary, 14 | np_calc_overlapping_labels, 15 | np_connected_components, 16 | np_extract_label, 17 | np_filter_connected_components, 18 | np_unique, 19 | np_unique_withoutzero, 20 | ) 21 | from TPTBox.tests.speedtests.speedtest import speed_test 22 | from TPTBox.tests.test_utils import get_nii 23 | 24 | def get_nii_array(): 25 | num_points = random.randint(10, 31) 26 | nii, points, orientation, sizes = get_nii(x=(500, 500, 350), num_point=num_points) 27 | # nii.map_labels_({1: -1}, verbose=False) 28 | arr = nii.get_seg_array().astype(np.uint8) 29 | # arr[arr == 1] = -1 30 | # arr_r = arr.copy() 31 | return arr 32 | 33 | def np_cc_labelwise1(arr: np.ndarray): 34 | return np_filter_connected_components(arr, min_volume=10, max_volume=50, largest_k_components=3) 35 | 36 | speed_test( 37 | repeats=50, 38 | get_input_func=get_nii_array, 39 | functions=[ 40 | np_cc_labelwise1, 41 | ], 42 | assert_equal_function=lambda x, y: True, # np.array_equal(x, y), # noqa: ARG005 43 | # np.all([x[i] == y[i] for i in range(x.shape[0])]), # noqa: ARG005 44 | # np.all([x[i] == y[i] for i in range(len(x))]) 45 | ) 46 | # print(time_measures) 47 | -------------------------------------------------------------------------------- /TPTBox/tests/speedtests/speedtest_isempty.py: -------------------------------------------------------------------------------- 1 | if __name__ == "__main__": 2 | # speed test dilation 3 | import random 4 | 5 | import numpy as np 6 | from cc3d import statistics as cc3dstats 7 | from scipy.ndimage import center_of_mass 8 | 9 | from TPTBox.core.nii_wrapper import NII 10 | from TPTBox.core.np_utils import _to_labels, np_bbox_binary, np_count_nonzero, np_unique_withoutzero, np_volume 11 | from TPTBox.tests.speedtests.speedtest import speed_test 12 | from TPTBox.tests.test_utils import get_nii 13 | 14 | def get_nii_array(): 15 | num_points = random.randint(1, 10) 16 | nii, points, orientation, sizes = get_nii(x=(550, 550, 550), num_point=num_points) 17 | # nii.map_labels_({1: -1}, verbose=False) 18 | arr = nii.get_seg_array().astype(int) * 0 19 | # arr[arr == 1] = -1 20 | # arr_r = arr.copy() 21 | return arr 22 | 23 | def np_naive_count(arr: np.ndarray): 24 | return np.count_nonzero(arr) > 0 25 | 26 | def np_max(arr: np.ndarray): 27 | return arr.max() != 0 28 | 29 | def np_sum(arr: np.ndarray): 30 | return arr.sum() > 0 31 | 32 | def np_any(arr: np.ndarray): 33 | return np.any(arr) 34 | 35 | def np_nonzero(arr: np.ndarray): 36 | # super slow 37 | return arr.nonzero()[0].size != 0 38 | 39 | def np_sunique(arr: np.ndarray): 40 | # super slow 41 | return len(np_unique_withoutzero(arr)) != 0 42 | 43 | speed_test( 44 | repeats=100, 45 | get_input_func=get_nii_array, 46 | functions=[np_naive_count, np_max, np_any], 47 | assert_equal_function=lambda x, y: x == y, # noqa: ARG005 48 | # np.all([x[i] == y[i] for i in range(len(x))]) 49 | ) 50 | # print(time_measures) 51 | -------------------------------------------------------------------------------- /TPTBox/tests/speedtests/speedtest_maplabels.py: -------------------------------------------------------------------------------- 1 | if __name__ == "__main__": 2 | # speed test dilation 3 | import random 4 | 5 | import numpy as np 6 | from scipy.ndimage import center_of_mass 7 | 8 | from TPTBox.core.nii_wrapper import NII 9 | from TPTBox.core.np_utils import ( 10 | _to_labels, 11 | np_bbox_binary, 12 | np_extract_label, 13 | np_map_labels, 14 | ) 15 | from TPTBox.tests.speedtests.speedtest import speed_test 16 | from TPTBox.tests.test_utils import get_nii 17 | 18 | def get_nii_array(): 19 | num_points = random.randint(1, 15) 20 | nii, points, orientation, sizes = get_nii(x=(150, 150, 150), num_point=num_points) 21 | # nii.map_labels_({1: -1}, verbose=False) 22 | arr = nii.get_seg_array().astype(int) 23 | # arr[arr == 1] = -1 24 | # arr_r = arr.copy() 25 | return arr 26 | 27 | def map_labels(arr: np.ndarray): 28 | return np_map_labels(arr, {1: 2}) 29 | 30 | def map_labels2(arr: np.ndarray): 31 | crop = np_bbox_binary(arr == 1) 32 | arr2 = arr[crop] 33 | arr2 = np_map_labels(arr2, {1: 2}) 34 | arr[crop] = arr2 35 | return arr 36 | 37 | speed_test( 38 | repeats=50, 39 | get_input_func=get_nii_array, 40 | functions=[map_labels, map_labels2], 41 | assert_equal_function=lambda x, y: np.all([x[i] == y[i] for i in range(x.shape[0])]), # noqa: ARG005 42 | # np.all([x[i] == y[i] for i in range(len(x))]) 43 | ) 44 | # print(time_measures) 45 | -------------------------------------------------------------------------------- /TPTBox/tests/speedtests/speedtest_npunique.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | if __name__ == "__main__": 4 | # speed test dilation 5 | import random 6 | 7 | import numpy as np 8 | from cc3d import statistics as cc3dstats 9 | from scipy.ndimage import center_of_mass 10 | 11 | from TPTBox.core.np_utils import ( 12 | _to_labels, 13 | np_bbox_binary, 14 | np_bounding_boxes, 15 | np_center_of_mass, 16 | np_map_labels, 17 | np_unique, 18 | np_unique_withoutzero, 19 | np_volume, 20 | ) 21 | from TPTBox.tests.speedtests.speedtest import speed_test 22 | from TPTBox.tests.test_utils import get_nii 23 | 24 | def get_nii_array(): 25 | num_points = random.randint(1, 30) 26 | nii, points, orientation, sizes = get_nii(x=(140, 140, 150), num_point=num_points) 27 | # nii.map_labels_({1: -1}, verbose=False) 28 | arr = nii.get_seg_array().astype(np.uint8) 29 | # arr[arr == 1] = -1 30 | arr_r = arr.copy() 31 | return arr_r 32 | 33 | speed_test( 34 | repeats=50, 35 | get_input_func=get_nii_array, 36 | functions=[np_unique, np.unique], 37 | assert_equal_function=lambda x, y: True, # np.all([x[i] == y[i] for i in range(len(x))]), # noqa: ARG005 38 | # np.all([x[i] == y[i] for i in range(len(x))]) 39 | ) 40 | # print(time_measures) 41 | -------------------------------------------------------------------------------- /TPTBox/tests/speedtests/speedtest_uncrop.py: -------------------------------------------------------------------------------- 1 | if __name__ == "__main__": 2 | # speed test dilation 3 | import random 4 | 5 | import numpy as np 6 | from cc3d import statistics as cc3dstats 7 | from scipy.ndimage import center_of_mass 8 | 9 | from TPTBox.core.nii_wrapper import NII 10 | from TPTBox.core.np_utils import _to_labels, np_bbox_binary, np_extract_label, np_unique_withoutzero 11 | from TPTBox.tests.speedtests.speedtest import speed_test 12 | from TPTBox.tests.test_utils import get_nii 13 | 14 | def get_nii_array(): 15 | num_points = random.randint(50, 51) 16 | nii, points, orientation, sizes = get_nii(x=(350, 350, 150), num_point=num_points) 17 | # nii.map_labels_({1: -1}, verbose=False) 18 | arr = nii.get_seg_array().astype(np.uint) 19 | crop = np_bbox_binary(arr) 20 | arrc = arr[crop] 21 | labels = np_unique_withoutzero(arrc) 22 | # arr[arr == 1] = -1 23 | # arr_r = arr.copy() 24 | return arr, crop, labels, arrc 25 | 26 | def uncrop_naive(arr: np.ndarray, crop, labels, arrc: np.ndarray): # noqa: ARG001 27 | results = {} 28 | for l in labels: 29 | img_l = np_extract_label(arrc, l, inplace=False) 30 | lcrop = np_bbox_binary(img_l) 31 | img_lc = img_l[lcrop] 32 | # Process here 33 | # need to uncrop somehow 34 | arrn = np.zeros(arr.shape, dtype=arr.dtype) 35 | arrn[crop][lcrop] = img_lc 36 | results[l] = arrn 37 | return results[1] 38 | 39 | def copy_uncrop(arr: np.ndarray, crop, labels, arrc: np.ndarray): 40 | zarr = np.zeros((len(labels), *arr.shape), dtype=arr.dtype) 41 | results = {} 42 | for idx, l in enumerate(labels): 43 | img_l = np_extract_label(arrc, l, inplace=False) 44 | lcrop = np_bbox_binary(img_l) 45 | img_lc = img_l[lcrop] 46 | # Process here 47 | img_lc *= img_lc 48 | # need to uncrop somehow 49 | arrn = zarr[idx] 50 | arrn[crop][lcrop] = img_lc 51 | results[l] = arrn 52 | return results[1] 53 | 54 | def uncrop2(arr: np.ndarray, crop, labels, arrc: np.ndarray): 55 | zarr = np.zeros((len(labels), *arr.shape), dtype=arr.dtype) 56 | results = {} 57 | for idx, l in enumerate(labels): 58 | img_l = np_extract_label(arrc, l, inplace=False) 59 | # lcrop = np_bbox_binary(img_l) 60 | img_lc = img_l # [lcrop] 61 | # Process here 62 | img_lc *= img_lc 63 | # need to uncrop somehow 64 | arrn = zarr[idx] 65 | arrn[crop] = img_lc 66 | results[l] = arrn 67 | return results[1] 68 | 69 | speed_test( 70 | repeats=20, 71 | get_input_func=get_nii_array, 72 | functions=[uncrop2, copy_uncrop], 73 | assert_equal_function=lambda x, y: np.all([x[i] == y[i] for i in range(x.shape[0])]), # noqa: ARG005 74 | # np.all([x[i] == y[i] for i in range(len(x))]) 75 | ) 76 | # print(time_measures) 77 | -------------------------------------------------------------------------------- /TPTBox/tests/test_cc3d.py: -------------------------------------------------------------------------------- 1 | if __name__ == "__main__": 2 | # speed test dilation 3 | import random 4 | from time import perf_counter 5 | 6 | import numpy as np 7 | from cc3d import statistics as cc3dstats 8 | from scipy.ndimage import center_of_mass 9 | 10 | from TPTBox.core.nii_wrapper import NII 11 | from TPTBox.core.np_utils import _connected_components, _to_labels, np_bbox_binary, np_connected_components, np_extract_label, np_unique 12 | from TPTBox.tests.test_utils import get_nii 13 | 14 | arr = np.array( 15 | [ 16 | [0, 0, 0, 1, 1, 0], 17 | [0, 1, 0, 1, 0, 0], 18 | [0, 1, 2, 1, 0, 0], 19 | [0, 2, 2, 1, 3, 0], 20 | [0, 0, 0, 3, 3, 0], 21 | [0, 0, 0, 1, 3, 0], 22 | ] 23 | ) 24 | 25 | labels_out = _connected_components(arr, connectivity=4, return_N=False) 26 | print(arr) 27 | print(labels_out) 28 | 29 | # crop and uncrop 30 | # 31 | crop = np_bbox_binary(arr) 32 | arr_c = arr[crop] 33 | print(crop) 34 | start = perf_counter() 35 | -------------------------------------------------------------------------------- /examples/dicom_select/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Hendrik-code/TPTBox/c6cd01460d09df747fd6baa77680a8f575049c7c/examples/dicom_select/__init__.py -------------------------------------------------------------------------------- /examples/nako/README.md: -------------------------------------------------------------------------------- 1 | # German National Cohort Dicom2Nii + Stitching example 2 | 3 | This folder contains sample scrips to export the subset of NAKO files to a BIDS compliant filenames and a folder simplified folder structure. 4 | 5 | We decided on a fixed naming schema for: 6 | - 3D_GRE (vibe / " point dixon) 7 | - ME_vibe (Multi echo Vibe) 8 | - TSE SAG LWS/BWS/HWS (Sagittal spine images) 9 | - PD FS SPC COR (proton density) 10 | - T2 Haste composed 11 | 12 | We do not use "Sag T2 Spine" as it contains stitching from "TSE SAG" and the spine is deformed for better viewing but incorrect for evaluation. Use our Stitching pipeline instead. If you have other file types you have to add them in the script or the script will ask you what default naming schema you want to use. 13 | 14 | 15 | 16 | Installing python 3.10 or higher: 17 | ```bash 18 | pip install pydicom 19 | pip install dicom2nifti 20 | pip install func_timeout 21 | 22 | 23 | pip install TPTBox 24 | ``` 25 | 26 | Making the data folder. Use -c to run multiple parallel export processes (halve of you cpu-core count recommended) 27 | ```bash 28 | dicom2nii_bids.py -i [Path to the dicom folder] 29 | ``` 30 | Than we provide scrips to stich T2w and Vibe images. 31 | ```bash 32 | stitching_T2w.py -i [Path to the bids folder (dataset-nako)] 33 | stitching_vibe.py -i [Path to the bids folder (dataset-nako)] 34 | ``` 35 | The stiched images are than under [dataset-nako]/rawdata_stiched. 36 | Note doing this for 30000 scans my take a while. You can start multiple scripts at the same time in different consols (google tmux). 37 | -------------------------------------------------------------------------------- /examples/nako/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Hendrik-code/TPTBox/c6cd01460d09df747fd6baa77680a8f575049c7c/examples/nako/__init__.py -------------------------------------------------------------------------------- /examples/nako/stitching_T2w.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | # noqa: INP001 4 | import argparse 5 | import time 6 | from pathlib import Path 7 | 8 | import TPTBox.stitching.stitching_tools as st 9 | from TPTBox import BIDS_FILE, BIDS_Global_info, Print_Logger 10 | 11 | logger = Print_Logger() 12 | arg_parser = argparse.ArgumentParser() 13 | arg_parser.add_argument("-i", "--inputfolder", help="input folder (where the rawdata folder is located)", required=True) 14 | arg_parser.add_argument("-p", "--outparant", help="input folder (where the rawdata folder is located)", default="rawdata_stitched") 15 | arg_parser.add_argument("-s", "--sleep", type=float, default=0, help="sleep after each save") 16 | arg_parser.add_argument("-r", "--rawdata", type=str, default="rawdata", help="the rawdata folder to be searched") 17 | args = arg_parser.parse_args() 18 | print(f"args={args}") 19 | print(f"args.inputfolder={args.inputfolder}") 20 | print(f"args.outparant={args.outparant}") 21 | print(f"args.rawdata={args.rawdata}") 22 | 23 | bgi = BIDS_Global_info(datasets=[Path(args.inputfolder)], parents=[args.rawdata]) 24 | print() 25 | skipped = [] 26 | skipped_no_t2w = [] 27 | skipped_to_many = [] 28 | already_stitched = 0 29 | new_stitched = 0 30 | for name, subj in bgi.enumerate_subjects(sort=True): 31 | q = subj.new_query() 32 | q.filter_format("T2w") 33 | q.filter("chunk", lambda x: str(x) in ["HWS", "BWS", "LWS"]) 34 | q.flatten() 35 | files: dict[str, BIDS_FILE] = {} 36 | to_few = False 37 | c = 0 38 | for chunk in ["HWS", "BWS", "LWS"]: 39 | q_tmp = q.copy() 40 | q_tmp.filter("chunk", chunk) 41 | l_t2w = list(q_tmp.loop_list()) 42 | c += len(l_t2w) 43 | if len(l_t2w) > 1: 44 | l_t2w = sorted(l_t2w, key=lambda x: x.get("sequ", -1))[-1] # type: ignore 45 | skipped_to_many.append(name) 46 | if len(l_t2w) != 1: 47 | to_few = True 48 | continue 49 | files[chunk] = l_t2w[0] 50 | if to_few: 51 | if c == 0: 52 | skipped_no_t2w.append(name) 53 | continue 54 | skipped.append((name, c)) 55 | continue 56 | out = files["HWS"].get_changed_path(info={"chunk": None, "sequ": "stitched"}, parent=args.outparant) 57 | try: 58 | if out.exists(): 59 | print(name, "exist", end="\r") 60 | already_stitched += 1 61 | continue 62 | print("Stich", out) 63 | nii = st.GNC_stitch_T2w(files["HWS"], files["BWS"], files["LWS"]) 64 | crop = nii.compute_crop() 65 | nii.apply_crop_(crop) 66 | nii.save(out) 67 | if args.sleep != 0: 68 | logger.print(f"Sleepy time {args.sleep} s; stitched {new_stitched}") 69 | time.sleep(args.sleep) 70 | new_stitched += 1 71 | except BaseException: 72 | out.unlink(missing_ok=True) 73 | skipped.append(name + "_FAIL") 74 | raise 75 | print("These subject have no T2w images:", len(skipped_no_t2w), skipped_no_t2w) if len(skipped_no_t2w) != 0 else None 76 | if len(skipped_to_many) != 0: 77 | logger.on_warning( 78 | "These subject have redundant T2w images:", 79 | len(skipped_to_many), 80 | skipped_to_many, 81 | "May be they have low quality data. Will use the highest sequence id, which is most likely the best.", 82 | ) 83 | logger.on_warning("These subject where skipped, because there are to few files or raised an error:", skipped) if len(skipped) != 0 else None 84 | print("Subject skipped:", len(skipped)) if len(skipped) != 0 else None 85 | print("Subject already stitched:", already_stitched) if (already_stitched) != 0 else None 86 | print("Subject stitched:", new_stitched) 87 | -------------------------------------------------------------------------------- /examples/nako/stitching_vibe.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | # noqa: INP001 4 | import argparse 5 | import random 6 | from pathlib import Path 7 | 8 | import TPTBox.stitching.stitching_tools as st 9 | from TPTBox import BIDS_FILE, BIDS_Global_info, Print_Logger 10 | from TPTBox.core.bids_constants import sequence_splitting_keys 11 | from TPTBox.core.compat import zip_strict 12 | 13 | logger = Print_Logger() 14 | arg_parser = argparse.ArgumentParser() 15 | arg_parser.add_argument("-i", "--inputfolder", help="input folder (where the rawdata folder is located)", required=True) 16 | arg_parser.add_argument("-p", "--outparant", help="input folder (where the rawdata folder is located)", default="rawdata_stitched") 17 | arg_parser.add_argument("-r", "--rawdata", type=str, default="rawdata", help="the rawdata folder to be searched") 18 | args = arg_parser.parse_args() 19 | print(f"args={args}") 20 | print(f"args.inputfolder={args.inputfolder}") 21 | print(f"args.rawdata={args.rawdata}") 22 | 23 | sequence_splitting_keys.remove("chunk") 24 | bgi = BIDS_Global_info(datasets=[Path(args.inputfolder)], parents=[args.rawdata], sequence_splitting_keys=sequence_splitting_keys) 25 | print() 26 | skipped = [] 27 | skipped_single = [] 28 | already_stitched = 0 29 | new_stitched = 0 30 | l = list(bgi.enumerate_subjects()) 31 | random.shuffle(l) 32 | 33 | 34 | def split_multi_scans(v: list[BIDS_FILE], out: BIDS_FILE): 35 | jsons = [x.open_json() for x in v] 36 | ids = [(j["SeriesNumber"], bids) for j, bids in zip_strict(jsons, v)] 37 | ids.sort() 38 | curr = [] 39 | curr_id = [] 40 | splits = [] 41 | splits_c = [] 42 | for _, bids in ids: 43 | chunk = bids.get("chunk") 44 | if chunk in curr_id: 45 | splits.append(curr) 46 | splits_c.append(curr_id) 47 | curr = [] 48 | curr_id = [] 49 | curr.append(bids) 50 | curr_id.append(chunk) 51 | splits.append(curr) 52 | splits_c.append(curr_id) 53 | 54 | # test if those are not patched scans 55 | multiple_full_scans = True 56 | for c in splits_c: 57 | if c != [str(s + 1) for s in list(range(len(c)))]: 58 | multiple_full_scans = False 59 | break 60 | if not multiple_full_scans: 61 | logger.on_warning( 62 | "[", 63 | v[0], 64 | "]; is patched with multiple partial scans. Some scans will probably have movement errors. Delete those duplicated files", 65 | splits_c, 66 | ) 67 | return False 68 | else: 69 | for number, v_list in enumerate(splits[::-1], start=1): 70 | out_new = out.get_changed_path(parent=out.get_parent(), info={"sequ": "stitched", "nameconflict": None, "run": str(number)}) 71 | if out_new.exists(): 72 | continue 73 | st.stitching(*v_list, out=out_new) 74 | return True 75 | 76 | 77 | for name, subj in l: 78 | q = subj.new_query() 79 | q.filter_format("vibe") 80 | q.filter("chunk", lambda _: True) # chunk key must be present. Stiched images do not have a chunk key, so they are skipped 81 | files: dict[str, BIDS_FILE] = {} 82 | for fam in q.loop_dict(key_addendum=["part"]): 83 | for v in fam.values(): 84 | part = v[0].get("part") 85 | out = v[0].get_changed_bids(parent=args.outparant, info={"chunk": None, "sequ": "stitched"}) 86 | # Check if there are multiple scans 87 | ids = {} 88 | skip = False 89 | for b in v: 90 | key = str(b.get("chunk")) + "-" + str(b.get("part")) 91 | ids.setdefault(key, 0) 92 | ids[key] += 1 93 | if ids[key] >= 2: 94 | skip = True 95 | if skip: 96 | succ = split_multi_scans(v, out) 97 | if not succ: 98 | skipped.append(name) 99 | 100 | continue 101 | try: 102 | if out.exists(): 103 | print(name, "exist", end="\r") 104 | already_stitched += 1 105 | continue 106 | if len(v) == 1: 107 | skipped_single.append(name) 108 | continue 109 | st.stitching(v, out=out) 110 | new_stitched += 1 111 | except BaseException: 112 | out.unlink(missing_ok=True) 113 | raise 114 | 115 | 116 | skipped = set(skipped) 117 | skipped_single = set(skipped_single) 118 | c = len(skipped) + len(skipped_single) 119 | logger.on_warning("These subject where skipped, because there multiple scans:", list(skipped)) if len(skipped) != 0 else None 120 | if len(skipped_single) != 0: 121 | logger.on_warning("These subject where skipped, because there ins only a single scans:", list(skipped_single)) 122 | logger.on_warning("Subject skipped:", c) if c != 0 else None 123 | print("Images already stitched:", already_stitched) if already_stitched != 0 else None 124 | print("Images stitched:", new_stitched) 125 | -------------------------------------------------------------------------------- /examples/registration/IVD_transfer/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Hendrik-code/TPTBox/c6cd01460d09df747fd6baa77680a8f575049c7c/examples/registration/IVD_transfer/__init__.py -------------------------------------------------------------------------------- /examples/registration/atlas_poi_transfer_leg/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Hendrik-code/TPTBox/c6cd01460d09df747fd6baa77680a8f575049c7c/examples/registration/atlas_poi_transfer_leg/__init__.py -------------------------------------------------------------------------------- /examples/registration/atlas_poi_transfer_leg/example.py: -------------------------------------------------------------------------------- 1 | import os 2 | from pathlib import Path 3 | 4 | from examples.registration.atlas_poi_transfer_leg.atlas_poi_transfer_leg_ct import ( 5 | Register_Point_Atlas, 6 | parse_coordinates_to_poi, 7 | prep_Atlas, 8 | ) 9 | from TPTBox import POI, to_nii 10 | from TPTBox.core.vert_constants import Full_Body_Instance, Lower_Body 11 | 12 | ########################################## 13 | # Settings 14 | text_file_is_left_leg = True 15 | file_text = "/DATA/NAS/tools/TPTBox/examples/atlas_poi_transfer_leg/010__left.txt" 16 | segmentation_path = "/DATA/NAS/datasets_processed/CT_fullbody/dataset-watrinet/source/Dataset001_all/0001/bone.nii.gz" 17 | out_folder = Path("/DATA/NAS/datasets_processed/CT_fullbody/dataset-watrinet/atlas") 18 | atlas_id = 1 19 | ########################################## 20 | # Load segmentation 21 | seg = to_nii(segmentation_path, True) 22 | 23 | if not text_file_is_left_leg: 24 | axis = seg.get_axis("R") 25 | if axis == 0: 26 | target = seg.set_array(seg.get_array()[::-1]).copy() 27 | elif axis == 1: 28 | target = seg.set_array(seg.get_array()[:, ::-1]).copy() 29 | elif axis == 2: 30 | target = seg.set_array(seg.get_array()[:, :, ::-1]).copy() 31 | assert text_file_is_left_leg, "Not implement: Flip NII and POI" 32 | # Prep atlas 33 | atlas_path = out_folder / f"atlas{atlas_id:03}.nii.gz" 34 | atlas_cms_poi_path = out_folder / f"atlas{atlas_id:03}_cms_poi.json" # Center of mass 35 | atlas_poi_path = out_folder / f"atlas{atlas_id:03}_poi.json" 36 | prep_Atlas(seg, atlas_path, atlas_cms_poi_path, text_file_is_left_leg) 37 | 38 | 39 | poi = parse_coordinates_to_poi(file_text, True).to_other(seg) if ".txt" in file_text else POI.load(file_text).resample_from_to(seg) 40 | if not text_file_is_left_leg: 41 | for k1, k2, (x, y, z) in poi.items(): 42 | axis = poi.get_axis("R") 43 | if axis == 0: 44 | poi[k1, k2] = (poi.shape[0] - 1 - x, y, z) 45 | elif axis == 1: 46 | poi[k1, k2] = (x, poi.shape[1] - 1 - y, z) 47 | elif axis == 2: 48 | poi[k1, k2] = (x, y, poi.shape[2] - 1 - z) 49 | else: 50 | raise ValueError(axis) 51 | poi.level_one_info = Full_Body_Instance 52 | poi.level_two_info = Lower_Body 53 | poi.to_global().save(atlas_poi_path) 54 | # Step 1 55 | 56 | ########################################## 57 | for i in range(500): 58 | # Settings 59 | target_seg_path = ( 60 | f"/DATA/NAS/datasets_processed/CT_fullbody/dataset-watrinet/source/Dataset001_all/{i:04}/bone.nii.gz" # TODO Path to target seg 61 | ) 62 | s = str(target_seg_path).split(".")[0] 63 | split_leg_path = s + "_seg-left-right-split_msk.nii.gz" 64 | out_new_pois = s + "_desc-leg_poi.json" 65 | out_new_pois_nii = s + "_desc-leg_poi.nii.gz" 66 | atlas_id = 1 67 | ddevice = "cuda" 68 | gpu = 0 69 | if not Path(target_seg_path).exists() or Path(out_new_pois_nii).exists(): 70 | continue 71 | ########################################## 72 | # Atlas 73 | atlas_p = out_folder / f"atlas{atlas_id:03}.nii.gz" 74 | atlas_centroids = out_folder / f"atlas{atlas_id:03}_cms_poi.json" # Center of mass 75 | atlas_poi_path = out_folder / f"atlas{atlas_id:03}_poi.json" 76 | # Load segmentation 77 | target = to_nii(target_seg_path, True) 78 | atlas = to_nii(atlas_p, True) 79 | 80 | # Creating this object will start the registration 81 | registration_obj = Register_Point_Atlas( 82 | target, atlas, split_leg_path=split_leg_path, atlas_centroids=atlas_centroids, gpu=gpu, ddevice=ddevice, verbose=0 83 | ) 84 | out_poi = registration_obj.make_poi_from_poi(POI.load(atlas_poi_path), out_new_pois) 85 | nii = out_poi.make_point_cloud_nii()[1] + to_nii(split_leg_path, True) * 100 86 | nii.save(out_new_pois_nii) 87 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.poetry] 2 | name = "TPTBox" 3 | version = "0.0.0" 4 | description = "A Torso Processing Toolbox capable of processing BIDS-compatible datasets, singular niftys, points of interests, segmentations, and much more." 5 | authors = [ 6 | "Robert Graf ", 7 | "Hendrik Möller ", 8 | ] 9 | repository = "https://github.com/Hendrik-code/TPTBox" 10 | license = "GNU AFFERO GENERAL PUBLIC LICENSE v3.0, 19 November 2007" 11 | readme = "README.md" 12 | packages = [{ include = "TPTBox" }] 13 | 14 | [tool.poetry.dependencies] 15 | python = "^3.9 || ^3.10 || ^3.11 || ^3.12" 16 | pathlib = "*" 17 | nibabel = "^5.2.0" 18 | numpy = "^1.26.3" 19 | typing-extensions = "^4.9.0" 20 | scipy = "^1.12.0" 21 | dataclasses = "*" 22 | SimpleITK = "^2.3.1" 23 | matplotlib = "^3.8.2" 24 | dill = "^0.3.7" 25 | scikit-image = "^0.22.0" 26 | fill-voids = "^2.0.6" 27 | connected-components-3d = "^3.12.3" 28 | tqdm = "*" 29 | joblib = "*" 30 | scikit-learn = "*" 31 | antspyx = "0.4.2" 32 | #hf-deepali = "*" 33 | 34 | [tool.poetry.dev-dependencies] 35 | pytest = ">=8.1.1" 36 | vtk = "*" 37 | pre-commit = "*" 38 | pyvista = "^0.43.2" 39 | coverage = ">=7.0.1" 40 | pytest-mock = "^3.6.0" 41 | 42 | 43 | 44 | [build-system] 45 | requires = ["poetry-core>=1.0.0", "poetry-dynamic-versioning>=1.0.0,<2.0.0"] 46 | build-backend = "poetry_dynamic_versioning.backend" 47 | 48 | 49 | [tool.poetry-dynamic-versioning] 50 | enable = true 51 | 52 | 53 | [tool.ruff] 54 | namespace-packages = ["datagen"] 55 | exclude = [ 56 | ".bzr", 57 | ".direnv", 58 | ".eggs", 59 | ".git", 60 | ".git-rewrite", 61 | ".hg", 62 | ".mypy_cache", 63 | ".nox", 64 | ".pants.d", 65 | ".pytype", 66 | ".ruff_cache", 67 | ".svn", 68 | ".tox", 69 | ".venv", 70 | "__pypackages__", 71 | "_build", 72 | "buck-out", 73 | "build", 74 | "dist", 75 | "node_modules", 76 | "venv", 77 | ".toml", 78 | ] 79 | line-length = 140 80 | indent-width = 4 81 | target-version = "py310" 82 | 83 | [tool.ruff.lint] 84 | ## Enable Pyflakes (`F`) and a subset of the pycodestyle (`E`) codes by default. 85 | ## Unlike Flake8, Ruff doesn't enable pycodestyle warnings (`W`) or 86 | ## McCabe complexity (`C901`) by default. 87 | # 88 | select = [ 89 | "E", 90 | "F", 91 | "W", 92 | "C901", 93 | "I", 94 | "N", 95 | "UP", 96 | "ASYNC", 97 | "BLE", 98 | "B", 99 | "A", 100 | "C4", 101 | "ICN", 102 | "G", 103 | "INP", 104 | "PIE", 105 | "PYI", 106 | #"RET", 107 | "SIM", 108 | "TID", 109 | "INT", 110 | "ARG", 111 | #"PTH", 112 | "TD005", 113 | "FIX003", 114 | "FIX004", 115 | #"ERA", For clean up 116 | #"D", Dockstring For clean up 117 | #"ANN", Annoation For clean up 118 | "PD", 119 | "PGH", 120 | "PL", 121 | "TRY", 122 | "FLY", 123 | "NPY", 124 | "AIR", 125 | "PERF", 126 | "FURB", 127 | "RUF", 128 | ] 129 | 130 | 131 | ignore = [ 132 | "RUF100", 133 | "F401", 134 | "BLE001", 135 | "E501", 136 | "N801", 137 | "NPY002", 138 | "PD002", 139 | "PERF203", 140 | "PTH123", 141 | "PGH003", 142 | "PLR0911", 143 | "PLR0912", 144 | "PLR0913", 145 | "PLR0915", 146 | "PLR2004", 147 | "SIM105", 148 | "TRY003", 149 | "UP038", 150 | "N999", 151 | "E741", 152 | "SIM118", # dictionay keys 153 | "N802", # function name lowercase 154 | "F811", 155 | "N803", 156 | "N806", 157 | "B905", # strict= in zip 158 | "UP007", # Union and "|" python 3.9 159 | ] 160 | 161 | # Allow fix for all enabled rules (when `--fix`) is provided. 162 | fixable = ["ALL"] 163 | unfixable = [] 164 | ignore-init-module-imports = true 165 | extend-safe-fixes = ["RUF015", "C419", "C408", "B006"] 166 | #unnecessary-iterable-allocation-for-first-element = true 167 | 168 | 169 | # Allow unused variables when underscore-prefixed. 170 | dummy-variable-rgx = "^(_+|(_+[a-zA-Z0-9_]*[a-zA-Z0-9]+?))$" 171 | 172 | [tool.ruff.lint.mccabe] 173 | # Flag errors (`C901`) whenever the complexity level exceeds 5. 174 | max-complexity = 20 175 | 176 | 177 | [tool.ruff.format] 178 | # Like Black, use double quotes for strings. 179 | quote-style = "double" 180 | 181 | # Like Black, indent with spaces, rather than tabs. 182 | indent-style = "space" 183 | 184 | # Like Black, respect magic trailing commas. 185 | skip-magic-trailing-comma = false 186 | 187 | # Enable reformatting of code snippets in docstrings. 188 | docstring-code-format = true 189 | 190 | 191 | # Like Black, automatically detect the appropriate line ending. 192 | line-ending = "auto" 193 | # Add this to your setting.json (user) 194 | # Ctrl+shift+P settings json 195 | #"[python]": { 196 | # "editor.formatOnSave": true, 197 | # "editor.defaultFormatter": "charliermarsh.ruff", 198 | # "editor.codeActionsOnSave": { 199 | # "source.fixAll": "explicit", 200 | # "source.organizeImports": "never" 201 | # } 202 | # }, 203 | # "notebook.formatOnSave.enabled": true, 204 | # "notebook.codeActionsOnSave": { 205 | # "source.fixAll": false, 206 | # "source.organizeImports": false 207 | # }, 208 | -------------------------------------------------------------------------------- /unit_tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Hendrik-code/TPTBox/c6cd01460d09df747fd6baa77680a8f575049c7c/unit_tests/__init__.py -------------------------------------------------------------------------------- /unit_tests/test_bids_dataset_parallel.py: -------------------------------------------------------------------------------- 1 | # Call 'python -m unittest' on this folder 2 | # coverage run -m unittest 3 | # coverage report 4 | # coverage html 5 | from __future__ import annotations 6 | 7 | import os 8 | import sys 9 | import unittest 10 | from pathlib import Path 11 | 12 | if not os.path.isdir("test"): # noqa: PTH112 13 | sys.path.append("..") 14 | file = Path(__file__).resolve() 15 | sys.path.append(str(file.parents[2])) 16 | import random # noqa: E402 17 | import unittest.mock # noqa: E402 18 | 19 | from joblib import Parallel, delayed # noqa: E402 20 | 21 | from TPTBox import BIDS_Global_info # noqa: E402 22 | from TPTBox.tests.test_utils import get_tests_dir, repeats # noqa: E402 23 | 24 | 25 | class Test_bids_global_info(unittest.TestCase): 26 | def dataset_test_successfull_load(self, tests_path: Path, parent: str): 27 | bids_ds = BIDS_Global_info(datasets=[tests_path], parents=[parent]) 28 | 29 | expected_keys_mri = ["msk_seg-subreg_label-6", "msk_seg-vert_label-6", "T2w_label-6"] 30 | expected_keys_ct = ["ct_label-22", "msk_seg-vert_label-22", "msk_seg-subreg_label-22"] 31 | expected_keys = expected_keys_mri if parent == "sample_mri" else expected_keys_ct 32 | 33 | expected_subject_name = "mri" if parent == "sample_mri" else "ct" 34 | 35 | for name, subject in bids_ds.enumerate_subjects(sort=True): 36 | self.assertEqual(name, expected_subject_name) 37 | q = subject.new_query() 38 | families = q.loop_dict(sort=True) 39 | for f in families: 40 | print(f.family_id, f.get_key_len()) 41 | key_len_dict = f.get_key_len() 42 | for k in expected_keys: 43 | self.assertTrue(k in key_len_dict) 44 | self.assertTrue(key_len_dict[k] == 1) 45 | 46 | def test_load_same_dataset_multiple_times(self): 47 | tests_path = get_tests_dir() 48 | parent = "sample_ct" 49 | for _i in range(repeats): 50 | self.dataset_test_successfull_load(tests_path=tests_path, parent=parent) 51 | 52 | def test_load_different_datasets_multiple_times(self): 53 | tests_path = get_tests_dir() 54 | for _i in range(repeats): 55 | parent = "sample_ct" if random.random() < 0.5 else "sample_mri" 56 | self.dataset_test_successfull_load(tests_path=tests_path, parent=parent) 57 | 58 | def test_load_same_dataset_multiple_times_parallel(self): 59 | tests_path = get_tests_dir() 60 | parent = "sample_ct" 61 | Parallel(n_jobs=5, backend="threading")( 62 | delayed(self.dataset_test_successfull_load)(tests_path=tests_path, parent=parent) for i in range(repeats) 63 | ) 64 | 65 | def test_load_different_datasets_multiple_times_parallel(self): 66 | tests_path = get_tests_dir() 67 | Parallel(n_jobs=5, backend="threading")( 68 | delayed(self.dataset_test_successfull_load)( 69 | tests_path=tests_path, parent="sample_ct" if random.random() < 0.5 else "sample_mri" 70 | ) 71 | for i in range(repeats) 72 | ) 73 | -------------------------------------------------------------------------------- /unit_tests/test_bids_file.py: -------------------------------------------------------------------------------- 1 | # Call 'python -m unittest' on this folder 2 | # coverage run -m unittest 3 | # coverage report 4 | # coverage html 5 | from __future__ import annotations 6 | 7 | import os 8 | import random 9 | import sys 10 | import unittest 11 | from pathlib import Path 12 | 13 | sys.path.append(str(Path(__file__).resolve().parents[2])) 14 | import TPTBox 15 | import TPTBox.core.bids_files as bids 16 | from TPTBox.tests.test_utils import a, get_BIDS_test 17 | 18 | 19 | class Test_bids_file(unittest.TestCase): 20 | @unittest.skipIf(not Path("/media/robert/Expansion/dataset-Testset").exists(), "requires real data to be opened") 21 | def test_non_flatten_dixon(self): 22 | global_info = get_BIDS_test() 23 | for subj_name, subject in global_info.enumerate_subjects(): 24 | with self.subTest(dataset=f"{subj_name} filter"): 25 | query = subject.new_query() 26 | # A nii.gz must exist 27 | query.filter("Filetype", "nii.gz") 28 | # It must exist a dixon and a msk 29 | query.filter("format", "dixon") 30 | query.filter("format", "msk") 31 | # Example of lamda function filtering 32 | query.filter("sequ", lambda x: int(x) == 303, required=True) # type: ignore 33 | with self.subTest(dataset=f"{subj_name} loop"): 34 | for sequences in query.loop_dict(): 35 | self.assertIsInstance(sequences, dict) 36 | self.assertTrue("dixon" in sequences) 37 | self.assertEqual(len(sequences["dixon"]), 3) 38 | # sequences. 39 | 40 | @unittest.skipIf(not Path("/media/robert/Expansion/dataset-Testset").exists(), "requires real data to be opened") 41 | def test_non_flatten_ct(self): 42 | global_info = get_BIDS_test() 43 | for subj_name, subject in global_info.enumerate_subjects(): 44 | subject: bids.Subject_Container 45 | with self.subTest(dataset=f"{subj_name} filter"): 46 | query = subject.new_query() 47 | # A nii.gz must exist 48 | query.filter("Filetype", "nii.gz") 49 | # It must exist a dixon and a msk 50 | query.filter("format", "ct") 51 | # Example of lamda function filtering 52 | query.filter("sequ", lambda x: x != "None" and isinstance(x, str) and int(x) == 203, required=True) 53 | with self.subTest(dataset=f"{subj_name} loop"): 54 | for sequences in query.loop_dict(): 55 | self.assertIsInstance(sequences, dict) 56 | self.assertTrue("ct" in sequences) 57 | self.assertTrue("vert" in sequences) 58 | self.assertTrue("subreg" in sequences) 59 | self.assertTrue("snp" in sequences) 60 | 61 | @unittest.skipIf(not Path("/media/robert/Expansion/dataset-Testset").exists(), "requires real data to be opened") 62 | def test_get_sequence_files(self): 63 | global_info = get_BIDS_test() 64 | for _, subject in global_info.enumerate_subjects(): 65 | if "20210111_301" in subject.sequences.keys(): 66 | sequences = subject.get_sequence_files("20210111_301") 67 | self.assertIsInstance(sequences, TPTBox.BIDS_Family) 68 | self.assertTrue("dixon" in sequences) 69 | self.assertTrue("msk" in sequences) 70 | self.assertTrue("snp" in sequences) 71 | self.assertTrue("ctd_subreg" in sequences) 72 | self.assertIsInstance(sequences["dixon"], list) 73 | self.assertTrue(len(sequences["dixon"]) == 3) 74 | self.assertTrue("dixon" in sequences) 75 | 76 | else: 77 | 78 | def mapping(x: bids.BIDS_FILE): 79 | if x.format == "ctd" and x.info["seg"] == "subreg": 80 | return "other_key_word" 81 | return None 82 | 83 | sequences = subject.get_sequence_files("20220517_406", key_transform=mapping) 84 | self.assertIsInstance(sequences, TPTBox.BIDS_Family) 85 | self.assertTrue("other_key_word" in sequences) 86 | self.assertTrue("snp" in sequences) 87 | self.assertTrue("msk_seg-subreg" in sequences) 88 | self.assertTrue("msk_vert" in sequences) 89 | self.assertTrue("ct" in sequences) 90 | 91 | @unittest.skipIf(not Path("/media/robert/Expansion/dataset-Testset").exists(), "requires real data to be opened") 92 | def test_conditional_action(self): 93 | global_info = get_BIDS_test() 94 | for subj_name, subject in global_info.enumerate_subjects(): 95 | with self.subTest(dataset=f"{subj_name} filter"): 96 | query = subject.new_query(flatten=True) 97 | # A nii.gz must exist 98 | query.filter("Filetype", "nii.gz") 99 | # It must exist a dixon 100 | query.filter("format", "dixon") 101 | with self.subTest(dataset=f"{subj_name} action"): 102 | # Find if the dixon is a inphase image by looking into the json and set "part" to real 103 | query.action( 104 | # Set Part key to real. Will be called if the filter = True 105 | action_fun=lambda x: x.set("part", "real"), 106 | # x is the json, because of the key="json". We return True if the json confirms that this is a real-part image 107 | filter_fun=lambda x: "IP" in x["ImageType"], # type: ignore 108 | key="json", 109 | # The json is required 110 | required=True, 111 | ) 112 | 113 | with self.subTest(dataset=f"{subj_name} loop"): 114 | real_count = 0 115 | dix_count = 0 116 | for sequences in query.loop_list(): 117 | self.assertIsInstance(sequences, bids.BIDS_FILE) 118 | if "part" in sequences.info: 119 | real_count += 1 120 | dix_count += 1 121 | self.assertEqual(dix_count / real_count, 3.0) 122 | 123 | 124 | if __name__ == "__main__": 125 | unittest.main() 126 | 127 | # @unittest.skipIf(condition, reason) 128 | # with self.subTest(i=i): 129 | -------------------------------------------------------------------------------- /unit_tests/test_bids_file_print.py: -------------------------------------------------------------------------------- 1 | # Call 'python -m unittest' on this folder 2 | # coverage run -m unittest 3 | # coverage report 4 | # coverage html 5 | from __future__ import annotations 6 | 7 | import os 8 | import sys 9 | import unittest 10 | from pathlib import Path 11 | 12 | sys.path.append(str(Path(__file__).resolve().parents[2])) 13 | import io 14 | import unittest.mock 15 | 16 | import TPTBox.core.bids_files as bids 17 | from TPTBox import BIDS_FILE 18 | from TPTBox.tests.test_utils import a, get_BIDS_test 19 | 20 | 21 | class Test_bids_file(unittest.TestCase): 22 | @unittest.mock.patch("sys.stdout", new_callable=io.StringIO) 23 | def test_non_verbose(self, mock_stdout): 24 | bids.validate_entities("xxx", "yyy", "zzz", False) 25 | self.assertEqual(mock_stdout.getvalue(), "") 26 | 27 | @unittest.mock.patch("sys.stdout", new_callable=io.StringIO) 28 | def test_set(self, mock_stdout): 29 | f = BIDS_FILE("sub-spinegan0026_ses-20210109_sequ-203_seg-subreg_ctd.json", "/media/robert/Expansion/dataset-Testset", verbose=True) 30 | f.set("seg", "mimi") 31 | self.assertEqual(mock_stdout.getvalue(), "") 32 | f.set("TUT", "mimi") 33 | self.assertNotEqual(mock_stdout.getvalue(), "") 34 | mock_stdout.truncate(0) 35 | self.assertEqual(f.get("seg"), "mimi") 36 | self.assertEqual(mock_stdout.getvalue(), "") 37 | self.assertEqual(f.get("TUT"), "mimi") 38 | self.assertEqual(mock_stdout.getvalue(), "") 39 | mock_stdout.truncate(0) 40 | self.assertEqual(f.get("sequ"), "203") 41 | self.assertEqual(mock_stdout.getvalue(), "") 42 | mock_stdout.truncate(0) 43 | self.assertEqual(f.get("b", default="999"), "999") 44 | self.assertEqual(mock_stdout.getvalue(), "") 45 | mock_stdout.truncate(0) 46 | try: 47 | f.get("x") 48 | self.assertFalse(True) 49 | except Exception: 50 | pass 51 | f.set("task", "LR") 52 | # sys.__stdout__.write(mock_stdout.getvalue()) 53 | self.assertEqual(mock_stdout.getvalue(), "") 54 | mock_stdout.truncate(0) 55 | f.set("dir", "LR123") 56 | self.assertEqual(mock_stdout.getvalue(), "") 57 | mock_stdout.truncate(0) 58 | f.set("task", "LR-123") 59 | self.assertNotEqual(mock_stdout.getvalue(), "") 60 | mock_stdout.truncate(0) 61 | # for key in ["run", "mod", "echo", "flip", "inv"]: 62 | # f.set(key, str(random.randint(0, 100000))) 63 | # self.assertEqual(mock_stdout.getvalue(), "", msg=f"{key},{f.get(key)}") 64 | # mock_stdout.truncate(0) 65 | # letters = string.ascii_lowercase 66 | # f.set(key, "".join(random.choice(letters) for i in range(10))) 67 | # self.assertNotEqual(mock_stdout.getvalue(), "") 68 | # mock_stdout.truncate(0) 69 | 70 | 71 | if __name__ == "__main__": 72 | unittest.main() 73 | 74 | # @unittest.skipIf(condition, reason) 75 | # with self.subTest(i=i): 76 | -------------------------------------------------------------------------------- /unit_tests/test_centroids_save.py: -------------------------------------------------------------------------------- 1 | # Call 'python -m unittest' on this folder 2 | # coverage run -m unittest 3 | # coverage report 4 | # coverage html 5 | from __future__ import annotations 6 | 7 | import os 8 | import random 9 | import sys 10 | import tempfile 11 | import unittest 12 | from pathlib import Path 13 | 14 | import nibabel as nib 15 | import numpy as np 16 | 17 | import TPTBox.core.bids_files as bids 18 | from TPTBox import POI, POI_Global 19 | from TPTBox.core.nii_wrapper import AX_CODES, NII 20 | from TPTBox.core.vert_constants import conversion_poi2text 21 | 22 | repeats = 20 23 | 24 | 25 | def get_random_ax_code() -> AX_CODES: 26 | directions = [["R", "L"], ["S", "I"], ["A", "P"]] 27 | idx = [0, 1, 2] 28 | random.shuffle(idx) 29 | return tuple(directions[i][random.randint(0, 1)] for i in idx) # type: ignore 30 | 31 | 32 | def get_random_shape(): 33 | return tuple(int(1000 * random.random() + 10) for _ in range(3)) # type: ignore 34 | 35 | 36 | def get_centroids(x: tuple[int, int, int] = (50, 30, 40), num_point=3, default_sub=None): 37 | out_points: dict[tuple[int, int], tuple[float, float, float]] = {} 38 | 39 | for _ in range(num_point): 40 | point = tuple(random.randint(1, a * 100) / 100.0 for a in x) 41 | out_points[random.randint(1, 256), random.randint(1, 256) if default_sub is None else 50] = point 42 | return POI( 43 | centroids=out_points, 44 | orientation=get_random_ax_code(), 45 | zoom=(random.random() * 3, random.random() * 3, random.random() * 3), 46 | shape=x, 47 | origin=(0, 0, 0), 48 | rotation=np.eye(3), 49 | ) 50 | 51 | 52 | def get_centroids2(x: tuple[int, int, int] = (50, 30, 40), num_point=3): 53 | out_points: dict[tuple[int, int], tuple[float, float, float]] = {} 54 | l = list(conversion_poi2text.keys()) 55 | for _ in range(num_point): 56 | point = tuple(random.randint(1, a * 100) / 100.0 for a in x) 57 | out_points[random.randint(1, 27), l[random.randint(0, len(l) - 1)]] = point 58 | return POI(centroids=out_points, orientation=get_random_ax_code(), zoom=(1, 1, 1), shape=x) 59 | 60 | 61 | s = Path("BIDS/test/") 62 | if not s.exists(): 63 | s = Path() 64 | 65 | 66 | class Test_Centroids(unittest.TestCase): 67 | def test_save_0(self): 68 | for _ in range(repeats): 69 | p = Path(s, "test_save_0.json") 70 | cdt = get_centroids(x=(500, 700, 900), num_point=99) 71 | cdt.save(p, verbose=False) 72 | cdt2 = POI.load(p) 73 | self.assertEqual(cdt, cdt2) 74 | p.unlink() 75 | 76 | def test_save_1(self): 77 | for _ in range(repeats): 78 | p = Path(s, "test_save_1.json") 79 | cdt = get_centroids2(x=get_random_shape(), num_point=99) 80 | cdt.save(p, verbose=False, save_hint=1) 81 | cdt2 = POI.load(p) 82 | self.assertEqual(cdt, cdt2) 83 | p.unlink() 84 | 85 | def test_save_2(self): 86 | for _ in range(repeats): 87 | p = Path(s, "test_save_2.json") 88 | cdt = get_centroids(x=get_random_shape(), num_point=99) 89 | cdt.save(p, verbose=False, save_hint=2) 90 | cdt2 = POI.load(p) 91 | self.assertEqual(cdt, cdt2) 92 | Path(p).unlink() 93 | 94 | def test_save_10(self): 95 | for _ in range(repeats): 96 | p = Path(s, "test_save_10.json") 97 | cdt = get_centroids(x=get_random_shape(), num_point=2) 98 | cdt.save(p, verbose=False, save_hint=10) 99 | cdt2 = POI.load(p) 100 | cdt = cdt.rescale((1, 1, 1), verbose=False).reorient_(("R", "P", "I")) 101 | cdt.shape = None # type: ignore 102 | cdt.rotation = None # type: ignore 103 | self.assertEqual(cdt, cdt2) 104 | Path(p).unlink() 105 | 106 | def test_save_Glob(self): 107 | for _ in range(repeats): 108 | p = Path(s, "test_save_glob.json") 109 | cdt = get_centroids(x=get_random_shape(), num_point=20).to_global() 110 | cdt.save(p, verbose=False) 111 | cdt2 = POI_Global.load(p) 112 | self.assertEqual(cdt, cdt2) 113 | Path(p).unlink() 114 | 115 | def test_save_Glob_2(self): 116 | for _ in range(repeats): 117 | p = Path(s, "test_save_glob_2.json") 118 | cdt = get_centroids(x=get_random_shape(), num_point=20) 119 | glob_poi = cdt.to_global() 120 | cdt.save(p, verbose=False) 121 | glob_poi.save(p, verbose=False) 122 | cdt_a = POI_Global.load(p) 123 | cdt_b = POI_Global.load(p) 124 | self.assertEqual(cdt_a, cdt_b) 125 | Path(p).unlink() 126 | 127 | def test_save_all(self): 128 | for _ in range(repeats): 129 | p = Path(s, "test_save_all.json") 130 | cdt = get_centroids2(x=get_random_shape(), num_point=99) 131 | cdt2 = cdt 132 | for _ in range(5): 133 | cdt2.save(p, verbose=False, save_hint=random.randint(0, 2)) 134 | cdt2 = POI.load(p) 135 | self.assertEqual(cdt, cdt2) 136 | Path(p).unlink() 137 | 138 | 139 | if __name__ == "__main__": 140 | unittest.main() 141 | 142 | # @unittest.skipIf(condition, reason) 143 | # with self.subTest(i=i): 144 | -------------------------------------------------------------------------------- /unit_tests/test_compat.py: -------------------------------------------------------------------------------- 1 | # Call 'python -m unittest' on this folder 2 | # coverage run -m unittest 3 | # coverage report 4 | # coverage html 5 | from __future__ import annotations 6 | 7 | import unittest 8 | 9 | from TPTBox.core.compat import zip_strict 10 | 11 | 12 | class TestZipStrict(unittest.TestCase): 13 | def test_equal_length_lists(self): 14 | a = [1, 2, 3] 15 | b = ['a', 'b', 'c'] 16 | expected = [(1, 'a'), (2, 'b'), (3, 'c')] 17 | result = list(zip_strict(a, b)) 18 | self.assertEqual(result, expected) 19 | 20 | def test_unequal_length_lists(self): 21 | a = [1, 2, 3] 22 | b = ['a', 'b'] 23 | with self.assertRaises(ValueError) as context: 24 | list(zip_strict(a, b)) 25 | self.assertIn("Length mismatch", str(context.exception)) 26 | 27 | def test_empty_iterables(self): 28 | a = [] 29 | b = [] 30 | result = list(zip_strict(a, b)) 31 | self.assertEqual(result, []) 32 | 33 | def test_multiple_iterables(self): 34 | a = [1, 2, 3] 35 | b = [4, 5, 6] 36 | c = [7, 8, 9] 37 | expected = [(1, 4, 7), (2, 5, 8), (3, 6, 9)] 38 | result = list(zip_strict(a, b, c)) 39 | self.assertEqual(result, expected) 40 | 41 | def test_generator_iterables(self): 42 | a = (x for x in range(3)) 43 | b = (chr(97 + x) for x in range(3)) 44 | expected = [(0, 'a'), (1, 'b'), (2, 'c')] 45 | result = list(zip_strict(a, b)) 46 | self.assertEqual(result, expected) 47 | -------------------------------------------------------------------------------- /unit_tests/test_poi_global.py: -------------------------------------------------------------------------------- 1 | # Generated by CodiumAI 2 | from __future__ import annotations 3 | 4 | import random 5 | import unittest 6 | from collections.abc import Sequence 7 | 8 | import numpy as np 9 | 10 | from TPTBox import AX_CODES, POI, POI_Global 11 | from TPTBox.core.poi_fun.poi_abstract import POI_Descriptor 12 | from TPTBox.tests.test_utils import get_poi, get_random_ax_code 13 | 14 | 15 | class TestPOI(unittest.TestCase): 16 | def test_glob_by_definition(self): 17 | poi = get_poi() 18 | poi.orientation = ("R", "A", "S") 19 | poi.zoom = (1, 1, 1) 20 | glob_poi = POI_Global(poi._get_centroids().copy()) 21 | poi = poi.reorient(get_random_ax_code()) 22 | poi.rescale_((3, 2, 1)) 23 | glob_poi.to_other_poi(poi) 24 | 25 | def test_not_implemented(self): 26 | self.assertRaises(NotImplementedError, POI_Global, None) 27 | 28 | def test_is_global(self): 29 | poi = get_poi() 30 | poi.orientation = ("L", "A", "S") 31 | poi.zoom = (1, 1, 0.5) 32 | self.assertFalse(poi.is_global) 33 | glob_poi = POI_Global(poi._get_centroids().copy()) 34 | self.assertTrue(glob_poi.is_global) 35 | 36 | def test_copy(self): 37 | poi = get_poi() 38 | glob_poi = POI_Global(poi) 39 | 40 | c = glob_poi.copy() 41 | self.assertEqual(c, glob_poi) 42 | c[1, 1] = (-5, -10, -100) 43 | self.assertNotEqual(c, glob_poi) 44 | 45 | c = glob_poi.copy() 46 | self.assertEqual(c.info, glob_poi.info) 47 | c.info["Meaning_of_live"] = 42 48 | self.assertNotEqual(c.info, glob_poi.info) 49 | 50 | c = glob_poi.copy() 51 | self.assertEqual(c.format, glob_poi.format) 52 | c.format = 10 53 | self.assertNotEqual(c.format, glob_poi.format) 54 | 55 | glob_poi.info["Question_of_the_meaning_of_live"] = "Unknown" 56 | glob_poi.format = 1 57 | self.assertEqual(c, glob_poi) 58 | glob_poi[3, 4] = (1, 2, 3) 59 | self.assertNotEqual(c, glob_poi) 60 | 61 | c = glob_poi.copy() 62 | self.assertEqual(c, glob_poi) 63 | c[1, 1] = (-5, -10, -100) 64 | self.assertNotEqual(c, glob_poi) 65 | 66 | c = glob_poi.copy() 67 | self.assertEqual(c.info, glob_poi.info) 68 | c.info["Meaning_of_live"] = 42 69 | self.assertNotEqual(c.info, glob_poi.info) 70 | 71 | c = glob_poi.copy() 72 | self.assertEqual(c.format, glob_poi.format) 73 | c.format = 10 74 | self.assertNotEqual(c.format, glob_poi.format) 75 | 76 | self.assertEqual(len(glob_poi.copy(centroids=POI_Descriptor())), 0) 77 | 78 | def test_subreg_labes(self): 79 | from TPTBox.core.vert_constants import vert_subreg_labels 80 | 81 | self.assertEqual(len(vert_subreg_labels()), 10) 82 | -------------------------------------------------------------------------------- /unit_tests/test_reg_seg.py: -------------------------------------------------------------------------------- 1 | # Call 'python -m unittest' on this folder 2 | # coverage run -m unittest 3 | # coverage report 4 | # coverage html 5 | from __future__ import annotations 6 | 7 | import os 8 | import random 9 | import sys 10 | import unittest 11 | from pathlib import Path 12 | 13 | sys.path.append(str(Path(__file__).resolve().parents[2])) 14 | import nibabel as nib 15 | 16 | from TPTBox import to_nii 17 | 18 | test_data = [ 19 | "BIDS/test/test_data/sub-fxclass0001_seg-subreg_msk.nii.gz", 20 | "BIDS/test/test_data/sub-fxclass0001_seg-vert_msk.nii.gz", 21 | "BIDS/test/test_data/sub-fxclass0004_seg-subreg_msk.nii.gz", 22 | "BIDS/test/test_data/sub-fxclass0004_seg-vert_msk.nii.gz", 23 | ] 24 | out_name_sub = "BIDS/test/test_data/sub-fxclass0004_seg-subreg_reg-0001_msk.nii.gz" 25 | out_name_vert = "BIDS/test/test_data/sub-fxclass0004_seg-vert_reg-0001_msk.nii.gz" 26 | 27 | 28 | class Test_registration(unittest.TestCase): 29 | @unittest.skipIf(not Path(test_data[0]).exists(), "requires real data test data") 30 | def test_seg_registration(self): 31 | pass 32 | # TODO OUTDATED 33 | # t = ridged_segmentation_from_seg(*test_data, verbose=True, ids=list(range(40, 50)), exclusion=[19]) 34 | # slice = t.compute_crop(dist=20) 35 | # nii_out = t.transform_nii(moving_img_nii=(test_data[2], True), slice=slice) 36 | # nii_out.save(out_name_sub) 37 | # nii_out = t.transform_nii(moving_img_nii=(test_data[3], True), slice=slice) 38 | # nii_out.save(out_name_vert) 39 | 40 | 41 | if __name__ == "__main__": 42 | unittest.main() 43 | 44 | # @unittest.skipIf(condition, reason) 45 | # with self.subTest(i=i): 46 | -------------------------------------------------------------------------------- /unit_tests/test_slicing.py: -------------------------------------------------------------------------------- 1 | # Call 'python -m unittest' on this folder 2 | # coverage run -m unittest 3 | # coverage report 4 | # coverage html 5 | from __future__ import annotations 6 | 7 | import random 8 | import unittest 9 | from pathlib import Path 10 | 11 | import nibabel as nib 12 | import numpy as np 13 | 14 | from TPTBox.core.compat import zip_strict 15 | from TPTBox.core.nii_wrapper import AX_CODES, NII 16 | from TPTBox.stitching import stitching_raw 17 | from TPTBox.tests.test_utils import overlap 18 | 19 | # TODO saving did not work with the test and I do not understand why. 20 | 21 | 22 | def get_random_ax_code() -> AX_CODES: 23 | directions = [["R", "L"], ["S", "I"], ["A", "P"]] 24 | idx = [0, 1, 2] 25 | random.shuffle(idx) 26 | return tuple(directions[i][random.randint(0, 1)] for i in idx) # type: ignore 27 | 28 | 29 | def get_nii(x: tuple[int, int, int] | None = None, num_point=3, rotation=True): # type: ignore 30 | if x is None: 31 | x = (random.randint(20, 40), random.randint(20, 40), random.randint(20, 40)) 32 | a = np.zeros(x, dtype=np.uint16) 33 | points = [] 34 | out_points: dict[int, dict[int, tuple[float, float, float]]] = {} 35 | sizes = [] 36 | idx = 1 37 | while True: 38 | if num_point == len(points): 39 | break 40 | point = tuple(random.randint(1, a - 1) for a in x) 41 | size = tuple(random.randint(1, 1 + a) for a in [5, 5, 5]) 42 | if any(a - b < 0 for a, b in zip_strict(point, size)): 43 | continue 44 | if any(a + b > c - 1 for a, b, c in zip_strict(point, size, x)): 45 | continue 46 | skip = False 47 | for p2, s2 in zip_strict(points, sizes): 48 | if overlap(point, size, p2, s2): 49 | skip = True 50 | break 51 | if skip: 52 | continue 53 | a[ 54 | point[0] - size[0] : point[0] + size[0] + 1, 55 | point[1] - size[1] : point[1] + size[1] + 1, 56 | point[2] - size[2] : point[2] + size[2] + 1, 57 | ] = idx 58 | 59 | points.append(point) 60 | sizes.append(size) 61 | out_points[idx] = {50: tuple(float(a) for a in point)} 62 | 63 | idx += 1 64 | aff = np.eye(4) 65 | 66 | aff[0, 3] = random.randint(-100, 100) 67 | aff[1, 3] = random.randint(-100, 100) 68 | aff[2, 3] = random.randint(-100, 100) 69 | if rotation: 70 | m = 30 71 | from scipy.spatial.transform import Rotation 72 | 73 | r = Rotation.from_euler("xyz", (random.randint(-m, m), random.randint(-m, m), random.randint(-m, m)), degrees=True) 74 | aff[:3, :3] = r.as_matrix() 75 | n = NII(nib.Nifti1Image(a, aff), seg=True) 76 | n.reorient_(get_random_ax_code()) 77 | return n 78 | 79 | 80 | class TestNPInterOperability(unittest.TestCase): 81 | def test_get_int(self): 82 | # Define inputs for the function 83 | for _ in range(5): 84 | nii = get_nii() 85 | arr = nii.get_array() 86 | for _ in range(5): 87 | x, y, z = random.randint(0, nii.shape[0] - 1), random.randint(0, nii.shape[1] - 1), random.randint(0, nii.shape[2] - 1) 88 | assert nii[x, y, z] == arr[x, y, z], (nii[x, y, z], arr[x, y, z]) 89 | 90 | def test_assign_int(self): 91 | # Define inputs for the function 92 | for _ in range(5): 93 | nii = get_nii() 94 | for _ in range(5): 95 | v = random.randint(0, 255) 96 | x, y, z = random.randint(0, nii.shape[0] - 1), random.randint(0, nii.shape[1] - 1), random.randint(0, nii.shape[2] - 1) 97 | v_old = nii[x, y, z] 98 | nii[x, y, z] = v 99 | assert nii[x, y, z] == v, (nii[x, y, z], v) 100 | assert v == v_old or nii[x, y, z] != v_old 101 | 102 | def test_slice(self): 103 | for _ in range(5): 104 | nii = get_nii() 105 | sl_nii: NII = nii[:10, :10, :10] 106 | assert sl_nii.shape == (10, 10, 10) 107 | sl_nii = nii[::2, ::3, ::4] 108 | assert sl_nii.zoom == (2, 3, 4) 109 | sl_nii = nii[::-1, ::-1, ::-1] 110 | assert sl_nii.zoom == (1, 1, 1) 111 | for axis in ("R", "A", "S"): 112 | ax = sl_nii.get_axis(axis) 113 | assert ax == nii.get_axis(axis) 114 | assert sl_nii.orientation[ax] != nii.orientation[ax] 115 | 116 | def test_slice_assign(self): 117 | for _ in range(5): 118 | nii = get_nii() 119 | nii[:10, :10, :10] = 99 120 | assert nii[5, 5, 5] == 99 121 | 122 | def test_np(self): 123 | print(np.sum(get_nii())) 124 | print(np.sin(get_nii())) 125 | 126 | 127 | if __name__ == "__main__": 128 | unittest.main() 129 | -------------------------------------------------------------------------------- /unit_tests/test_stiching.py: -------------------------------------------------------------------------------- 1 | # Call 'python -m unittest' on this folder 2 | # coverage run -m unittest 3 | # coverage report 4 | # coverage html 5 | from __future__ import annotations 6 | 7 | import random 8 | import unittest 9 | from pathlib import Path 10 | 11 | import nibabel as nib 12 | import numpy as np 13 | 14 | from TPTBox.core.compat import zip_strict 15 | from TPTBox.core.nii_wrapper import NII 16 | from TPTBox.stitching import stitching_raw 17 | from TPTBox.tests.test_utils import overlap 18 | 19 | # TODO saving did not work with the test and I do not understand why. 20 | 21 | 22 | def get_nii(x: tuple[int, int, int] | None = None, num_point=3, rotation=True): # type: ignore 23 | if x is None: 24 | x = (random.randint(20, 40), random.randint(20, 40), random.randint(20, 40)) 25 | a = np.zeros(x, dtype=np.uint16) 26 | points = [] 27 | out_points: dict[int, dict[int, tuple[float, float, float]]] = {} 28 | sizes = [] 29 | idx = 1 30 | while True: 31 | if num_point == len(points): 32 | break 33 | point = tuple(random.randint(1, a - 1) for a in x) 34 | size = tuple(random.randint(1, 1 + a) for a in [5, 5, 5]) 35 | if any(a - b < 0 for a, b in zip_strict(point, size)): 36 | continue 37 | if any(a + b > c - 1 for a, b, c in zip_strict(point, size, x)): 38 | continue 39 | skip = False 40 | for p2, s2 in zip_strict(points, sizes): 41 | if overlap(point, size, p2, s2): 42 | skip = True 43 | break 44 | if skip: 45 | continue 46 | a[ 47 | point[0] - size[0] : point[0] + size[0] + 1, 48 | point[1] - size[1] : point[1] + size[1] + 1, 49 | point[2] - size[2] : point[2] + size[2] + 1, 50 | ] = idx 51 | 52 | points.append(point) 53 | sizes.append(size) 54 | out_points[idx] = {50: tuple(float(a) for a in point)} 55 | 56 | idx += 1 57 | aff = np.eye(4) 58 | 59 | aff[0, 3] = random.randint(-100, 100) 60 | aff[1, 3] = random.randint(-100, 100) 61 | aff[2, 3] = random.randint(-100, 100) 62 | if rotation: 63 | m = 30 64 | from scipy.spatial.transform import Rotation 65 | 66 | r = Rotation.from_euler("xyz", (random.randint(-m, m), random.randint(-m, m), random.randint(-m, m)), degrees=True) 67 | aff[:3, :3] = r.as_matrix() 68 | n = NII(nib.Nifti1Image(a, aff), seg=True) 69 | 70 | return n, out_points, n.orientation, sizes 71 | 72 | 73 | class TestStitchingFunction(unittest.TestCase): 74 | def test_stitching( 75 | self, 76 | idx="C66EMBZJmy75n4XHv2YsSXVs", 77 | match_histogram=True, 78 | store_ramp=False, 79 | verbose=False, 80 | min_value=0, 81 | bias_field=True, 82 | crop_to_bias_field=False, 83 | crop_empty=False, 84 | histogram=None, 85 | ramp_edge_min_value=0, 86 | min_spacing=None, 87 | kick_out_fully_integrated_images=False, 88 | is_segmentation=False, 89 | dtype=float, 90 | save=False, 91 | ): 92 | # Define inputs for the function 93 | images = [get_nii()[0].nii, get_nii()[0].nii] 94 | output = Path(f"~/output{idx}.nii.gz") 95 | if save: 96 | print(output.absolute()) 97 | output.unlink(missing_ok=True) 98 | 99 | # Call the function 100 | result, _ = stitching_raw( 101 | images, 102 | str(output), 103 | match_histogram, 104 | store_ramp, 105 | verbose, 106 | min_value, 107 | bias_field, 108 | crop_to_bias_field, 109 | crop_empty, 110 | histogram, 111 | ramp_edge_min_value, 112 | min_spacing, 113 | kick_out_fully_integrated_images, 114 | is_segmentation, 115 | dtype, 116 | save, 117 | ) 118 | if save: 119 | self.assertTrue(output.parent.exists(), output) 120 | self.assertTrue(output.exists(), output) 121 | output.unlink(missing_ok=True) 122 | # Assertions 123 | self.assertIsInstance(result, nib.Nifti1Image) # Check if result is a Nifti1Image instance 124 | # Add more assertions based on your requirements 125 | 126 | def test_stitching2(self): 127 | self.test_stitching( 128 | idx="X", 129 | match_histogram=False, 130 | store_ramp=False, 131 | verbose=True, 132 | min_value=-1024, 133 | bias_field=False, 134 | crop_to_bias_field=False, 135 | crop_empty=True, 136 | histogram=None, 137 | ramp_edge_min_value=20, 138 | min_spacing=2, 139 | kick_out_fully_integrated_images=True, 140 | is_segmentation=False, 141 | dtype=float, 142 | save=False, 143 | ) 144 | 145 | def test_stitching3(self): 146 | self.test_stitching( 147 | idx="X6Fqat2JLZbJKCom6BUX7F84", 148 | match_histogram=False, 149 | min_value=0, 150 | bias_field=False, 151 | ramp_edge_min_value=0, 152 | is_segmentation=True, 153 | save=False, 154 | store_ramp=True, 155 | ) 156 | 157 | 158 | if __name__ == "__main__": 159 | unittest.main() 160 | -------------------------------------------------------------------------------- /unit_tests/test_vertconstants.py: -------------------------------------------------------------------------------- 1 | # Call 'python -m unittest' on this folder 2 | # coverage run -m unittest 3 | # coverage report 4 | # coverage html 5 | from __future__ import annotations 6 | 7 | import sys 8 | import unittest 9 | from pathlib import Path 10 | 11 | sys.path.append(str(Path(__file__).resolve().parents[2])) 12 | import numpy as np 13 | 14 | from TPTBox import Vertebra_Instance 15 | 16 | structures = ["rib", "ivd", "endplate"] 17 | 18 | 19 | class Test_Locations(unittest.TestCase): 20 | def test_vertebra_instance_uniqueness(self): 21 | name2idx = Vertebra_Instance.name2idx() 22 | 23 | print(name2idx) 24 | label_values = list(name2idx.values()) 25 | label_valuesset, counts = np.unique(label_values, return_counts=True) 26 | 27 | for idx, c in enumerate(counts): 28 | if c > 1: 29 | l = label_valuesset[idx] 30 | 31 | structures_overlap = {i: g for i, g in name2idx.items() if g == l} 32 | 33 | self.assertTrue(False, f"Overlapping labels in {structures_overlap}") 34 | 35 | def test_vertebra_instance_completeness(self): 36 | name2idx = Vertebra_Instance.name2idx() 37 | idx2name = Vertebra_Instance.idx2name() 38 | 39 | for start in name2idx: 40 | i = name2idx[start] 41 | i2 = idx2name[i] 42 | self.assertTrue(i2 == start) 43 | 44 | def test_vertebra_instance_correctness(self): 45 | for i in [ 46 | Vertebra_Instance.C7, 47 | Vertebra_Instance.T1, 48 | Vertebra_Instance.T2, 49 | Vertebra_Instance.T3, 50 | Vertebra_Instance.T4, 51 | Vertebra_Instance.T5, 52 | Vertebra_Instance.T12, 53 | Vertebra_Instance.T13, 54 | ]: 55 | self.assertIsNotNone(i.RIB) 56 | 57 | for i in [ 58 | Vertebra_Instance.C6, 59 | Vertebra_Instance.C1, 60 | Vertebra_Instance.C3, 61 | Vertebra_Instance.L3, 62 | Vertebra_Instance.L6, 63 | Vertebra_Instance.COCC, 64 | Vertebra_Instance.S1, 65 | Vertebra_Instance.S3, 66 | Vertebra_Instance.S6, 67 | ]: 68 | with self.assertRaises(AssertionError): 69 | self.assertIsNone(i.RIB) 70 | --------------------------------------------------------------------------------