├── .github └── workflows │ └── test_and_deploy.yml ├── .gitignore ├── .napari-hub └── DESCRIPTION.md ├── .pre-commit-config.yaml ├── LICENSE ├── MANIFEST.in ├── README.md ├── instructions ├── 3d_viewing.gif ├── copy-paste_labels.gif ├── copy_labels.gif ├── label_options.gif ├── napari-ndlabelcorrection_filter_by_size.gif └── select_labels_by_mask.png ├── pyproject.toml ├── setup.cfg ├── src └── napari_segmentation_correction │ ├── __init__.py │ ├── _tests │ ├── __init__.py │ └── test_copy_label.py │ ├── _widget.py │ ├── connected_components.py │ ├── copy_label_widget.py │ ├── cross_widget.py │ ├── custom_table_widget.py │ ├── erosion_dilation_widget.py │ ├── icons │ ├── Back.png │ ├── Forward.png │ ├── Home.png │ ├── Pan.png │ ├── Pan_checked.png │ ├── Zoom.png │ ├── Zoom_checked.png │ ├── configure_subplots.png │ ├── edit_parameters.png │ └── save_figure.png │ ├── image_calculator.py │ ├── label_interpolator.py │ ├── layer_controls.py │ ├── layer_dropdown.py │ ├── layer_manager.py │ ├── napari.yaml │ ├── plot_widget.py │ ├── prop_filter_widget.py │ ├── regionprops_extended.py │ ├── regionprops_widget.py │ ├── save_labels_widget.py │ ├── select_delete_widget.py │ ├── smoothing_widget.py │ └── threshold_widget.py └── tox.ini /.github/workflows/test_and_deploy.yml: -------------------------------------------------------------------------------- 1 | # This workflows will upload a Python Package using Twine when a release is created 2 | # For more information see: https://help.github.com/en/actions/language-and-framework-guides/using-python-with-github-actions#publishing-to-package-registries 3 | 4 | name: tests 5 | 6 | on: 7 | push: 8 | branches: 9 | - main 10 | - npe2 11 | tags: 12 | - "v*" # Push events to matching v*, i.e. v1.0, v20.15.10 13 | pull_request: 14 | branches: 15 | - main 16 | - npe2 17 | workflow_dispatch: 18 | 19 | jobs: 20 | test: 21 | name: ${{ matrix.platform }} py${{ matrix.python-version }} 22 | runs-on: ${{ matrix.platform }} 23 | strategy: 24 | matrix: 25 | platform: [ubuntu-latest, windows-latest, macos-latest] 26 | python-version: ['3.10', '3.11'] 27 | 28 | steps: 29 | - uses: actions/checkout@v3 30 | 31 | - name: Set up Python ${{ matrix.python-version }} 32 | uses: actions/setup-python@v4 33 | with: 34 | python-version: ${{ matrix.python-version }} 35 | 36 | # these libraries enable testing on Qt on linux 37 | - uses: tlambert03/setup-qt-libs@v1 38 | 39 | # strategy borrowed from vispy for installing opengl libs on windows 40 | - name: Install Windows OpenGL 41 | if: runner.os == 'Windows' 42 | run: | 43 | git clone --depth 1 https://github.com/pyvista/gl-ci-helpers.git 44 | powershell gl-ci-helpers/appveyor/install_opengl.ps1 45 | 46 | # note: if you need dependencies from conda, considering using 47 | # setup-miniconda: https://github.com/conda-incubator/setup-miniconda 48 | # and 49 | # tox-conda: https://github.com/tox-dev/tox-conda 50 | - name: Install dependencies 51 | run: | 52 | python -m pip install --upgrade pip 53 | python -m pip install setuptools tox tox-gh-actions 54 | 55 | # this runs the platform-specific tests declared in tox.ini 56 | - name: Test with tox 57 | uses: aganders3/headless-gui@v1 58 | with: 59 | run: python -m tox 60 | env: 61 | PLATFORM: ${{ matrix.platform }} 62 | 63 | - name: Coverage 64 | uses: codecov/codecov-action@v3 65 | 66 | deploy: 67 | # this will run when you have tagged a commit, starting with "v*" 68 | # and requires that you have put your twine API key in your 69 | # github secrets (see readme for details) 70 | needs: [test] 71 | runs-on: ubuntu-latest 72 | if: contains(github.ref, 'tags') 73 | steps: 74 | - uses: actions/checkout@v3 75 | - name: Set up Python 76 | uses: actions/setup-python@v4 77 | with: 78 | python-version: "3.x" 79 | - name: Install dependencies 80 | run: | 81 | python -m pip install --upgrade pip 82 | pip install -U setuptools setuptools_scm wheel twine build 83 | - name: Build and publish 84 | env: 85 | TWINE_USERNAME: __token__ 86 | TWINE_PASSWORD: ${{ secrets.TWINE_API_KEY }} 87 | run: | 88 | git tag 89 | python -m build . 90 | twine upload dist/* 91 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | env/ 12 | build/ 13 | develop-eggs/ 14 | dist/ 15 | downloads/ 16 | eggs/ 17 | .eggs/ 18 | lib/ 19 | lib64/ 20 | parts/ 21 | sdist/ 22 | var/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | 27 | # PyInstaller 28 | # Usually these files are written by a python script from a template 29 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 30 | *.manifest 31 | *.spec 32 | 33 | # Installer logs 34 | pip-log.txt 35 | pip-delete-this-directory.txt 36 | 37 | # Unit test / coverage reports 38 | htmlcov/ 39 | .tox/ 40 | .coverage 41 | .coverage.* 42 | .cache 43 | nosetests.xml 44 | coverage.xml 45 | *,cover 46 | .hypothesis/ 47 | .napari_cache 48 | 49 | # Translations 50 | *.mo 51 | *.pot 52 | 53 | # Django stuff: 54 | *.log 55 | local_settings.py 56 | 57 | # Flask instance folder 58 | instance/ 59 | 60 | # Sphinx documentation 61 | docs/_build/ 62 | 63 | # MkDocs documentation 64 | /site/ 65 | 66 | # PyBuilder 67 | target/ 68 | 69 | # Pycharm and VSCode 70 | .idea/ 71 | venv/ 72 | .vscode/ 73 | 74 | # IPython Notebook 75 | .ipynb_checkpoints 76 | 77 | # pyenv 78 | .python-version 79 | 80 | # OS 81 | .DS_Store 82 | 83 | # written by setuptools_scm 84 | **/_version.py 85 | -------------------------------------------------------------------------------- /.napari-hub/DESCRIPTION.md: -------------------------------------------------------------------------------- 1 | 8 | 9 | The developer has not yet provided a napari-hub specific description. 10 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | - repo: https://github.com/pre-commit/pre-commit-hooks 3 | rev: v5.0.0 4 | hooks: 5 | - id: check-docstring-first 6 | - id: end-of-file-fixer 7 | - id: trailing-whitespace 8 | exclude: ^\.napari-hub/.* 9 | - id: check-yaml # checks for correct yaml syntax for github actions ex. 10 | - repo: https://github.com/astral-sh/ruff-pre-commit 11 | rev: v0.12.5 12 | hooks: 13 | - id: ruff 14 | args: [--fix] 15 | - id: ruff-format 16 | # - repo: https://github.com/pre-commit/mirrors-mypy 17 | # rev: v1.17.0 18 | # hooks: 19 | # - id: mypy 20 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | 2 | Copyright (c) 2023, Anniek Stokkermans 3 | All rights reserved. 4 | 5 | Redistribution and use in source and binary forms, with or without 6 | modification, are permitted provided that the following conditions are met: 7 | 8 | * Redistributions of source code must retain the above copyright notice, this 9 | list of conditions and the following disclaimer. 10 | 11 | * Redistributions in binary form must reproduce the above copyright notice, 12 | this list of conditions and the following disclaimer in the documentation 13 | and/or other materials provided with the distribution. 14 | 15 | * Neither the name of copyright holder nor the names of its 16 | contributors may be used to endorse or promote products derived from 17 | this software without specific prior written permission. 18 | 19 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 20 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 21 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 22 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 23 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 24 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 25 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 26 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 27 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 28 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 29 | -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | include LICENSE 2 | include README.md 3 | 4 | recursive-exclude * __pycache__ 5 | recursive-exclude * *.py[co] 6 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # napari-segmentation-correction 2 | 3 | [![License BSD-3](https://img.shields.io/pypi/l/napari-segmentation-correction.svg?color=green)](https://github.com/AnniekStok/napari-segmentation-correction/raw/main/LICENSE) 4 | [![PyPI](https://img.shields.io/pypi/v/napari-segmentation-correction.svg?color=green)](https://pypi.org/project/napari-segmentation-correction) 5 | [![Python Version](https://img.shields.io/pypi/pyversions/napari-segmentation-correction.svg?color=green)](https://python.org) 6 | [![tests](https://github.com/AnniekStok/napari-segmentation-correction/workflows/tests/badge.svg)](https://github.com/AnniekStok/napari-segmentation-correction/actions) 7 | [![codecov](https://codecov.io/gh/AnniekStok/napari-segmentation-correction/branch/main/graph/badge.svg)](https://codecov.io/gh/AnniekStok/napari-segmentation-correction) 8 | [![napari hub](https://img.shields.io/endpoint?url=https://api.napari-hub.org/shields/napari-segmentation-correction)](https://napari-hub.org/plugins/napari-segmentation-correction) 9 | 10 | Toolbox for viewing, analyzing and correcting (cell) segmentation in 2D, 3D or 4D (t, z, y, x) (virtual) arrays. 11 | ---------------------------------- 12 | 13 | This [napari] plugin was generated with [Cookiecutter] using [@napari]'s [cookiecutter-napari-plugin] template. 14 | 15 | 22 | 23 | ## Installation 24 | 25 | You can install `napari-segmentation-correction` via [pip]: 26 | 27 | To install latest development version : 28 | 29 | pip install git+https://github.com/AnniekStok/napari-segmentation-correction.git 30 | 31 | ## Usage 32 | This plugin serves as a toolbox aiming to help with correcting segmentation results. 33 | Functionalities: 34 | - Orthogonal views for 3D data. 35 | - Copy labels from a 2-5 dimensional array with multiple segmentation options to your current 2-5 dimensional label layer. 36 | - Label connected components, keep the largest connected cluster of labels, keep the largest fragment per label. 37 | - Smooth labels using a median filter. 38 | - Erode/dilate labels (scipy.ndimage and scikit-image) 39 | - Binarize an image or labels layer by applying an intensity threshold 40 | - Image calculator for mathematical operations on two images 41 | - Select/delete labels using a mask. 42 | - Binary mask interpolation in the z or time dimension. 43 | - Explore label properties (scikit-image regionprops) in a table widget and a Matplotlib plot. 44 | - Filter labels by properties. 45 | 46 | ### Copy labels between different labels layers 47 | ![copy_labels](https://github.com/user-attachments/assets/4f6a638d-c6bc-4a61-bdcd-6cc29b6f817e) 48 | 49 | 50 | 51 | 59 | 62 | 63 |
52 | 2D/3D/4D labels can be copied from a source layer to a target layer via SHIFT+CLICK.

53 | The data in the source layer should have the same shape as the target layer, but can optionally have one extra dimension (e.g. stack multiple segmentation solutions as channels).

54 | To copy labels, select a 'source' and a 'target' labels layer in the dropdown. By default, the source layer will be displayed as contours.

55 | Select whether to copy a slice, a volume, or a series across time.

56 | Checking Use source label value keeps the original label values.

57 | Selecting Preserve target labels only allows copying into background (0) regions. Otherwise, SHIFT+CLICK replaces the existing label region. 58 |
60 | copy_labels 61 |
64 | 65 | ### Connected component analysis 66 | There are shortcut buttons for connected components labeling, keeping the largest cluster of connected labels, and to keep the largest fragment per label. 67 | 68 | conncomp 69 | 70 | ### Select / delete labels that overlap with a binary mask 71 | All labels that share any pixel overlap with the mask are selected or removed. 72 | select_delete 73 | 74 | ### Binary mask interpolation 75 | It is possible to interpolate a 3D or 4D mask to fill in the region in between. In 3D, this means creating a 3D volume from slices, in 4D this means creating a time series of a volume that linearly 'morphs' into a different shape. 76 | 77 | 78 | 79 | 82 | 85 | 86 |
80 | interpolate 81 | 83 | labelinterpolation gif 84 |
87 | 88 | ### Measuring label properties 89 | You can measure label properties, including intensity (if a matching image layer is provided), area/volume, perimeter/surface area, circularity/sphericity, ellipse/ellipsoid axes in the 'Region Properties' tab. Use the '3D data' checkbox to distinguish between measuring in 2D + time, 3D, and 3D + time, depending on your layer dimensions (2D to 4D). Once finished, a table displays the measurements, and a filter widget allows you to select objects matching a condition. The measurements are also displayed in the 'Plot'-tab for each layer for which you ran the region properties calculation. 90 | 91 | ![propfilter](https://github.com/user-attachments/assets/ab9c6b61-4366-4ad1-b813-7465aa183988) 92 | 93 | 94 | ## Contributing 95 | 96 | Contributions are very welcome. Tests can be run with [tox], please ensure 97 | the coverage at least stays the same before you submit a pull request. 98 | 99 | ## License 100 | 101 | Distributed under the terms of the [BSD-3] license, 102 | "napari-segmentation-correction" is free and open source software 103 | 104 | ## Issues 105 | 106 | If you encounter any problems, please [file an issue] along with a detailed description. 107 | 108 | [napari]: https://github.com/napari/napari 109 | [Cookiecutter]: https://github.com/audreyr/cookiecutter 110 | [@napari]: https://github.com/napari 111 | [MIT]: http://opensource.org/licenses/MIT 112 | [BSD-3]: http://opensource.org/licenses/BSD-3-Clause 113 | [GNU GPL v3.0]: http://www.gnu.org/licenses/gpl-3.0.txt 114 | [GNU LGPL v3.0]: http://www.gnu.org/licenses/lgpl-3.0.txt 115 | [Apache Software License 2.0]: http://www.apache.org/licenses/LICENSE-2.0 116 | [Mozilla Public License 2.0]: https://www.mozilla.org/media/MPL/2.0/index.txt 117 | [cookiecutter-napari-plugin]: https://github.com/napari/cookiecutter-napari-plugin 118 | 119 | [file an issue]: https://github.com/AnniekStok/napari-segmentation-correction/issues 120 | 121 | [napari]: https://github.com/napari/napari 122 | [tox]: https://tox.readthedocs.io/en/latest/ 123 | [pip]: https://pypi.org/project/pip/ 124 | [PyPI]: https://pypi.org/ 125 | -------------------------------------------------------------------------------- /instructions/3d_viewing.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AnniekStok/napari-segmentation-correction/ece844476d64050a1f2ce54421aef53689088b45/instructions/3d_viewing.gif -------------------------------------------------------------------------------- /instructions/copy-paste_labels.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AnniekStok/napari-segmentation-correction/ece844476d64050a1f2ce54421aef53689088b45/instructions/copy-paste_labels.gif -------------------------------------------------------------------------------- /instructions/copy_labels.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AnniekStok/napari-segmentation-correction/ece844476d64050a1f2ce54421aef53689088b45/instructions/copy_labels.gif -------------------------------------------------------------------------------- /instructions/label_options.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AnniekStok/napari-segmentation-correction/ece844476d64050a1f2ce54421aef53689088b45/instructions/label_options.gif -------------------------------------------------------------------------------- /instructions/napari-ndlabelcorrection_filter_by_size.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AnniekStok/napari-segmentation-correction/ece844476d64050a1f2ce54421aef53689088b45/instructions/napari-ndlabelcorrection_filter_by_size.gif -------------------------------------------------------------------------------- /instructions/select_labels_by_mask.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AnniekStok/napari-segmentation-correction/ece844476d64050a1f2ce54421aef53689088b45/instructions/select_labels_by_mask.png -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["setuptools>=42.0.0", "wheel"] 3 | build-backend = "setuptools.build_meta" 4 | 5 | [tool.ruff] 6 | line-length = 88 7 | target-version = "py310" 8 | fix = true 9 | 10 | lint.select = [ 11 | "E", "F", "W", #flake8 12 | "UP", # pyupgrade 13 | "I", # isort 14 | "BLE", # flake8-blind-exception 15 | "B", # flake8-bugbear 16 | "A", # flake8-builtins 17 | "C4", # flake8-comprehensions 18 | "ISC", # flake8-implicit-str-concat 19 | "G", # flake8-logging-format 20 | "PIE", # flake8-pie 21 | "SIM", # flake8-simplify 22 | ] 23 | 24 | lint.ignore = [ 25 | "UP006", "UP007", # type annotation. As using magicgui require runtime type annotation then we disable this. 26 | "ISC001", # implicit string concatenation 27 | "E501", # line too long 28 | ] 29 | 30 | lint.per-file-ignores = { "scripts/*.py" = ["F"] } 31 | 32 | # https://docs.astral.sh/ruff/formatter/ 33 | [tool.ruff.format] 34 | 35 | [tool.mypy] 36 | ignore_missing_imports = true 37 | 38 | exclude = [ 39 | ".bzr", 40 | ".direnv", 41 | ".eggs", 42 | ".git", 43 | ".mypy_cache", 44 | ".pants.d", 45 | ".ruff_cache", 46 | ".svn", 47 | ".tox", 48 | ".venv", 49 | "__pypackages__", 50 | "_build", 51 | "buck-out", 52 | "build", 53 | "dist", 54 | "node_modules", 55 | "venv", 56 | "*vendored*", 57 | "*_vendor*", 58 | ] 59 | 60 | target-version = "py38" 61 | fix = true 62 | -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [metadata] 2 | name = napari-segmentation-correction 3 | version = attr: napari_segmentation_correction.__version__ 4 | description = A plugin for manually correcting cell segmentation in 3D (z, y, x) or 4D (t, z, y, x) (virtual) arrays. 5 | long_description = file: README.md 6 | long_description_content_type = text/markdown 7 | url = https://github.com/AnniekStok/napari-segmentation-correction 8 | author = Anniek Stokkermans 9 | author_email = anniek.stokkermans@gmail.com 10 | license = BSD-3-Clause 11 | license_files = LICENSE 12 | classifiers = 13 | Development Status :: 2 - Pre-Alpha 14 | Framework :: napari 15 | Intended Audience :: Developers 16 | License :: OSI Approved :: BSD License 17 | Operating System :: OS Independent 18 | Programming Language :: Python 19 | Programming Language :: Python :: 3 20 | Programming Language :: Python :: 3 :: Only 21 | Programming Language :: Python :: 3.8 22 | Programming Language :: Python :: 3.9 23 | Programming Language :: Python :: 3.10 24 | Topic :: Scientific/Engineering :: Image Processing 25 | project_urls = 26 | Bug Tracker = https://github.com/AnniekStok/napari-segmentation-correction/issues 27 | Documentation = https://github.com/AnniekStok/napari-segmentation-correction#README.md 28 | Source Code = https://github.com/AnniekStok/napari-segmentation-correction 29 | User Support = https://github.com/AnniekStok/napari-segmentation-correction/issues 30 | 31 | [options] 32 | packages = find: 33 | install_requires = 34 | napari >= 0.6.0 35 | numpy 36 | scikit-image 37 | dask_image 38 | dask 39 | matplotlib 40 | imagecodecs 41 | napari-plane-sliders @ git+https://github.com/AnniekStok/napari-plane-sliders.git@main 42 | napari-orthogonal-views @ git+https://github.com/AnniekStok/napari-orthogonal-views.git@v0.0.4 43 | 44 | python_requires = >=3.8 45 | include_package_data = True 46 | package_dir = 47 | =src 48 | 49 | # add your package requirements here 50 | 51 | [options.packages.find] 52 | where = src 53 | 54 | [options.entry_points] 55 | napari.manifest = 56 | napari-segmentation-correction = napari_segmentation_correction:napari.yaml 57 | 58 | [options.extras_require] 59 | testing = 60 | tox 61 | pytest # https://docs.pytest.org/en/latest/contents.html 62 | pytest-cov # https://pytest-cov.readthedocs.io/en/latest/ 63 | pytest-qt 64 | 65 | [options.package_data] 66 | * = *.yaml 67 | -------------------------------------------------------------------------------- /src/napari_segmentation_correction/__init__.py: -------------------------------------------------------------------------------- 1 | __version__ = "0.0.1" 2 | from ._widget import AnnotateLabelsND 3 | 4 | __all__ = ("AnnotateLabelsND",) 5 | -------------------------------------------------------------------------------- /src/napari_segmentation_correction/_tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AnniekStok/napari-segmentation-correction/ece844476d64050a1f2ce54421aef53689088b45/src/napari_segmentation_correction/_tests/__init__.py -------------------------------------------------------------------------------- /src/napari_segmentation_correction/_tests/test_copy_label.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pytest 3 | from napari.layers import Labels 4 | 5 | from napari_segmentation_correction.copy_label_widget import CopyLabelWidget 6 | 7 | 8 | @pytest.fixture 9 | def make_event(): 10 | def _make_event(position, modifiers=("Shift",), dims_displayed=(0, 1)): 11 | class DummyEvent: 12 | def __init__(self): 13 | self.type = "mouse_press" 14 | self.position = position 15 | self.modifiers = modifiers 16 | self.dims_displayed = dims_displayed 17 | self.view_direction = None 18 | 19 | # mimic napari Event 20 | 21 | return DummyEvent() 22 | 23 | return _make_event 24 | 25 | 26 | @pytest.fixture 27 | def make_layers(): 28 | def _make_layers(shape): 29 | src = Labels(np.zeros(shape, dtype=np.uint16)) 30 | tgt = Labels(np.zeros(shape, dtype=np.uint16)) 31 | return src, tgt 32 | 33 | return _make_layers 34 | 35 | 36 | def test_copy_2d(make_event, make_napari_viewer): 37 | viewer = make_napari_viewer() 38 | src = viewer.add_labels(np.zeros((64, 64), dtype=np.uint16)) 39 | tgt = viewer.add_labels(np.zeros((64, 64), dtype=np.uint16)) 40 | src.data[30:40, 30:40] = 1 41 | 42 | widget = CopyLabelWidget(viewer) 43 | widget.source_layer = src 44 | widget.target_layer = tgt 45 | widget.dims_widget.slice.setChecked(True) 46 | 47 | event = make_event([32, 32]) 48 | widget.copy_label(event) 49 | 50 | expected = np.zeros((64, 64), dtype=np.uint8) 51 | expected[30:40, 30:40] = 1 52 | np.testing.assert_array_equal(tgt.data, expected) 53 | 54 | 55 | def test_copy_3d(make_event, make_napari_viewer): 56 | viewer = make_napari_viewer() 57 | src = viewer.add_labels(np.zeros((10, 64, 64), dtype=np.uint16)) 58 | tgt = viewer.add_labels(np.zeros((10, 64, 64), dtype=np.uint16)) 59 | src.data[5:8, 20:30, 20:30] = 10 60 | 61 | widget = CopyLabelWidget(viewer) 62 | widget.source_layer = src 63 | widget.target_layer = tgt 64 | widget.dims_widget.volume.setChecked(True) 65 | 66 | viewer.dims.current_step = (5, 0, 0) 67 | 68 | event = make_event([5, 25, 25], dims_displayed=[1, 2]) 69 | widget.copy_label(event) 70 | 71 | # check copying volume 72 | expected = np.zeros((10, 64, 64), dtype=np.uint8) 73 | expected[5:8, 20:30, 20:30] = 1 74 | np.testing.assert_array_equal(tgt.data, expected) 75 | 76 | # check preserving label value 77 | widget.preserve_label_value.setChecked(True) 78 | widget.copy_label(event) 79 | expected[5:8, 20:30, 20:30] = 10 80 | np.testing.assert_array_equal(tgt.data, expected) 81 | 82 | # check replacing a slice only 83 | widget.dims_widget.slice.setChecked(True) 84 | widget.preserve_label_value.setChecked(False) 85 | widget.copy_label(event) 86 | expected[5, 20:30, 20:30] = 11 87 | np.testing.assert_array_equal(tgt.data, expected) 88 | 89 | 90 | def test_copy_4d(make_event, make_napari_viewer): 91 | viewer = make_napari_viewer() 92 | src = viewer.add_labels(np.zeros((3, 10, 64, 64), dtype=np.uint16)) 93 | tgt = viewer.add_labels(np.zeros((3, 10, 64, 64), dtype=np.uint16)) 94 | src.data[1:2, 5:8, 20:30, 20:30] = 10 95 | 96 | widget = CopyLabelWidget(viewer) 97 | widget.source_layer = src 98 | widget.target_layer = tgt 99 | widget.dims_widget.series.setChecked(True) 100 | 101 | viewer.dims.current_step = (1, 5, 0, 0) 102 | 103 | event = make_event([1, 5, 25, 25], dims_displayed=[2, 3]) 104 | widget.copy_label(event) 105 | expected = np.zeros((3, 10, 64, 64), dtype=np.uint8) 106 | expected[1:2, 5:8, 20:30, 20:30] = 1 107 | np.testing.assert_array_equal(tgt.data, expected) 108 | 109 | widget.preserve_label_value.setChecked(True) 110 | widget.copy_label(event) 111 | expected[1:2, 5:8, 20:30, 20:30] = 10 112 | np.testing.assert_array_equal(tgt.data, expected) 113 | 114 | 115 | def test_copy_4d_slice_to_2d(make_event, make_napari_viewer): 116 | viewer = make_napari_viewer() 117 | src = viewer.add_labels(np.zeros((3, 10, 64, 64), dtype=np.uint16)) 118 | tgt = Labels(np.zeros((64, 64), dtype=np.uint16)) 119 | src.data[1:2, 5:8, 20:30, 20:30] = 10 120 | 121 | widget = CopyLabelWidget(viewer) 122 | widget.source_layer = src 123 | widget.target_layer = tgt 124 | widget.dims_widget.slice.setChecked(True) 125 | 126 | viewer.dims.current_step = (1, 5, 0, 0) 127 | 128 | event = make_event([1, 5, 25, 25], dims_displayed=[2, 3]) 129 | widget.copy_label(event) 130 | expected = np.zeros((64, 64), dtype=np.uint8) 131 | expected[20:30, 20:30] = 1 132 | np.testing.assert_array_equal(tgt.data, expected) 133 | 134 | widget.preserve_label_value.setChecked(True) 135 | widget.copy_label(event) 136 | expected[20:30, 20:30] = 10 137 | np.testing.assert_array_equal(tgt.data, expected) 138 | 139 | 140 | def test_copy_2d_slice_to_4d(make_event, make_napari_viewer): 141 | viewer = make_napari_viewer() 142 | src = Labels(np.zeros((64, 64), dtype=np.uint16)) 143 | tgt = viewer.add_labels(np.zeros((3, 10, 64, 64), dtype=np.uint16)) 144 | src.data[20:30, 20:30] = 10 145 | 146 | widget = CopyLabelWidget(viewer) 147 | widget.source_layer = src 148 | widget.target_layer = tgt 149 | widget.dims_widget.slice.setChecked(True) 150 | 151 | viewer.dims.current_step = (2, 5, 0, 0) 152 | 153 | event = make_event([2, 5, 25, 25], dims_displayed=[2, 3]) 154 | widget.copy_label(event) 155 | expected = np.zeros((3, 10, 64, 64), dtype=np.uint8) 156 | expected[2, 5, 20:30, 20:30] = 1 157 | np.testing.assert_array_equal(tgt.data, expected) 158 | 159 | widget.preserve_label_value.setChecked(True) 160 | widget.copy_label(event) 161 | expected[2, 5, 20:30, 20:30] = 10 162 | np.testing.assert_array_equal(tgt.data, expected) 163 | -------------------------------------------------------------------------------- /src/napari_segmentation_correction/_widget.py: -------------------------------------------------------------------------------- 1 | """ 2 | Napari plugin widget for editing N-dimensional label data 3 | """ 4 | 5 | import napari 6 | from napari.layers import Labels 7 | from napari_orthogonal_views.ortho_view_manager import _get_manager 8 | from qtpy.QtWidgets import ( 9 | QScrollArea, 10 | QTabWidget, 11 | QVBoxLayout, 12 | QWidget, 13 | ) 14 | 15 | from .connected_components import ConnectedComponents 16 | from .erosion_dilation_widget import ErosionDilationWidget 17 | from .image_calculator import ImageCalculator 18 | from .label_interpolator import InterpolationWidget 19 | from .layer_controls import LayerControlsWidget 20 | from .layer_manager import LayerManager 21 | from .plot_widget import PlotWidget 22 | from .regionprops_widget import RegionPropsWidget 23 | from .select_delete_widget import SelectDeleteMask 24 | from .smoothing_widget import SmoothingWidget 25 | from .threshold_widget import ThresholdWidget 26 | 27 | 28 | class AnnotateLabelsND(QWidget): 29 | """Widget for manual correction of label data, for example to prepare ground truth data for training a segmentation model""" 30 | 31 | def __init__(self, viewer: "napari.viewer.Viewer") -> None: 32 | super().__init__() 33 | self.viewer = viewer 34 | self.source_labels = None 35 | self.target_labels = None 36 | self.points = None 37 | self.copy_points = None 38 | self.edit_layout = QVBoxLayout() 39 | self.tab_widget = QTabWidget(self) 40 | self.option_labels = None 41 | 42 | ### Add label manager 43 | self.label_manager = LayerManager(self.viewer) 44 | 45 | ### Add layer controls widget 46 | self.layer_controls = LayerControlsWidget(self.viewer, self.label_manager) 47 | 48 | ### activate orthogonal views and register custom function 49 | def label_options_click_hook(orig_layer, copied_layer): 50 | copied_layer.mouse_drag_callbacks.append( 51 | lambda layer, event: self.layer_controls.copy_label_widget.sync_click( 52 | orig_layer, layer, event 53 | ) 54 | ) 55 | 56 | orth_view_manager = _get_manager(self.viewer) 57 | orth_view_manager.register_layer_hook(Labels, label_options_click_hook) 58 | 59 | ### Add widget for connected component labeling 60 | conn_comp_widget = ConnectedComponents(self.viewer, self.label_manager) 61 | self.edit_layout.addWidget(conn_comp_widget) 62 | 63 | ### Add widget for smoothing labels 64 | smooth_widget = SmoothingWidget(self.viewer, self.label_manager) 65 | self.edit_layout.addWidget(smooth_widget) 66 | 67 | ### Add widget for eroding/dilating labels 68 | erode_dilate_widget = ErosionDilationWidget(self.viewer, self.label_manager) 69 | self.edit_layout.addWidget(erode_dilate_widget) 70 | 71 | ### Threshold image 72 | threshold_widget = ThresholdWidget(self.viewer) 73 | self.edit_layout.addWidget(threshold_widget) 74 | 75 | # Add image calculator 76 | image_calc = ImageCalculator(self.viewer) 77 | self.edit_layout.addWidget(image_calc) 78 | 79 | # Add widget for selecting/deleting by mask 80 | select_del = SelectDeleteMask(self.viewer) 81 | self.edit_layout.addWidget(select_del) 82 | 83 | # Add widget for interpolating masks 84 | interpolation_widget = InterpolationWidget(self.viewer, self.label_manager) 85 | self.edit_layout.addWidget(interpolation_widget) 86 | 87 | ### Add layer controls widget to tab 88 | controls_scroll_area = QScrollArea() 89 | controls_scroll_area.setWidget(self.layer_controls) 90 | controls_scroll_area.setWidgetResizable(True) 91 | self.tab_widget.addTab(controls_scroll_area, "Layer Controls") 92 | self.tab_widget.setCurrentIndex(1) 93 | 94 | ### add combined editing widgets widgets 95 | self.edit_widgets = QWidget() 96 | self.edit_widgets.setLayout(self.edit_layout) 97 | scroll_area = QScrollArea() 98 | scroll_area.setWidget(self.edit_widgets) 99 | scroll_area.setWidgetResizable(True) 100 | self.tab_widget.addTab(scroll_area, "Editing") 101 | 102 | ### Add widget for reginproperties 103 | self.regionprops_widget = RegionPropsWidget(self.viewer, self.label_manager) 104 | props_scroll_area = QScrollArea() 105 | props_scroll_area.setWidget(self.regionprops_widget) 106 | props_scroll_area.setWidgetResizable(True) 107 | self.tab_widget.addTab(props_scroll_area, "Region properties") 108 | 109 | ### Add widget for displaying plot with regionprops 110 | self.plot_widget = PlotWidget(self.label_manager) 111 | self.tab_widget.addTab(self.plot_widget, "Plot") 112 | 113 | # Add the tab widget to the main layout 114 | self.main_layout = QVBoxLayout() 115 | self.main_layout.addWidget(self.tab_widget) 116 | self.setLayout(self.main_layout) 117 | -------------------------------------------------------------------------------- /src/napari_segmentation_correction/connected_components.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | 4 | import dask.array as da 5 | import napari 6 | import numpy as np 7 | import tifffile 8 | from qtpy.QtWidgets import ( 9 | QFileDialog, 10 | QGroupBox, 11 | QPushButton, 12 | QVBoxLayout, 13 | QWidget, 14 | ) 15 | from scipy import ndimage 16 | from skimage.io import imread 17 | from skimage.measure import label 18 | 19 | from .layer_manager import LayerManager 20 | 21 | 22 | def keep_largest_fragment_per_label(img: np.ndarray, labels: list[int]) -> np.ndarray: 23 | """Keep only the largest connected component per label in `labels`.""" 24 | out = np.zeros_like(img) 25 | for label_value in labels: 26 | if label_value == 0: 27 | continue 28 | mask = img == label_value 29 | if not np.any(mask): 30 | continue 31 | 32 | labeled, n = ndimage.label(mask) 33 | if n == 0: 34 | continue 35 | if n == 1: 36 | out[mask] = label_value 37 | continue 38 | 39 | sizes = np.bincount(labeled.ravel())[1:] # skip background 40 | largest_cc = 1 + np.argmax(sizes) # component index (1-based) 41 | out[labeled == largest_cc] = label_value 42 | 43 | return out 44 | 45 | 46 | class ConnectedComponents(QWidget): 47 | """Widget to run connected component analysis""" 48 | 49 | def __init__( 50 | self, viewer: "napari.viewer.Viewer", label_manager: LayerManager 51 | ) -> None: 52 | super().__init__() 53 | 54 | self.viewer = viewer 55 | self.label_manager = label_manager 56 | self.outputdir = None 57 | 58 | conn_comp_box = QGroupBox("Connected Component Analysis") 59 | conn_comp_box_layout = QVBoxLayout() 60 | 61 | self.conncomp_btn = QPushButton("Find connected components") 62 | self.conncomp_btn.setToolTip( 63 | "Run connected component analysis to (re)label the labels layer" 64 | ) 65 | self.conncomp_btn.clicked.connect(self._conn_comp) 66 | self.conncomp_btn.setEnabled( 67 | isinstance(self.label_manager.selected_layer, napari.layers.Labels) 68 | ) 69 | self.label_manager.layer_update.connect(self._update_button_state) 70 | 71 | conn_comp_box_layout.addWidget(self.conncomp_btn) 72 | 73 | self.keep_largest_btn = QPushButton("Keep largest component cluster") 74 | self.keep_largest_btn.setToolTip( 75 | "Keep only the labels part of the largest non-zero connected component" 76 | ) 77 | self.keep_largest_btn.clicked.connect(self._keep_largest_cluster) 78 | self.keep_largest_btn.setEnabled( 79 | isinstance(self.label_manager.selected_layer, napari.layers.Labels) 80 | ) 81 | conn_comp_box_layout.addWidget(self.keep_largest_btn) 82 | 83 | self.keep_largest_fragment_btn = QPushButton("Keep largest fragment per label") 84 | self.keep_largest_fragment_btn.setToolTip( 85 | "For each label, keep only the largest connected fragment" 86 | ) 87 | self.keep_largest_fragment_btn.clicked.connect(self._keep_largest_fragment) 88 | self.keep_largest_fragment_btn.setEnabled( 89 | isinstance(self.label_manager.selected_layer, napari.layers.Labels) 90 | ) 91 | conn_comp_box_layout.addWidget(self.keep_largest_fragment_btn) 92 | 93 | conn_comp_box.setLayout(conn_comp_box_layout) 94 | main_layout = QVBoxLayout() 95 | main_layout.addWidget(conn_comp_box) 96 | self.setLayout(main_layout) 97 | 98 | def _update_button_state(self): 99 | self.conncomp_btn.setEnabled( 100 | isinstance(self.label_manager.selected_layer, napari.layers.Labels) 101 | ) 102 | self.keep_largest_btn.setEnabled( 103 | isinstance(self.label_manager.selected_layer, napari.layers.Labels) 104 | ) 105 | self.keep_largest_fragment_btn.setEnabled( 106 | isinstance(self.label_manager.selected_layer, napari.layers.Labels) 107 | ) 108 | 109 | def _keep_largest_cluster(self): 110 | """Keep only the labels part of the largest non-zero connected component""" 111 | 112 | if isinstance(self.label_manager.selected_layer.data, da.core.Array): 113 | if self.outputdir is None: 114 | self.outputdir = QFileDialog.getExistingDirectory( 115 | self, "Select Output Folder" 116 | ) 117 | 118 | outputdir = os.path.join( 119 | self.outputdir, 120 | (self.label_manager.selected_layer.name + "_largest_cluster"), 121 | ) 122 | if os.path.exists(outputdir): 123 | shutil.rmtree(outputdir) 124 | os.mkdir(outputdir) 125 | 126 | for i in range( 127 | self.label_manager.selected_layer.data.shape[0] 128 | ): # Loop over the first dimension 129 | current_stack = self.label_manager.selected_layer.data[ 130 | i 131 | ].compute() # Compute the current stack 132 | mask = current_stack > 0 133 | labeled = label(mask) 134 | props = np.bincount(labeled.flat) 135 | props[0] = 0 # ignore background 136 | largest_label = props.argmax() 137 | largest_cluster = (labeled == largest_label) * current_stack 138 | tifffile.imwrite( 139 | os.path.join( 140 | outputdir, 141 | ( 142 | self.label_manager.selected_layer.name 143 | + "_largest_cluster_TP" 144 | + str(i).zfill(4) 145 | + ".tif" 146 | ), 147 | ), 148 | np.array(largest_cluster, dtype="uint16"), 149 | ) 150 | 151 | file_list = [ 152 | os.path.join(outputdir, fname) 153 | for fname in os.listdir(outputdir) 154 | if fname.endswith(".tif") 155 | ] 156 | self.label_manager.selected_layer = self.viewer.add_labels( 157 | da.stack([imread(fname) for fname in sorted(file_list)]), 158 | name=self.label_manager.selected_layer.name + "_largest_cluster", 159 | scale=self.label_manager.selected_layer.scale, 160 | ) 161 | else: 162 | shape = self.label_manager.selected_layer.data.shape 163 | if len(shape) > 3: 164 | largest_cluster = np.zeros_like(self.label_manager.selected_layer.data) 165 | for i in range(shape[0]): 166 | mask = self.label_manager.selected_layer.data[i] > 0 167 | labeled = label(mask) 168 | props = np.bincount(labeled.flat) 169 | props[0] = 0 # ignore background 170 | largest_label = props.argmax() 171 | largest_cluster[i] = ( 172 | labeled == largest_label 173 | ) * self.label_manager.selected_layer.data[i] 174 | 175 | else: 176 | mask = self.label_manager.selected_layer.data > 0 177 | labeled = label(mask) 178 | props = np.bincount(labeled.flat) 179 | props[0] = 0 # ignore background 180 | largest_label = props.argmax() 181 | largest_cluster = ( 182 | labeled == largest_label 183 | ) * self.label_manager.selected_layer.data 184 | 185 | self.label_manager.selected_layer = self.viewer.add_labels( 186 | largest_cluster, 187 | name=self.label_manager.selected_layer.name + "_largest_cluster", 188 | scale=self.label_manager.selected_layer.scale, 189 | ) 190 | 191 | def _keep_largest_fragment(self): 192 | """Keep only the largest fragment per label""" 193 | 194 | if isinstance(self.label_manager.selected_layer.data, da.core.Array): 195 | if self.outputdir is None: 196 | self.outputdir = QFileDialog.getExistingDirectory( 197 | self, "Select Output Folder" 198 | ) 199 | 200 | outputdir = os.path.join( 201 | self.outputdir, 202 | (self.label_manager.selected_layer.name + "_largest_fragment"), 203 | ) 204 | if os.path.exists(outputdir): 205 | shutil.rmtree(outputdir) 206 | os.mkdir(outputdir) 207 | 208 | for i in range( 209 | self.label_manager.selected_layer.data.shape[0] 210 | ): # Loop over the first dimension 211 | current_stack = self.label_manager.selected_layer.data[ 212 | i 213 | ].compute() # Compute the current stack 214 | 215 | labels = np.unique(current_stack) 216 | largest_fragments = keep_largest_fragment_per_label( 217 | current_stack, labels 218 | ) 219 | 220 | tifffile.imwrite( 221 | os.path.join( 222 | outputdir, 223 | ( 224 | self.label_manager.selected_layer.name 225 | + "_largest_fragments_TP" 226 | + str(i).zfill(4) 227 | + ".tif" 228 | ), 229 | ), 230 | largest_fragments, 231 | ) 232 | 233 | file_list = [ 234 | os.path.join(outputdir, fname) 235 | for fname in os.listdir(outputdir) 236 | if fname.endswith(".tif") 237 | ] 238 | self.label_manager.selected_layer = self.viewer.add_labels( 239 | da.stack([imread(fname) for fname in sorted(file_list)]), 240 | name=self.label_manager.selected_layer.name + "_largest_fragments", 241 | scale=self.label_manager.selected_layer.scale, 242 | ) 243 | 244 | else: 245 | shape = self.label_manager.selected_layer.data.shape 246 | if len(shape) > 3: 247 | largest_fragments = np.zeros_like( 248 | self.label_manager.selected_layer.data 249 | ) 250 | for i in range(shape[0]): 251 | labels = np.unique(self.label_manager.selected_layer.data[i]) 252 | largest_fragments[i] = keep_largest_fragment_per_label( 253 | self.label_manager.selected_layer.data[i], labels 254 | ) 255 | else: 256 | labels = np.unique(self.label_manager.selected_layer.data) 257 | largest_fragments = keep_largest_fragment_per_label( 258 | self.label_manager.selected_layer.data, labels 259 | ) 260 | 261 | self.label_manager.selected_layer = self.viewer.add_labels( 262 | largest_fragments, 263 | name=self.label_manager.selected_layer.name + "_largest_fragments", 264 | scale=self.label_manager.selected_layer.scale, 265 | ) 266 | 267 | def _conn_comp(self): 268 | """Run connected component analysis to (re) label the labels array""" 269 | 270 | if isinstance(self.label_manager.selected_layer.data, da.core.Array): 271 | if self.outputdir is None: 272 | self.outputdir = QFileDialog.getExistingDirectory( 273 | self, "Select Output Folder" 274 | ) 275 | 276 | outputdir = os.path.join( 277 | self.outputdir, 278 | (self.label_manager.selected_layer.name + "_conncomp"), 279 | ) 280 | if os.path.exists(outputdir): 281 | shutil.rmtree(outputdir) 282 | os.mkdir(outputdir) 283 | 284 | for i in range( 285 | self.label_manager.selected_layer.data.shape[0] 286 | ): # Loop over the first dimension 287 | current_stack = self.label_manager.selected_layer.data[ 288 | i 289 | ].compute() # Compute the current stack 290 | relabeled = label(current_stack) 291 | tifffile.imwrite( 292 | os.path.join( 293 | outputdir, 294 | ( 295 | self.label_manager.selected_layer.name 296 | + "_conn_comp_TP" 297 | + str(i).zfill(4) 298 | + ".tif" 299 | ), 300 | ), 301 | np.array(relabeled, dtype="uint16"), 302 | ) 303 | 304 | file_list = [ 305 | os.path.join(outputdir, fname) 306 | for fname in os.listdir(outputdir) 307 | if fname.endswith(".tif") 308 | ] 309 | self.label_manager.selected_layer = self.viewer.add_labels( 310 | da.stack([imread(fname) for fname in sorted(file_list)]), 311 | name=self.label_manager.selected_layer.name + "_conn_comp", 312 | scale=self.label_manager.selected_layer.scale, 313 | ) 314 | else: 315 | shape = self.label_manager.selected_layer.data.shape 316 | if len(shape) > 3: 317 | conn_comp = np.zeros_like(self.label_manager.selected_layer.data) 318 | for i in range(shape[0]): 319 | conn_comp[i] = label(self.label_manager.selected_layer.data[i]) 320 | 321 | else: 322 | conn_comp = label(self.label_manager.selected_layer.data) 323 | 324 | self.label_manager.selected_layer = self.viewer.add_labels( 325 | conn_comp, 326 | name=self.label_manager.selected_layer.name + "_conn_comp", 327 | scale=self.label_manager.selected_layer.scale, 328 | ) 329 | -------------------------------------------------------------------------------- /src/napari_segmentation_correction/copy_label_widget.py: -------------------------------------------------------------------------------- 1 | import dask.array as da 2 | import napari 3 | import numpy as np 4 | from napari.layers import Labels 5 | from napari.utils.events import Event 6 | from qtpy.QtWidgets import ( 7 | QButtonGroup, 8 | QCheckBox, 9 | QGroupBox, 10 | QHBoxLayout, 11 | QLabel, 12 | QMessageBox, 13 | QPushButton, 14 | QRadioButton, 15 | QVBoxLayout, 16 | QWidget, 17 | ) 18 | 19 | from .layer_dropdown import LayerDropdown 20 | 21 | 22 | def check_value_dtype(value, dtype): 23 | # Get min and max for the dtype 24 | info = np.iinfo(dtype) 25 | within_range = info.min <= value <= info.max 26 | 27 | # If not in range, find the next suitable unsigned dtype 28 | next_dtype = None 29 | if not within_range: 30 | unsigned_dtypes = [np.uint8, np.uint16, np.uint32, np.uint64] 31 | for dt in unsigned_dtypes: 32 | if np.iinfo(dt).max >= value: 33 | next_dtype = dt 34 | break 35 | 36 | return within_range, next_dtype 37 | 38 | 39 | class DimsRadioButtons(QWidget): 40 | def __init__(self) -> None: 41 | super().__init__() 42 | 43 | label = QLabel("Copy dimensions:") 44 | 45 | button_group = QButtonGroup() 46 | self.slice = QRadioButton("Slice (last 2 dims)") 47 | self.slice.setEnabled(False) 48 | self.slice.setChecked(False) 49 | self.volume = QRadioButton("Volume (last 3 dims)") 50 | self.volume.setChecked(False) 51 | self.volume.setEnabled(False) 52 | self.series = QRadioButton("Series (last 4 dims)") 53 | self.series.setChecked(False) 54 | self.series.setEnabled(False) 55 | 56 | button_group.addButton(self.slice) 57 | button_group.addButton(self.volume) 58 | button_group.addButton(self.series) 59 | 60 | button_layout = QHBoxLayout() 61 | button_layout.addWidget(self.slice) 62 | button_layout.addWidget(self.volume) 63 | button_layout.addWidget(self.series) 64 | 65 | layout = QVBoxLayout() 66 | layout.addWidget(label) 67 | layout.addLayout(button_layout) 68 | self.setLayout(layout) 69 | 70 | 71 | class CopyLabelWidget(QWidget): 72 | """Widget to copy labels from a source layer to a target layer.""" 73 | 74 | def __init__(self, viewer: "napari.viewer.Viewer") -> None: 75 | super().__init__() 76 | 77 | self.viewer = viewer 78 | 79 | self.source_layer = None 80 | self.target_layer = None 81 | self._source_callback = None 82 | 83 | copy_labels_box = QGroupBox("Copy-paste labels") 84 | copy_labels_layout = QVBoxLayout() 85 | 86 | # instruction label 87 | label = QLabel( 88 | "Use shift + click on the source layer to copy labels to the target layer." 89 | ) 90 | label.setWordWrap(True) 91 | font = label.font() 92 | font.setItalic(True) 93 | label.setFont(font) 94 | copy_labels_layout.addWidget(label) 95 | 96 | # Whether or not to preserve the source layer label value when copying or to use 97 | # the next available label in the target layer 98 | self.preserve_label_value = QCheckBox("Use source label value") 99 | self.preserve_existing_labels = QCheckBox("Preserve target labels") 100 | option_layout = QHBoxLayout() 101 | option_layout.addWidget(self.preserve_label_value) 102 | option_layout.addWidget(self.preserve_existing_labels) 103 | copy_labels_layout.addLayout(option_layout) 104 | 105 | # Source layer and target layer dropdowns 106 | image_layout = QVBoxLayout() 107 | source_layout = QHBoxLayout() 108 | source_layout.addWidget(QLabel("Source labels")) 109 | self.source_dropdown = LayerDropdown(self.viewer, (Labels), allow_none=True) 110 | self.source_dropdown.viewer.layers.selection.events.changed.disconnect( 111 | self.source_dropdown._on_selection_changed 112 | ) 113 | self.source_dropdown.layer_changed.connect(self._update_source) 114 | source_layout.addWidget(self.source_dropdown) 115 | 116 | target_layout = QHBoxLayout() 117 | target_layout.addWidget(QLabel("Target labels")) 118 | self.target_dropdown = LayerDropdown(self.viewer, (Labels), allow_none=True) 119 | self.target_dropdown.viewer.layers.selection.events.changed.disconnect( 120 | self.target_dropdown._on_selection_changed 121 | ) 122 | self.target_dropdown.layer_changed.connect(self._update_target) 123 | target_layout.addWidget(self.target_dropdown) 124 | 125 | image_layout.addLayout(source_layout) 126 | image_layout.addLayout(target_layout) 127 | 128 | copy_labels_layout.addLayout(image_layout) 129 | 130 | # Radiobuttons for selecting whether to copy a slice/volume/series 131 | self.dims_widget = DimsRadioButtons() 132 | copy_labels_layout.addWidget(self.dims_widget) 133 | 134 | # Undo the last copy action if possible 135 | self.prev_state = None 136 | self.coords_clipped = None 137 | self.target_slices = None 138 | self.undo_btn = QPushButton("Undo last copy") 139 | self.undo_btn.setEnabled(False) 140 | self.undo_btn.clicked.connect(self.undo) 141 | copy_labels_layout.addWidget(self.undo_btn) 142 | 143 | # assemble the layout 144 | copy_labels_box.setLayout(copy_labels_layout) 145 | layout = QVBoxLayout() 146 | layout.addWidget(copy_labels_box) 147 | self.setLayout(layout) 148 | 149 | def _update_source(self, selected_layer: str) -> None: 150 | """Update the layer that is set to be the 'source labels' layer for copying 151 | labels from.""" 152 | 153 | if self.source_layer is not None and self._source_callback is not None: 154 | try: 155 | self.source_layer.mouse_drag_callbacks.remove(self._source_callback) 156 | self.source_layer.contour = 0 157 | except ValueError: 158 | pass 159 | if selected_layer == "": 160 | self.source_layer = None 161 | self._source_callback = None 162 | else: 163 | self.source_layer = self.viewer.layers[selected_layer] 164 | self.source_layer.contour = 1 165 | self.source_dropdown.setCurrentText(selected_layer) 166 | self._source_callback = self._make_copy_label_callback(self.source_layer) 167 | self.source_layer.mouse_drag_callbacks.append(self._source_callback) 168 | 169 | self.update_radiobuttons() 170 | 171 | def _update_target(self, selected_layer: str) -> None: 172 | """Update the layer to copy labels to.""" 173 | 174 | if selected_layer == "": 175 | self.target_layer = None 176 | else: 177 | self.target_layer = self.viewer.layers[selected_layer] 178 | self.target_dropdown.setCurrentText(selected_layer) 179 | 180 | self.update_radiobuttons() 181 | 182 | def update_radiobuttons(self) -> None: 183 | """Update the state of the dimension checkboxes based on the source and target 184 | layers.""" 185 | 186 | # All buttons off 187 | self.dims_widget.slice.setEnabled(False) 188 | self.dims_widget.slice.setChecked(False) 189 | self.dims_widget.volume.setEnabled(False) 190 | self.dims_widget.volume.setChecked(False) 191 | self.dims_widget.series.setEnabled(False) 192 | self.dims_widget.series.setChecked(False) 193 | self.undo_btn.setEnabled(False) 194 | 195 | if self.source_layer is not None and self.target_layer is not None: 196 | # Set the highest possible option based on the number of dimensions of the 197 | # source and target layers 198 | source_dims = self.source_layer.data.ndim 199 | target_dims = self.target_layer.data.ndim 200 | dims = min(source_dims, target_dims) 201 | 202 | if dims >= 4: 203 | self.dims_widget.series.setEnabled(True) 204 | self.dims_widget.volume.setEnabled(True) 205 | self.dims_widget.slice.setEnabled(True) 206 | self.dims_widget.volume.setChecked(True) 207 | elif dims >= 3: 208 | self.dims_widget.volume.setEnabled(True) 209 | self.dims_widget.slice.setEnabled(True) 210 | self.dims_widget.volume.setChecked(True) 211 | elif dims >= 2: 212 | self.dims_widget.slice.setEnabled(True) 213 | self.dims_widget.slice.setChecked(True) 214 | 215 | def _make_copy_label_callback(self, layer: Labels) -> callable: 216 | """Create a callback function for copying labels from the source layer to""" 217 | 218 | def callback(layer, event): 219 | if event.type == "mouse_press" and "Shift" in event.modifiers: 220 | self.copy_label(event) 221 | 222 | return callback 223 | 224 | def copy_label(self, event: Event, source_layer: Labels | None = None) -> None: 225 | """Copy a 2D/3D/4D label from this layer to a target layer""" 226 | 227 | if self.source_layer is None or self.target_layer is None: 228 | return 229 | 230 | # Check whether to copy a slice/volume/series label according to the 231 | # radiobutton choice 232 | if self.dims_widget.series.isChecked(): 233 | n_dims_copied = 4 234 | elif self.dims_widget.volume.isChecked(): 235 | n_dims_copied = 3 236 | else: 237 | n_dims_copied = 2 238 | 239 | if n_dims_copied == 4 and ( 240 | isinstance(self.source_layer.data, da.core.Array) 241 | or isinstance(self.target_layer.data, da.core.Array) 242 | ): 243 | msg = QMessageBox() 244 | msg.setWindowTitle("Warning") 245 | msg.setText( 246 | "Copying labels in 4D dimensions between dask arrays is slow, are you sure you want to continue?" 247 | ) 248 | msg.setIcon(QMessageBox.Information) 249 | msg.setStandardButtons(QMessageBox.Ok | QMessageBox.Cancel) 250 | result = msg.exec_() 251 | if result != QMessageBox.Ok: 252 | return 253 | 254 | # extract label value from source layer (orthoview or self.source_layer) 255 | source_layer = source_layer if source_layer is not None else self.source_layer 256 | selected_label = source_layer.get_value( 257 | event.position, 258 | view_direction=event.view_direction, 259 | dims_displayed=event.dims_displayed, 260 | world=True, 261 | ) 262 | 263 | # do not process clicking on the background 264 | if selected_label == 0: 265 | return 266 | 267 | # extract coords from click position 268 | coords = self.source_layer.world_to_data(event.position) 269 | coords = [int(c) for c in coords] 270 | 271 | # Get dimensions of source and target layers 272 | dims_displayed = event.dims_displayed 273 | ndims_source = len(self.source_layer.data.shape) 274 | ndims_target = len(self.target_layer.data.shape) 275 | 276 | # Assign a new label value if None is provided 277 | target_label = selected_label if self.preserve_label_value.isChecked() else None 278 | if target_label is None: 279 | target_label = np.max(self.target_layer.data) + 1 280 | 281 | # Check if the target label is within the dtype range of the target layer, if not 282 | # suggest converting to a larger dtype 283 | within_range, next_dtype = check_value_dtype( 284 | target_label, self.target_layer.data.dtype 285 | ) 286 | if not within_range: 287 | msg = QMessageBox() 288 | msg.setWindowTitle("Invalid label!") 289 | if next_dtype is not None: 290 | msg.setText( 291 | f"Label {target_label} exceeds bit depth! Convert to {next_dtype}?" 292 | ) 293 | msg.setIcon(QMessageBox.Information) 294 | msg.setStandardButtons(QMessageBox.Ok | QMessageBox.Cancel) 295 | result = msg.exec_() 296 | if result == QMessageBox.Ok: # Check if Ok was clicked 297 | self.target_layer.data = self.target_layer.data.astype(next_dtype) 298 | else: 299 | return 300 | 301 | # Select dims to copy 302 | source_shape = self.source_layer.data.shape 303 | labels_shape = self.target_layer.data.shape 304 | source_dims_to_copy = source_shape[-n_dims_copied:] 305 | target_dims_to_copy = labels_shape[-n_dims_copied:] 306 | 307 | # Check if the dimensions to copy match 308 | if source_dims_to_copy != target_dims_to_copy: 309 | msg = QMessageBox() 310 | msg.setWindowTitle("Invalid dimensions!") 311 | msg.setText( 312 | f"The dimensions of the source layer and the target layer do not match.\n" 313 | f"Label source layer has shape {source_shape}, target layer has shape {labels_shape}.\n" 314 | f"Compared last {n_dims_copied} dims: {source_dims_to_copy} vs {target_dims_to_copy}." 315 | ) 316 | msg.setIcon(QMessageBox.Information) 317 | msg.setStandardButtons(QMessageBox.Ok) 318 | msg.exec_() 319 | return False 320 | 321 | # Create source_slices for all dimensions of the source layer 322 | source_slices = [slice(None)] * ndims_source 323 | 324 | # When copying 2D labels, we need to check the dimensions displayed, in case 325 | # the sure is copying from one of the orthoviews. 326 | dims_difference = ndims_source - ndims_target 327 | if n_dims_copied == 2: 328 | # Create a list of `slice(None)` for all dimensions of self.source_layer.data 329 | if dims_difference < 0: 330 | dims_displayed = [dim + dims_difference for dim in dims_displayed] 331 | for i in range(ndims_source): 332 | if i not in dims_displayed: 333 | source_slices[i] = coords[ 334 | i 335 | ] # Replace the slice with a specific coordinate for slider dims 336 | 337 | else: 338 | # Calculate the coords for the remaining dims 339 | remaining_coords = coords[:-n_dims_copied] 340 | for i, coord in enumerate(remaining_coords): 341 | source_slices[i] = coord 342 | 343 | # Clip coords to the shape of the target data 344 | if dims_difference > 0: 345 | coords_clipped = coords[dims_difference:] 346 | target_slices = source_slices[dims_difference:] 347 | elif dims_difference < 0: 348 | coords_clipped = [ 349 | *self.viewer.dims.current_step[: abs(dims_difference)], 350 | *coords, 351 | ] 352 | target_slices = [ 353 | *self.viewer.dims.current_step[: abs(dims_difference)], 354 | *source_slices, 355 | ] 356 | else: 357 | coords_clipped = coords 358 | target_slices = source_slices 359 | 360 | # Create mask 361 | if isinstance(self.source_layer.data, da.core.Array): 362 | mask = ( 363 | self.source_layer.data[tuple(source_slices)].compute() == selected_label 364 | ) 365 | else: 366 | mask = self.source_layer.data[tuple(source_slices)] == selected_label 367 | 368 | # Select the correct stack for 2D/3D/4D data 369 | orig_label = self.target_layer.data[tuple(coords_clipped)] 370 | sliced_data = self.target_layer.data[tuple(target_slices)] 371 | if isinstance(sliced_data, da.core.Array): 372 | sliced_data = sliced_data.compute() 373 | 374 | # Store previous state for undo 375 | self.prev_state = np.copy(sliced_data) 376 | self.target_slices = np.copy(target_slices) 377 | self.coords_clipped = np.copy(coords_clipped) 378 | 379 | # Replace label in target layer data 380 | orig_mask = sliced_data == orig_label 381 | if not self.preserve_existing_labels.isChecked(): 382 | sliced_data[orig_mask] = 0 383 | sliced_data[mask] = target_label 384 | else: 385 | sliced_data[orig_mask & (sliced_data == 0) & mask] = target_label 386 | 387 | self.target_layer.data[tuple(target_slices)] = sliced_data 388 | self.undo_btn.setEnabled(True) 389 | 390 | # refresh the layer 391 | self.target_layer.data = self.target_layer.data 392 | 393 | def undo(self) -> None: 394 | """Undo the last label copy operation""" 395 | 396 | if hasattr(self, "prev_state") and self.prev_state is not None: 397 | self.target_layer.data[tuple(self.target_slices)] = self.prev_state 398 | 399 | # refresh the layer 400 | self.target_layer.data = self.target_layer.data 401 | 402 | # set back to None and disable button 403 | self.prev_state = None 404 | self.coords_clipped = None 405 | self.target_slices = None 406 | self.undo_btn.setEnabled(False) 407 | 408 | def sync_click( 409 | self, orig_layer: Labels, copied_layer: Labels, event: Event 410 | ) -> None: 411 | """Forward the click event from orthogonal views""" 412 | 413 | if ( 414 | orig_layer is self.source_layer 415 | and event.type == "mouse_press" 416 | and "Shift" in event.modifiers 417 | ): 418 | # pass the copied layer for extracting the label value, because an orthoview 419 | # was used 420 | self.copy_label(event, copied_layer) 421 | -------------------------------------------------------------------------------- /src/napari_segmentation_correction/cross_widget.py: -------------------------------------------------------------------------------- 1 | import napari 2 | import numpy as np 3 | from napari.components.layerlist import Extent 4 | from napari.components.viewer_model import ViewerModel 5 | from napari.layers import Vectors 6 | from napari.utils.action_manager import action_manager 7 | from napari.utils.notifications import show_info 8 | from qtpy.QtWidgets import ( 9 | QCheckBox, 10 | ) 11 | from superqt.utils import qthrottled 12 | 13 | 14 | def center_cross_on_mouse( 15 | viewer_model: napari.components.viewer_model.ViewerModel, 16 | ): 17 | """move the cross to the mouse position""" 18 | 19 | if not getattr(viewer_model, "mouse_over_canvas", True): 20 | # There is no way for napari 0.4.15 to check if mouse is over sending canvas. 21 | show_info("Mouse is not over the canvas. You may need to click on the canvas.") 22 | return 23 | 24 | viewer_model.dims.current_step = tuple( 25 | np.round( 26 | [ 27 | max(min_, min(p, max_)) / step 28 | for p, (min_, max_, step) in zip( 29 | viewer_model.cursor.position, viewer_model.dims.range, strict=False 30 | ) 31 | ] 32 | ).astype(int) 33 | ) 34 | 35 | 36 | action_manager.register_action( 37 | name="napari:move_point", 38 | command=center_cross_on_mouse, 39 | description="Move dims point to mouse position", 40 | keymapprovider=ViewerModel, 41 | ) 42 | class CrossWidget(QCheckBox): 43 | """ 44 | Widget to control the cross layer. because of the performance reason 45 | the cross update is throttled 46 | """ 47 | 48 | def __init__(self, viewer: napari.Viewer): 49 | super().__init__("Add cross layer") 50 | self.viewer = viewer 51 | self.setChecked(False) 52 | self.stateChanged.connect(self._update_cross_visibility) 53 | self.layer = None 54 | self.color = 'red' 55 | self.viewer.dims.events.order.connect(self.update_cross) 56 | self.viewer.dims.events.ndim.connect(self._update_ndim) 57 | self.viewer.dims.events.current_step.connect(self.update_cross) 58 | self._extent = None 59 | 60 | self._update_extent() 61 | self.viewer.dims.events.connect(self._update_extent) 62 | 63 | @qthrottled(leading=False) 64 | def _update_extent(self): 65 | """ 66 | Calculate the extent of the data. 67 | 68 | Ignores the cross layer itself in calculating the extent. 69 | """ 70 | 71 | extent_list = [ 72 | layer.extent for layer in self.viewer.layers if layer is not self.layer 73 | ] 74 | self._extent = Extent( 75 | data=None, 76 | world=self.viewer.layers._get_extent_world(extent_list), 77 | step=self.viewer.layers._get_step_size(extent_list), 78 | ) 79 | self.update_cross() 80 | 81 | def _update_ndim(self, event): 82 | if self.layer in self.viewer.layers: 83 | self.viewer.layers.remove(self.layer) 84 | self.layer = Vectors(name=".cross", ndim=event.value) 85 | self.layer.vector_style = "line" 86 | self.layer.edge_width = 2 87 | self.layer.edge_color = self.color 88 | self.update_cross() 89 | self.layer.events.edge_color.connect(self._set_color) 90 | 91 | def _set_color(self, event): 92 | self.color = self.layer.edge_color 93 | 94 | def _update_cross_visibility(self, state): 95 | if state: 96 | if self.layer is None: 97 | self.layer = Vectors(name=".cross", ndim=self.viewer.dims.ndim) 98 | self.layer.vector_style = "line" 99 | self.layer.edge_width = 2 100 | self.viewer.layers.append(self.layer) 101 | else: 102 | self.viewer.layers.remove(self.layer) 103 | self.update_cross() 104 | if not np.any(self.layer.edge_color): 105 | self.layer.edge_color = self.color 106 | self.layer.vector_style = "line" 107 | 108 | def update_cross(self): 109 | if self.layer not in self.viewer.layers: 110 | self.setChecked(False) 111 | return 112 | 113 | with self.viewer.dims.events.blocker(): 114 | 115 | point = self.viewer.dims.current_step 116 | vec = [] 117 | for i, (lower, upper) in enumerate(self._extent.world.T): 118 | if (upper - lower) / self._extent.step[i] == 1: 119 | continue 120 | point1 = list(point) 121 | point1[i] = (lower + self._extent.step[i] / 2) / self._extent.step[i] 122 | point2 = [0 for _ in point] 123 | point2[i] = (upper - lower) / self._extent.step[i] 124 | vec.append((point1, point2)) 125 | if np.any(self.layer.scale != self._extent.step): 126 | self.layer.scale = self._extent.step 127 | self.layer.data = vec 128 | -------------------------------------------------------------------------------- /src/napari_segmentation_correction/custom_table_widget.py: -------------------------------------------------------------------------------- 1 | import napari 2 | import pandas as pd 3 | from matplotlib.colors import to_rgb 4 | from qtpy.QtGui import QColor 5 | from qtpy.QtWidgets import ( 6 | QFileDialog, 7 | QHBoxLayout, 8 | QPushButton, 9 | QStyledItemDelegate, 10 | QTableWidget, 11 | QTableWidgetItem, 12 | QVBoxLayout, 13 | QWidget, 14 | ) 15 | 16 | 17 | class FloatDelegate(QStyledItemDelegate): 18 | def __init__(self, decimals, parent=None): 19 | super().__init__(parent) 20 | self.nDecimals = decimals 21 | 22 | def displayText(self, value, locale): 23 | try: 24 | number = float(value) 25 | except (ValueError, TypeError): 26 | return str(value) 27 | 28 | if number.is_integer(): 29 | return str(int(number)) 30 | return f"{number:.{self.nDecimals}f}" 31 | 32 | 33 | class ColoredTableWidget(QWidget): 34 | """Customized table widget with colored rows based on label colors in a napari Labels layer""" 35 | 36 | def __init__(self, layer: "napari.layers.Layer", viewer: "napari.Viewer" = None): 37 | super().__init__() 38 | 39 | self._layer = layer 40 | self._viewer = viewer 41 | self._table_widget = QTableWidget() 42 | 43 | self._layer.events.colormap.connect(self._set_label_colors_to_rows) 44 | if hasattr(layer, "properties"): 45 | self._set_data(layer.properties) 46 | else: 47 | self._set_data({}) 48 | self.ascending = False # for choosing whether to sort ascending or descending 49 | 50 | # Reconnect the clicked signal to your custom method. 51 | self._table_widget.clicked.connect(self._clicked_table) 52 | 53 | # Connect to single click in the header to sort the table. 54 | self._table_widget.horizontalHeader().sectionClicked.connect(self._sort_table) 55 | 56 | copy_button = QPushButton("Copy to clipboard") 57 | copy_button.clicked.connect(self._copy_table) 58 | 59 | save_button = QPushButton("Save as csv") 60 | save_button.clicked.connect(self._save_table) 61 | 62 | button_layout = QHBoxLayout() 63 | button_layout.addWidget(copy_button) 64 | button_layout.addWidget(save_button) 65 | main_layout = QVBoxLayout() 66 | main_layout.addLayout(button_layout) 67 | main_layout.addWidget(self._table_widget) 68 | self.setLayout(main_layout) 69 | self.setMinimumHeight(300) 70 | 71 | def _set_data(self, table: dict) -> None: 72 | """Set the content of the table from a dictionary""" 73 | 74 | self._table = table 75 | self._layer.properties = table 76 | 77 | self._table_widget.clear() 78 | try: 79 | self._table_widget.setRowCount(len(next(iter(table.values())))) 80 | self._table_widget.setColumnCount(len(table)) 81 | except StopIteration: 82 | pass 83 | 84 | for i, column in enumerate(table): 85 | self._table_widget.setHorizontalHeaderItem(i, QTableWidgetItem(column)) 86 | for j, value in enumerate(table.get(column)): 87 | self._table_widget.setItem(j, i, QTableWidgetItem(str(value))) 88 | 89 | self._table_widget.setItemDelegate(FloatDelegate(3, self._table_widget)) 90 | 91 | self._set_label_colors_to_rows() 92 | 93 | def _set_label_colors_to_rows(self) -> None: 94 | """Apply the colors of the napari label image to the table""" 95 | 96 | for i in range(self._table_widget.rowCount()): 97 | label = self._table["label"][i] 98 | label_color = to_rgb(self._layer.get_color(label)) 99 | scaled_color = ( 100 | int(label_color[0] * 255), 101 | int(label_color[1] * 255), 102 | int(label_color[2] * 255), 103 | ) 104 | for j in range(self._table_widget.columnCount()): 105 | self._table_widget.item(i, j).setBackground(QColor(*scaled_color)) 106 | 107 | def _save_table(self) -> None: 108 | """Save table to csv file""" 109 | filename, _ = QFileDialog.getSaveFileName(self, "Save as csv", ".", "*.csv") 110 | pd.DataFrame(self._table).to_csv(filename) 111 | 112 | def _copy_table(self) -> None: 113 | """Copy table to clipboard""" 114 | pd.DataFrame(self._table).to_clipboard() 115 | 116 | def _clicked_table(self): 117 | """Set the current viewer dims to the label that was clicked on.""" 118 | 119 | row = self._table_widget.currentRow() 120 | spatial_columns = sorted([key for key in self._table if "centroid" in key]) 121 | spatial_coords = [int(self._table[col][row]) for col in spatial_columns] 122 | 123 | if "time_point" in self._table: 124 | t = int(self._table["time_point"][row]) 125 | new_step = (t, *spatial_coords) 126 | else: 127 | new_step = spatial_coords 128 | self._viewer.dims.current_step = new_step 129 | 130 | def _sort_table(self): 131 | """Sorts the table in ascending or descending order""" 132 | 133 | selected_column = list(self._table.keys())[self._table_widget.currentColumn()] 134 | df = pd.DataFrame(self._table).sort_values( 135 | by=selected_column, ascending=self.ascending 136 | ) 137 | self.ascending = not self.ascending 138 | 139 | self._set_data(df.to_dict(orient="list")) 140 | self._set_label_colors_to_rows() 141 | -------------------------------------------------------------------------------- /src/napari_segmentation_correction/erosion_dilation_widget.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | 4 | import dask.array as da 5 | import napari 6 | import numpy as np 7 | import tifffile 8 | from qtpy.QtWidgets import ( 9 | QFileDialog, 10 | QGroupBox, 11 | QHBoxLayout, 12 | QLabel, 13 | QPushButton, 14 | QSpinBox, 15 | QVBoxLayout, 16 | QWidget, 17 | ) 18 | from scipy import ndimage 19 | from scipy.ndimage import binary_erosion 20 | from skimage.io import imread 21 | from skimage.segmentation import expand_labels 22 | 23 | from .layer_manager import LayerManager 24 | 25 | 26 | class ErosionDilationWidget(QWidget): 27 | """Widget to perform erosion/dilation on label images""" 28 | 29 | def __init__( 30 | self, viewer: "napari.viewer.Viewer", label_manager: LayerManager 31 | ) -> None: 32 | super().__init__() 33 | 34 | self.viewer = viewer 35 | self.label_manager = label_manager 36 | self.outputdir = None 37 | 38 | dil_erode_box = QGroupBox("Erode/dilate labels") 39 | dil_erode_box_layout = QVBoxLayout() 40 | 41 | radius_layout = QHBoxLayout() 42 | str_element_diameter_label = QLabel("Structuring element diameter") 43 | str_element_diameter_label.setFixedWidth(200) 44 | self.structuring_element_diameter = QSpinBox() 45 | self.structuring_element_diameter.setMaximum(100) 46 | self.structuring_element_diameter.setValue(1) 47 | radius_layout.addWidget(str_element_diameter_label) 48 | radius_layout.addWidget(self.structuring_element_diameter) 49 | 50 | iterations_layout = QHBoxLayout() 51 | iterations_label = QLabel("Iterations") 52 | iterations_label.setFixedWidth(200) 53 | self.iterations = QSpinBox() 54 | self.iterations.setMaximum(100) 55 | self.iterations.setValue(1) 56 | iterations_layout.addWidget(iterations_label) 57 | iterations_layout.addWidget(self.iterations) 58 | 59 | shrink_dilate_buttons_layout = QHBoxLayout() 60 | self.erode_btn = QPushButton("Erode") 61 | self.dilate_btn = QPushButton("Dilate") 62 | self.erode_btn.clicked.connect(self._erode_labels) 63 | self.dilate_btn.clicked.connect(self._dilate_labels) 64 | shrink_dilate_buttons_layout.addWidget(self.erode_btn) 65 | shrink_dilate_buttons_layout.addWidget(self.dilate_btn) 66 | 67 | self.erode_btn.setEnabled( 68 | isinstance(self.label_manager.selected_layer, napari.layers.Labels) 69 | ) 70 | self.label_manager.layer_update.connect( 71 | lambda: self.erode_btn.setEnabled( 72 | isinstance(self.label_manager.selected_layer, napari.layers.Labels) 73 | ) 74 | ) 75 | self.dilate_btn.setEnabled( 76 | isinstance(self.label_manager.selected_layer, napari.layers.Labels) 77 | ) 78 | self.label_manager.layer_update.connect( 79 | lambda: self.dilate_btn.setEnabled( 80 | isinstance(self.label_manager.selected_layer, napari.layers.Labels) 81 | ) 82 | ) 83 | 84 | dil_erode_box_layout.addLayout(radius_layout) 85 | dil_erode_box_layout.addLayout(iterations_layout) 86 | dil_erode_box_layout.addLayout(shrink_dilate_buttons_layout) 87 | 88 | dil_erode_box.setLayout(dil_erode_box_layout) 89 | 90 | layout = QVBoxLayout() 91 | layout.addWidget(dil_erode_box) 92 | self.setLayout(layout) 93 | 94 | def _erode_labels(self): 95 | """Shrink oversized labels through erosion""" 96 | 97 | diam = self.structuring_element_diameter.value() 98 | iterations = self.iterations.value() 99 | 100 | if self.label_manager.selected_layer.data.ndim == 2: 101 | structuring_element = np.ones( 102 | (diam, diam), dtype=bool 103 | ) # Define a 3x3 structuring element for 2D erosion 104 | else: 105 | structuring_element = np.ones( 106 | (diam, diam, diam), dtype=bool 107 | ) # Define a 3x3x3 structuring element for 3D erosion 108 | 109 | if isinstance(self.label_manager.selected_layer.data, da.core.Array): 110 | if self.outputdir is None: 111 | self.outputdir = QFileDialog.getExistingDirectory( 112 | self, "Select Output Folder" 113 | ) 114 | 115 | outputdir = os.path.join( 116 | self.outputdir, 117 | (self.label_manager.selected_layer.name + "_eroded"), 118 | ) 119 | if os.path.exists(outputdir): 120 | shutil.rmtree(outputdir) 121 | os.mkdir(outputdir) 122 | 123 | for i in range( 124 | self.label_manager.selected_layer.data.shape[0] 125 | ): # Loop over the first dimension 126 | current_stack = self.label_manager.selected_layer.data[ 127 | i 128 | ].compute() # Compute the current stack 129 | mask = current_stack > 0 130 | filled_mask = ndimage.binary_fill_holes(mask) 131 | eroded_mask = binary_erosion( 132 | filled_mask, 133 | structure=structuring_element, 134 | iterations=iterations, 135 | ) 136 | eroded = np.where(eroded_mask, current_stack, 0) 137 | tifffile.imwrite( 138 | os.path.join( 139 | outputdir, 140 | ( 141 | self.label_manager.selected_layer.name 142 | + "_eroded_TP" 143 | + str(i).zfill(4) 144 | + ".tif" 145 | ), 146 | ), 147 | np.array(eroded, dtype="uint16"), 148 | ) 149 | 150 | file_list = [ 151 | os.path.join(outputdir, fname) 152 | for fname in os.listdir(outputdir) 153 | if fname.endswith(".tif") 154 | ] 155 | self.label_manager.selected_layer = self.viewer.add_labels( 156 | da.stack([imread(fname) for fname in sorted(file_list)]), 157 | name=self.label_manager.selected_layer.name + "_eroded", 158 | scale=self.label_manager.selected_layer.scale, 159 | ) 160 | 161 | else: 162 | if len(self.label_manager.selected_layer.data.shape) == 4: 163 | stack = [] 164 | for i in range(self.label_manager.selected_layer.data.shape[0]): 165 | mask = self.label_manager.selected_layer.data[i] > 0 166 | filled_mask = ndimage.binary_fill_holes(mask) 167 | eroded_mask = binary_erosion( 168 | filled_mask, 169 | structure=structuring_element, 170 | iterations=iterations, 171 | ) 172 | stack.append( 173 | np.where( 174 | eroded_mask, 175 | self.label_manager.selected_layer.data[i], 176 | 0, 177 | ) 178 | ) 179 | self.label_manager.selected_layer = self.viewer.add_labels( 180 | np.stack(stack, axis=0), 181 | name=self.label_manager.selected_layer.name + "_eroded", 182 | scale=self.label_manager.selected_layer.scale, 183 | ) 184 | 185 | elif self.label_manager.selected_layer.data.ndim in (2, 3): 186 | mask = self.label_manager.selected_layer.data > 0 187 | filled_mask = ndimage.binary_fill_holes(mask) 188 | eroded_mask = binary_erosion( 189 | filled_mask, 190 | structure=structuring_element, 191 | iterations=iterations, 192 | ) 193 | self.label_manager.selected_layer = self.viewer.add_labels( 194 | np.where(eroded_mask, self.label_manager.selected_layer.data, 0), 195 | name=self.label_manager.selected_layer.name + "_eroded", 196 | scale=self.label_manager.selected_layer.scale, 197 | ) 198 | else: 199 | print("4D, 3D, or 2D array required!") 200 | 201 | def _dilate_labels(self): 202 | """Dilate labels in the selected layer.""" 203 | 204 | diam = self.structuring_element_diameter.value() 205 | iterations = self.iterations.value() 206 | 207 | if isinstance(self.label_manager.selected_layer.data, da.core.Array): 208 | if self.outputdir is None: 209 | self.outputdir = QFileDialog.getExistingDirectory( 210 | self, "Select Output Folder" 211 | ) 212 | 213 | outputdir = os.path.join( 214 | self.outputdir, 215 | (self.label_manager.selected_layer.name + "_dilated"), 216 | ) 217 | if os.path.exists(outputdir): 218 | shutil.rmtree(outputdir) 219 | os.mkdir(outputdir) 220 | 221 | for i in range( 222 | self.label_manager.selected_layer.data.shape[0] 223 | ): # Loop over the first dimension 224 | expanded_labels = self.label_manager.selected_layer.data[ 225 | i 226 | ].compute() # Compute the current stack 227 | for _j in range(iterations): 228 | expanded_labels = expand_labels(expanded_labels, distance=diam) 229 | tifffile.imwrite( 230 | os.path.join( 231 | outputdir, 232 | ( 233 | self.label_manager.selected_layer.name 234 | + "_dilated_TP" 235 | + str(i).zfill(4) 236 | + ".tif" 237 | ), 238 | ), 239 | np.array(expanded_labels, dtype="uint16"), 240 | ) 241 | 242 | file_list = [ 243 | os.path.join(outputdir, fname) 244 | for fname in os.listdir(outputdir) 245 | if fname.endswith(".tif") 246 | ] 247 | self.label_manager.selected_layer = self.viewer.add_labels( 248 | da.stack([imread(fname) for fname in sorted(file_list)]), 249 | name=self.label_manager.selected_layer.name + "_dilated", 250 | scale=self.label_manager.selected_layer.scale, 251 | ) 252 | 253 | else: 254 | if len(self.label_manager.selected_layer.data.shape) == 4: 255 | stack = [] 256 | for i in range(self.label_manager.selected_layer.data.shape[0]): 257 | expanded_labels = self.label_manager.selected_layer.data[i] 258 | for _j in range(iterations): 259 | expanded_labels = expand_labels(expanded_labels, distance=diam) 260 | stack.append(expanded_labels) 261 | self.label_manager.selected_layer = self.viewer.add_labels( 262 | np.stack(stack, axis=0), 263 | name=self.label_manager.selected_layer.name + "_dilated", 264 | scale=self.label_manager.selected_layer.scale, 265 | ) 266 | 267 | elif self.label_manager.selected_layer.data.ndim in (2, 3): 268 | expanded_labels = self.label_manager.selected_layer.data 269 | for _i in range(iterations): 270 | expanded_labels = expand_labels(expanded_labels, distance=diam) 271 | 272 | self.label_manager.selected_layer = self.viewer.add_labels( 273 | expanded_labels, 274 | name=self.label_manager.selected_layer.name + "_dilated", 275 | scale=self.label_manager.selected_layer.scale, 276 | ) 277 | else: 278 | print("input should be a 2D, 3D or 4D label image.") 279 | -------------------------------------------------------------------------------- /src/napari_segmentation_correction/icons/Back.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AnniekStok/napari-segmentation-correction/ece844476d64050a1f2ce54421aef53689088b45/src/napari_segmentation_correction/icons/Back.png -------------------------------------------------------------------------------- /src/napari_segmentation_correction/icons/Forward.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AnniekStok/napari-segmentation-correction/ece844476d64050a1f2ce54421aef53689088b45/src/napari_segmentation_correction/icons/Forward.png -------------------------------------------------------------------------------- /src/napari_segmentation_correction/icons/Home.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AnniekStok/napari-segmentation-correction/ece844476d64050a1f2ce54421aef53689088b45/src/napari_segmentation_correction/icons/Home.png -------------------------------------------------------------------------------- /src/napari_segmentation_correction/icons/Pan.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AnniekStok/napari-segmentation-correction/ece844476d64050a1f2ce54421aef53689088b45/src/napari_segmentation_correction/icons/Pan.png -------------------------------------------------------------------------------- /src/napari_segmentation_correction/icons/Pan_checked.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AnniekStok/napari-segmentation-correction/ece844476d64050a1f2ce54421aef53689088b45/src/napari_segmentation_correction/icons/Pan_checked.png -------------------------------------------------------------------------------- /src/napari_segmentation_correction/icons/Zoom.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AnniekStok/napari-segmentation-correction/ece844476d64050a1f2ce54421aef53689088b45/src/napari_segmentation_correction/icons/Zoom.png -------------------------------------------------------------------------------- /src/napari_segmentation_correction/icons/Zoom_checked.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AnniekStok/napari-segmentation-correction/ece844476d64050a1f2ce54421aef53689088b45/src/napari_segmentation_correction/icons/Zoom_checked.png -------------------------------------------------------------------------------- /src/napari_segmentation_correction/icons/configure_subplots.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AnniekStok/napari-segmentation-correction/ece844476d64050a1f2ce54421aef53689088b45/src/napari_segmentation_correction/icons/configure_subplots.png -------------------------------------------------------------------------------- /src/napari_segmentation_correction/icons/edit_parameters.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AnniekStok/napari-segmentation-correction/ece844476d64050a1f2ce54421aef53689088b45/src/napari_segmentation_correction/icons/edit_parameters.png -------------------------------------------------------------------------------- /src/napari_segmentation_correction/icons/save_figure.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AnniekStok/napari-segmentation-correction/ece844476d64050a1f2ce54421aef53689088b45/src/napari_segmentation_correction/icons/save_figure.png -------------------------------------------------------------------------------- /src/napari_segmentation_correction/image_calculator.py: -------------------------------------------------------------------------------- 1 | import dask.array as da 2 | import napari 3 | import numpy as np 4 | from napari.layers import Image, Labels 5 | from qtpy.QtWidgets import ( 6 | QComboBox, 7 | QGroupBox, 8 | QHBoxLayout, 9 | QLabel, 10 | QMessageBox, 11 | QPushButton, 12 | QVBoxLayout, 13 | QWidget, 14 | ) 15 | 16 | from .layer_dropdown import LayerDropdown 17 | 18 | 19 | class ImageCalculator(QWidget): 20 | """Widget to perform calculations between two images""" 21 | 22 | def __init__(self, viewer: "napari.viewer.Viewer") -> None: 23 | super().__init__() 24 | 25 | self.viewer = viewer 26 | 27 | ### Add one image to another 28 | image_calc_box = QGroupBox("Image Calculator") 29 | image_calc_box_layout = QVBoxLayout() 30 | 31 | image1_layout = QHBoxLayout() 32 | image1_layout.addWidget(QLabel("Label image 1")) 33 | self.image1_dropdown = LayerDropdown(self.viewer, (Image, Labels)) 34 | self.image1_dropdown.layer_changed.connect(self._update_image1) 35 | image1_layout.addWidget(self.image1_dropdown) 36 | 37 | image2_layout = QHBoxLayout() 38 | image2_layout.addWidget(QLabel("Label image 2")) 39 | self.image2_dropdown = LayerDropdown(self.viewer, (Image, Labels)) 40 | self.image2_dropdown.layer_changed.connect(self._update_image2) 41 | image2_layout.addWidget(self.image2_dropdown) 42 | 43 | image_calc_box_layout.addLayout(image1_layout) 44 | image_calc_box_layout.addLayout(image2_layout) 45 | 46 | operation_layout = QHBoxLayout() 47 | self.operation = QComboBox() 48 | self.operation.addItem("Add") 49 | self.operation.addItem("Subtract") 50 | self.operation.addItem("Multiply") 51 | self.operation.addItem("Divide") 52 | self.operation.addItem("AND") 53 | self.operation.addItem("OR") 54 | operation_layout.addWidget(QLabel("Operation")) 55 | operation_layout.addWidget(self.operation) 56 | image_calc_box_layout.addLayout(operation_layout) 57 | 58 | add_images_btn = QPushButton("Run") 59 | add_images_btn.clicked.connect(self._calculate_images) 60 | add_images_btn.setEnabled(self.image1_dropdown.selected_layer is not None and self.image2_dropdown.selected_layer is not None) 61 | self.image1_dropdown.layer_changed.connect(lambda: add_images_btn.setEnabled(self.image1_dropdown.selected_layer is not None and self.image2_dropdown.selected_layer is not None)) 62 | self.image2_dropdown.layer_changed.connect(lambda: add_images_btn.setEnabled(self.image1_dropdown.selected_layer is not None and self.image2_dropdown.selected_layer is not None)) 63 | 64 | image_calc_box_layout.addWidget(add_images_btn) 65 | 66 | image_calc_box.setLayout(image_calc_box_layout) 67 | main_layout = QVBoxLayout() 68 | main_layout.addWidget(image_calc_box) 69 | self.setLayout(main_layout) 70 | 71 | def _update_image1(self, selected_layer: str) -> None: 72 | """Update the layer that is set to be the 'source labels' layer for copying labels from.""" 73 | 74 | if selected_layer == "": 75 | self.image1_layer = None 76 | else: 77 | self.image1_layer = self.viewer.layers[selected_layer] 78 | self.image1_dropdown.setCurrentText(selected_layer) 79 | 80 | def _update_image2(self, selected_layer: str) -> None: 81 | """Update the layer that is set to be the 'source labels' layer for copying labels from.""" 82 | 83 | if selected_layer == "": 84 | self.image2_layer = None 85 | else: 86 | self.image2_layer = self.viewer.layers[selected_layer] 87 | self.image2_dropdown.setCurrentText(selected_layer) 88 | 89 | def _calculate_images(self): 90 | """Add label image 2 to label image 1""" 91 | 92 | if isinstance(self.image1_layer, da.core.Array) or isinstance( 93 | self.image2_layer, da.core.Array 94 | ): 95 | msg = QMessageBox() 96 | msg.setWindowTitle( 97 | "Cannot yet run image calculator on dask arrays" 98 | ) 99 | msg.setText("Cannot yet run image calculator on dask arrays") 100 | msg.setIcon(QMessageBox.Information) 101 | msg.setStandardButtons(QMessageBox.Ok) 102 | msg.exec_() 103 | return False 104 | if self.image1_layer.data.shape != self.image2_layer.data.shape: 105 | msg = QMessageBox() 106 | msg.setWindowTitle("Images must have the same shape") 107 | msg.setText("Images must have the same shape") 108 | msg.setIcon(QMessageBox.Information) 109 | msg.setStandardButtons(QMessageBox.Ok) 110 | msg.exec_() 111 | return False 112 | 113 | if self.operation.currentText() == "Add": 114 | self.viewer.add_image( 115 | np.add(self.image1_layer.data, self.image2_layer.data), 116 | scale = self.image1_layer.scale, 117 | ) 118 | if self.operation.currentText() == "Subtract": 119 | self.viewer.add_image( 120 | np.subtract(self.image1_layer.data, self.image2_layer.data), 121 | scale = self.image1_layer.scale, 122 | ) 123 | if self.operation.currentText() == "Multiply": 124 | self.viewer.add_image( 125 | np.multiply(self.image1_layer.data, self.image2_layer.data), 126 | scale = self.image1_layer.scale, 127 | ) 128 | if self.operation.currentText() == "Divide": 129 | self.viewer.add_image( 130 | np.divide( 131 | self.image1_layer.data, 132 | self.image2_layer.data, 133 | out=np.zeros_like(self.image1_layer.data, dtype=float), 134 | where=self.image2_layer.data != 0, 135 | ), 136 | scale = self.image1_layer.scale, 137 | ) 138 | if self.operation.currentText() == "AND": 139 | self.viewer.add_labels( 140 | np.logical_and( 141 | self.image1_layer.data != 0, self.image2_layer.data != 0 142 | ).astype(int), 143 | scale = self.image1_layer.scale, 144 | ) 145 | if self.operation.currentText() == "OR": 146 | self.viewer.add_labels( 147 | np.logical_or( 148 | self.image1_layer.data != 0, self.image2_layer.data != 0 149 | ).astype(int), 150 | scale = self.image1_layer.scale, 151 | ) 152 | -------------------------------------------------------------------------------- /src/napari_segmentation_correction/label_interpolator.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | 4 | import dask.array as da 5 | import napari 6 | import numpy as np 7 | import tifffile 8 | from qtpy.QtWidgets import ( 9 | QFileDialog, 10 | QGroupBox, 11 | QPushButton, 12 | QVBoxLayout, 13 | QWidget, 14 | ) 15 | from scipy.ndimage import distance_transform_edt 16 | from skimage.io import imread 17 | 18 | from .layer_manager import LayerManager 19 | 20 | 21 | def signed_distance_transform(mask): 22 | mask = mask.astype(bool) 23 | dist_out = distance_transform_edt(~mask) 24 | dist_in = distance_transform_edt(mask) 25 | return dist_out - dist_in 26 | 27 | 28 | def interpolate_binary_mask(mask): 29 | """ 30 | Interpolates a sparse binary mask array using SDTs along the first axis. 31 | Args: 32 | mask (ndarray): Binary array of shape (T, X, Y, Z) or (X, Y, Z), etc., 33 | where some slices along 'axis' contain valid masks 34 | Returns: 35 | ndarray: Binary array of same shape with interpolated masks along 'axis' 36 | """ 37 | 38 | output = np.zeros_like(mask, dtype=np.uint8) 39 | 40 | # Find slices along axis that have any nonzero values 41 | valid_idxs = [i for i in range(mask.shape[0]) if np.any(mask[i])] 42 | 43 | for i in range(len(valid_idxs) - 1): 44 | i_start, i_end = valid_idxs[i], valid_idxs[i + 1] 45 | sdt_start = signed_distance_transform(mask[i_start]) 46 | sdt_end = signed_distance_transform(mask[i_end]) 47 | 48 | for j in range(i_start, i_end + 1): 49 | alpha = (j - i_start) / (i_end - i_start) 50 | sdt_interp = (1 - alpha) * sdt_start + alpha * sdt_end 51 | output[j] = (sdt_interp < 0).astype(np.uint8) 52 | 53 | return output 54 | 55 | 56 | class InterpolationWidget(QWidget): 57 | """Widget to interpolate between nonzero pixels in a label layer using signed distance transforms.""" 58 | 59 | def __init__( 60 | self, viewer: "napari.viewer.Viewer", label_manager: LayerManager 61 | ) -> None: 62 | super().__init__() 63 | 64 | self.viewer = viewer 65 | self.label_manager = label_manager 66 | self.outputdir = None 67 | 68 | interpolator_box = QGroupBox("Interpolate mask") 69 | interpolator_box_layout = QVBoxLayout() 70 | 71 | run_btn = QPushButton("Run interpolation along first axis") 72 | run_btn.clicked.connect(self._interpolate) 73 | run_btn.setEnabled(self.label_manager.selected_layer is not None) 74 | self.label_manager.layer_update.connect( 75 | lambda: run_btn.setEnabled(self.label_manager.selected_layer is not None) 76 | ) 77 | interpolator_box_layout.addWidget(run_btn) 78 | 79 | interpolator_box.setLayout(interpolator_box_layout) 80 | main_layout = QVBoxLayout() 81 | main_layout.addWidget(interpolator_box) 82 | self.setLayout(main_layout) 83 | 84 | def _interpolate(self): 85 | """Interpolate between the nonzero pixels in the selected layer""" 86 | 87 | if isinstance(self.label_manager.selected_layer.data, da.core.Array): 88 | if self.outputdir is None: 89 | self.outputdir = QFileDialog.getExistingDirectory( 90 | self, "Select Output Folder" 91 | ) 92 | 93 | outputdir = os.path.join( 94 | self.outputdir, 95 | (self.label_manager.selected_layer.name + "_interpolated"), 96 | ) 97 | if os.path.exists(outputdir): 98 | shutil.rmtree(outputdir) 99 | os.mkdir(outputdir) 100 | 101 | in_memory_stack = [] 102 | for i in range( 103 | self.label_manager.selected_layer.data.shape[0] 104 | ): # Loop over the first dimension 105 | current_stack = self.label_manager.selected_layer.data[ 106 | i 107 | ].compute() # Compute the current stack 108 | 109 | in_memory_stack.append(current_stack) 110 | 111 | in_memory_stack = np.stack(in_memory_stack, axis=0) 112 | interpolated = interpolate_binary_mask(in_memory_stack) 113 | 114 | for i in range(interpolated.shape[0]): 115 | tifffile.imwrite( 116 | os.path.join( 117 | outputdir, 118 | ( 119 | self.label_manager.selected_layer.name 120 | + "_interpolation_TP" 121 | + str(i).zfill(4) 122 | + ".tif" 123 | ), 124 | ), 125 | np.array(interpolated[i], dtype="uint8"), 126 | ) 127 | 128 | file_list = [ 129 | os.path.join(outputdir, fname) 130 | for fname in os.listdir(outputdir) 131 | if fname.endswith(".tif") 132 | ] 133 | self.label_manager.selected_layer = self.viewer.add_labels( 134 | da.stack([imread(fname) for fname in sorted(file_list)]), 135 | name=self.label_manager.selected_layer.name + "_interpolated", 136 | scale=self.label_manager.selected_layer.scale, 137 | ) 138 | else: 139 | interpolated = interpolate_binary_mask( 140 | self.label_manager.selected_layer.data 141 | ) 142 | 143 | self.label_manager.selected_layer = self.viewer.add_labels( 144 | interpolated, 145 | name=self.label_manager.selected_layer.name + "_interpolated", 146 | scale=self.label_manager.selected_layer.scale, 147 | ) 148 | -------------------------------------------------------------------------------- /src/napari_segmentation_correction/layer_controls.py: -------------------------------------------------------------------------------- 1 | import napari 2 | from napari_plane_sliders import PlaneSliderWidget 3 | from qtpy.QtWidgets import ( 4 | QGroupBox, 5 | QVBoxLayout, 6 | QWidget, 7 | ) 8 | 9 | from .copy_label_widget import CopyLabelWidget 10 | from .layer_manager import LayerManager 11 | from .save_labels_widget import SaveLabelsWidget 12 | 13 | 14 | class LayerControlsWidget(QWidget): 15 | """Widget showing region props as a table and plot widget""" 16 | 17 | def __init__( 18 | self, viewer: "napari.viewer.Viewer", label_manager: LayerManager 19 | ) -> None: 20 | super().__init__() 21 | 22 | self.viewer = viewer 23 | self.label_manager = label_manager 24 | 25 | layout = QVBoxLayout() 26 | 27 | ### create the dropdown for selecting label images 28 | layout.addWidget(self.label_manager) 29 | 30 | ### plane sliders 31 | plane_slider_box = QGroupBox("Plane Sliders") 32 | plane_slider_layout = QVBoxLayout() 33 | self.plane_sliders = PlaneSliderWidget(self.viewer) 34 | plane_slider_layout.addWidget(self.plane_sliders) 35 | plane_slider_box.setLayout(plane_slider_layout) 36 | layout.addWidget(plane_slider_box) 37 | 38 | ### Add widget for copy-pasting labels from one layer to another 39 | self.copy_label_widget = CopyLabelWidget(self.viewer) 40 | layout.addWidget(self.copy_label_widget) 41 | 42 | ### Add widget to save labels 43 | save_labels = SaveLabelsWidget(self.viewer, self.label_manager) 44 | layout.addWidget(save_labels) 45 | 46 | self.setLayout(layout) 47 | -------------------------------------------------------------------------------- /src/napari_segmentation_correction/layer_dropdown.py: -------------------------------------------------------------------------------- 1 | import napari 2 | from napari.utils.events import Event 3 | from PyQt5.QtCore import pyqtSignal 4 | from qtpy.QtWidgets import QComboBox 5 | 6 | 7 | class LayerDropdown(QComboBox): 8 | """QComboBox widget with functions for updating the selected layer and to update the list of options when the list of layers is modified.""" 9 | 10 | layer_changed = pyqtSignal(str) # signal to emit the selected layer name 11 | 12 | def __init__( 13 | self, viewer: napari.Viewer, layer_type: tuple, allow_none: bool = False 14 | ): 15 | super().__init__() 16 | self.viewer = viewer 17 | self.layer_type = layer_type 18 | self.allow_none = allow_none 19 | self.selected_layer = None 20 | self.viewer.layers.events.inserted.connect(self._on_insert) 21 | self.viewer.layers.events.changed.connect(self._update_dropdown) 22 | self.viewer.layers.events.removed.connect(self._update_dropdown) 23 | self.viewer.layers.selection.events.changed.connect(self._on_selection_changed) 24 | self.currentTextChanged.connect(self._emit_layer_changed) 25 | self._update_dropdown() 26 | 27 | def _on_insert(self, event) -> None: 28 | """Update dropdown and make new layer responsive to name changes""" 29 | 30 | layer = event.value 31 | if isinstance(layer, self.layer_type): 32 | 33 | @layer.events.name.connect 34 | def _on_rename(name_event): 35 | self._update_dropdown() 36 | 37 | self._update_dropdown() 38 | 39 | def _on_selection_changed(self) -> None: 40 | """Request signal emission if the user changes the layer selection.""" 41 | 42 | if ( 43 | len(self.viewer.layers.selection) == 1 44 | ): # Only consider single layer selection 45 | selected_layer = self.viewer.layers.selection.active 46 | if ( 47 | isinstance(selected_layer, self.layer_type) 48 | and selected_layer != self.selected_layer 49 | ): 50 | self.setCurrentText(selected_layer.name) 51 | self._emit_layer_changed() 52 | 53 | def _update_dropdown(self, event: Event | None = None) -> None: 54 | """Update the list of options in the dropdown menu whenever the list of layers is changed""" 55 | 56 | if ( 57 | event is None 58 | or not hasattr(event, "value") 59 | or isinstance(event.value, self.layer_type) 60 | ): 61 | selected_layer = self.currentText() 62 | self.clear() 63 | layers = [ 64 | layer 65 | for layer in self.viewer.layers 66 | if isinstance(layer, self.layer_type) 67 | ] 68 | items = [] 69 | if self.allow_none: 70 | self.addItem("No selection") 71 | items.append("No selection") 72 | 73 | for layer in layers: 74 | self.addItem(layer.name) 75 | items.append(layer.name) 76 | layer.events.name.connect(self._update_dropdown) 77 | 78 | # In case the currently selected layer is one of the available items, set it again to the current value of the dropdown. 79 | if selected_layer in items: 80 | self.setCurrentText(selected_layer) 81 | 82 | def _emit_layer_changed(self) -> None: 83 | """Emit a signal holding the currently selected layer""" 84 | 85 | selected_layer_name = self.currentText() 86 | if ( 87 | selected_layer_name != "No selection" 88 | and selected_layer_name in self.viewer.layers 89 | ): 90 | self.selected_layer = self.viewer.layers[selected_layer_name] 91 | else: 92 | self.selected_layer = None 93 | selected_layer_name = "" 94 | 95 | self.layer_changed.emit(selected_layer_name) 96 | 97 | def deleteLater(self) -> None: 98 | """Ensure all connections are disconnected before deletion.""" 99 | self.viewer.layers.events.inserted.disconnect(self._on_insert) 100 | self.viewer.layers.events.changed.disconnect(self._update_dropdown) 101 | self.viewer.layers.events.removed.disconnect(self._update_dropdown) 102 | self.viewer.layers.selection.events.changed.disconnect( 103 | self._on_selection_changed 104 | ) 105 | self.currentTextChanged.disconnect(self._emit_layer_changed) 106 | super().deleteLater() 107 | -------------------------------------------------------------------------------- /src/napari_segmentation_correction/layer_manager.py: -------------------------------------------------------------------------------- 1 | import dask.array as da 2 | import napari 3 | import numpy as np 4 | from psygnal import Signal 5 | from qtpy.QtWidgets import ( 6 | QPushButton, 7 | QVBoxLayout, 8 | QWidget, 9 | ) 10 | 11 | from .layer_dropdown import LayerDropdown 12 | 13 | 14 | class LayerManager(QWidget): 15 | """QComboBox widget with functions for updating the selected layer and to update the list of options when the list of layers is modified.""" 16 | 17 | layer_update = Signal() 18 | 19 | def __init__(self, viewer: napari.Viewer): 20 | super().__init__() 21 | 22 | self.viewer = viewer 23 | self._selected_layer = None 24 | self.label_dropdown = LayerDropdown(self.viewer, (napari.layers.Labels)) 25 | self.label_dropdown.layer_changed.connect(self.set_active_layer) 26 | 27 | ### Add option to convert dask array to in-memory array 28 | self.convert_to_array_btn = QPushButton("Convert to in-memory array") 29 | self.convert_to_array_btn.setEnabled( 30 | self.selected_layer is not None 31 | and isinstance(self.selected_layer.data, da.core.Array) 32 | ) 33 | self.convert_to_array_btn.clicked.connect(self._convert_to_array) 34 | 35 | layout = QVBoxLayout() 36 | layout.addWidget(self.label_dropdown) 37 | layout.addWidget(self.convert_to_array_btn) 38 | self.setLayout(layout) 39 | 40 | # Send a signal on the layer dropdown to update the active layer 41 | self.label_dropdown._emit_layer_changed() 42 | 43 | @property 44 | def selected_layer(self): 45 | return self._selected_layer 46 | 47 | @selected_layer.setter 48 | def selected_layer(self, layer): 49 | if layer != self._selected_layer: 50 | self._selected_layer = layer 51 | 52 | def set_active_layer(self, selected_layer) -> None: 53 | """Update the layer that is set to be the 'labels' layer that is being edited.""" 54 | 55 | if selected_layer == "": 56 | self._selected_layer = None 57 | else: 58 | self.selected_layer = self.viewer.layers[selected_layer] 59 | self.label_dropdown.setCurrentText(selected_layer) 60 | self.convert_to_array_btn.setEnabled( 61 | isinstance(self._selected_layer.data, da.core.Array) 62 | ) 63 | 64 | self.layer_update.emit() 65 | 66 | def _convert_to_array(self) -> None: 67 | """Convert from dask array to in-memory array. This is necessary for manual editing using the label tools (brush, eraser, fill bucket).""" 68 | 69 | if isinstance(self._selected_layer.data, da.core.Array): 70 | stack = [] 71 | for i in range(self._selected_layer.data.shape[0]): 72 | current_stack = self._selected_layer.data[i].compute() 73 | stack.append(current_stack) 74 | self._selected_layer.data = np.stack(stack, axis=0) 75 | self.convert_to_array_btn.setEnabled(False) 76 | -------------------------------------------------------------------------------- /src/napari_segmentation_correction/napari.yaml: -------------------------------------------------------------------------------- 1 | name: napari-segmentation-correction 2 | display_name: Manual Segmentation Correction 3 | visibility: public 4 | categories: ["Annotation", "Segmentation"] 5 | contributions: 6 | commands: 7 | - id: napari-segmentation-correction.annotateND 8 | python_name: napari_segmentation_correction._widget:AnnotateLabelsND 9 | title: Manual Segmentation Correction 10 | widgets: 11 | - command: napari-segmentation-correction.annotateND 12 | display_name: Manual Segmentation Correction 13 | -------------------------------------------------------------------------------- /src/napari_segmentation_correction/plot_widget.py: -------------------------------------------------------------------------------- 1 | import os 2 | from pathlib import Path 3 | 4 | import matplotlib.pyplot as plt 5 | from matplotlib.backends.backend_qt5agg import ( 6 | FigureCanvas, 7 | NavigationToolbar2QT, 8 | ) 9 | from matplotlib.colors import ListedColormap, to_rgb 10 | from qtpy.QtCore import Qt 11 | from qtpy.QtGui import QIcon 12 | from qtpy.QtWidgets import QComboBox, QHBoxLayout, QLabel, QVBoxLayout, QWidget 13 | 14 | from .layer_manager import LayerManager 15 | 16 | ICON_ROOT = Path(__file__).parent / "icons" 17 | 18 | 19 | class PlotWidget(QWidget): 20 | """Matplotlib widget that displays features of the selected labels layer""" 21 | 22 | def __init__(self, label_manager: LayerManager): 23 | super().__init__() 24 | 25 | self.label_manager = label_manager 26 | self.label_manager.layer_update.connect(self._layer_update) 27 | 28 | # Main plot. 29 | self.fig = plt.figure(constrained_layout=True) 30 | self.plot_canvas = FigureCanvas(self.fig) 31 | self.ax = self.plot_canvas.figure.subplots() 32 | self.toolbar = NavigationToolbar2QT(self.plot_canvas) 33 | 34 | # Specify plot customizations. 35 | self.fig.patch.set_facecolor("#262930") 36 | self.ax.tick_params(colors="white") 37 | self.ax.set_facecolor("#262930") 38 | self.ax.xaxis.label.set_color("white") 39 | self.ax.yaxis.label.set_color("white") 40 | self.ax.spines["bottom"].set_color("white") 41 | self.ax.spines["top"].set_color("white") 42 | self.ax.spines["right"].set_color("white") 43 | self.ax.spines["left"].set_color("white") 44 | for action_name in self.toolbar._actions: 45 | action = self.toolbar._actions[action_name] 46 | icon_path = os.path.join(ICON_ROOT, action_name + ".png") 47 | action.setIcon(QIcon(icon_path)) 48 | 49 | # Create a dropdown window for selecting what to plot on the axes. 50 | x_axis_layout = QHBoxLayout() 51 | self.x_combo = QComboBox() 52 | x_axis_layout.addWidget(QLabel("x-axis")) 53 | x_axis_layout.addWidget(self.x_combo) 54 | 55 | y_axis_layout = QHBoxLayout() 56 | self.y_combo = QComboBox() 57 | y_axis_layout.addWidget(QLabel("y-axis")) 58 | y_axis_layout.addWidget(self.y_combo) 59 | 60 | color_group_layout = QHBoxLayout() 61 | self.group_combo = QComboBox() 62 | color_group_layout.addWidget(QLabel("Group color")) 63 | color_group_layout.addWidget(self.group_combo) 64 | 65 | self.x_combo.currentIndexChanged.connect(self._update_plot) 66 | self.y_combo.currentIndexChanged.connect(self._update_plot) 67 | self.group_combo.currentIndexChanged.connect(self._update_plot) 68 | 69 | dropdown_layout = QVBoxLayout() 70 | dropdown_layout.addLayout(x_axis_layout) 71 | dropdown_layout.addLayout(y_axis_layout) 72 | dropdown_layout.addLayout(color_group_layout) 73 | dropdown_widget = QWidget() 74 | dropdown_widget.setLayout(dropdown_layout) 75 | dropdown_layout.setAlignment(Qt.AlignTop) 76 | 77 | # Create and apply a horizontal layout for the dropdown widget, toolbar and canvas. 78 | plotting_layout = QVBoxLayout() 79 | plotting_layout.addWidget(dropdown_widget) 80 | plotting_layout.addWidget(self.toolbar) 81 | plotting_layout.addWidget(self.plot_canvas) 82 | plotting_layout.setAlignment(Qt.AlignTop) 83 | 84 | self.setLayout(plotting_layout) 85 | self.setMinimumHeight(500) 86 | 87 | def _layer_update(self) -> None: 88 | """Connect events to plot updates""" 89 | 90 | if self.label_manager.selected_layer is not None: 91 | self.label_manager.selected_layer.events.show_selected_label.connect( 92 | self._update_plot 93 | ) 94 | self.label_manager.selected_layer.events.selected_label.connect( 95 | self._update_plot 96 | ) 97 | self.label_manager.selected_layer.events.features.connect( 98 | self._update_dropdown 99 | ) 100 | 101 | self._update_dropdown() 102 | self._update_plot() 103 | 104 | def _update_dropdown(self) -> None: 105 | """Update the dropdowns with the column headers""" 106 | 107 | if ( 108 | self.label_manager.selected_layer is not None 109 | and len(self.label_manager.selected_layer.features) > 0 110 | ): 111 | # temporarily disconnect listening to updates in the comboboxes 112 | self.x_combo.currentIndexChanged.disconnect(self._update_plot) 113 | self.y_combo.currentIndexChanged.disconnect(self._update_plot) 114 | self.group_combo.currentIndexChanged.disconnect(self._update_plot) 115 | 116 | prev_index = self.x_combo.currentIndex() if self.x_combo.count() > 0 else 0 117 | self.x_combo.clear() 118 | self.x_combo.addItems( 119 | [ 120 | item 121 | for item in self.label_manager.selected_layer.features.columns 122 | if item != "index" 123 | ] 124 | ) 125 | self.x_combo.setCurrentIndex(prev_index) 126 | 127 | prev_index = self.y_combo.currentIndex() if self.y_combo.count() > 0 else 1 128 | self.y_combo.clear() 129 | self.y_combo.addItems( 130 | [ 131 | item 132 | for item in self.label_manager.selected_layer.features.columns 133 | if item != "index" 134 | ] 135 | ) 136 | self.y_combo.setCurrentIndex(prev_index) 137 | 138 | prev_index = ( 139 | self.group_combo.currentIndex() if self.group_combo.count() > 0 else 0 140 | ) 141 | self.group_combo.clear() 142 | self.group_combo.addItems( 143 | [ 144 | item 145 | for item in self.label_manager.selected_layer.features.columns 146 | if item != "index" 147 | ] 148 | ) 149 | self.group_combo.setCurrentIndex(prev_index) 150 | 151 | # reconnect to updates in the comboboxes 152 | self.x_combo.currentIndexChanged.connect(self._update_plot) 153 | self.y_combo.currentIndexChanged.connect(self._update_plot) 154 | self.group_combo.currentIndexChanged.connect(self._update_plot) 155 | 156 | self._update_plot() 157 | 158 | def _update_plot(self) -> None: 159 | """Update the plot by plotting the features selected by the user.""" 160 | 161 | x_axis_property = self.x_combo.currentText() 162 | y_axis_property = self.y_combo.currentText() 163 | group = self.group_combo.currentText() 164 | 165 | # Clear data points, and reset the axis scaling and labels. 166 | for artist in self.ax.lines + self.ax.collections: 167 | artist.remove() 168 | self.ax.set_xlabel(x_axis_property) 169 | self.ax.set_ylabel(y_axis_property) 170 | self.ax.relim() # Recalculate limits for the current data 171 | self.ax.autoscale_view() # Update the view to include the new limits 172 | 173 | if ( 174 | self.label_manager.selected_layer is not None 175 | and len(self.label_manager.selected_layer.features) > 0 176 | and (x_axis_property != "" and y_axis_property != "" and group != "") 177 | ): 178 | if group == "label": 179 | if self.label_manager.selected_layer.show_selected_label: 180 | label = self.label_manager.selected_layer.selected_label 181 | plotting_data = self.label_manager.selected_layer.features[ 182 | self.label_manager.selected_layer.features["label"] == label 183 | ] 184 | unique_labels = [label] 185 | else: 186 | plotting_data = self.label_manager.selected_layer.features 187 | unique_labels = plotting_data["label"].unique() 188 | 189 | # Create consistent label-to-color mapping 190 | colormap = self.label_manager.selected_layer.colormap 191 | label_colors = { 192 | label: to_rgb(colormap.map(label)) for label in unique_labels 193 | } 194 | 195 | # Scatter plot 196 | self.ax.scatter( 197 | plotting_data[x_axis_property], 198 | plotting_data[y_axis_property], 199 | c=plotting_data["label"].map(label_colors), 200 | cmap=ListedColormap( 201 | list(label_colors.values()) 202 | ), # Consistent colormap 203 | s=10, 204 | ) 205 | 206 | # Line plot for time_point x-axis 207 | if x_axis_property == "time_point": 208 | for label, color in label_colors.items(): 209 | label_data = plotting_data[ 210 | plotting_data["label"] == label 211 | ].sort_values(by=x_axis_property) 212 | self.ax.plot( 213 | label_data[x_axis_property], 214 | label_data[y_axis_property], 215 | linestyle="-", 216 | color=color, 217 | linewidth=1, 218 | ) 219 | 220 | else: 221 | # Continuous colormap for other grouping 222 | self.ax.scatter( 223 | self.label_manager.selected_layer.features[x_axis_property], 224 | self.label_manager.selected_layer.features[y_axis_property], 225 | c=self.label_manager.selected_layer.features[group], 226 | cmap="summer", 227 | s=10, 228 | ) 229 | 230 | self.plot_canvas.draw() 231 | -------------------------------------------------------------------------------- /src/napari_segmentation_correction/prop_filter_widget.py: -------------------------------------------------------------------------------- 1 | import functools 2 | import os 3 | import shutil 4 | from warnings import warn 5 | 6 | import dask.array as da 7 | import napari 8 | import numpy as np 9 | import pandas as pd 10 | import tifffile 11 | from qtpy.QtWidgets import ( 12 | QComboBox, 13 | QDoubleSpinBox, 14 | QFileDialog, 15 | QGroupBox, 16 | QHBoxLayout, 17 | QPushButton, 18 | QVBoxLayout, 19 | QWidget, 20 | ) 21 | from skimage.io import imread 22 | 23 | from .layer_manager import LayerManager 24 | 25 | 26 | class PropertyFilterWidget(QWidget): 27 | """Widget to filter objects by numerical property""" 28 | 29 | def __init__( 30 | self, viewer: "napari.viewer.Viewer", label_manager: LayerManager 31 | ) -> None: 32 | super().__init__() 33 | 34 | self.viewer = viewer 35 | self.label_manager = label_manager 36 | self.outputdir = None 37 | 38 | filterbox = QGroupBox("Filter objects by property") 39 | filter_layout = QHBoxLayout() 40 | 41 | self.property = QComboBox() 42 | self.property.currentIndexChanged.connect(self.update_min_max_value) 43 | self.label_manager.layer_update.connect(self.set_properties) 44 | self.operation = QComboBox() 45 | self.operation.addItems([">", "<", ">=", "<="]) 46 | self.operation.setToolTip("Operation to apply for filtering") 47 | self.value = QDoubleSpinBox() 48 | self.value.setValue(0.0) 49 | self.value.setSingleStep(0.1) 50 | self.value.setToolTip("Threshold value for the selected property") 51 | self.keep_delete = QComboBox() 52 | self.keep_delete.addItems(["Keep", "Delete"]) 53 | self.keep_delete.setToolTip( 54 | "Choose whether to keep or delete objects matching the criteria" 55 | ) 56 | self.run_btn = QPushButton("Run") 57 | self.run_btn.clicked.connect(self.filter_by_property) 58 | 59 | filter_layout.addWidget(self.property) 60 | filter_layout.addWidget(self.value) 61 | filter_layout.addWidget(self.operation) 62 | filter_layout.addWidget(self.keep_delete) 63 | 64 | main_layout = QVBoxLayout() 65 | main_layout.addLayout(filter_layout) 66 | main_layout.addWidget(self.run_btn) 67 | 68 | filterbox.setLayout(main_layout) 69 | 70 | layout = QVBoxLayout() 71 | layout.addWidget(filterbox) 72 | self.setLayout(layout) 73 | 74 | def set_properties(self) -> None: 75 | """Set available properties for the selected layer""" 76 | current_prop = self.property.currentText() 77 | if self.label_manager.selected_layer is not None: 78 | props = list(self.label_manager.selected_layer.properties.keys()) 79 | self.property.clear() 80 | self.property.addItems( 81 | [p for p in props if p not in ("label", "time_point")] 82 | ) 83 | if current_prop in props: 84 | self.property.setCurrentText(current_prop) 85 | self.run_btn.setEnabled(True) if len( 86 | props 87 | ) > 0 else self.run_btn.setEnabled(False) 88 | else: 89 | self.run_btn.setEnabled(False) 90 | 91 | def update_min_max_value(self) -> None: 92 | """Update min and max values for the threshold spinbox based on selected property""" 93 | prop = self.property.currentText() 94 | if prop in self.label_manager.selected_layer.properties: 95 | values = self.label_manager.selected_layer.properties[prop] 96 | self.value.setMinimum(np.min(values)) 97 | self.value.setMaximum(np.max(values)) 98 | if self.value.value() < np.min(values) or self.value.value() > np.max( 99 | values 100 | ): 101 | self.value.setValue(np.min(values)) 102 | 103 | def filter_by_property(self) -> None: 104 | """Filter objects by selected property and threshold value""" 105 | 106 | prop = self.property.currentText() 107 | value = self.value.value() 108 | operation = self.operation.currentText() 109 | keep_delete = self.keep_delete.currentText() 110 | 111 | if self.label_manager.selected_layer is not None: 112 | if isinstance(self.label_manager.selected_layer.data, da.core.Array): 113 | if self.outputdir is None: 114 | self.outputdir = QFileDialog.getExistingDirectory( 115 | self, "Select Output Folder" 116 | ) 117 | 118 | outputdir = os.path.join( 119 | self.outputdir, 120 | (self.label_manager.selected_layer.name + "_filtered"), 121 | ) 122 | if os.path.exists(outputdir): 123 | shutil.rmtree(outputdir) 124 | os.mkdir(outputdir) 125 | 126 | df = pd.DataFrame(self.label_manager.selected_layer.properties) 127 | if "time_point" in df.columns and prop in df.columns: 128 | filtered_label_imgs = [] 129 | for time_point in range( 130 | self.label_manager.selected_layer.data.shape[0] 131 | ): 132 | df_subset = df.loc[df["time_point"] == time_point] 133 | 134 | if operation == ">": 135 | filtered_labels = df_subset.loc[ 136 | df_subset[prop] > value, "label" 137 | ] 138 | elif operation == "<": 139 | filtered_labels = df_subset.loc[ 140 | df_subset[prop] < value, "label" 141 | ] 142 | elif operation == "<=": 143 | filtered_labels = df_subset.loc[ 144 | df_subset[prop] <= value, "label" 145 | ] 146 | elif operation == ">=": 147 | filtered_labels = df_subset.loc[ 148 | df_subset[prop] >= value, "label" 149 | ] 150 | 151 | if isinstance( 152 | self.label_manager.selected_layer.data, da.core.Array 153 | ): 154 | labels = np.array( 155 | self.label_manager.selected_layer.data[time_point].compute() 156 | ) 157 | else: 158 | labels = np.array( 159 | self.label_manager.selected_layer.data[time_point] 160 | ) 161 | 162 | if len(filtered_labels) == 0: 163 | mask = np.zeros_like(labels, dtype=bool) 164 | else: 165 | mask = functools.reduce( 166 | np.logical_or, (labels == val for val in filtered_labels) 167 | ) 168 | if keep_delete == "Delete": 169 | new_labels = np.where(~mask, labels, 0) 170 | else: 171 | new_labels = np.where(mask, labels, 0) 172 | 173 | if isinstance( 174 | self.label_manager.selected_layer.data, da.core.Array 175 | ): 176 | tifffile.imwrite( 177 | os.path.join( 178 | outputdir, 179 | ( 180 | self.label_manager.selected_layer.name 181 | + "_filtered_TP" 182 | + str(time_point).zfill(4) 183 | + ".tif" 184 | ), 185 | ), 186 | np.array(new_labels, dtype="uint16"), 187 | ) 188 | else: 189 | filtered_label_imgs.append(new_labels) 190 | 191 | if isinstance(self.label_manager.selected_layer.data, da.core.Array): 192 | file_list = [ 193 | os.path.join(outputdir, fname) 194 | for fname in os.listdir(outputdir) 195 | if fname.endswith(".tif") 196 | ] 197 | self.label_manager.selected_layer = self.viewer.add_labels( 198 | da.stack([imread(fname) for fname in sorted(file_list)]), 199 | name=self.label_manager.selected_layer.name + "_filtered", 200 | scale=self.label_manager.selected_layer.scale, 201 | ) 202 | else: 203 | result = np.stack(filtered_label_imgs) 204 | self.label_manager.selected_layer = self.viewer.add_labels( 205 | result, 206 | name=self.label_manager.selected_layer.name + "_filtered", 207 | scale=self.label_manager.selected_layer.scale, 208 | ) 209 | 210 | elif prop in df.columns: 211 | if operation == ">": 212 | filtered_labels = df.loc[df[prop] > value, "label"] 213 | elif operation == "<": 214 | filtered_labels = df.loc[df[prop] < value, "label"] 215 | elif operation == "<=": 216 | filtered_labels = df.loc[df[prop] <= value, "label"] 217 | elif operation == ">=": 218 | filtered_labels = df.loc[df[prop] >= value, "label"] 219 | 220 | labels = np.array(self.label_manager.selected_layer.data) 221 | mask = functools.reduce( 222 | np.logical_or, 223 | (labels == val for val in filtered_labels), 224 | ) 225 | if keep_delete == "Delete": 226 | result = np.where(~mask, labels, 0) 227 | else: 228 | result = np.where(mask, labels, 0) 229 | 230 | self.label_manager.selected_layer = self.viewer.add_labels( 231 | result, 232 | name=self.label_manager.selected_layer.name + "_filtered", 233 | scale=self.label_manager.selected_layer.scale, 234 | ) 235 | 236 | else: 237 | warn(f"Property {prop} not found in layer properties", stacklevel=2) 238 | return 239 | -------------------------------------------------------------------------------- /src/napari_segmentation_correction/regionprops_extended.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import numpy as np 4 | import pandas as pd 5 | from skimage.measure import marching_cubes, mesh_surface_area, regionprops 6 | from skimage.measure._regionprops import RegionProperties 7 | 8 | 9 | class ExtendedRegionProperties(RegionProperties): 10 | """Adding additional properties to skimage.measure._regionprops following the logic from the porespy package with some modifications to include the spacing information.""" 11 | 12 | @property 13 | def ellipse_axes(self): 14 | return (self.axis_major_length, self.axis_minor_length) 15 | 16 | @property 17 | def ellipsoid_axes(self): 18 | """ 19 | Calculate the three axes radii of the fitted ellipsoid. 20 | 21 | This method calculates the principal axes of inertia for the label region, which are used to determine the lengths of the axes radii of an ellipsoid that approximates the shape of the region. 22 | 23 | Returns: 24 | tuple: A tuple containing the lengths of the three principal axes radii (longr, midr, shortr). 25 | 26 | """ 27 | 28 | # Extract the coordinates of the region's voxels 29 | cell = np.where(self.image) 30 | voxel_count = self.voxel_count 31 | 32 | z, y, x = cell 33 | 34 | # Center the coordinates and apply spacing 35 | z = (z - np.mean(z)) * self._spacing[0] 36 | y = (y - np.mean(y)) * self._spacing[1] 37 | x = (x - np.mean(x)) * self._spacing[2] 38 | 39 | # Calculate the elements of the inertia tensor 40 | i_xx = np.sum(y**2 + z**2) 41 | i_yy = np.sum(x**2 + z**2) 42 | i_zz = np.sum(x**2 + y**2) 43 | i_xy = np.sum(x * y) 44 | i_xz = np.sum(x * z) 45 | i_yz = np.sum(y * z) 46 | i = np.array([[i_xx, -i_xy, -i_xz], [-i_xy, i_yy, -i_yz], [-i_xz, -i_yz, i_zz]]) 47 | 48 | # Compute the eigenvalues and eigenvectors of the inertia tensor. The eigenvalues of the inertia tensor represent the principal moments of inertia, and the eigenvectors represent the directions of the principal axes. 49 | eig = np.linalg.eig(i) 50 | eigval = eig[0] 51 | 52 | # Identify the principal axes 53 | longaxis = np.where(np.min(eigval) == eigval)[0][0] 54 | shortaxis = np.where(np.max(eigval) == eigval)[0][0] 55 | midaxis = ( 56 | 0 57 | if shortaxis != 0 and longaxis != 0 58 | else 1 59 | if shortaxis != 1 and longaxis != 1 60 | else 2 61 | ) 62 | 63 | # Calculate the lengths of the principal axes 64 | longr = math.sqrt( 65 | 5.0 66 | / 2.0 67 | * (eigval[midaxis] + eigval[shortaxis] - eigval[longaxis]) 68 | / voxel_count 69 | ) 70 | midr = math.sqrt( 71 | 5.0 72 | / 2.0 73 | * (eigval[shortaxis] + eigval[longaxis] - eigval[midaxis]) 74 | / voxel_count 75 | ) 76 | shortr = math.sqrt( 77 | 5.0 78 | / 2.0 79 | * (eigval[longaxis] + eigval[midaxis] - eigval[shortaxis]) 80 | / voxel_count 81 | ) 82 | 83 | return (longr, midr, shortr) 84 | 85 | @property 86 | def circularity(self): 87 | """ 88 | Calculate the circularity of the region. 89 | 90 | Circularity is defined as 4π * (Area / Perimeter^2). 91 | 92 | Returns: 93 | float: The circularity of the region. 94 | """ 95 | return 4 * math.pi * self.area / self.perimeter**2 96 | 97 | @property 98 | def pixel_count(self): 99 | """ 100 | Get the number of pixels in the region. 101 | 102 | Returns: 103 | int: The number of pixels in the region. 104 | """ 105 | return self.voxel_count 106 | 107 | @property 108 | def surface_area(self): 109 | """ 110 | Calculate the surface area of the region using the marching cubes algorithm. 111 | 112 | The marching cubes algorithm is used to extract a 2D surface mesh from a 3D volume. 113 | The mesh surface area is calculated with skimage.measure.mesh_surface_area. 114 | 115 | Returns: 116 | float: The surface area of the region. 117 | """ 118 | verts, faces, _, _ = marching_cubes( 119 | self.image, level=0.5, spacing=self._spacing 120 | ) 121 | surface_area = mesh_surface_area(verts, faces) 122 | return surface_area 123 | 124 | @property 125 | def sphericity(self): 126 | """ 127 | Calculate the sphericity of the region. 128 | 129 | Sphericity is defined as the ratio of the surface area of a sphere with the same volume as the region to the surface area of the region. 130 | 131 | Returns: 132 | float: The sphericity of the region. 133 | """ 134 | vol = self.volume 135 | r = (3 / 4 / np.pi * vol) ** (1 / 3) 136 | a_equiv = 4 * np.pi * r**2 137 | a_region = self.surface_area 138 | return a_equiv / a_region 139 | 140 | @property 141 | def volume(self): 142 | """ 143 | Calculate the volume of the region. 144 | 145 | The volume is calculated as the number of voxels in the region multiplied by the product of the spacing in each dimension. 146 | 147 | Returns: 148 | float: The volume of the region. 149 | """ 150 | vol = np.sum(self.image) * np.prod(self._spacing) 151 | return vol 152 | 153 | @property 154 | def voxel_count(self): 155 | """ 156 | Get the number of voxels in the region. 157 | 158 | Returns: 159 | int: The number of voxels in the region. 160 | """ 161 | voxel_count = int(np.sum(self.image)) 162 | return voxel_count 163 | 164 | 165 | def regionprops_extended( 166 | img: np.ndarray, spacing: tuple[float], intensity_image: np.ndarray | None = None 167 | ) -> list[ExtendedRegionProperties]: 168 | """ 169 | Create instances of ExtendedRegionProperties that extend skimage.measure.RegionProperties. 170 | 171 | Args: 172 | img (np.ndarray): The labeled image. 173 | spacing (tuple[float]): The spacing between voxels in each dimension. 174 | intensity_image (np.ndarray, optional): The intensity image. 175 | 176 | Returns: 177 | list[ExtendedRegionProperties]: A list of ExtendedRegionProperties instances. 178 | """ 179 | results = regionprops(img, intensity_image=intensity_image, spacing=spacing) 180 | for i, _ in enumerate(results): 181 | a = results[i] 182 | b = ExtendedRegionProperties( 183 | slice=a.slice, 184 | label=a.label, 185 | label_image=a._label_image, 186 | intensity_image=a._intensity_image, 187 | cache_active=a._cache_active, 188 | spacing=a._spacing, 189 | ) 190 | results[i] = b 191 | 192 | return results 193 | 194 | 195 | def props_to_dataframe(regionprops, selected_properties=None) -> pd.DataFrame: 196 | """Convert ExtendedRegionProperties instance to pandas dataframe, following the logical from porespy.metrics._regionprops.props_to_dataframe""" 197 | 198 | if selected_properties is None: 199 | selected_properties = regionprops[0].__dict__() 200 | 201 | new_props = ["label"] 202 | # need to check if any of the props return multiple values 203 | for item in selected_properties: 204 | if isinstance(getattr(regionprops[0], item), tuple): 205 | for i, _ in enumerate(getattr(regionprops[0], item)): 206 | new_props.append(item + "-" + str(i + 1)) 207 | else: 208 | new_props.append(item) 209 | 210 | # get the measurements for all properties of interest 211 | d = {} 212 | for k in new_props: 213 | if "-" in k: 214 | # If property is a tuple, extract the tuple element 215 | prop_name, idx = k.split("-") 216 | idx = int(idx) - 1 217 | d[k] = np.array([getattr(r, prop_name)[idx] for r in regionprops]) 218 | else: 219 | d[k] = np.array([getattr(r, k) for r in regionprops]) 220 | 221 | # Create pandas data frame an return 222 | df = pd.DataFrame(d) 223 | 224 | return df 225 | 226 | 227 | def calculate_extended_props( 228 | image: np.ndarray, 229 | properties: list[str], 230 | spacing: list[float], 231 | intensity_image: np.ndarray | None = None, 232 | ) -> pd.DataFrame: 233 | """Create regionproperties, and convert to pandas dataframe""" 234 | 235 | props = regionprops_extended( 236 | image, spacing=spacing, intensity_image=intensity_image 237 | ) 238 | return props_to_dataframe(props, properties) if len(props) > 0 else pd.DataFrame() 239 | -------------------------------------------------------------------------------- /src/napari_segmentation_correction/regionprops_widget.py: -------------------------------------------------------------------------------- 1 | import dask.array as da 2 | import napari 3 | import numpy as np 4 | import pandas as pd 5 | from qtpy.QtCore import Qt 6 | from qtpy.QtWidgets import ( 7 | QCheckBox, 8 | QDoubleSpinBox, 9 | QGridLayout, 10 | QGroupBox, 11 | QHBoxLayout, 12 | QLabel, 13 | QMessageBox, 14 | QPushButton, 15 | QSizePolicy, 16 | QVBoxLayout, 17 | QWidget, 18 | ) 19 | from tqdm import tqdm 20 | 21 | from .custom_table_widget import ColoredTableWidget 22 | from .layer_dropdown import LayerDropdown 23 | from .layer_manager import LayerManager 24 | from .prop_filter_widget import PropertyFilterWidget 25 | from .regionprops_extended import calculate_extended_props 26 | 27 | 28 | class RegionPropsWidget(QWidget): 29 | """Widget showing region props as a table and plot widget""" 30 | 31 | def __init__( 32 | self, viewer: "napari.viewer.Viewer", label_manager: LayerManager 33 | ) -> None: 34 | super().__init__() 35 | 36 | self.viewer = viewer 37 | self.label_manager = label_manager 38 | self.label_manager.layer_update.connect(self.update_dims) 39 | self.table = None 40 | self.ndims = 2 41 | self.feature_dims = 2 42 | self.axis_widgets = [] 43 | 44 | dim_box = QGroupBox("Dimensions") 45 | grid = QGridLayout() 46 | 47 | # headers 48 | grid.addWidget(QLabel("Axis"), 0, 0) 49 | grid.addWidget(QLabel("Index"), 0, 1) 50 | grid.addWidget(QLabel("Pixel scaling"), 0, 2) 51 | 52 | # Z row 53 | self.z_label = QLabel("Z") 54 | self.z_label.setVisible(False) 55 | self.z_axis = QLabel("0") 56 | self.z_axis.setVisible(False) 57 | self.axis_widgets.append(self.z_axis) 58 | self.z_scale = QDoubleSpinBox() 59 | self.z_scale.setValue(1.0) 60 | self.z_scale.setSingleStep(0.1) 61 | self.z_scale.setMinimum(0.01) 62 | self.z_scale.setToolTip("Voxel size along Z axis") 63 | self.z_scale.setVisible(False) 64 | self.z_scale.setDecimals(3) 65 | 66 | grid.addWidget(self.z_label, 1, 0) 67 | grid.addWidget(self.z_axis, 1, 1) 68 | grid.addWidget(self.z_scale, 1, 2) 69 | 70 | # Y row 71 | self.y_label = QLabel("Y") 72 | self.y_axis = QLabel("1") 73 | self.axis_widgets.append(self.y_axis) 74 | self.y_scale = QDoubleSpinBox() 75 | self.y_scale.setValue(1.0) 76 | self.y_scale.setSingleStep(0.1) 77 | self.y_scale.setMinimum(0.01) 78 | self.y_scale.setToolTip("Voxel size along Y axis") 79 | self.y_scale.setDecimals(3) 80 | 81 | grid.addWidget(self.y_label, 2, 0) 82 | grid.addWidget(self.y_axis, 2, 1) 83 | grid.addWidget(self.y_scale, 2, 2) 84 | 85 | # X row 86 | self.x_label = QLabel("X") 87 | self.x_axis = QLabel("2") 88 | self.axis_widgets.append(self.x_axis) 89 | self.x_scale = QDoubleSpinBox() 90 | self.x_scale.setValue(1.0) 91 | self.x_scale.setSingleStep(0.1) 92 | self.x_scale.setMinimum(0.01) 93 | self.x_scale.setToolTip("Voxel size along X axis") 94 | self.x_scale.setDecimals(3) 95 | 96 | grid.addWidget(self.x_label, 3, 0) 97 | grid.addWidget(self.x_axis, 3, 1) 98 | grid.addWidget(self.x_scale, 3, 2) 99 | 100 | # add "use z" checkbox above grid 101 | main_layout = QVBoxLayout() 102 | self.use_z = QCheckBox("3D data (use Z axis)") 103 | self.use_z.setVisible(False) 104 | self.use_z.setEnabled(False) 105 | self.use_z.stateChanged.connect(self.update_use_z) 106 | main_layout.addWidget(self.use_z) 107 | main_layout.addLayout(grid) 108 | 109 | dim_box.setLayout(main_layout) 110 | dim_box.setMaximumHeight(200) 111 | 112 | # features widget 113 | self.feature_properties = [ 114 | { 115 | "region_prop_name": "intensity_mean", 116 | "display_name": "Mean intensity", 117 | "enabled": False, 118 | "selected": False, 119 | "dims": [2, 3], 120 | }, 121 | { 122 | "region_prop_name": "area", 123 | "display_name": "Area", 124 | "enabled": True, 125 | "selected": True, 126 | "dims": [2], 127 | }, 128 | { 129 | "region_prop_name": "perimeter", 130 | "display_name": "Perimeter", 131 | "enabled": True, 132 | "selected": False, 133 | "dims": [2], 134 | }, 135 | { 136 | "region_prop_name": "circularity", 137 | "display_name": "Circularity", 138 | "enabled": True, 139 | "selected": False, 140 | "dims": [2], 141 | }, 142 | { 143 | "region_prop_name": "ellipse_axes", 144 | "display_name": "Ellipse axes", 145 | "enabled": True, 146 | "selected": False, 147 | "dims": [2], 148 | }, 149 | { 150 | "region_prop_name": "volume", 151 | "display_name": "Volume", 152 | "enabled": True, 153 | "selected": True, 154 | "dims": [3], 155 | }, 156 | { 157 | "region_prop_name": "surface_area", 158 | "display_name": "Surface area", 159 | "enabled": True, 160 | "selected": False, 161 | "dims": [3], 162 | }, 163 | { 164 | "region_prop_name": "sphericity", 165 | "display_name": "Sphericity", 166 | "enabled": True, 167 | "selected": False, 168 | "dims": [3], 169 | }, 170 | { 171 | "region_prop_name": "ellipsoid_axes", 172 | "display_name": "Ellipsoid axes", 173 | "enabled": True, 174 | "selected": False, 175 | "dims": [3], 176 | }, 177 | ] 178 | 179 | feature_box = QGroupBox("Features to measure") 180 | feature_box.setMaximumHeight(250) 181 | self.checkbox_layout = QVBoxLayout() 182 | self.checkboxes = [] 183 | self.intensity_image_dropdown = None 184 | 185 | feature_box.setLayout(self.checkbox_layout) 186 | 187 | # Push button to measure features 188 | self.measure_btn = QPushButton("Measure properties") 189 | self.measure_btn.clicked.connect(self.measure) 190 | 191 | ### Add widget for property filtering 192 | self.prop_filter_widget = PropertyFilterWidget(self.viewer, self.label_manager) 193 | self.prop_filter_widget.setVisible(False) 194 | 195 | # Add table layout 196 | self.regionprops_layout = QVBoxLayout() 197 | 198 | # Assemble layout 199 | main_box = QGroupBox("Region properties") 200 | main_layout = QVBoxLayout() 201 | main_layout.addWidget(dim_box) 202 | main_layout.addWidget(feature_box) 203 | main_layout.addWidget(self.measure_btn) 204 | main_layout.addWidget(self.prop_filter_widget) 205 | main_layout.addLayout(self.regionprops_layout) 206 | main_box.setLayout(main_layout) 207 | 208 | layout = QVBoxLayout() 209 | layout.addWidget(main_box) 210 | 211 | layout.setAlignment(Qt.AlignTop) 212 | 213 | self.setLayout(layout) 214 | 215 | # refresh dimensions based on current label layer, if any 216 | self.update_dims() 217 | 218 | def update_properties(self) -> None: 219 | if hasattr(self, "intensity_image_dropdown") and self.intensity_image_dropdown: 220 | self.intensity_image_dropdown.layer_changed.disconnect() 221 | self.intensity_image_dropdown.deleteLater() 222 | while self.checkbox_layout.count(): 223 | item = self.checkbox_layout.takeAt(0) 224 | widget = item.widget() 225 | if widget is not None: 226 | widget.deleteLater() 227 | self.intensity_image_dropdown = None 228 | self.checkboxes = [] 229 | 230 | # create checkbox for each feature 231 | self.properties = [ 232 | f for f in self.feature_properties if self.feature_dims in f["dims"] 233 | ] 234 | self.checkbox_state = { 235 | prop["region_prop_name"]: prop["selected"] for prop in self.properties 236 | } 237 | for prop in self.properties: 238 | if self.feature_dims in prop["dims"]: 239 | checkbox = QCheckBox(prop["display_name"]) 240 | checkbox.setEnabled(prop["enabled"]) 241 | checkbox.setStyleSheet("QCheckBox:disabled { color: grey }") 242 | checkbox.setChecked(self.checkbox_state[prop["region_prop_name"]]) 243 | checkbox.stateChanged.connect( 244 | lambda state, prop=prop: self.checkbox_state.update( 245 | {prop["region_prop_name"]: state == 2} 246 | ) 247 | ) 248 | self.checkboxes.append( 249 | {"region_prop_name": prop["region_prop_name"], "checkbox": checkbox} 250 | ) 251 | 252 | if prop["region_prop_name"] == "intensity_mean": 253 | self.intensity_image_dropdown = LayerDropdown( 254 | self.viewer, napari.layers.Image 255 | ) 256 | self.intensity_image_dropdown.layer_changed.connect( 257 | self._update_intensity_checkbox 258 | ) 259 | if self.intensity_image_dropdown.selected_layer is not None: 260 | checkbox.setEnabled(True) 261 | int_layout = QHBoxLayout() 262 | int_layout.addWidget(checkbox) 263 | int_layout.addWidget(self.intensity_image_dropdown) 264 | int_layout.setContentsMargins(0, 0, 0, 0) 265 | int_layout.setAlignment(Qt.AlignLeft | Qt.AlignVCenter) 266 | 267 | int_widget = QWidget() 268 | int_widget.setLayout(int_layout) 269 | 270 | # Enforce same height behavior as a normal checkbox 271 | int_widget.setSizePolicy(checkbox.sizePolicy()) 272 | self.intensity_image_dropdown.setSizePolicy( 273 | QSizePolicy.Expanding, QSizePolicy.Fixed 274 | ) 275 | 276 | self.checkbox_layout.addWidget(int_widget) 277 | else: 278 | self.checkbox_layout.addWidget(checkbox) 279 | 280 | def _update_intensity_checkbox(self) -> None: 281 | """Enable or disable the intensity_mean checkbox based on the selected layer.""" 282 | if self.intensity_image_dropdown is not None: 283 | checkbox = next( 284 | ( 285 | cb["checkbox"] 286 | for cb in self.checkboxes 287 | if cb["region_prop_name"] == "intensity_mean" 288 | ), 289 | None, 290 | ) 291 | if checkbox is not None: 292 | checkbox.setEnabled( 293 | isinstance( 294 | self.intensity_image_dropdown.selected_layer, 295 | napari.layers.Image, 296 | ) 297 | ) 298 | 299 | def update_use_z(self, state: int) -> None: 300 | self.z_label.setVisible(state == 2) 301 | self.z_axis.setVisible(state == 2) 302 | self.z_scale.setVisible(state == 2) 303 | self.z_axis.setEnabled(state == 2) 304 | self.z_scale.setEnabled(state == 2) 305 | self.feature_dims = 3 if state == 2 else 2 306 | self.update_dims() 307 | 308 | def update_dims(self) -> None: 309 | """Update the number of dimensions to measure based on the selected checkboxes""" 310 | 311 | if self.label_manager.selected_layer is not None: 312 | self.measure_btn.setEnabled(True) 313 | self.ndims = self.label_manager.selected_layer.ndim 314 | self.use_z.setVisible(self.ndims == 3) 315 | self.use_z.setEnabled(self.ndims == 3) 316 | if self.ndims == 4: 317 | self.feature_dims = 3 318 | self.z_axis.setVisible(True) 319 | self.z_label.setVisible(True) 320 | self.z_scale.setVisible(True) 321 | self.z_axis.setEnabled(True) 322 | self.z_scale.setEnabled(True) 323 | 324 | ax_names = [str(ax) for ax in range(self.ndims)] 325 | if len(ax_names) > 0: 326 | for i, widget in enumerate(self.axis_widgets): 327 | if self.ndims == 4: 328 | widget.setText(ax_names[i + 1]) 329 | elif self.ndims == 2: 330 | widget.setText(ax_names[i - 1]) 331 | else: 332 | widget.setText(ax_names[i]) 333 | 334 | self.update_properties() 335 | self.update_table() 336 | 337 | self.z_scale.setValue(self.label_manager.selected_layer.scale[-3]) if len( 338 | self.label_manager.selected_layer.scale 339 | ) >= 3 else self.z_scale.setValue(1.0) 340 | self.y_scale.setValue(self.label_manager.selected_layer.scale[-2]) 341 | self.x_scale.setValue(self.label_manager.selected_layer.scale[-1]) 342 | 343 | else: 344 | self.measure_btn.setEnabled(False) 345 | 346 | def measure(self): 347 | if self.use_z.isChecked() or self.ndims == 4: 348 | spacing = (self.z_scale.value(), self.y_scale.value(), self.x_scale.value()) 349 | else: 350 | spacing = (self.y_scale.value(), self.x_scale.value()) 351 | 352 | # ensure spacing is applied to the layer and the viewer step is updated 353 | layer_scale = list(self.label_manager.selected_layer.scale) 354 | layer_scale[-1] = spacing[-1] 355 | layer_scale[-2] = spacing[-2] 356 | if len(layer_scale) > 3: 357 | layer_scale[-3] = spacing[-3] 358 | old_step = list(self.viewer.dims.current_step) 359 | step_size = [dim_range.step for dim_range in self.viewer.dims.range] 360 | new_step = [ 361 | step * step_size 362 | for step, step_size in zip(old_step, step_size, strict=False) 363 | ] 364 | self.label_manager.selected_layer.scale = layer_scale 365 | self.viewer.reset_view() 366 | self.viewer.dims.current_step = new_step 367 | 368 | features = self.get_selected_features() 369 | 370 | if ( 371 | isinstance( 372 | self.intensity_image_dropdown.selected_layer, napari.layers.Image 373 | ) 374 | and "intensity_mean" in features 375 | ): 376 | intensity_image = self.intensity_image_dropdown.selected_layer.data 377 | else: 378 | intensity_image = None 379 | 380 | data = self.label_manager.selected_layer.data 381 | if intensity_image is not None and intensity_image.shape != data.shape: 382 | msg = QMessageBox() 383 | msg.setWindowTitle("Shape mismatch") 384 | msg.setText( 385 | f"Label layer and intensity image must have the same shape. Got {self.label_manager.selected_layer.data.shape} and {intensity_image.shape}." 386 | ) 387 | msg.setIcon(QMessageBox.Critical) 388 | msg.setStandardButtons(QMessageBox.Ok) 389 | msg.exec_() 390 | return 391 | if (self.use_z.isChecked() and self.ndims == 3) or ( 392 | not self.use_z.isChecked() and self.ndims == 2 393 | ): 394 | props = calculate_extended_props( 395 | data, 396 | intensity_image=intensity_image, 397 | properties=features, 398 | spacing=spacing, 399 | ) 400 | else: 401 | for i in tqdm(range(data.shape[0])): 402 | d = data[i].compute() if isinstance(data, da.core.Array) else data[i] 403 | if isinstance(intensity_image, da.core.Array): 404 | int_img = intensity_image[i].compute() 405 | elif isinstance(intensity_image, np.ndarray): 406 | int_img = intensity_image[i] 407 | else: 408 | int_img = None 409 | props_slice = calculate_extended_props( 410 | d, 411 | intensity_image=int_img, 412 | properties=features, 413 | spacing=spacing, 414 | ) 415 | props_slice["time_point"] = i 416 | if i == 0: 417 | props = props_slice 418 | else: 419 | props = pd.concat([props, props_slice], ignore_index=True) 420 | 421 | if hasattr(self.label_manager.selected_layer, "properties"): 422 | self.label_manager.selected_layer.properties = props 423 | self.prop_filter_widget.set_properties() 424 | self.update_table() 425 | 426 | def update_table(self) -> None: 427 | if self.table is not None: 428 | self.table.hide() 429 | self.prop_filter_widget.setVisible(False) 430 | 431 | if ( 432 | self.viewer is not None 433 | and self.label_manager.selected_layer is not None 434 | and len(self.label_manager.selected_layer.properties) > 0 435 | ): 436 | self.table = ColoredTableWidget( 437 | self.label_manager.selected_layer, self.viewer 438 | ) 439 | self.table.setMinimumWidth(500) 440 | self.regionprops_layout.addWidget(self.table) 441 | self.prop_filter_widget.setVisible(True) 442 | 443 | def get_selected_features(self) -> list[str]: 444 | """Return a list of the features that have been selected""" 445 | 446 | selected_features = [ 447 | key for key in self.checkbox_state if self.checkbox_state[key] 448 | ] 449 | if "intensity_mean" in selected_features and not isinstance( 450 | self.intensity_image_dropdown.selected_layer, napari.layers.Image 451 | ): 452 | selected_features.remove("intensity_mean") 453 | selected_features.append("label") # always include label 454 | selected_features.append("centroid") 455 | return selected_features 456 | 457 | def set_selected_features(self, features: list[str]) -> None: 458 | """Set the selected features based on the input list""" 459 | 460 | for checkbox in self.checkboxes: 461 | checkbox["checkbox"].setChecked(checkbox["region_prop_name"] in features) 462 | -------------------------------------------------------------------------------- /src/napari_segmentation_correction/save_labels_widget.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import dask.array as da 4 | import napari 5 | import numpy as np 6 | import tifffile 7 | from napari.layers import Labels 8 | from qtpy.QtWidgets import ( 9 | QCheckBox, 10 | QComboBox, 11 | QFileDialog, 12 | QGroupBox, 13 | QHBoxLayout, 14 | QLineEdit, 15 | QPushButton, 16 | QVBoxLayout, 17 | QWidget, 18 | ) 19 | 20 | from .layer_manager import LayerManager 21 | 22 | 23 | class SaveLabelsWidget(QWidget): 24 | """Widget for saving label data with options for datatype, compression, and whether to split time points.""" 25 | 26 | def __init__( 27 | self, viewer: "napari.viewer.Viewer", label_manager: LayerManager 28 | ) -> None: 29 | super().__init__() 30 | 31 | self.viewer = viewer 32 | self.label_manager = label_manager 33 | 34 | # Select datatype 35 | self.select_dtype = QComboBox() 36 | self.select_dtype.addItem("np.uint8") 37 | self.select_dtype.addItem("np.uint16") 38 | self.select_dtype.addItem("np.uint32") 39 | self.select_dtype.addItem("np.uint64") 40 | self.select_dtype.setToolTip("File bit depth for saving.") 41 | 42 | # Split time points 43 | self.split_time_points = QCheckBox("Split time points") 44 | self.split_time_points.setEnabled( 45 | isinstance(self.label_manager.selected_layer, napari.layers.Labels) 46 | and self.label_manager.selected_layer.data.ndim >= 3 47 | ) 48 | self.label_manager.layer_update.connect( 49 | lambda: self.split_time_points.setEnabled( 50 | isinstance(self.label_manager.selected_layer, napari.layers.Labels) 51 | and self.label_manager.selected_layer.data.ndim >= 3 52 | ) 53 | ) 54 | self.split_time_points.setToolTip( 55 | "Saves each time point to a separate file. Assumes that the time dimension is along the first axis." 56 | ) 57 | 58 | # Use compression 59 | self.use_compression = QCheckBox("Use compression") 60 | self.use_compression.setChecked(True) 61 | self.use_compression.setToolTip( 62 | "Use zstd compression? This may take a bit longer to save." 63 | ) 64 | 65 | # Filename 66 | self.filename = QLineEdit() 67 | self.filename.setPlaceholderText("File name") 68 | self.label_manager.layer_update.connect(self.update_filename) 69 | self.filename.setToolTip("Filename for saving labels") 70 | 71 | ## Add save button 72 | self.save_btn = QPushButton("Save labels") 73 | self.save_btn.clicked.connect(self._save_labels) 74 | self.save_btn.setEnabled(isinstance(self.label_manager.selected_layer, Labels)) 75 | self.label_manager.layer_update.connect( 76 | lambda: self.save_btn.setEnabled( 77 | isinstance(self.label_manager.selected_layer, napari.layers.Labels) 78 | ) 79 | ) 80 | 81 | # Combine layouts 82 | save_box = QGroupBox("Save labels") 83 | layout = QVBoxLayout() 84 | 85 | settings_layout = QHBoxLayout() 86 | settings_layout.addWidget(self.select_dtype) 87 | settings_layout.addWidget(self.split_time_points) 88 | settings_layout.addWidget(self.use_compression) 89 | layout.addLayout(settings_layout) 90 | layout.addWidget(self.filename) 91 | layout.addWidget(self.save_btn) 92 | save_box.setLayout(layout) 93 | 94 | main_layout = QVBoxLayout() 95 | main_layout.addWidget(save_box) 96 | self.setLayout(main_layout) 97 | 98 | def update_filename(self) -> None: 99 | """Update the default filename""" 100 | 101 | if isinstance(self.label_manager.selected_layer, napari.layers.Labels): 102 | self.filename.setText(self.label_manager.selected_layer.name) 103 | 104 | def _save_labels(self) -> None: 105 | """Save the currently active labels layer. If it consists of multiple timepoints, they are written to multiple 3D stacks.""" 106 | 107 | data = self.label_manager.selected_layer.data 108 | ndim = data.ndim 109 | split_time_points = ndim >= 3 and self.split_time_points.isChecked() 110 | dtype_map = { 111 | "np.uint8": np.uint8, 112 | "np.uint16": np.uint16, 113 | "np.uint32": np.uint32, 114 | "np.uint64": np.uint64, 115 | } 116 | dtype = dtype_map[self.select_dtype.currentText()] 117 | use_compression = self.use_compression.isChecked() 118 | filename = self.filename.text() 119 | 120 | destination = QFileDialog.getExistingDirectory(self, "Select Output Folder") 121 | 122 | if ndim >= 3 and split_time_points: 123 | for i in range(data.shape[0]): 124 | if isinstance(data, da.core.Array): 125 | current_stack = data[i].compute().astype(dtype) 126 | else: 127 | current_stack = data[i].astype(dtype) 128 | 129 | tifffile.imwrite( 130 | ( 131 | os.path.join( 132 | destination, 133 | ( 134 | filename.split(".tif")[0] 135 | + "_TP" 136 | + str(i).zfill(4) 137 | + ".tif" 138 | ), 139 | ) 140 | ), 141 | current_stack, 142 | compression="zstd" if use_compression else None, 143 | ) 144 | 145 | else: 146 | tifffile.imwrite( 147 | os.path.join(destination, (filename + ".tif")), 148 | data.astype(dtype), 149 | compression="zstd" if use_compression else None, 150 | ) 151 | -------------------------------------------------------------------------------- /src/napari_segmentation_correction/select_delete_widget.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import functools 3 | import os 4 | import shutil 5 | 6 | import dask.array as da 7 | import napari 8 | import numpy as np 9 | import tifffile 10 | from napari.layers import Labels 11 | from qtpy.QtWidgets import ( 12 | QCheckBox, 13 | QFileDialog, 14 | QGroupBox, 15 | QHBoxLayout, 16 | QLabel, 17 | QMessageBox, 18 | QPushButton, 19 | QVBoxLayout, 20 | QWidget, 21 | ) 22 | from skimage.io import imread 23 | 24 | from .layer_dropdown import LayerDropdown 25 | 26 | 27 | class SelectDeleteMask(QWidget): 28 | """Widget to select labels to keep or to delete based on overlap with a mask.""" 29 | 30 | def __init__(self, viewer: "napari.viewer.Viewer") -> None: 31 | super().__init__() 32 | 33 | self.viewer = viewer 34 | self.image1_layer = None 35 | self.mask_layer = None 36 | self.outputdir = None 37 | 38 | ### Add one image to another 39 | select_delete_box = QGroupBox("Select / Delete labels by mask") 40 | select_delete_box_layout = QVBoxLayout() 41 | 42 | image1_layout = QHBoxLayout() 43 | image1_layout.addWidget(QLabel("Labels")) 44 | self.image1_dropdown = LayerDropdown(self.viewer, (Labels)) 45 | self.image1_dropdown.layer_changed.connect(self._update_image1) 46 | image1_layout.addWidget(self.image1_dropdown) 47 | 48 | image2_layout = QHBoxLayout() 49 | image2_layout.addWidget(QLabel("Mask")) 50 | self.mask_dropdown = LayerDropdown(self.viewer, (Labels)) 51 | self.mask_dropdown.layer_changed.connect(self._update_image2) 52 | image2_layout.addWidget(self.mask_dropdown) 53 | 54 | select_delete_box_layout.addLayout(image1_layout) 55 | select_delete_box_layout.addLayout(image2_layout) 56 | 57 | self.stack_checkbox = QCheckBox("Apply 3D mask to all time points in 4D array") 58 | self.stack_checkbox.setEnabled(False) 59 | select_delete_box_layout.addWidget(self.stack_checkbox) 60 | 61 | self.select_btn = QPushButton("Select labels") 62 | self.select_btn.clicked.connect(self.select_labels) 63 | select_delete_box_layout.addWidget(self.select_btn) 64 | 65 | self.delete_btn = QPushButton("Delete labels") 66 | self.delete_btn.clicked.connect(self.delete_labels) 67 | select_delete_box_layout.addWidget(self.delete_btn) 68 | 69 | self.image1_dropdown.layer_changed.connect(self._update_buttons) 70 | self.mask_dropdown.layer_changed.connect(self._update_buttons) 71 | self._update_buttons() 72 | 73 | select_delete_box.setLayout(select_delete_box_layout) 74 | main_layout = QVBoxLayout() 75 | main_layout.addWidget(select_delete_box) 76 | self.setLayout(main_layout) 77 | 78 | def _update_buttons(self) -> None: 79 | """Update button state according to whether image layers are present""" 80 | 81 | active = ( 82 | self.image1_dropdown.selected_layer is not None 83 | and self.mask_dropdown.selected_layer is not None 84 | ) 85 | self.select_btn.setEnabled(active) 86 | self.delete_btn.setEnabled(active) 87 | 88 | def _update_image1(self, selected_layer: str) -> None: 89 | """Update the layer that is set to be the 'source labels' layer for copying labels from.""" 90 | 91 | if selected_layer == "": 92 | self.image1_layer = None 93 | else: 94 | self.image1_layer = self.viewer.layers[selected_layer] 95 | self.image1_dropdown.setCurrentText(selected_layer) 96 | 97 | # update the checkbox and buttons as needed needed 98 | if self.mask_layer is not None and self.image1_layer is not None: 99 | self.select_btn.setEnabled(True) 100 | self.delete_btn.setEnabled(True) 101 | if len(self.image1_layer.data.shape) == len(self.mask_layer.data.shape) + 1: 102 | self.stack_checkbox.setEnabled(True) 103 | else: 104 | self.stack_checkbox.setEnabled(False) 105 | self.stack_checkbox.setCheckState(False) 106 | else: 107 | self.select_btn.setEnabled(False) 108 | self.delete_btn.setEnabled(False) 109 | self.stack_checkbox.setEnabled(False) 110 | self.stack_checkbox.setCheckState(False) 111 | 112 | def _update_image2(self, selected_layer: str) -> None: 113 | """Update the layer that is set to be the 'source labels' layer for copying labels from.""" 114 | 115 | if selected_layer == "": 116 | self.mask_layer = None 117 | else: 118 | self.mask_layer = self.viewer.layers[selected_layer] 119 | self.mask_dropdown.setCurrentText(selected_layer) 120 | 121 | # update the checkbox and buttons as needed 122 | if self.mask_layer is not None and self.image1_layer is not None: 123 | self.select_btn.setEnabled(True) 124 | self.delete_btn.setEnabled(True) 125 | if len(self.image1_layer.data.shape) == len(self.mask_layer.data.shape) + 1: 126 | self.stack_checkbox.setEnabled(True) 127 | else: 128 | self.stack_checkbox.setEnabled(False) 129 | self.stack_checkbox.setCheckState(False) 130 | else: 131 | self.select_btn.setEnabled(False) 132 | self.delete_btn.setEnabled(False) 133 | self.stack_checkbox.setEnabled(False) 134 | self.stack_checkbox.setCheckState(False) 135 | 136 | def select_labels(self): 137 | # check data dimensions first 138 | image_shape = self.image1_layer.data.shape 139 | mask_shape = self.mask_layer.data.shape 140 | 141 | if len(image_shape) == len(mask_shape) + 1 and image_shape[1:] == mask_shape: 142 | # apply mask to single time point or to full stack depending on checkbox state 143 | if self.stack_checkbox.isChecked(): 144 | # loop over all time points 145 | print("applying the mask to all time points") 146 | # check if the data is a dask array 147 | if isinstance(self.image1_layer.data, da.core.Array): 148 | if self.outputdir is None: 149 | self.outputdir = QFileDialog.getExistingDirectory( 150 | self, "Select Output Folder" 151 | ) 152 | 153 | outputdir = os.path.join( 154 | self.outputdir, 155 | (self.image1_layer.name + "_filtered_labels"), 156 | ) 157 | if os.path.exists(outputdir): 158 | shutil.rmtree(outputdir) 159 | os.mkdir(outputdir) 160 | 161 | for i in range( 162 | self.image1_layer.data.shape[0] 163 | ): # Loop over the first dimension 164 | current_stack = self.image1_layer.data[ 165 | i 166 | ].compute() # Compute the current stack 167 | 168 | to_keep = np.unique(current_stack[self.mask_layer.data > 0]) 169 | filtered_mask = functools.reduce( 170 | np.logical_or, (current_stack == val for val in to_keep) 171 | ) 172 | filtered_data = np.where(filtered_mask, current_stack, 0) 173 | 174 | tifffile.imwrite( 175 | os.path.join( 176 | outputdir, 177 | ( 178 | self.image1_layer.name 179 | + "_filtered_labels_TP" 180 | + str(i).zfill(4) 181 | + ".tif" 182 | ), 183 | ), 184 | np.array(filtered_data, dtype="uint16"), 185 | ) 186 | 187 | file_list = [ 188 | os.path.join(outputdir, fname) 189 | for fname in os.listdir(outputdir) 190 | if fname.endswith(".tif") 191 | ] 192 | self.image1_layer = self.viewer.add_labels( 193 | da.stack([imread(fname) for fname in sorted(file_list)]), 194 | name=self.image1_layer.name + "_filtered_labels", 195 | scale=self.image1_layer.scale, 196 | ) 197 | 198 | else: 199 | for tp in range(self.image1_layer.data.shape[0]): 200 | to_keep = np.unique( 201 | self.image1_layer.data[tp][self.mask_layer.data > 0] 202 | ) 203 | filtered_mask = functools.reduce( 204 | np.logical_or, 205 | (self.image1_layer.data[tp] == val for val in to_keep), 206 | ) 207 | filtered_data_tp = np.where( 208 | filtered_mask, self.image1_layer.data[tp], 0 209 | ) 210 | self.image1_layer.data[tp] = filtered_data_tp 211 | 212 | else: 213 | tp = self.viewer.dims.current_step[0] 214 | print("applying the mask to the current time point only", tp) 215 | if isinstance(self.image1_layer.data, da.core.Array): 216 | outputdir = QFileDialog.getExistingDirectory( 217 | self, 218 | "Please select the directory that holds the images. Data will be changed here. Selecting a new empty directory will create a copy of all data", 219 | ) 220 | 221 | if len(os.listdir(outputdir)) == 0: 222 | for i in range( 223 | self.image1_layer.data.shape[0] 224 | ): # Loop over the first dimension 225 | current_stack = self.image1_layer.data[ 226 | i 227 | ].compute() # Compute the current stack 228 | 229 | if i == tp: 230 | to_keep = np.unique( 231 | current_stack[self.mask_layer.data > 0] 232 | ) 233 | filtered_mask = functools.reduce( 234 | np.logical_or, 235 | (current_stack == val for val in to_keep), 236 | ) 237 | current_stack = np.where( 238 | filtered_mask, current_stack, 0 239 | ) 240 | tifffile.imwrite( 241 | os.path.join( 242 | outputdir, 243 | ( 244 | self.image1_layer.name 245 | + "_filtered_labels_TP" 246 | + str(i).zfill(4) 247 | + ".tif" 248 | ), 249 | ), 250 | np.array(current_stack, dtype="uint16"), 251 | ) 252 | 253 | file_list = sorted( 254 | [ 255 | os.path.join(outputdir, fname) 256 | for fname in os.listdir(outputdir) 257 | if fname.endswith(".tif") 258 | ] 259 | ) 260 | else: 261 | current_stack = self.image1_layer.data[ 262 | tp 263 | ].compute() # Compute the current stack 264 | 265 | to_keep = np.unique(current_stack[self.mask_layer.data > 0]) 266 | filtered_mask = functools.reduce( 267 | np.logical_or, (current_stack == val for val in to_keep) 268 | ) 269 | current_stack = np.where(filtered_mask, current_stack, 0) 270 | 271 | file_list = sorted( 272 | [ 273 | os.path.join(outputdir, fname) 274 | for fname in os.listdir(outputdir) 275 | if fname.endswith(".tif") 276 | ] 277 | ) 278 | 279 | tifffile.imwrite( 280 | file_list[tp], 281 | np.array(current_stack, dtype="uint16"), 282 | ) 283 | 284 | self.image1_layer = self.viewer.add_labels( 285 | da.stack([imread(fname) for fname in file_list]), 286 | name=self.image1_layer.name + "_filtered_labels", 287 | scale=self.image1_layer.scale, 288 | ) 289 | 290 | else: 291 | tp = self.viewer.dims.current_step[0] 292 | to_keep = np.unique( 293 | self.image1_layer.data[tp][self.mask_layer.data > 0] 294 | ) 295 | filtered_mask = functools.reduce( 296 | np.logical_or, 297 | (self.image1_layer.data[tp] == val for val in to_keep), 298 | ) 299 | filtered_data_tp = np.where( 300 | filtered_mask, self.image1_layer.data[tp], 0 301 | ) 302 | self.image1_layer.data[tp] = filtered_data_tp 303 | 304 | elif image_shape == mask_shape: 305 | if isinstance(self.image1_layer.data, da.core.Array): 306 | if self.outputdir is None: 307 | self.outputdir = QFileDialog.getExistingDirectory( 308 | self, "Select Output Folder" 309 | ) 310 | 311 | outputdir = os.path.join( 312 | self.outputdir, 313 | (self.image1_layer.name + "_filtered_labels"), 314 | ) 315 | if os.path.exists(outputdir): 316 | shutil.rmtree(outputdir) 317 | os.mkdir(outputdir) 318 | 319 | for i in range( 320 | self.image1_layer.data.shape[0] 321 | ): # Loop over the first dimension 322 | current_stack = self.image1_layer.data[ 323 | i 324 | ].compute() # Compute the current stack 325 | 326 | if isinstance(self.mask_layer.data, da.core.Array): 327 | to_keep = np.unique( 328 | current_stack[self.mask_layer.data[i].compute() > 0] 329 | ) 330 | else: 331 | to_keep = np.unique(current_stack[self.mask_layer.data[i] > 0]) 332 | 333 | filtered_mask = functools.reduce( 334 | np.logical_or, (current_stack == val for val in to_keep) 335 | ) 336 | filtered_data_tp = np.where(filtered_mask, current_stack, 0) 337 | 338 | tifffile.imwrite( 339 | os.path.join( 340 | outputdir, 341 | ( 342 | self.image1_layer.name 343 | + "_filtered_labels_TP" 344 | + str(i).zfill(4) 345 | + ".tif" 346 | ), 347 | ), 348 | np.array(filtered_data_tp, dtype="uint16"), 349 | ) 350 | 351 | file_list = [ 352 | os.path.join(outputdir, fname) 353 | for fname in os.listdir(outputdir) 354 | if fname.endswith(".tif") 355 | ] 356 | self.image1_layer = self.viewer.add_labels( 357 | da.stack([imread(fname) for fname in sorted(file_list)]), 358 | name=self.image1_layer.name + "_filtered_labels", 359 | scale=self.image1_layer.scale, 360 | ) 361 | else: 362 | to_keep = np.unique(self.image1_layer.data[self.mask_layer.data > 0]) 363 | filtered_mask = functools.reduce( 364 | np.logical_or, (self.image1_layer.data == val for val in to_keep) 365 | ) 366 | self.viewer.add_labels( 367 | np.where(filtered_mask, self.image1_layer.data, 0), 368 | name="selected labels", 369 | scale=self.image1_layer.scale, 370 | ) 371 | 372 | else: 373 | msg = QMessageBox() 374 | msg.setWindowTitle("Images do not have compatible shapes") 375 | msg.setText("Please provide images that have matching dimensions") 376 | msg.setIcon(QMessageBox.Information) 377 | msg.setStandardButtons(QMessageBox.Ok) 378 | msg.exec_() 379 | return False 380 | 381 | def delete_labels(self): 382 | """Delete labels that overlap with given mask. If the shape of the mask has 1 dimension less than the image, the mask will be applied to the current time point (index in the first dimension) of the image data.""" 383 | 384 | if isinstance(self.mask_layer.data, da.core.Array): 385 | msg = QMessageBox() 386 | msg.setWindowTitle("Please provide a mask that is not a Dask array") 387 | msg.setText("Please provide a mask that is not a Dask array") 388 | msg.setIcon(QMessageBox.Information) 389 | msg.setStandardButtons(QMessageBox.Ok) 390 | msg.exec_() 391 | return False 392 | 393 | # check data dimensions first 394 | image_shape = self.image1_layer.data.shape 395 | mask_shape = self.mask_layer.data.shape 396 | 397 | if len(image_shape) == len(mask_shape) + 1 and image_shape[1:] == mask_shape: 398 | # apply mask to single time point or to full stack depending on checkbox state 399 | if self.stack_checkbox.isChecked(): 400 | # loop over all time points 401 | print("applying the mask to all time points") 402 | # check if the data is a dask array 403 | if isinstance(self.image1_layer.data, da.core.Array): 404 | if self.outputdir is None: 405 | self.outputdir = QFileDialog.getExistingDirectory( 406 | self, "Select Output Folder" 407 | ) 408 | 409 | outputdir = os.path.join( 410 | self.outputdir, 411 | (self.image1_layer.name + "_filtered_labels"), 412 | ) 413 | if os.path.exists(outputdir): 414 | shutil.rmtree(outputdir) 415 | os.mkdir(outputdir) 416 | 417 | for i in range( 418 | self.image1_layer.data.shape[0] 419 | ): # Loop over the first dimension 420 | current_stack = self.image1_layer.data[ 421 | i 422 | ].compute() # Compute the current stack 423 | 424 | to_delete = np.unique(current_stack[self.mask_layer.data > 0]) 425 | for label in to_delete: 426 | current_stack[current_stack == label] = 0 427 | tifffile.imwrite( 428 | os.path.join( 429 | outputdir, 430 | ( 431 | self.image1_layer.name 432 | + "_filtered_labels_TP" 433 | + str(i).zfill(4) 434 | + ".tif" 435 | ), 436 | ), 437 | np.array(current_stack, dtype="uint16"), 438 | ) 439 | 440 | file_list = [ 441 | os.path.join(outputdir, fname) 442 | for fname in os.listdir(outputdir) 443 | if fname.endswith(".tif") 444 | ] 445 | self.image1_layer = self.viewer.add_labels( 446 | da.stack([imread(fname) for fname in sorted(file_list)]), 447 | name=self.image1_layer.name + "_filtered_labels", 448 | scale=self.image1_layer.scale, 449 | ) 450 | 451 | else: 452 | for tp in range(self.image1_layer.data.shape[0]): 453 | to_delete = np.unique( 454 | self.image1_layer.data[tp][self.mask_layer.data > 0] 455 | ) 456 | for label in to_delete: 457 | self.image1_layer.data[tp][ 458 | self.image1_layer.data[tp] == label 459 | ] = 0 460 | 461 | else: 462 | tp = self.viewer.dims.current_step[0] 463 | if isinstance(self.image1_layer.data, da.core.Array): 464 | outputdir = QFileDialog.getExistingDirectory( 465 | self, 466 | "Please select the directory that holds the images. Data will be changed here. Selecting a new empty directory will create a copy of all data", 467 | ) 468 | 469 | if len(os.listdir(outputdir)) == 0: 470 | for i in range( 471 | self.image1_layer.data.shape[0] 472 | ): # Loop over the first dimension 473 | current_stack = self.image1_layer.data[ 474 | i 475 | ].compute() # Compute the current stack 476 | 477 | if i == tp: 478 | to_delete = np.unique( 479 | current_stack[self.mask_layer.data > 0] 480 | ) 481 | for label in to_delete: 482 | current_stack[current_stack == label] = 0 483 | tifffile.imwrite( 484 | os.path.join( 485 | outputdir, 486 | ( 487 | self.image1_layer.name 488 | + "_filtered_labels_TP" 489 | + str(i).zfill(4) 490 | + ".tif" 491 | ), 492 | ), 493 | np.array(current_stack, dtype="uint16"), 494 | ) 495 | 496 | file_list = sorted( 497 | [ 498 | os.path.join(outputdir, fname) 499 | for fname in os.listdir(outputdir) 500 | if fname.endswith(".tif") 501 | ] 502 | ) 503 | else: 504 | current_stack = self.image1_layer.data[ 505 | tp 506 | ].compute() # Compute the current stack 507 | to_delete = np.unique(current_stack[self.mask_layer.data > 0]) 508 | for label in to_delete: 509 | current_stack[current_stack == label] = 0 510 | 511 | file_list = sorted( 512 | [ 513 | os.path.join(outputdir, fname) 514 | for fname in os.listdir(outputdir) 515 | if fname.endswith(".tif") 516 | ] 517 | ) 518 | 519 | tifffile.imwrite( 520 | file_list[tp], 521 | np.array(current_stack, dtype="uint16"), 522 | ) 523 | 524 | self.image1_layer = self.viewer.add_labels( 525 | da.stack([imread(fname) for fname in file_list]), 526 | name=self.image1_layer.name + "_filtered_labels", 527 | scale=self.image1_layer.scale, 528 | ) 529 | 530 | else: 531 | tp = self.viewer.dims.current_step[0] 532 | to_delete = np.unique( 533 | self.image1_layer.data[tp][self.mask_layer.data > 0] 534 | ) 535 | for label in to_delete: 536 | self.image1_layer.data[tp][ 537 | self.image1_layer.data[tp] == label 538 | ] = 0 539 | 540 | elif image_shape == mask_shape: 541 | if isinstance(self.image1_layer.data, da.core.Array): 542 | if self.outputdir is None: 543 | self.outputdir = QFileDialog.getExistingDirectory( 544 | self, "Select Output Folder" 545 | ) 546 | 547 | outputdir = os.path.join( 548 | self.outputdir, 549 | (self.image1_layer.name + "_filtered_labels"), 550 | ) 551 | if os.path.exists(outputdir): 552 | shutil.rmtree(outputdir) 553 | os.mkdir(outputdir) 554 | 555 | for i in range( 556 | self.image1_layer.data.shape[0] 557 | ): # Loop over the first dimension 558 | current_stack = self.image1_layer.data[ 559 | i 560 | ].compute() # Compute the current stack 561 | 562 | to_delete = np.unique(current_stack[self.mask_layer.data[tp] > 0]) 563 | for label in to_delete: 564 | current_stack[current_stack == label] = 0 565 | tifffile.imwrite( 566 | os.path.join( 567 | outputdir, 568 | ( 569 | self.image1_layer.name 570 | + "_filtered_labels_TP" 571 | + str(i).zfill(4) 572 | + ".tif" 573 | ), 574 | ), 575 | np.array(current_stack, dtype="uint16"), 576 | ) 577 | 578 | file_list = [ 579 | os.path.join(outputdir, fname) 580 | for fname in os.listdir(outputdir) 581 | if fname.endswith(".tif") 582 | ] 583 | self.image1_layer = self.viewer.add_labels( 584 | da.stack([imread(fname) for fname in sorted(file_list)]), 585 | name=self.image1_layer.name + "_filtered_labels", 586 | scale=self.image1_layer.scale, 587 | ) 588 | else: 589 | to_delete = np.unique(self.image1_layer.data[self.mask_layer.data > 0]) 590 | selected_labels = copy.deepcopy(self.image1_layer.data) 591 | for label in to_delete: 592 | selected_labels[selected_labels == label] = 0 593 | self.viewer.add_labels( 594 | selected_labels, 595 | name=self.image1_layer.name + "_filtered_labels", 596 | scale=self.image1_layer.scale, 597 | ) 598 | 599 | else: 600 | msg = QMessageBox() 601 | msg.setWindowTitle("Images do not have compatible shapes") 602 | msg.setText("Please provide images that have matching dimensions") 603 | msg.setIcon(QMessageBox.Information) 604 | msg.setStandardButtons(QMessageBox.Ok) 605 | msg.exec_() 606 | return False 607 | -------------------------------------------------------------------------------- /src/napari_segmentation_correction/smoothing_widget.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | 4 | import dask.array as da 5 | import napari 6 | import numpy as np 7 | import tifffile 8 | from qtpy.QtWidgets import ( 9 | QGroupBox, 10 | QHBoxLayout, 11 | QLabel, 12 | QMessageBox, 13 | QPushButton, 14 | QSpinBox, 15 | QVBoxLayout, 16 | QWidget, 17 | ) 18 | from scipy import ndimage 19 | from skimage.io import imread 20 | 21 | from .layer_manager import LayerManager 22 | 23 | 24 | class SmoothingWidget(QWidget): 25 | """Widget that 'smooths' labels by applying a median filter""" 26 | 27 | def __init__( 28 | self, viewer: "napari.viewer.Viewer", label_manager: LayerManager 29 | ) -> None: 30 | super().__init__() 31 | 32 | self.viewer = viewer 33 | self.label_manager = label_manager 34 | self.outputdir = None 35 | 36 | smoothbox = QGroupBox("Smooth objects") 37 | smooth_boxlayout = QVBoxLayout() 38 | 39 | smooth_layout = QHBoxLayout() 40 | self.median_radius_field = QSpinBox() 41 | self.median_radius_field.setMaximum(100) 42 | self.smooth_btn = QPushButton("Smooth") 43 | smooth_layout.addWidget(self.median_radius_field) 44 | smooth_layout.addWidget(self.smooth_btn) 45 | 46 | smooth_boxlayout.addWidget(QLabel("Median filter radius")) 47 | smooth_boxlayout.addLayout(smooth_layout) 48 | 49 | self.smooth_btn.clicked.connect(self._smooth_objects) 50 | self.smooth_btn.setEnabled( 51 | isinstance(self.label_manager.selected_layer, napari.layers.Labels) 52 | ) 53 | self.label_manager.layer_update.connect( 54 | lambda: self.smooth_btn.setEnabled( 55 | isinstance(self.label_manager.selected_layer, napari.layers.Labels) 56 | ) 57 | ) 58 | 59 | smoothbox.setLayout(smooth_boxlayout) 60 | layout = QVBoxLayout() 61 | layout.addWidget(smoothbox) 62 | self.setLayout(layout) 63 | 64 | def _smooth_objects(self) -> None: 65 | """Smooth objects by using a median filter.""" 66 | 67 | if isinstance(self.label_manager.selected_layer.data, da.core.Array): 68 | if self.outputdir is None: 69 | msg = QMessageBox() 70 | msg.setWindowTitle("No output directory selected") 71 | msg.setText("Please specify an output directory first!") 72 | msg.setIcon(QMessageBox.Information) 73 | msg.setStandardButtons(QMessageBox.Ok) 74 | msg.exec_() 75 | return False 76 | 77 | else: 78 | outputdir = os.path.join( 79 | self.outputdir, 80 | (self.label_manager.selected_layer.name + "_smoothed"), 81 | ) 82 | if os.path.exists(outputdir): 83 | shutil.rmtree(outputdir) 84 | os.mkdir(outputdir) 85 | 86 | for i in range( 87 | self.label_manager.selected_layer.data.shape[0] 88 | ): # Loop over the first dimension 89 | current_stack = self.label_manager.selected_layer.data[ 90 | i 91 | ].compute() # Compute the current stack 92 | smoothed = ndimage.median_filter( 93 | current_stack, size=self.median_radius_field.value() 94 | ) 95 | tifffile.imwrite( 96 | os.path.join( 97 | outputdir, 98 | ( 99 | self.label_manager.selected_layer.name 100 | + "_smoothed_TP" 101 | + str(i).zfill(4) 102 | + ".tif" 103 | ), 104 | ), 105 | np.array(smoothed, dtype="uint16"), 106 | ) 107 | 108 | file_list = [ 109 | os.path.join(outputdir, fname) 110 | for fname in os.listdir(outputdir) 111 | if fname.endswith(".tif") 112 | ] 113 | self.label_manager.selected_layer = self.viewer.add_labels( 114 | da.stack([imread(fname) for fname in sorted(file_list)]), 115 | name=self.label_manager.selected_layer.name + "_smoothed", 116 | scale=self.label_manager.selected_layer.scale, 117 | ) 118 | 119 | else: 120 | if len(self.label_manager.selected_layer.data.shape) == 4: 121 | stack = [] 122 | for i in range(self.label_manager.selected_layer.data.shape[0]): 123 | smoothed = ndimage.median_filter( 124 | self.label_manager.selected_layer.data[i], 125 | size=self.median_radius_field.value(), 126 | ) 127 | stack.append(smoothed) 128 | self.label_manager.selected_layer = self.viewer.add_labels( 129 | np.stack(stack, axis=0), 130 | name=self.label_manager.selected_layer.name + "_smoothed", 131 | scale=self.label_manager.selected_layer.scale, 132 | ) 133 | 134 | elif self.label_manager.selected_layer.data.ndim in (2, 3): 135 | self.label_manager.selected_layer = self.viewer.add_labels( 136 | ndimage.median_filter( 137 | self.label_manager.selected_layer.data, 138 | size=self.median_radius_field.value(), 139 | ), 140 | name=self.label_manager.selected_layer.name + "_smoothed", 141 | scale=self.label_manager.selected_layer.scale, 142 | ) 143 | else: 144 | print("input should be a 2D, 3D or 4D array") 145 | -------------------------------------------------------------------------------- /src/napari_segmentation_correction/threshold_widget.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | 4 | import dask.array as da 5 | import napari 6 | import numpy as np 7 | import tifffile 8 | from napari.layers import Image, Labels 9 | from qtpy.QtWidgets import ( 10 | QFileDialog, 11 | QGroupBox, 12 | QHBoxLayout, 13 | QLabel, 14 | QPushButton, 15 | QSpinBox, 16 | QVBoxLayout, 17 | QWidget, 18 | ) 19 | from skimage.io import imread 20 | 21 | from .layer_dropdown import LayerDropdown 22 | 23 | 24 | class ThresholdWidget(QWidget): 25 | """Widget that applies a threshold to an image or labels layer""" 26 | 27 | def __init__(self, viewer: "napari.viewer.Viewer") -> None: 28 | super().__init__() 29 | 30 | self.viewer = viewer 31 | self.outputdir = None 32 | self.threshold_layer = None 33 | 34 | threshold_box = QGroupBox("Threshold") 35 | threshold_box_layout = QVBoxLayout() 36 | 37 | self.threshold_layer_dropdown = LayerDropdown( 38 | self.viewer, (Image, Labels) 39 | ) 40 | self.threshold_layer_dropdown.layer_changed.connect( 41 | self._update_threshold_layer 42 | ) 43 | threshold_box_layout.addWidget(self.threshold_layer_dropdown) 44 | 45 | min_threshold_layout = QHBoxLayout() 46 | min_threshold_layout.addWidget(QLabel("Min value")) 47 | self.min_threshold = QSpinBox() 48 | self.min_threshold.setMaximum(65535) 49 | min_threshold_layout.addWidget(self.min_threshold) 50 | 51 | max_threshold_layout = QHBoxLayout() 52 | max_threshold_layout.addWidget(QLabel("Max value")) 53 | self.max_threshold = QSpinBox() 54 | self.max_threshold.setMaximum(65535) 55 | self.max_threshold.setValue(65535) 56 | max_threshold_layout.addWidget(self.max_threshold) 57 | 58 | threshold_box_layout.addLayout(min_threshold_layout) 59 | threshold_box_layout.addLayout(max_threshold_layout) 60 | threshold_btn = QPushButton("Run") 61 | threshold_btn.clicked.connect(self._threshold) 62 | 63 | threshold_btn.setEnabled(isinstance(self.threshold_layer_dropdown.selected_layer, napari.layers.Labels | napari.layers.Image)) 64 | self.threshold_layer_dropdown.layer_changed.connect( 65 | lambda: threshold_btn.setEnabled(isinstance(self.threshold_layer_dropdown.selected_layer, napari.layers.Labels | napari.layers.Image)) 66 | ) 67 | 68 | threshold_box_layout.addWidget(threshold_btn) 69 | threshold_box.setLayout(threshold_box_layout) 70 | 71 | layout = QVBoxLayout() 72 | layout.addWidget(threshold_box) 73 | self.setLayout(layout) 74 | 75 | def _update_threshold_layer(self, selected_layer) -> None: 76 | """Update the layer that is set to be the 'source labels' layer for copying labels from.""" 77 | 78 | if selected_layer == "": 79 | self.threshold_layer = None 80 | else: 81 | self.threshold_layer = self.viewer.layers[selected_layer] 82 | self.threshold_layer_dropdown.setCurrentText(selected_layer) 83 | 84 | def _threshold(self): 85 | """Threshold the selected label or intensity image""" 86 | 87 | if isinstance(self.threshold_layer.data, da.core.Array): 88 | if self.outputdir is None: 89 | self.outputdir = QFileDialog.getExistingDirectory(self, "Select Output Folder") 90 | 91 | outputdir = os.path.join( 92 | self.outputdir, 93 | (self.threshold_layer.name + "_threshold"), 94 | ) 95 | if os.path.exists(outputdir): 96 | shutil.rmtree(outputdir) 97 | os.mkdir(outputdir) 98 | 99 | for i in range( 100 | self.threshold_layer.data.shape[0] 101 | ): # Loop over the first dimension 102 | data = self.threshold_layer.data[ 103 | i 104 | ].compute() # Compute the current stack 105 | 106 | thresholded = ( 107 | data >= int(self.min_threshold.value()) 108 | ) & (data <= int(self.max_threshold.value())) 109 | 110 | tifffile.imwrite( 111 | os.path.join( 112 | outputdir, 113 | ( 114 | self.threshold_layer.name 115 | + "_thresholded_TP" 116 | + str(i).zfill(4) 117 | + ".tif" 118 | ), 119 | ), 120 | np.array(thresholded, dtype="uint8"), 121 | ) 122 | 123 | file_list = [ 124 | os.path.join(outputdir, fname) 125 | for fname in os.listdir(outputdir) 126 | if fname.endswith(".tif") 127 | ] 128 | self.viewer.add_labels( 129 | da.stack([imread(fname) for fname in sorted(file_list)]), 130 | name=self.threshold_layer.name + "_thresholded", 131 | scale=self.threshold_layer.scale 132 | ) 133 | 134 | else: 135 | thresholded = ( 136 | self.threshold_layer.data >= int(self.min_threshold.value()) 137 | ) & (self.threshold_layer.data <= int(self.max_threshold.value())) 138 | self.viewer.add_labels( 139 | thresholded, name=self.threshold_layer.name + "_thresholded", 140 | scale=self.threshold_layer.scale 141 | ) 142 | -------------------------------------------------------------------------------- /tox.ini: -------------------------------------------------------------------------------- 1 | # For more information about tox, see https://tox.readthedocs.io/en/latest/ 2 | [tox] 3 | envlist = py{310, 311}-{linux,macos,windows} 4 | isolated_build=true 5 | 6 | [gh-actions] 7 | python = 8 | 3.10: py310 9 | 10 | [gh-actions:env] 11 | PLATFORM = 12 | ubuntu-latest: linux 13 | macos-latest: macos 14 | windows-latest: windows 15 | 16 | [testenv] 17 | platform = 18 | macos: darwin 19 | linux: linux 20 | windows: win32 21 | passenv = 22 | CI 23 | GITHUB_ACTIONS 24 | DISPLAY 25 | XAUTHORITY 26 | NUMPY_EXPERIMENTAL_ARRAY_FUNCTION 27 | PYVISTA_OFF_SCREEN 28 | extras = 29 | testing 30 | commands = pytest -v --color=yes --cov=napari_segmentation_correction --cov-report=xml 31 | --------------------------------------------------------------------------------