├── .github
└── workflows
│ └── test_and_deploy.yml
├── .gitignore
├── .napari-hub
└── DESCRIPTION.md
├── .pre-commit-config.yaml
├── LICENSE
├── MANIFEST.in
├── README.md
├── instructions
├── 3d_viewing.gif
├── copy-paste_labels.gif
├── copy_labels.gif
├── label_options.gif
├── napari-ndlabelcorrection_filter_by_size.gif
└── select_labels_by_mask.png
├── pyproject.toml
├── setup.cfg
├── src
└── napari_segmentation_correction
│ ├── __init__.py
│ ├── _tests
│ ├── __init__.py
│ └── test_copy_label.py
│ ├── _widget.py
│ ├── connected_components.py
│ ├── copy_label_widget.py
│ ├── cross_widget.py
│ ├── custom_table_widget.py
│ ├── erosion_dilation_widget.py
│ ├── icons
│ ├── Back.png
│ ├── Forward.png
│ ├── Home.png
│ ├── Pan.png
│ ├── Pan_checked.png
│ ├── Zoom.png
│ ├── Zoom_checked.png
│ ├── configure_subplots.png
│ ├── edit_parameters.png
│ └── save_figure.png
│ ├── image_calculator.py
│ ├── label_interpolator.py
│ ├── layer_controls.py
│ ├── layer_dropdown.py
│ ├── layer_manager.py
│ ├── napari.yaml
│ ├── plot_widget.py
│ ├── prop_filter_widget.py
│ ├── regionprops_extended.py
│ ├── regionprops_widget.py
│ ├── save_labels_widget.py
│ ├── select_delete_widget.py
│ ├── smoothing_widget.py
│ └── threshold_widget.py
└── tox.ini
/.github/workflows/test_and_deploy.yml:
--------------------------------------------------------------------------------
1 | # This workflows will upload a Python Package using Twine when a release is created
2 | # For more information see: https://help.github.com/en/actions/language-and-framework-guides/using-python-with-github-actions#publishing-to-package-registries
3 |
4 | name: tests
5 |
6 | on:
7 | push:
8 | branches:
9 | - main
10 | - npe2
11 | tags:
12 | - "v*" # Push events to matching v*, i.e. v1.0, v20.15.10
13 | pull_request:
14 | branches:
15 | - main
16 | - npe2
17 | workflow_dispatch:
18 |
19 | jobs:
20 | test:
21 | name: ${{ matrix.platform }} py${{ matrix.python-version }}
22 | runs-on: ${{ matrix.platform }}
23 | strategy:
24 | matrix:
25 | platform: [ubuntu-latest, windows-latest, macos-latest]
26 | python-version: ['3.10', '3.11']
27 |
28 | steps:
29 | - uses: actions/checkout@v3
30 |
31 | - name: Set up Python ${{ matrix.python-version }}
32 | uses: actions/setup-python@v4
33 | with:
34 | python-version: ${{ matrix.python-version }}
35 |
36 | # these libraries enable testing on Qt on linux
37 | - uses: tlambert03/setup-qt-libs@v1
38 |
39 | # strategy borrowed from vispy for installing opengl libs on windows
40 | - name: Install Windows OpenGL
41 | if: runner.os == 'Windows'
42 | run: |
43 | git clone --depth 1 https://github.com/pyvista/gl-ci-helpers.git
44 | powershell gl-ci-helpers/appveyor/install_opengl.ps1
45 |
46 | # note: if you need dependencies from conda, considering using
47 | # setup-miniconda: https://github.com/conda-incubator/setup-miniconda
48 | # and
49 | # tox-conda: https://github.com/tox-dev/tox-conda
50 | - name: Install dependencies
51 | run: |
52 | python -m pip install --upgrade pip
53 | python -m pip install setuptools tox tox-gh-actions
54 |
55 | # this runs the platform-specific tests declared in tox.ini
56 | - name: Test with tox
57 | uses: aganders3/headless-gui@v1
58 | with:
59 | run: python -m tox
60 | env:
61 | PLATFORM: ${{ matrix.platform }}
62 |
63 | - name: Coverage
64 | uses: codecov/codecov-action@v3
65 |
66 | deploy:
67 | # this will run when you have tagged a commit, starting with "v*"
68 | # and requires that you have put your twine API key in your
69 | # github secrets (see readme for details)
70 | needs: [test]
71 | runs-on: ubuntu-latest
72 | if: contains(github.ref, 'tags')
73 | steps:
74 | - uses: actions/checkout@v3
75 | - name: Set up Python
76 | uses: actions/setup-python@v4
77 | with:
78 | python-version: "3.x"
79 | - name: Install dependencies
80 | run: |
81 | python -m pip install --upgrade pip
82 | pip install -U setuptools setuptools_scm wheel twine build
83 | - name: Build and publish
84 | env:
85 | TWINE_USERNAME: __token__
86 | TWINE_PASSWORD: ${{ secrets.TWINE_API_KEY }}
87 | run: |
88 | git tag
89 | python -m build .
90 | twine upload dist/*
91 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | env/
12 | build/
13 | develop-eggs/
14 | dist/
15 | downloads/
16 | eggs/
17 | .eggs/
18 | lib/
19 | lib64/
20 | parts/
21 | sdist/
22 | var/
23 | *.egg-info/
24 | .installed.cfg
25 | *.egg
26 |
27 | # PyInstaller
28 | # Usually these files are written by a python script from a template
29 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
30 | *.manifest
31 | *.spec
32 |
33 | # Installer logs
34 | pip-log.txt
35 | pip-delete-this-directory.txt
36 |
37 | # Unit test / coverage reports
38 | htmlcov/
39 | .tox/
40 | .coverage
41 | .coverage.*
42 | .cache
43 | nosetests.xml
44 | coverage.xml
45 | *,cover
46 | .hypothesis/
47 | .napari_cache
48 |
49 | # Translations
50 | *.mo
51 | *.pot
52 |
53 | # Django stuff:
54 | *.log
55 | local_settings.py
56 |
57 | # Flask instance folder
58 | instance/
59 |
60 | # Sphinx documentation
61 | docs/_build/
62 |
63 | # MkDocs documentation
64 | /site/
65 |
66 | # PyBuilder
67 | target/
68 |
69 | # Pycharm and VSCode
70 | .idea/
71 | venv/
72 | .vscode/
73 |
74 | # IPython Notebook
75 | .ipynb_checkpoints
76 |
77 | # pyenv
78 | .python-version
79 |
80 | # OS
81 | .DS_Store
82 |
83 | # written by setuptools_scm
84 | **/_version.py
85 |
--------------------------------------------------------------------------------
/.napari-hub/DESCRIPTION.md:
--------------------------------------------------------------------------------
1 |
8 |
9 | The developer has not yet provided a napari-hub specific description.
10 |
--------------------------------------------------------------------------------
/.pre-commit-config.yaml:
--------------------------------------------------------------------------------
1 | repos:
2 | - repo: https://github.com/pre-commit/pre-commit-hooks
3 | rev: v5.0.0
4 | hooks:
5 | - id: check-docstring-first
6 | - id: end-of-file-fixer
7 | - id: trailing-whitespace
8 | exclude: ^\.napari-hub/.*
9 | - id: check-yaml # checks for correct yaml syntax for github actions ex.
10 | - repo: https://github.com/astral-sh/ruff-pre-commit
11 | rev: v0.12.5
12 | hooks:
13 | - id: ruff
14 | args: [--fix]
15 | - id: ruff-format
16 | # - repo: https://github.com/pre-commit/mirrors-mypy
17 | # rev: v1.17.0
18 | # hooks:
19 | # - id: mypy
20 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 |
2 | Copyright (c) 2023, Anniek Stokkermans
3 | All rights reserved.
4 |
5 | Redistribution and use in source and binary forms, with or without
6 | modification, are permitted provided that the following conditions are met:
7 |
8 | * Redistributions of source code must retain the above copyright notice, this
9 | list of conditions and the following disclaimer.
10 |
11 | * Redistributions in binary form must reproduce the above copyright notice,
12 | this list of conditions and the following disclaimer in the documentation
13 | and/or other materials provided with the distribution.
14 |
15 | * Neither the name of copyright holder nor the names of its
16 | contributors may be used to endorse or promote products derived from
17 | this software without specific prior written permission.
18 |
19 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29 |
--------------------------------------------------------------------------------
/MANIFEST.in:
--------------------------------------------------------------------------------
1 | include LICENSE
2 | include README.md
3 |
4 | recursive-exclude * __pycache__
5 | recursive-exclude * *.py[co]
6 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # napari-segmentation-correction
2 |
3 | [](https://github.com/AnniekStok/napari-segmentation-correction/raw/main/LICENSE)
4 | [](https://pypi.org/project/napari-segmentation-correction)
5 | [](https://python.org)
6 | [](https://github.com/AnniekStok/napari-segmentation-correction/actions)
7 | [](https://codecov.io/gh/AnniekStok/napari-segmentation-correction)
8 | [](https://napari-hub.org/plugins/napari-segmentation-correction)
9 |
10 | Toolbox for viewing, analyzing and correcting (cell) segmentation in 2D, 3D or 4D (t, z, y, x) (virtual) arrays.
11 | ----------------------------------
12 |
13 | This [napari] plugin was generated with [Cookiecutter] using [@napari]'s [cookiecutter-napari-plugin] template.
14 |
15 |
22 |
23 | ## Installation
24 |
25 | You can install `napari-segmentation-correction` via [pip]:
26 |
27 | To install latest development version :
28 |
29 | pip install git+https://github.com/AnniekStok/napari-segmentation-correction.git
30 |
31 | ## Usage
32 | This plugin serves as a toolbox aiming to help with correcting segmentation results.
33 | Functionalities:
34 | - Orthogonal views for 3D data.
35 | - Copy labels from a 2-5 dimensional array with multiple segmentation options to your current 2-5 dimensional label layer.
36 | - Label connected components, keep the largest connected cluster of labels, keep the largest fragment per label.
37 | - Smooth labels using a median filter.
38 | - Erode/dilate labels (scipy.ndimage and scikit-image)
39 | - Binarize an image or labels layer by applying an intensity threshold
40 | - Image calculator for mathematical operations on two images
41 | - Select/delete labels using a mask.
42 | - Binary mask interpolation in the z or time dimension.
43 | - Explore label properties (scikit-image regionprops) in a table widget and a Matplotlib plot.
44 | - Filter labels by properties.
45 |
46 | ### Copy labels between different labels layers
47 | 
48 |
49 |
50 |
51 |
52 | 2D/3D/4D labels can be copied from a source layer to a target layer via SHIFT+CLICK.
53 | The data in the source layer should have the same shape as the target layer, but can optionally have one extra dimension (e.g. stack multiple segmentation solutions as channels).
54 | To copy labels, select a 'source' and a 'target' labels layer in the dropdown. By default, the source layer will be displayed as contours.
55 | Select whether to copy a slice, a volume, or a series across time.
56 | Checking Use source label value keeps the original label values.
57 | Selecting Preserve target labels only allows copying into background (0) regions. Otherwise, SHIFT+CLICK replaces the existing label region.
58 | |
59 |
60 |
61 | |
62 |
63 |
64 |
65 | ### Connected component analysis
66 | There are shortcut buttons for connected components labeling, keeping the largest cluster of connected labels, and to keep the largest fragment per label.
67 |
68 |
69 |
70 | ### Select / delete labels that overlap with a binary mask
71 | All labels that share any pixel overlap with the mask are selected or removed.
72 |
73 |
74 | ### Binary mask interpolation
75 | It is possible to interpolate a 3D or 4D mask to fill in the region in between. In 3D, this means creating a 3D volume from slices, in 4D this means creating a time series of a volume that linearly 'morphs' into a different shape.
76 |
77 |
78 |
79 |
80 |
81 | |
82 |
83 |
84 | |
85 |
86 |
87 |
88 | ### Measuring label properties
89 | You can measure label properties, including intensity (if a matching image layer is provided), area/volume, perimeter/surface area, circularity/sphericity, ellipse/ellipsoid axes in the 'Region Properties' tab. Use the '3D data' checkbox to distinguish between measuring in 2D + time, 3D, and 3D + time, depending on your layer dimensions (2D to 4D). Once finished, a table displays the measurements, and a filter widget allows you to select objects matching a condition. The measurements are also displayed in the 'Plot'-tab for each layer for which you ran the region properties calculation.
90 |
91 | 
92 |
93 |
94 | ## Contributing
95 |
96 | Contributions are very welcome. Tests can be run with [tox], please ensure
97 | the coverage at least stays the same before you submit a pull request.
98 |
99 | ## License
100 |
101 | Distributed under the terms of the [BSD-3] license,
102 | "napari-segmentation-correction" is free and open source software
103 |
104 | ## Issues
105 |
106 | If you encounter any problems, please [file an issue] along with a detailed description.
107 |
108 | [napari]: https://github.com/napari/napari
109 | [Cookiecutter]: https://github.com/audreyr/cookiecutter
110 | [@napari]: https://github.com/napari
111 | [MIT]: http://opensource.org/licenses/MIT
112 | [BSD-3]: http://opensource.org/licenses/BSD-3-Clause
113 | [GNU GPL v3.0]: http://www.gnu.org/licenses/gpl-3.0.txt
114 | [GNU LGPL v3.0]: http://www.gnu.org/licenses/lgpl-3.0.txt
115 | [Apache Software License 2.0]: http://www.apache.org/licenses/LICENSE-2.0
116 | [Mozilla Public License 2.0]: https://www.mozilla.org/media/MPL/2.0/index.txt
117 | [cookiecutter-napari-plugin]: https://github.com/napari/cookiecutter-napari-plugin
118 |
119 | [file an issue]: https://github.com/AnniekStok/napari-segmentation-correction/issues
120 |
121 | [napari]: https://github.com/napari/napari
122 | [tox]: https://tox.readthedocs.io/en/latest/
123 | [pip]: https://pypi.org/project/pip/
124 | [PyPI]: https://pypi.org/
125 |
--------------------------------------------------------------------------------
/instructions/3d_viewing.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AnniekStok/napari-segmentation-correction/ece844476d64050a1f2ce54421aef53689088b45/instructions/3d_viewing.gif
--------------------------------------------------------------------------------
/instructions/copy-paste_labels.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AnniekStok/napari-segmentation-correction/ece844476d64050a1f2ce54421aef53689088b45/instructions/copy-paste_labels.gif
--------------------------------------------------------------------------------
/instructions/copy_labels.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AnniekStok/napari-segmentation-correction/ece844476d64050a1f2ce54421aef53689088b45/instructions/copy_labels.gif
--------------------------------------------------------------------------------
/instructions/label_options.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AnniekStok/napari-segmentation-correction/ece844476d64050a1f2ce54421aef53689088b45/instructions/label_options.gif
--------------------------------------------------------------------------------
/instructions/napari-ndlabelcorrection_filter_by_size.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AnniekStok/napari-segmentation-correction/ece844476d64050a1f2ce54421aef53689088b45/instructions/napari-ndlabelcorrection_filter_by_size.gif
--------------------------------------------------------------------------------
/instructions/select_labels_by_mask.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AnniekStok/napari-segmentation-correction/ece844476d64050a1f2ce54421aef53689088b45/instructions/select_labels_by_mask.png
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [build-system]
2 | requires = ["setuptools>=42.0.0", "wheel"]
3 | build-backend = "setuptools.build_meta"
4 |
5 | [tool.ruff]
6 | line-length = 88
7 | target-version = "py310"
8 | fix = true
9 |
10 | lint.select = [
11 | "E", "F", "W", #flake8
12 | "UP", # pyupgrade
13 | "I", # isort
14 | "BLE", # flake8-blind-exception
15 | "B", # flake8-bugbear
16 | "A", # flake8-builtins
17 | "C4", # flake8-comprehensions
18 | "ISC", # flake8-implicit-str-concat
19 | "G", # flake8-logging-format
20 | "PIE", # flake8-pie
21 | "SIM", # flake8-simplify
22 | ]
23 |
24 | lint.ignore = [
25 | "UP006", "UP007", # type annotation. As using magicgui require runtime type annotation then we disable this.
26 | "ISC001", # implicit string concatenation
27 | "E501", # line too long
28 | ]
29 |
30 | lint.per-file-ignores = { "scripts/*.py" = ["F"] }
31 |
32 | # https://docs.astral.sh/ruff/formatter/
33 | [tool.ruff.format]
34 |
35 | [tool.mypy]
36 | ignore_missing_imports = true
37 |
38 | exclude = [
39 | ".bzr",
40 | ".direnv",
41 | ".eggs",
42 | ".git",
43 | ".mypy_cache",
44 | ".pants.d",
45 | ".ruff_cache",
46 | ".svn",
47 | ".tox",
48 | ".venv",
49 | "__pypackages__",
50 | "_build",
51 | "buck-out",
52 | "build",
53 | "dist",
54 | "node_modules",
55 | "venv",
56 | "*vendored*",
57 | "*_vendor*",
58 | ]
59 |
60 | target-version = "py38"
61 | fix = true
62 |
--------------------------------------------------------------------------------
/setup.cfg:
--------------------------------------------------------------------------------
1 | [metadata]
2 | name = napari-segmentation-correction
3 | version = attr: napari_segmentation_correction.__version__
4 | description = A plugin for manually correcting cell segmentation in 3D (z, y, x) or 4D (t, z, y, x) (virtual) arrays.
5 | long_description = file: README.md
6 | long_description_content_type = text/markdown
7 | url = https://github.com/AnniekStok/napari-segmentation-correction
8 | author = Anniek Stokkermans
9 | author_email = anniek.stokkermans@gmail.com
10 | license = BSD-3-Clause
11 | license_files = LICENSE
12 | classifiers =
13 | Development Status :: 2 - Pre-Alpha
14 | Framework :: napari
15 | Intended Audience :: Developers
16 | License :: OSI Approved :: BSD License
17 | Operating System :: OS Independent
18 | Programming Language :: Python
19 | Programming Language :: Python :: 3
20 | Programming Language :: Python :: 3 :: Only
21 | Programming Language :: Python :: 3.8
22 | Programming Language :: Python :: 3.9
23 | Programming Language :: Python :: 3.10
24 | Topic :: Scientific/Engineering :: Image Processing
25 | project_urls =
26 | Bug Tracker = https://github.com/AnniekStok/napari-segmentation-correction/issues
27 | Documentation = https://github.com/AnniekStok/napari-segmentation-correction#README.md
28 | Source Code = https://github.com/AnniekStok/napari-segmentation-correction
29 | User Support = https://github.com/AnniekStok/napari-segmentation-correction/issues
30 |
31 | [options]
32 | packages = find:
33 | install_requires =
34 | napari >= 0.6.0
35 | numpy
36 | scikit-image
37 | dask_image
38 | dask
39 | matplotlib
40 | imagecodecs
41 | napari-plane-sliders @ git+https://github.com/AnniekStok/napari-plane-sliders.git@main
42 | napari-orthogonal-views @ git+https://github.com/AnniekStok/napari-orthogonal-views.git@v0.0.4
43 |
44 | python_requires = >=3.8
45 | include_package_data = True
46 | package_dir =
47 | =src
48 |
49 | # add your package requirements here
50 |
51 | [options.packages.find]
52 | where = src
53 |
54 | [options.entry_points]
55 | napari.manifest =
56 | napari-segmentation-correction = napari_segmentation_correction:napari.yaml
57 |
58 | [options.extras_require]
59 | testing =
60 | tox
61 | pytest # https://docs.pytest.org/en/latest/contents.html
62 | pytest-cov # https://pytest-cov.readthedocs.io/en/latest/
63 | pytest-qt
64 |
65 | [options.package_data]
66 | * = *.yaml
67 |
--------------------------------------------------------------------------------
/src/napari_segmentation_correction/__init__.py:
--------------------------------------------------------------------------------
1 | __version__ = "0.0.1"
2 | from ._widget import AnnotateLabelsND
3 |
4 | __all__ = ("AnnotateLabelsND",)
5 |
--------------------------------------------------------------------------------
/src/napari_segmentation_correction/_tests/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AnniekStok/napari-segmentation-correction/ece844476d64050a1f2ce54421aef53689088b45/src/napari_segmentation_correction/_tests/__init__.py
--------------------------------------------------------------------------------
/src/napari_segmentation_correction/_tests/test_copy_label.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import pytest
3 | from napari.layers import Labels
4 |
5 | from napari_segmentation_correction.copy_label_widget import CopyLabelWidget
6 |
7 |
8 | @pytest.fixture
9 | def make_event():
10 | def _make_event(position, modifiers=("Shift",), dims_displayed=(0, 1)):
11 | class DummyEvent:
12 | def __init__(self):
13 | self.type = "mouse_press"
14 | self.position = position
15 | self.modifiers = modifiers
16 | self.dims_displayed = dims_displayed
17 | self.view_direction = None
18 |
19 | # mimic napari Event
20 |
21 | return DummyEvent()
22 |
23 | return _make_event
24 |
25 |
26 | @pytest.fixture
27 | def make_layers():
28 | def _make_layers(shape):
29 | src = Labels(np.zeros(shape, dtype=np.uint16))
30 | tgt = Labels(np.zeros(shape, dtype=np.uint16))
31 | return src, tgt
32 |
33 | return _make_layers
34 |
35 |
36 | def test_copy_2d(make_event, make_napari_viewer):
37 | viewer = make_napari_viewer()
38 | src = viewer.add_labels(np.zeros((64, 64), dtype=np.uint16))
39 | tgt = viewer.add_labels(np.zeros((64, 64), dtype=np.uint16))
40 | src.data[30:40, 30:40] = 1
41 |
42 | widget = CopyLabelWidget(viewer)
43 | widget.source_layer = src
44 | widget.target_layer = tgt
45 | widget.dims_widget.slice.setChecked(True)
46 |
47 | event = make_event([32, 32])
48 | widget.copy_label(event)
49 |
50 | expected = np.zeros((64, 64), dtype=np.uint8)
51 | expected[30:40, 30:40] = 1
52 | np.testing.assert_array_equal(tgt.data, expected)
53 |
54 |
55 | def test_copy_3d(make_event, make_napari_viewer):
56 | viewer = make_napari_viewer()
57 | src = viewer.add_labels(np.zeros((10, 64, 64), dtype=np.uint16))
58 | tgt = viewer.add_labels(np.zeros((10, 64, 64), dtype=np.uint16))
59 | src.data[5:8, 20:30, 20:30] = 10
60 |
61 | widget = CopyLabelWidget(viewer)
62 | widget.source_layer = src
63 | widget.target_layer = tgt
64 | widget.dims_widget.volume.setChecked(True)
65 |
66 | viewer.dims.current_step = (5, 0, 0)
67 |
68 | event = make_event([5, 25, 25], dims_displayed=[1, 2])
69 | widget.copy_label(event)
70 |
71 | # check copying volume
72 | expected = np.zeros((10, 64, 64), dtype=np.uint8)
73 | expected[5:8, 20:30, 20:30] = 1
74 | np.testing.assert_array_equal(tgt.data, expected)
75 |
76 | # check preserving label value
77 | widget.preserve_label_value.setChecked(True)
78 | widget.copy_label(event)
79 | expected[5:8, 20:30, 20:30] = 10
80 | np.testing.assert_array_equal(tgt.data, expected)
81 |
82 | # check replacing a slice only
83 | widget.dims_widget.slice.setChecked(True)
84 | widget.preserve_label_value.setChecked(False)
85 | widget.copy_label(event)
86 | expected[5, 20:30, 20:30] = 11
87 | np.testing.assert_array_equal(tgt.data, expected)
88 |
89 |
90 | def test_copy_4d(make_event, make_napari_viewer):
91 | viewer = make_napari_viewer()
92 | src = viewer.add_labels(np.zeros((3, 10, 64, 64), dtype=np.uint16))
93 | tgt = viewer.add_labels(np.zeros((3, 10, 64, 64), dtype=np.uint16))
94 | src.data[1:2, 5:8, 20:30, 20:30] = 10
95 |
96 | widget = CopyLabelWidget(viewer)
97 | widget.source_layer = src
98 | widget.target_layer = tgt
99 | widget.dims_widget.series.setChecked(True)
100 |
101 | viewer.dims.current_step = (1, 5, 0, 0)
102 |
103 | event = make_event([1, 5, 25, 25], dims_displayed=[2, 3])
104 | widget.copy_label(event)
105 | expected = np.zeros((3, 10, 64, 64), dtype=np.uint8)
106 | expected[1:2, 5:8, 20:30, 20:30] = 1
107 | np.testing.assert_array_equal(tgt.data, expected)
108 |
109 | widget.preserve_label_value.setChecked(True)
110 | widget.copy_label(event)
111 | expected[1:2, 5:8, 20:30, 20:30] = 10
112 | np.testing.assert_array_equal(tgt.data, expected)
113 |
114 |
115 | def test_copy_4d_slice_to_2d(make_event, make_napari_viewer):
116 | viewer = make_napari_viewer()
117 | src = viewer.add_labels(np.zeros((3, 10, 64, 64), dtype=np.uint16))
118 | tgt = Labels(np.zeros((64, 64), dtype=np.uint16))
119 | src.data[1:2, 5:8, 20:30, 20:30] = 10
120 |
121 | widget = CopyLabelWidget(viewer)
122 | widget.source_layer = src
123 | widget.target_layer = tgt
124 | widget.dims_widget.slice.setChecked(True)
125 |
126 | viewer.dims.current_step = (1, 5, 0, 0)
127 |
128 | event = make_event([1, 5, 25, 25], dims_displayed=[2, 3])
129 | widget.copy_label(event)
130 | expected = np.zeros((64, 64), dtype=np.uint8)
131 | expected[20:30, 20:30] = 1
132 | np.testing.assert_array_equal(tgt.data, expected)
133 |
134 | widget.preserve_label_value.setChecked(True)
135 | widget.copy_label(event)
136 | expected[20:30, 20:30] = 10
137 | np.testing.assert_array_equal(tgt.data, expected)
138 |
139 |
140 | def test_copy_2d_slice_to_4d(make_event, make_napari_viewer):
141 | viewer = make_napari_viewer()
142 | src = Labels(np.zeros((64, 64), dtype=np.uint16))
143 | tgt = viewer.add_labels(np.zeros((3, 10, 64, 64), dtype=np.uint16))
144 | src.data[20:30, 20:30] = 10
145 |
146 | widget = CopyLabelWidget(viewer)
147 | widget.source_layer = src
148 | widget.target_layer = tgt
149 | widget.dims_widget.slice.setChecked(True)
150 |
151 | viewer.dims.current_step = (2, 5, 0, 0)
152 |
153 | event = make_event([2, 5, 25, 25], dims_displayed=[2, 3])
154 | widget.copy_label(event)
155 | expected = np.zeros((3, 10, 64, 64), dtype=np.uint8)
156 | expected[2, 5, 20:30, 20:30] = 1
157 | np.testing.assert_array_equal(tgt.data, expected)
158 |
159 | widget.preserve_label_value.setChecked(True)
160 | widget.copy_label(event)
161 | expected[2, 5, 20:30, 20:30] = 10
162 | np.testing.assert_array_equal(tgt.data, expected)
163 |
--------------------------------------------------------------------------------
/src/napari_segmentation_correction/_widget.py:
--------------------------------------------------------------------------------
1 | """
2 | Napari plugin widget for editing N-dimensional label data
3 | """
4 |
5 | import napari
6 | from napari.layers import Labels
7 | from napari_orthogonal_views.ortho_view_manager import _get_manager
8 | from qtpy.QtWidgets import (
9 | QScrollArea,
10 | QTabWidget,
11 | QVBoxLayout,
12 | QWidget,
13 | )
14 |
15 | from .connected_components import ConnectedComponents
16 | from .erosion_dilation_widget import ErosionDilationWidget
17 | from .image_calculator import ImageCalculator
18 | from .label_interpolator import InterpolationWidget
19 | from .layer_controls import LayerControlsWidget
20 | from .layer_manager import LayerManager
21 | from .plot_widget import PlotWidget
22 | from .regionprops_widget import RegionPropsWidget
23 | from .select_delete_widget import SelectDeleteMask
24 | from .smoothing_widget import SmoothingWidget
25 | from .threshold_widget import ThresholdWidget
26 |
27 |
28 | class AnnotateLabelsND(QWidget):
29 | """Widget for manual correction of label data, for example to prepare ground truth data for training a segmentation model"""
30 |
31 | def __init__(self, viewer: "napari.viewer.Viewer") -> None:
32 | super().__init__()
33 | self.viewer = viewer
34 | self.source_labels = None
35 | self.target_labels = None
36 | self.points = None
37 | self.copy_points = None
38 | self.edit_layout = QVBoxLayout()
39 | self.tab_widget = QTabWidget(self)
40 | self.option_labels = None
41 |
42 | ### Add label manager
43 | self.label_manager = LayerManager(self.viewer)
44 |
45 | ### Add layer controls widget
46 | self.layer_controls = LayerControlsWidget(self.viewer, self.label_manager)
47 |
48 | ### activate orthogonal views and register custom function
49 | def label_options_click_hook(orig_layer, copied_layer):
50 | copied_layer.mouse_drag_callbacks.append(
51 | lambda layer, event: self.layer_controls.copy_label_widget.sync_click(
52 | orig_layer, layer, event
53 | )
54 | )
55 |
56 | orth_view_manager = _get_manager(self.viewer)
57 | orth_view_manager.register_layer_hook(Labels, label_options_click_hook)
58 |
59 | ### Add widget for connected component labeling
60 | conn_comp_widget = ConnectedComponents(self.viewer, self.label_manager)
61 | self.edit_layout.addWidget(conn_comp_widget)
62 |
63 | ### Add widget for smoothing labels
64 | smooth_widget = SmoothingWidget(self.viewer, self.label_manager)
65 | self.edit_layout.addWidget(smooth_widget)
66 |
67 | ### Add widget for eroding/dilating labels
68 | erode_dilate_widget = ErosionDilationWidget(self.viewer, self.label_manager)
69 | self.edit_layout.addWidget(erode_dilate_widget)
70 |
71 | ### Threshold image
72 | threshold_widget = ThresholdWidget(self.viewer)
73 | self.edit_layout.addWidget(threshold_widget)
74 |
75 | # Add image calculator
76 | image_calc = ImageCalculator(self.viewer)
77 | self.edit_layout.addWidget(image_calc)
78 |
79 | # Add widget for selecting/deleting by mask
80 | select_del = SelectDeleteMask(self.viewer)
81 | self.edit_layout.addWidget(select_del)
82 |
83 | # Add widget for interpolating masks
84 | interpolation_widget = InterpolationWidget(self.viewer, self.label_manager)
85 | self.edit_layout.addWidget(interpolation_widget)
86 |
87 | ### Add layer controls widget to tab
88 | controls_scroll_area = QScrollArea()
89 | controls_scroll_area.setWidget(self.layer_controls)
90 | controls_scroll_area.setWidgetResizable(True)
91 | self.tab_widget.addTab(controls_scroll_area, "Layer Controls")
92 | self.tab_widget.setCurrentIndex(1)
93 |
94 | ### add combined editing widgets widgets
95 | self.edit_widgets = QWidget()
96 | self.edit_widgets.setLayout(self.edit_layout)
97 | scroll_area = QScrollArea()
98 | scroll_area.setWidget(self.edit_widgets)
99 | scroll_area.setWidgetResizable(True)
100 | self.tab_widget.addTab(scroll_area, "Editing")
101 |
102 | ### Add widget for reginproperties
103 | self.regionprops_widget = RegionPropsWidget(self.viewer, self.label_manager)
104 | props_scroll_area = QScrollArea()
105 | props_scroll_area.setWidget(self.regionprops_widget)
106 | props_scroll_area.setWidgetResizable(True)
107 | self.tab_widget.addTab(props_scroll_area, "Region properties")
108 |
109 | ### Add widget for displaying plot with regionprops
110 | self.plot_widget = PlotWidget(self.label_manager)
111 | self.tab_widget.addTab(self.plot_widget, "Plot")
112 |
113 | # Add the tab widget to the main layout
114 | self.main_layout = QVBoxLayout()
115 | self.main_layout.addWidget(self.tab_widget)
116 | self.setLayout(self.main_layout)
117 |
--------------------------------------------------------------------------------
/src/napari_segmentation_correction/connected_components.py:
--------------------------------------------------------------------------------
1 | import os
2 | import shutil
3 |
4 | import dask.array as da
5 | import napari
6 | import numpy as np
7 | import tifffile
8 | from qtpy.QtWidgets import (
9 | QFileDialog,
10 | QGroupBox,
11 | QPushButton,
12 | QVBoxLayout,
13 | QWidget,
14 | )
15 | from scipy import ndimage
16 | from skimage.io import imread
17 | from skimage.measure import label
18 |
19 | from .layer_manager import LayerManager
20 |
21 |
22 | def keep_largest_fragment_per_label(img: np.ndarray, labels: list[int]) -> np.ndarray:
23 | """Keep only the largest connected component per label in `labels`."""
24 | out = np.zeros_like(img)
25 | for label_value in labels:
26 | if label_value == 0:
27 | continue
28 | mask = img == label_value
29 | if not np.any(mask):
30 | continue
31 |
32 | labeled, n = ndimage.label(mask)
33 | if n == 0:
34 | continue
35 | if n == 1:
36 | out[mask] = label_value
37 | continue
38 |
39 | sizes = np.bincount(labeled.ravel())[1:] # skip background
40 | largest_cc = 1 + np.argmax(sizes) # component index (1-based)
41 | out[labeled == largest_cc] = label_value
42 |
43 | return out
44 |
45 |
46 | class ConnectedComponents(QWidget):
47 | """Widget to run connected component analysis"""
48 |
49 | def __init__(
50 | self, viewer: "napari.viewer.Viewer", label_manager: LayerManager
51 | ) -> None:
52 | super().__init__()
53 |
54 | self.viewer = viewer
55 | self.label_manager = label_manager
56 | self.outputdir = None
57 |
58 | conn_comp_box = QGroupBox("Connected Component Analysis")
59 | conn_comp_box_layout = QVBoxLayout()
60 |
61 | self.conncomp_btn = QPushButton("Find connected components")
62 | self.conncomp_btn.setToolTip(
63 | "Run connected component analysis to (re)label the labels layer"
64 | )
65 | self.conncomp_btn.clicked.connect(self._conn_comp)
66 | self.conncomp_btn.setEnabled(
67 | isinstance(self.label_manager.selected_layer, napari.layers.Labels)
68 | )
69 | self.label_manager.layer_update.connect(self._update_button_state)
70 |
71 | conn_comp_box_layout.addWidget(self.conncomp_btn)
72 |
73 | self.keep_largest_btn = QPushButton("Keep largest component cluster")
74 | self.keep_largest_btn.setToolTip(
75 | "Keep only the labels part of the largest non-zero connected component"
76 | )
77 | self.keep_largest_btn.clicked.connect(self._keep_largest_cluster)
78 | self.keep_largest_btn.setEnabled(
79 | isinstance(self.label_manager.selected_layer, napari.layers.Labels)
80 | )
81 | conn_comp_box_layout.addWidget(self.keep_largest_btn)
82 |
83 | self.keep_largest_fragment_btn = QPushButton("Keep largest fragment per label")
84 | self.keep_largest_fragment_btn.setToolTip(
85 | "For each label, keep only the largest connected fragment"
86 | )
87 | self.keep_largest_fragment_btn.clicked.connect(self._keep_largest_fragment)
88 | self.keep_largest_fragment_btn.setEnabled(
89 | isinstance(self.label_manager.selected_layer, napari.layers.Labels)
90 | )
91 | conn_comp_box_layout.addWidget(self.keep_largest_fragment_btn)
92 |
93 | conn_comp_box.setLayout(conn_comp_box_layout)
94 | main_layout = QVBoxLayout()
95 | main_layout.addWidget(conn_comp_box)
96 | self.setLayout(main_layout)
97 |
98 | def _update_button_state(self):
99 | self.conncomp_btn.setEnabled(
100 | isinstance(self.label_manager.selected_layer, napari.layers.Labels)
101 | )
102 | self.keep_largest_btn.setEnabled(
103 | isinstance(self.label_manager.selected_layer, napari.layers.Labels)
104 | )
105 | self.keep_largest_fragment_btn.setEnabled(
106 | isinstance(self.label_manager.selected_layer, napari.layers.Labels)
107 | )
108 |
109 | def _keep_largest_cluster(self):
110 | """Keep only the labels part of the largest non-zero connected component"""
111 |
112 | if isinstance(self.label_manager.selected_layer.data, da.core.Array):
113 | if self.outputdir is None:
114 | self.outputdir = QFileDialog.getExistingDirectory(
115 | self, "Select Output Folder"
116 | )
117 |
118 | outputdir = os.path.join(
119 | self.outputdir,
120 | (self.label_manager.selected_layer.name + "_largest_cluster"),
121 | )
122 | if os.path.exists(outputdir):
123 | shutil.rmtree(outputdir)
124 | os.mkdir(outputdir)
125 |
126 | for i in range(
127 | self.label_manager.selected_layer.data.shape[0]
128 | ): # Loop over the first dimension
129 | current_stack = self.label_manager.selected_layer.data[
130 | i
131 | ].compute() # Compute the current stack
132 | mask = current_stack > 0
133 | labeled = label(mask)
134 | props = np.bincount(labeled.flat)
135 | props[0] = 0 # ignore background
136 | largest_label = props.argmax()
137 | largest_cluster = (labeled == largest_label) * current_stack
138 | tifffile.imwrite(
139 | os.path.join(
140 | outputdir,
141 | (
142 | self.label_manager.selected_layer.name
143 | + "_largest_cluster_TP"
144 | + str(i).zfill(4)
145 | + ".tif"
146 | ),
147 | ),
148 | np.array(largest_cluster, dtype="uint16"),
149 | )
150 |
151 | file_list = [
152 | os.path.join(outputdir, fname)
153 | for fname in os.listdir(outputdir)
154 | if fname.endswith(".tif")
155 | ]
156 | self.label_manager.selected_layer = self.viewer.add_labels(
157 | da.stack([imread(fname) for fname in sorted(file_list)]),
158 | name=self.label_manager.selected_layer.name + "_largest_cluster",
159 | scale=self.label_manager.selected_layer.scale,
160 | )
161 | else:
162 | shape = self.label_manager.selected_layer.data.shape
163 | if len(shape) > 3:
164 | largest_cluster = np.zeros_like(self.label_manager.selected_layer.data)
165 | for i in range(shape[0]):
166 | mask = self.label_manager.selected_layer.data[i] > 0
167 | labeled = label(mask)
168 | props = np.bincount(labeled.flat)
169 | props[0] = 0 # ignore background
170 | largest_label = props.argmax()
171 | largest_cluster[i] = (
172 | labeled == largest_label
173 | ) * self.label_manager.selected_layer.data[i]
174 |
175 | else:
176 | mask = self.label_manager.selected_layer.data > 0
177 | labeled = label(mask)
178 | props = np.bincount(labeled.flat)
179 | props[0] = 0 # ignore background
180 | largest_label = props.argmax()
181 | largest_cluster = (
182 | labeled == largest_label
183 | ) * self.label_manager.selected_layer.data
184 |
185 | self.label_manager.selected_layer = self.viewer.add_labels(
186 | largest_cluster,
187 | name=self.label_manager.selected_layer.name + "_largest_cluster",
188 | scale=self.label_manager.selected_layer.scale,
189 | )
190 |
191 | def _keep_largest_fragment(self):
192 | """Keep only the largest fragment per label"""
193 |
194 | if isinstance(self.label_manager.selected_layer.data, da.core.Array):
195 | if self.outputdir is None:
196 | self.outputdir = QFileDialog.getExistingDirectory(
197 | self, "Select Output Folder"
198 | )
199 |
200 | outputdir = os.path.join(
201 | self.outputdir,
202 | (self.label_manager.selected_layer.name + "_largest_fragment"),
203 | )
204 | if os.path.exists(outputdir):
205 | shutil.rmtree(outputdir)
206 | os.mkdir(outputdir)
207 |
208 | for i in range(
209 | self.label_manager.selected_layer.data.shape[0]
210 | ): # Loop over the first dimension
211 | current_stack = self.label_manager.selected_layer.data[
212 | i
213 | ].compute() # Compute the current stack
214 |
215 | labels = np.unique(current_stack)
216 | largest_fragments = keep_largest_fragment_per_label(
217 | current_stack, labels
218 | )
219 |
220 | tifffile.imwrite(
221 | os.path.join(
222 | outputdir,
223 | (
224 | self.label_manager.selected_layer.name
225 | + "_largest_fragments_TP"
226 | + str(i).zfill(4)
227 | + ".tif"
228 | ),
229 | ),
230 | largest_fragments,
231 | )
232 |
233 | file_list = [
234 | os.path.join(outputdir, fname)
235 | for fname in os.listdir(outputdir)
236 | if fname.endswith(".tif")
237 | ]
238 | self.label_manager.selected_layer = self.viewer.add_labels(
239 | da.stack([imread(fname) for fname in sorted(file_list)]),
240 | name=self.label_manager.selected_layer.name + "_largest_fragments",
241 | scale=self.label_manager.selected_layer.scale,
242 | )
243 |
244 | else:
245 | shape = self.label_manager.selected_layer.data.shape
246 | if len(shape) > 3:
247 | largest_fragments = np.zeros_like(
248 | self.label_manager.selected_layer.data
249 | )
250 | for i in range(shape[0]):
251 | labels = np.unique(self.label_manager.selected_layer.data[i])
252 | largest_fragments[i] = keep_largest_fragment_per_label(
253 | self.label_manager.selected_layer.data[i], labels
254 | )
255 | else:
256 | labels = np.unique(self.label_manager.selected_layer.data)
257 | largest_fragments = keep_largest_fragment_per_label(
258 | self.label_manager.selected_layer.data, labels
259 | )
260 |
261 | self.label_manager.selected_layer = self.viewer.add_labels(
262 | largest_fragments,
263 | name=self.label_manager.selected_layer.name + "_largest_fragments",
264 | scale=self.label_manager.selected_layer.scale,
265 | )
266 |
267 | def _conn_comp(self):
268 | """Run connected component analysis to (re) label the labels array"""
269 |
270 | if isinstance(self.label_manager.selected_layer.data, da.core.Array):
271 | if self.outputdir is None:
272 | self.outputdir = QFileDialog.getExistingDirectory(
273 | self, "Select Output Folder"
274 | )
275 |
276 | outputdir = os.path.join(
277 | self.outputdir,
278 | (self.label_manager.selected_layer.name + "_conncomp"),
279 | )
280 | if os.path.exists(outputdir):
281 | shutil.rmtree(outputdir)
282 | os.mkdir(outputdir)
283 |
284 | for i in range(
285 | self.label_manager.selected_layer.data.shape[0]
286 | ): # Loop over the first dimension
287 | current_stack = self.label_manager.selected_layer.data[
288 | i
289 | ].compute() # Compute the current stack
290 | relabeled = label(current_stack)
291 | tifffile.imwrite(
292 | os.path.join(
293 | outputdir,
294 | (
295 | self.label_manager.selected_layer.name
296 | + "_conn_comp_TP"
297 | + str(i).zfill(4)
298 | + ".tif"
299 | ),
300 | ),
301 | np.array(relabeled, dtype="uint16"),
302 | )
303 |
304 | file_list = [
305 | os.path.join(outputdir, fname)
306 | for fname in os.listdir(outputdir)
307 | if fname.endswith(".tif")
308 | ]
309 | self.label_manager.selected_layer = self.viewer.add_labels(
310 | da.stack([imread(fname) for fname in sorted(file_list)]),
311 | name=self.label_manager.selected_layer.name + "_conn_comp",
312 | scale=self.label_manager.selected_layer.scale,
313 | )
314 | else:
315 | shape = self.label_manager.selected_layer.data.shape
316 | if len(shape) > 3:
317 | conn_comp = np.zeros_like(self.label_manager.selected_layer.data)
318 | for i in range(shape[0]):
319 | conn_comp[i] = label(self.label_manager.selected_layer.data[i])
320 |
321 | else:
322 | conn_comp = label(self.label_manager.selected_layer.data)
323 |
324 | self.label_manager.selected_layer = self.viewer.add_labels(
325 | conn_comp,
326 | name=self.label_manager.selected_layer.name + "_conn_comp",
327 | scale=self.label_manager.selected_layer.scale,
328 | )
329 |
--------------------------------------------------------------------------------
/src/napari_segmentation_correction/copy_label_widget.py:
--------------------------------------------------------------------------------
1 | import dask.array as da
2 | import napari
3 | import numpy as np
4 | from napari.layers import Labels
5 | from napari.utils.events import Event
6 | from qtpy.QtWidgets import (
7 | QButtonGroup,
8 | QCheckBox,
9 | QGroupBox,
10 | QHBoxLayout,
11 | QLabel,
12 | QMessageBox,
13 | QPushButton,
14 | QRadioButton,
15 | QVBoxLayout,
16 | QWidget,
17 | )
18 |
19 | from .layer_dropdown import LayerDropdown
20 |
21 |
22 | def check_value_dtype(value, dtype):
23 | # Get min and max for the dtype
24 | info = np.iinfo(dtype)
25 | within_range = info.min <= value <= info.max
26 |
27 | # If not in range, find the next suitable unsigned dtype
28 | next_dtype = None
29 | if not within_range:
30 | unsigned_dtypes = [np.uint8, np.uint16, np.uint32, np.uint64]
31 | for dt in unsigned_dtypes:
32 | if np.iinfo(dt).max >= value:
33 | next_dtype = dt
34 | break
35 |
36 | return within_range, next_dtype
37 |
38 |
39 | class DimsRadioButtons(QWidget):
40 | def __init__(self) -> None:
41 | super().__init__()
42 |
43 | label = QLabel("Copy dimensions:")
44 |
45 | button_group = QButtonGroup()
46 | self.slice = QRadioButton("Slice (last 2 dims)")
47 | self.slice.setEnabled(False)
48 | self.slice.setChecked(False)
49 | self.volume = QRadioButton("Volume (last 3 dims)")
50 | self.volume.setChecked(False)
51 | self.volume.setEnabled(False)
52 | self.series = QRadioButton("Series (last 4 dims)")
53 | self.series.setChecked(False)
54 | self.series.setEnabled(False)
55 |
56 | button_group.addButton(self.slice)
57 | button_group.addButton(self.volume)
58 | button_group.addButton(self.series)
59 |
60 | button_layout = QHBoxLayout()
61 | button_layout.addWidget(self.slice)
62 | button_layout.addWidget(self.volume)
63 | button_layout.addWidget(self.series)
64 |
65 | layout = QVBoxLayout()
66 | layout.addWidget(label)
67 | layout.addLayout(button_layout)
68 | self.setLayout(layout)
69 |
70 |
71 | class CopyLabelWidget(QWidget):
72 | """Widget to copy labels from a source layer to a target layer."""
73 |
74 | def __init__(self, viewer: "napari.viewer.Viewer") -> None:
75 | super().__init__()
76 |
77 | self.viewer = viewer
78 |
79 | self.source_layer = None
80 | self.target_layer = None
81 | self._source_callback = None
82 |
83 | copy_labels_box = QGroupBox("Copy-paste labels")
84 | copy_labels_layout = QVBoxLayout()
85 |
86 | # instruction label
87 | label = QLabel(
88 | "Use shift + click on the source layer to copy labels to the target layer."
89 | )
90 | label.setWordWrap(True)
91 | font = label.font()
92 | font.setItalic(True)
93 | label.setFont(font)
94 | copy_labels_layout.addWidget(label)
95 |
96 | # Whether or not to preserve the source layer label value when copying or to use
97 | # the next available label in the target layer
98 | self.preserve_label_value = QCheckBox("Use source label value")
99 | self.preserve_existing_labels = QCheckBox("Preserve target labels")
100 | option_layout = QHBoxLayout()
101 | option_layout.addWidget(self.preserve_label_value)
102 | option_layout.addWidget(self.preserve_existing_labels)
103 | copy_labels_layout.addLayout(option_layout)
104 |
105 | # Source layer and target layer dropdowns
106 | image_layout = QVBoxLayout()
107 | source_layout = QHBoxLayout()
108 | source_layout.addWidget(QLabel("Source labels"))
109 | self.source_dropdown = LayerDropdown(self.viewer, (Labels), allow_none=True)
110 | self.source_dropdown.viewer.layers.selection.events.changed.disconnect(
111 | self.source_dropdown._on_selection_changed
112 | )
113 | self.source_dropdown.layer_changed.connect(self._update_source)
114 | source_layout.addWidget(self.source_dropdown)
115 |
116 | target_layout = QHBoxLayout()
117 | target_layout.addWidget(QLabel("Target labels"))
118 | self.target_dropdown = LayerDropdown(self.viewer, (Labels), allow_none=True)
119 | self.target_dropdown.viewer.layers.selection.events.changed.disconnect(
120 | self.target_dropdown._on_selection_changed
121 | )
122 | self.target_dropdown.layer_changed.connect(self._update_target)
123 | target_layout.addWidget(self.target_dropdown)
124 |
125 | image_layout.addLayout(source_layout)
126 | image_layout.addLayout(target_layout)
127 |
128 | copy_labels_layout.addLayout(image_layout)
129 |
130 | # Radiobuttons for selecting whether to copy a slice/volume/series
131 | self.dims_widget = DimsRadioButtons()
132 | copy_labels_layout.addWidget(self.dims_widget)
133 |
134 | # Undo the last copy action if possible
135 | self.prev_state = None
136 | self.coords_clipped = None
137 | self.target_slices = None
138 | self.undo_btn = QPushButton("Undo last copy")
139 | self.undo_btn.setEnabled(False)
140 | self.undo_btn.clicked.connect(self.undo)
141 | copy_labels_layout.addWidget(self.undo_btn)
142 |
143 | # assemble the layout
144 | copy_labels_box.setLayout(copy_labels_layout)
145 | layout = QVBoxLayout()
146 | layout.addWidget(copy_labels_box)
147 | self.setLayout(layout)
148 |
149 | def _update_source(self, selected_layer: str) -> None:
150 | """Update the layer that is set to be the 'source labels' layer for copying
151 | labels from."""
152 |
153 | if self.source_layer is not None and self._source_callback is not None:
154 | try:
155 | self.source_layer.mouse_drag_callbacks.remove(self._source_callback)
156 | self.source_layer.contour = 0
157 | except ValueError:
158 | pass
159 | if selected_layer == "":
160 | self.source_layer = None
161 | self._source_callback = None
162 | else:
163 | self.source_layer = self.viewer.layers[selected_layer]
164 | self.source_layer.contour = 1
165 | self.source_dropdown.setCurrentText(selected_layer)
166 | self._source_callback = self._make_copy_label_callback(self.source_layer)
167 | self.source_layer.mouse_drag_callbacks.append(self._source_callback)
168 |
169 | self.update_radiobuttons()
170 |
171 | def _update_target(self, selected_layer: str) -> None:
172 | """Update the layer to copy labels to."""
173 |
174 | if selected_layer == "":
175 | self.target_layer = None
176 | else:
177 | self.target_layer = self.viewer.layers[selected_layer]
178 | self.target_dropdown.setCurrentText(selected_layer)
179 |
180 | self.update_radiobuttons()
181 |
182 | def update_radiobuttons(self) -> None:
183 | """Update the state of the dimension checkboxes based on the source and target
184 | layers."""
185 |
186 | # All buttons off
187 | self.dims_widget.slice.setEnabled(False)
188 | self.dims_widget.slice.setChecked(False)
189 | self.dims_widget.volume.setEnabled(False)
190 | self.dims_widget.volume.setChecked(False)
191 | self.dims_widget.series.setEnabled(False)
192 | self.dims_widget.series.setChecked(False)
193 | self.undo_btn.setEnabled(False)
194 |
195 | if self.source_layer is not None and self.target_layer is not None:
196 | # Set the highest possible option based on the number of dimensions of the
197 | # source and target layers
198 | source_dims = self.source_layer.data.ndim
199 | target_dims = self.target_layer.data.ndim
200 | dims = min(source_dims, target_dims)
201 |
202 | if dims >= 4:
203 | self.dims_widget.series.setEnabled(True)
204 | self.dims_widget.volume.setEnabled(True)
205 | self.dims_widget.slice.setEnabled(True)
206 | self.dims_widget.volume.setChecked(True)
207 | elif dims >= 3:
208 | self.dims_widget.volume.setEnabled(True)
209 | self.dims_widget.slice.setEnabled(True)
210 | self.dims_widget.volume.setChecked(True)
211 | elif dims >= 2:
212 | self.dims_widget.slice.setEnabled(True)
213 | self.dims_widget.slice.setChecked(True)
214 |
215 | def _make_copy_label_callback(self, layer: Labels) -> callable:
216 | """Create a callback function for copying labels from the source layer to"""
217 |
218 | def callback(layer, event):
219 | if event.type == "mouse_press" and "Shift" in event.modifiers:
220 | self.copy_label(event)
221 |
222 | return callback
223 |
224 | def copy_label(self, event: Event, source_layer: Labels | None = None) -> None:
225 | """Copy a 2D/3D/4D label from this layer to a target layer"""
226 |
227 | if self.source_layer is None or self.target_layer is None:
228 | return
229 |
230 | # Check whether to copy a slice/volume/series label according to the
231 | # radiobutton choice
232 | if self.dims_widget.series.isChecked():
233 | n_dims_copied = 4
234 | elif self.dims_widget.volume.isChecked():
235 | n_dims_copied = 3
236 | else:
237 | n_dims_copied = 2
238 |
239 | if n_dims_copied == 4 and (
240 | isinstance(self.source_layer.data, da.core.Array)
241 | or isinstance(self.target_layer.data, da.core.Array)
242 | ):
243 | msg = QMessageBox()
244 | msg.setWindowTitle("Warning")
245 | msg.setText(
246 | "Copying labels in 4D dimensions between dask arrays is slow, are you sure you want to continue?"
247 | )
248 | msg.setIcon(QMessageBox.Information)
249 | msg.setStandardButtons(QMessageBox.Ok | QMessageBox.Cancel)
250 | result = msg.exec_()
251 | if result != QMessageBox.Ok:
252 | return
253 |
254 | # extract label value from source layer (orthoview or self.source_layer)
255 | source_layer = source_layer if source_layer is not None else self.source_layer
256 | selected_label = source_layer.get_value(
257 | event.position,
258 | view_direction=event.view_direction,
259 | dims_displayed=event.dims_displayed,
260 | world=True,
261 | )
262 |
263 | # do not process clicking on the background
264 | if selected_label == 0:
265 | return
266 |
267 | # extract coords from click position
268 | coords = self.source_layer.world_to_data(event.position)
269 | coords = [int(c) for c in coords]
270 |
271 | # Get dimensions of source and target layers
272 | dims_displayed = event.dims_displayed
273 | ndims_source = len(self.source_layer.data.shape)
274 | ndims_target = len(self.target_layer.data.shape)
275 |
276 | # Assign a new label value if None is provided
277 | target_label = selected_label if self.preserve_label_value.isChecked() else None
278 | if target_label is None:
279 | target_label = np.max(self.target_layer.data) + 1
280 |
281 | # Check if the target label is within the dtype range of the target layer, if not
282 | # suggest converting to a larger dtype
283 | within_range, next_dtype = check_value_dtype(
284 | target_label, self.target_layer.data.dtype
285 | )
286 | if not within_range:
287 | msg = QMessageBox()
288 | msg.setWindowTitle("Invalid label!")
289 | if next_dtype is not None:
290 | msg.setText(
291 | f"Label {target_label} exceeds bit depth! Convert to {next_dtype}?"
292 | )
293 | msg.setIcon(QMessageBox.Information)
294 | msg.setStandardButtons(QMessageBox.Ok | QMessageBox.Cancel)
295 | result = msg.exec_()
296 | if result == QMessageBox.Ok: # Check if Ok was clicked
297 | self.target_layer.data = self.target_layer.data.astype(next_dtype)
298 | else:
299 | return
300 |
301 | # Select dims to copy
302 | source_shape = self.source_layer.data.shape
303 | labels_shape = self.target_layer.data.shape
304 | source_dims_to_copy = source_shape[-n_dims_copied:]
305 | target_dims_to_copy = labels_shape[-n_dims_copied:]
306 |
307 | # Check if the dimensions to copy match
308 | if source_dims_to_copy != target_dims_to_copy:
309 | msg = QMessageBox()
310 | msg.setWindowTitle("Invalid dimensions!")
311 | msg.setText(
312 | f"The dimensions of the source layer and the target layer do not match.\n"
313 | f"Label source layer has shape {source_shape}, target layer has shape {labels_shape}.\n"
314 | f"Compared last {n_dims_copied} dims: {source_dims_to_copy} vs {target_dims_to_copy}."
315 | )
316 | msg.setIcon(QMessageBox.Information)
317 | msg.setStandardButtons(QMessageBox.Ok)
318 | msg.exec_()
319 | return False
320 |
321 | # Create source_slices for all dimensions of the source layer
322 | source_slices = [slice(None)] * ndims_source
323 |
324 | # When copying 2D labels, we need to check the dimensions displayed, in case
325 | # the sure is copying from one of the orthoviews.
326 | dims_difference = ndims_source - ndims_target
327 | if n_dims_copied == 2:
328 | # Create a list of `slice(None)` for all dimensions of self.source_layer.data
329 | if dims_difference < 0:
330 | dims_displayed = [dim + dims_difference for dim in dims_displayed]
331 | for i in range(ndims_source):
332 | if i not in dims_displayed:
333 | source_slices[i] = coords[
334 | i
335 | ] # Replace the slice with a specific coordinate for slider dims
336 |
337 | else:
338 | # Calculate the coords for the remaining dims
339 | remaining_coords = coords[:-n_dims_copied]
340 | for i, coord in enumerate(remaining_coords):
341 | source_slices[i] = coord
342 |
343 | # Clip coords to the shape of the target data
344 | if dims_difference > 0:
345 | coords_clipped = coords[dims_difference:]
346 | target_slices = source_slices[dims_difference:]
347 | elif dims_difference < 0:
348 | coords_clipped = [
349 | *self.viewer.dims.current_step[: abs(dims_difference)],
350 | *coords,
351 | ]
352 | target_slices = [
353 | *self.viewer.dims.current_step[: abs(dims_difference)],
354 | *source_slices,
355 | ]
356 | else:
357 | coords_clipped = coords
358 | target_slices = source_slices
359 |
360 | # Create mask
361 | if isinstance(self.source_layer.data, da.core.Array):
362 | mask = (
363 | self.source_layer.data[tuple(source_slices)].compute() == selected_label
364 | )
365 | else:
366 | mask = self.source_layer.data[tuple(source_slices)] == selected_label
367 |
368 | # Select the correct stack for 2D/3D/4D data
369 | orig_label = self.target_layer.data[tuple(coords_clipped)]
370 | sliced_data = self.target_layer.data[tuple(target_slices)]
371 | if isinstance(sliced_data, da.core.Array):
372 | sliced_data = sliced_data.compute()
373 |
374 | # Store previous state for undo
375 | self.prev_state = np.copy(sliced_data)
376 | self.target_slices = np.copy(target_slices)
377 | self.coords_clipped = np.copy(coords_clipped)
378 |
379 | # Replace label in target layer data
380 | orig_mask = sliced_data == orig_label
381 | if not self.preserve_existing_labels.isChecked():
382 | sliced_data[orig_mask] = 0
383 | sliced_data[mask] = target_label
384 | else:
385 | sliced_data[orig_mask & (sliced_data == 0) & mask] = target_label
386 |
387 | self.target_layer.data[tuple(target_slices)] = sliced_data
388 | self.undo_btn.setEnabled(True)
389 |
390 | # refresh the layer
391 | self.target_layer.data = self.target_layer.data
392 |
393 | def undo(self) -> None:
394 | """Undo the last label copy operation"""
395 |
396 | if hasattr(self, "prev_state") and self.prev_state is not None:
397 | self.target_layer.data[tuple(self.target_slices)] = self.prev_state
398 |
399 | # refresh the layer
400 | self.target_layer.data = self.target_layer.data
401 |
402 | # set back to None and disable button
403 | self.prev_state = None
404 | self.coords_clipped = None
405 | self.target_slices = None
406 | self.undo_btn.setEnabled(False)
407 |
408 | def sync_click(
409 | self, orig_layer: Labels, copied_layer: Labels, event: Event
410 | ) -> None:
411 | """Forward the click event from orthogonal views"""
412 |
413 | if (
414 | orig_layer is self.source_layer
415 | and event.type == "mouse_press"
416 | and "Shift" in event.modifiers
417 | ):
418 | # pass the copied layer for extracting the label value, because an orthoview
419 | # was used
420 | self.copy_label(event, copied_layer)
421 |
--------------------------------------------------------------------------------
/src/napari_segmentation_correction/cross_widget.py:
--------------------------------------------------------------------------------
1 | import napari
2 | import numpy as np
3 | from napari.components.layerlist import Extent
4 | from napari.components.viewer_model import ViewerModel
5 | from napari.layers import Vectors
6 | from napari.utils.action_manager import action_manager
7 | from napari.utils.notifications import show_info
8 | from qtpy.QtWidgets import (
9 | QCheckBox,
10 | )
11 | from superqt.utils import qthrottled
12 |
13 |
14 | def center_cross_on_mouse(
15 | viewer_model: napari.components.viewer_model.ViewerModel,
16 | ):
17 | """move the cross to the mouse position"""
18 |
19 | if not getattr(viewer_model, "mouse_over_canvas", True):
20 | # There is no way for napari 0.4.15 to check if mouse is over sending canvas.
21 | show_info("Mouse is not over the canvas. You may need to click on the canvas.")
22 | return
23 |
24 | viewer_model.dims.current_step = tuple(
25 | np.round(
26 | [
27 | max(min_, min(p, max_)) / step
28 | for p, (min_, max_, step) in zip(
29 | viewer_model.cursor.position, viewer_model.dims.range, strict=False
30 | )
31 | ]
32 | ).astype(int)
33 | )
34 |
35 |
36 | action_manager.register_action(
37 | name="napari:move_point",
38 | command=center_cross_on_mouse,
39 | description="Move dims point to mouse position",
40 | keymapprovider=ViewerModel,
41 | )
42 | class CrossWidget(QCheckBox):
43 | """
44 | Widget to control the cross layer. because of the performance reason
45 | the cross update is throttled
46 | """
47 |
48 | def __init__(self, viewer: napari.Viewer):
49 | super().__init__("Add cross layer")
50 | self.viewer = viewer
51 | self.setChecked(False)
52 | self.stateChanged.connect(self._update_cross_visibility)
53 | self.layer = None
54 | self.color = 'red'
55 | self.viewer.dims.events.order.connect(self.update_cross)
56 | self.viewer.dims.events.ndim.connect(self._update_ndim)
57 | self.viewer.dims.events.current_step.connect(self.update_cross)
58 | self._extent = None
59 |
60 | self._update_extent()
61 | self.viewer.dims.events.connect(self._update_extent)
62 |
63 | @qthrottled(leading=False)
64 | def _update_extent(self):
65 | """
66 | Calculate the extent of the data.
67 |
68 | Ignores the cross layer itself in calculating the extent.
69 | """
70 |
71 | extent_list = [
72 | layer.extent for layer in self.viewer.layers if layer is not self.layer
73 | ]
74 | self._extent = Extent(
75 | data=None,
76 | world=self.viewer.layers._get_extent_world(extent_list),
77 | step=self.viewer.layers._get_step_size(extent_list),
78 | )
79 | self.update_cross()
80 |
81 | def _update_ndim(self, event):
82 | if self.layer in self.viewer.layers:
83 | self.viewer.layers.remove(self.layer)
84 | self.layer = Vectors(name=".cross", ndim=event.value)
85 | self.layer.vector_style = "line"
86 | self.layer.edge_width = 2
87 | self.layer.edge_color = self.color
88 | self.update_cross()
89 | self.layer.events.edge_color.connect(self._set_color)
90 |
91 | def _set_color(self, event):
92 | self.color = self.layer.edge_color
93 |
94 | def _update_cross_visibility(self, state):
95 | if state:
96 | if self.layer is None:
97 | self.layer = Vectors(name=".cross", ndim=self.viewer.dims.ndim)
98 | self.layer.vector_style = "line"
99 | self.layer.edge_width = 2
100 | self.viewer.layers.append(self.layer)
101 | else:
102 | self.viewer.layers.remove(self.layer)
103 | self.update_cross()
104 | if not np.any(self.layer.edge_color):
105 | self.layer.edge_color = self.color
106 | self.layer.vector_style = "line"
107 |
108 | def update_cross(self):
109 | if self.layer not in self.viewer.layers:
110 | self.setChecked(False)
111 | return
112 |
113 | with self.viewer.dims.events.blocker():
114 |
115 | point = self.viewer.dims.current_step
116 | vec = []
117 | for i, (lower, upper) in enumerate(self._extent.world.T):
118 | if (upper - lower) / self._extent.step[i] == 1:
119 | continue
120 | point1 = list(point)
121 | point1[i] = (lower + self._extent.step[i] / 2) / self._extent.step[i]
122 | point2 = [0 for _ in point]
123 | point2[i] = (upper - lower) / self._extent.step[i]
124 | vec.append((point1, point2))
125 | if np.any(self.layer.scale != self._extent.step):
126 | self.layer.scale = self._extent.step
127 | self.layer.data = vec
128 |
--------------------------------------------------------------------------------
/src/napari_segmentation_correction/custom_table_widget.py:
--------------------------------------------------------------------------------
1 | import napari
2 | import pandas as pd
3 | from matplotlib.colors import to_rgb
4 | from qtpy.QtGui import QColor
5 | from qtpy.QtWidgets import (
6 | QFileDialog,
7 | QHBoxLayout,
8 | QPushButton,
9 | QStyledItemDelegate,
10 | QTableWidget,
11 | QTableWidgetItem,
12 | QVBoxLayout,
13 | QWidget,
14 | )
15 |
16 |
17 | class FloatDelegate(QStyledItemDelegate):
18 | def __init__(self, decimals, parent=None):
19 | super().__init__(parent)
20 | self.nDecimals = decimals
21 |
22 | def displayText(self, value, locale):
23 | try:
24 | number = float(value)
25 | except (ValueError, TypeError):
26 | return str(value)
27 |
28 | if number.is_integer():
29 | return str(int(number))
30 | return f"{number:.{self.nDecimals}f}"
31 |
32 |
33 | class ColoredTableWidget(QWidget):
34 | """Customized table widget with colored rows based on label colors in a napari Labels layer"""
35 |
36 | def __init__(self, layer: "napari.layers.Layer", viewer: "napari.Viewer" = None):
37 | super().__init__()
38 |
39 | self._layer = layer
40 | self._viewer = viewer
41 | self._table_widget = QTableWidget()
42 |
43 | self._layer.events.colormap.connect(self._set_label_colors_to_rows)
44 | if hasattr(layer, "properties"):
45 | self._set_data(layer.properties)
46 | else:
47 | self._set_data({})
48 | self.ascending = False # for choosing whether to sort ascending or descending
49 |
50 | # Reconnect the clicked signal to your custom method.
51 | self._table_widget.clicked.connect(self._clicked_table)
52 |
53 | # Connect to single click in the header to sort the table.
54 | self._table_widget.horizontalHeader().sectionClicked.connect(self._sort_table)
55 |
56 | copy_button = QPushButton("Copy to clipboard")
57 | copy_button.clicked.connect(self._copy_table)
58 |
59 | save_button = QPushButton("Save as csv")
60 | save_button.clicked.connect(self._save_table)
61 |
62 | button_layout = QHBoxLayout()
63 | button_layout.addWidget(copy_button)
64 | button_layout.addWidget(save_button)
65 | main_layout = QVBoxLayout()
66 | main_layout.addLayout(button_layout)
67 | main_layout.addWidget(self._table_widget)
68 | self.setLayout(main_layout)
69 | self.setMinimumHeight(300)
70 |
71 | def _set_data(self, table: dict) -> None:
72 | """Set the content of the table from a dictionary"""
73 |
74 | self._table = table
75 | self._layer.properties = table
76 |
77 | self._table_widget.clear()
78 | try:
79 | self._table_widget.setRowCount(len(next(iter(table.values()))))
80 | self._table_widget.setColumnCount(len(table))
81 | except StopIteration:
82 | pass
83 |
84 | for i, column in enumerate(table):
85 | self._table_widget.setHorizontalHeaderItem(i, QTableWidgetItem(column))
86 | for j, value in enumerate(table.get(column)):
87 | self._table_widget.setItem(j, i, QTableWidgetItem(str(value)))
88 |
89 | self._table_widget.setItemDelegate(FloatDelegate(3, self._table_widget))
90 |
91 | self._set_label_colors_to_rows()
92 |
93 | def _set_label_colors_to_rows(self) -> None:
94 | """Apply the colors of the napari label image to the table"""
95 |
96 | for i in range(self._table_widget.rowCount()):
97 | label = self._table["label"][i]
98 | label_color = to_rgb(self._layer.get_color(label))
99 | scaled_color = (
100 | int(label_color[0] * 255),
101 | int(label_color[1] * 255),
102 | int(label_color[2] * 255),
103 | )
104 | for j in range(self._table_widget.columnCount()):
105 | self._table_widget.item(i, j).setBackground(QColor(*scaled_color))
106 |
107 | def _save_table(self) -> None:
108 | """Save table to csv file"""
109 | filename, _ = QFileDialog.getSaveFileName(self, "Save as csv", ".", "*.csv")
110 | pd.DataFrame(self._table).to_csv(filename)
111 |
112 | def _copy_table(self) -> None:
113 | """Copy table to clipboard"""
114 | pd.DataFrame(self._table).to_clipboard()
115 |
116 | def _clicked_table(self):
117 | """Set the current viewer dims to the label that was clicked on."""
118 |
119 | row = self._table_widget.currentRow()
120 | spatial_columns = sorted([key for key in self._table if "centroid" in key])
121 | spatial_coords = [int(self._table[col][row]) for col in spatial_columns]
122 |
123 | if "time_point" in self._table:
124 | t = int(self._table["time_point"][row])
125 | new_step = (t, *spatial_coords)
126 | else:
127 | new_step = spatial_coords
128 | self._viewer.dims.current_step = new_step
129 |
130 | def _sort_table(self):
131 | """Sorts the table in ascending or descending order"""
132 |
133 | selected_column = list(self._table.keys())[self._table_widget.currentColumn()]
134 | df = pd.DataFrame(self._table).sort_values(
135 | by=selected_column, ascending=self.ascending
136 | )
137 | self.ascending = not self.ascending
138 |
139 | self._set_data(df.to_dict(orient="list"))
140 | self._set_label_colors_to_rows()
141 |
--------------------------------------------------------------------------------
/src/napari_segmentation_correction/erosion_dilation_widget.py:
--------------------------------------------------------------------------------
1 | import os
2 | import shutil
3 |
4 | import dask.array as da
5 | import napari
6 | import numpy as np
7 | import tifffile
8 | from qtpy.QtWidgets import (
9 | QFileDialog,
10 | QGroupBox,
11 | QHBoxLayout,
12 | QLabel,
13 | QPushButton,
14 | QSpinBox,
15 | QVBoxLayout,
16 | QWidget,
17 | )
18 | from scipy import ndimage
19 | from scipy.ndimage import binary_erosion
20 | from skimage.io import imread
21 | from skimage.segmentation import expand_labels
22 |
23 | from .layer_manager import LayerManager
24 |
25 |
26 | class ErosionDilationWidget(QWidget):
27 | """Widget to perform erosion/dilation on label images"""
28 |
29 | def __init__(
30 | self, viewer: "napari.viewer.Viewer", label_manager: LayerManager
31 | ) -> None:
32 | super().__init__()
33 |
34 | self.viewer = viewer
35 | self.label_manager = label_manager
36 | self.outputdir = None
37 |
38 | dil_erode_box = QGroupBox("Erode/dilate labels")
39 | dil_erode_box_layout = QVBoxLayout()
40 |
41 | radius_layout = QHBoxLayout()
42 | str_element_diameter_label = QLabel("Structuring element diameter")
43 | str_element_diameter_label.setFixedWidth(200)
44 | self.structuring_element_diameter = QSpinBox()
45 | self.structuring_element_diameter.setMaximum(100)
46 | self.structuring_element_diameter.setValue(1)
47 | radius_layout.addWidget(str_element_diameter_label)
48 | radius_layout.addWidget(self.structuring_element_diameter)
49 |
50 | iterations_layout = QHBoxLayout()
51 | iterations_label = QLabel("Iterations")
52 | iterations_label.setFixedWidth(200)
53 | self.iterations = QSpinBox()
54 | self.iterations.setMaximum(100)
55 | self.iterations.setValue(1)
56 | iterations_layout.addWidget(iterations_label)
57 | iterations_layout.addWidget(self.iterations)
58 |
59 | shrink_dilate_buttons_layout = QHBoxLayout()
60 | self.erode_btn = QPushButton("Erode")
61 | self.dilate_btn = QPushButton("Dilate")
62 | self.erode_btn.clicked.connect(self._erode_labels)
63 | self.dilate_btn.clicked.connect(self._dilate_labels)
64 | shrink_dilate_buttons_layout.addWidget(self.erode_btn)
65 | shrink_dilate_buttons_layout.addWidget(self.dilate_btn)
66 |
67 | self.erode_btn.setEnabled(
68 | isinstance(self.label_manager.selected_layer, napari.layers.Labels)
69 | )
70 | self.label_manager.layer_update.connect(
71 | lambda: self.erode_btn.setEnabled(
72 | isinstance(self.label_manager.selected_layer, napari.layers.Labels)
73 | )
74 | )
75 | self.dilate_btn.setEnabled(
76 | isinstance(self.label_manager.selected_layer, napari.layers.Labels)
77 | )
78 | self.label_manager.layer_update.connect(
79 | lambda: self.dilate_btn.setEnabled(
80 | isinstance(self.label_manager.selected_layer, napari.layers.Labels)
81 | )
82 | )
83 |
84 | dil_erode_box_layout.addLayout(radius_layout)
85 | dil_erode_box_layout.addLayout(iterations_layout)
86 | dil_erode_box_layout.addLayout(shrink_dilate_buttons_layout)
87 |
88 | dil_erode_box.setLayout(dil_erode_box_layout)
89 |
90 | layout = QVBoxLayout()
91 | layout.addWidget(dil_erode_box)
92 | self.setLayout(layout)
93 |
94 | def _erode_labels(self):
95 | """Shrink oversized labels through erosion"""
96 |
97 | diam = self.structuring_element_diameter.value()
98 | iterations = self.iterations.value()
99 |
100 | if self.label_manager.selected_layer.data.ndim == 2:
101 | structuring_element = np.ones(
102 | (diam, diam), dtype=bool
103 | ) # Define a 3x3 structuring element for 2D erosion
104 | else:
105 | structuring_element = np.ones(
106 | (diam, diam, diam), dtype=bool
107 | ) # Define a 3x3x3 structuring element for 3D erosion
108 |
109 | if isinstance(self.label_manager.selected_layer.data, da.core.Array):
110 | if self.outputdir is None:
111 | self.outputdir = QFileDialog.getExistingDirectory(
112 | self, "Select Output Folder"
113 | )
114 |
115 | outputdir = os.path.join(
116 | self.outputdir,
117 | (self.label_manager.selected_layer.name + "_eroded"),
118 | )
119 | if os.path.exists(outputdir):
120 | shutil.rmtree(outputdir)
121 | os.mkdir(outputdir)
122 |
123 | for i in range(
124 | self.label_manager.selected_layer.data.shape[0]
125 | ): # Loop over the first dimension
126 | current_stack = self.label_manager.selected_layer.data[
127 | i
128 | ].compute() # Compute the current stack
129 | mask = current_stack > 0
130 | filled_mask = ndimage.binary_fill_holes(mask)
131 | eroded_mask = binary_erosion(
132 | filled_mask,
133 | structure=structuring_element,
134 | iterations=iterations,
135 | )
136 | eroded = np.where(eroded_mask, current_stack, 0)
137 | tifffile.imwrite(
138 | os.path.join(
139 | outputdir,
140 | (
141 | self.label_manager.selected_layer.name
142 | + "_eroded_TP"
143 | + str(i).zfill(4)
144 | + ".tif"
145 | ),
146 | ),
147 | np.array(eroded, dtype="uint16"),
148 | )
149 |
150 | file_list = [
151 | os.path.join(outputdir, fname)
152 | for fname in os.listdir(outputdir)
153 | if fname.endswith(".tif")
154 | ]
155 | self.label_manager.selected_layer = self.viewer.add_labels(
156 | da.stack([imread(fname) for fname in sorted(file_list)]),
157 | name=self.label_manager.selected_layer.name + "_eroded",
158 | scale=self.label_manager.selected_layer.scale,
159 | )
160 |
161 | else:
162 | if len(self.label_manager.selected_layer.data.shape) == 4:
163 | stack = []
164 | for i in range(self.label_manager.selected_layer.data.shape[0]):
165 | mask = self.label_manager.selected_layer.data[i] > 0
166 | filled_mask = ndimage.binary_fill_holes(mask)
167 | eroded_mask = binary_erosion(
168 | filled_mask,
169 | structure=structuring_element,
170 | iterations=iterations,
171 | )
172 | stack.append(
173 | np.where(
174 | eroded_mask,
175 | self.label_manager.selected_layer.data[i],
176 | 0,
177 | )
178 | )
179 | self.label_manager.selected_layer = self.viewer.add_labels(
180 | np.stack(stack, axis=0),
181 | name=self.label_manager.selected_layer.name + "_eroded",
182 | scale=self.label_manager.selected_layer.scale,
183 | )
184 |
185 | elif self.label_manager.selected_layer.data.ndim in (2, 3):
186 | mask = self.label_manager.selected_layer.data > 0
187 | filled_mask = ndimage.binary_fill_holes(mask)
188 | eroded_mask = binary_erosion(
189 | filled_mask,
190 | structure=structuring_element,
191 | iterations=iterations,
192 | )
193 | self.label_manager.selected_layer = self.viewer.add_labels(
194 | np.where(eroded_mask, self.label_manager.selected_layer.data, 0),
195 | name=self.label_manager.selected_layer.name + "_eroded",
196 | scale=self.label_manager.selected_layer.scale,
197 | )
198 | else:
199 | print("4D, 3D, or 2D array required!")
200 |
201 | def _dilate_labels(self):
202 | """Dilate labels in the selected layer."""
203 |
204 | diam = self.structuring_element_diameter.value()
205 | iterations = self.iterations.value()
206 |
207 | if isinstance(self.label_manager.selected_layer.data, da.core.Array):
208 | if self.outputdir is None:
209 | self.outputdir = QFileDialog.getExistingDirectory(
210 | self, "Select Output Folder"
211 | )
212 |
213 | outputdir = os.path.join(
214 | self.outputdir,
215 | (self.label_manager.selected_layer.name + "_dilated"),
216 | )
217 | if os.path.exists(outputdir):
218 | shutil.rmtree(outputdir)
219 | os.mkdir(outputdir)
220 |
221 | for i in range(
222 | self.label_manager.selected_layer.data.shape[0]
223 | ): # Loop over the first dimension
224 | expanded_labels = self.label_manager.selected_layer.data[
225 | i
226 | ].compute() # Compute the current stack
227 | for _j in range(iterations):
228 | expanded_labels = expand_labels(expanded_labels, distance=diam)
229 | tifffile.imwrite(
230 | os.path.join(
231 | outputdir,
232 | (
233 | self.label_manager.selected_layer.name
234 | + "_dilated_TP"
235 | + str(i).zfill(4)
236 | + ".tif"
237 | ),
238 | ),
239 | np.array(expanded_labels, dtype="uint16"),
240 | )
241 |
242 | file_list = [
243 | os.path.join(outputdir, fname)
244 | for fname in os.listdir(outputdir)
245 | if fname.endswith(".tif")
246 | ]
247 | self.label_manager.selected_layer = self.viewer.add_labels(
248 | da.stack([imread(fname) for fname in sorted(file_list)]),
249 | name=self.label_manager.selected_layer.name + "_dilated",
250 | scale=self.label_manager.selected_layer.scale,
251 | )
252 |
253 | else:
254 | if len(self.label_manager.selected_layer.data.shape) == 4:
255 | stack = []
256 | for i in range(self.label_manager.selected_layer.data.shape[0]):
257 | expanded_labels = self.label_manager.selected_layer.data[i]
258 | for _j in range(iterations):
259 | expanded_labels = expand_labels(expanded_labels, distance=diam)
260 | stack.append(expanded_labels)
261 | self.label_manager.selected_layer = self.viewer.add_labels(
262 | np.stack(stack, axis=0),
263 | name=self.label_manager.selected_layer.name + "_dilated",
264 | scale=self.label_manager.selected_layer.scale,
265 | )
266 |
267 | elif self.label_manager.selected_layer.data.ndim in (2, 3):
268 | expanded_labels = self.label_manager.selected_layer.data
269 | for _i in range(iterations):
270 | expanded_labels = expand_labels(expanded_labels, distance=diam)
271 |
272 | self.label_manager.selected_layer = self.viewer.add_labels(
273 | expanded_labels,
274 | name=self.label_manager.selected_layer.name + "_dilated",
275 | scale=self.label_manager.selected_layer.scale,
276 | )
277 | else:
278 | print("input should be a 2D, 3D or 4D label image.")
279 |
--------------------------------------------------------------------------------
/src/napari_segmentation_correction/icons/Back.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AnniekStok/napari-segmentation-correction/ece844476d64050a1f2ce54421aef53689088b45/src/napari_segmentation_correction/icons/Back.png
--------------------------------------------------------------------------------
/src/napari_segmentation_correction/icons/Forward.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AnniekStok/napari-segmentation-correction/ece844476d64050a1f2ce54421aef53689088b45/src/napari_segmentation_correction/icons/Forward.png
--------------------------------------------------------------------------------
/src/napari_segmentation_correction/icons/Home.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AnniekStok/napari-segmentation-correction/ece844476d64050a1f2ce54421aef53689088b45/src/napari_segmentation_correction/icons/Home.png
--------------------------------------------------------------------------------
/src/napari_segmentation_correction/icons/Pan.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AnniekStok/napari-segmentation-correction/ece844476d64050a1f2ce54421aef53689088b45/src/napari_segmentation_correction/icons/Pan.png
--------------------------------------------------------------------------------
/src/napari_segmentation_correction/icons/Pan_checked.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AnniekStok/napari-segmentation-correction/ece844476d64050a1f2ce54421aef53689088b45/src/napari_segmentation_correction/icons/Pan_checked.png
--------------------------------------------------------------------------------
/src/napari_segmentation_correction/icons/Zoom.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AnniekStok/napari-segmentation-correction/ece844476d64050a1f2ce54421aef53689088b45/src/napari_segmentation_correction/icons/Zoom.png
--------------------------------------------------------------------------------
/src/napari_segmentation_correction/icons/Zoom_checked.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AnniekStok/napari-segmentation-correction/ece844476d64050a1f2ce54421aef53689088b45/src/napari_segmentation_correction/icons/Zoom_checked.png
--------------------------------------------------------------------------------
/src/napari_segmentation_correction/icons/configure_subplots.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AnniekStok/napari-segmentation-correction/ece844476d64050a1f2ce54421aef53689088b45/src/napari_segmentation_correction/icons/configure_subplots.png
--------------------------------------------------------------------------------
/src/napari_segmentation_correction/icons/edit_parameters.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AnniekStok/napari-segmentation-correction/ece844476d64050a1f2ce54421aef53689088b45/src/napari_segmentation_correction/icons/edit_parameters.png
--------------------------------------------------------------------------------
/src/napari_segmentation_correction/icons/save_figure.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AnniekStok/napari-segmentation-correction/ece844476d64050a1f2ce54421aef53689088b45/src/napari_segmentation_correction/icons/save_figure.png
--------------------------------------------------------------------------------
/src/napari_segmentation_correction/image_calculator.py:
--------------------------------------------------------------------------------
1 | import dask.array as da
2 | import napari
3 | import numpy as np
4 | from napari.layers import Image, Labels
5 | from qtpy.QtWidgets import (
6 | QComboBox,
7 | QGroupBox,
8 | QHBoxLayout,
9 | QLabel,
10 | QMessageBox,
11 | QPushButton,
12 | QVBoxLayout,
13 | QWidget,
14 | )
15 |
16 | from .layer_dropdown import LayerDropdown
17 |
18 |
19 | class ImageCalculator(QWidget):
20 | """Widget to perform calculations between two images"""
21 |
22 | def __init__(self, viewer: "napari.viewer.Viewer") -> None:
23 | super().__init__()
24 |
25 | self.viewer = viewer
26 |
27 | ### Add one image to another
28 | image_calc_box = QGroupBox("Image Calculator")
29 | image_calc_box_layout = QVBoxLayout()
30 |
31 | image1_layout = QHBoxLayout()
32 | image1_layout.addWidget(QLabel("Label image 1"))
33 | self.image1_dropdown = LayerDropdown(self.viewer, (Image, Labels))
34 | self.image1_dropdown.layer_changed.connect(self._update_image1)
35 | image1_layout.addWidget(self.image1_dropdown)
36 |
37 | image2_layout = QHBoxLayout()
38 | image2_layout.addWidget(QLabel("Label image 2"))
39 | self.image2_dropdown = LayerDropdown(self.viewer, (Image, Labels))
40 | self.image2_dropdown.layer_changed.connect(self._update_image2)
41 | image2_layout.addWidget(self.image2_dropdown)
42 |
43 | image_calc_box_layout.addLayout(image1_layout)
44 | image_calc_box_layout.addLayout(image2_layout)
45 |
46 | operation_layout = QHBoxLayout()
47 | self.operation = QComboBox()
48 | self.operation.addItem("Add")
49 | self.operation.addItem("Subtract")
50 | self.operation.addItem("Multiply")
51 | self.operation.addItem("Divide")
52 | self.operation.addItem("AND")
53 | self.operation.addItem("OR")
54 | operation_layout.addWidget(QLabel("Operation"))
55 | operation_layout.addWidget(self.operation)
56 | image_calc_box_layout.addLayout(operation_layout)
57 |
58 | add_images_btn = QPushButton("Run")
59 | add_images_btn.clicked.connect(self._calculate_images)
60 | add_images_btn.setEnabled(self.image1_dropdown.selected_layer is not None and self.image2_dropdown.selected_layer is not None)
61 | self.image1_dropdown.layer_changed.connect(lambda: add_images_btn.setEnabled(self.image1_dropdown.selected_layer is not None and self.image2_dropdown.selected_layer is not None))
62 | self.image2_dropdown.layer_changed.connect(lambda: add_images_btn.setEnabled(self.image1_dropdown.selected_layer is not None and self.image2_dropdown.selected_layer is not None))
63 |
64 | image_calc_box_layout.addWidget(add_images_btn)
65 |
66 | image_calc_box.setLayout(image_calc_box_layout)
67 | main_layout = QVBoxLayout()
68 | main_layout.addWidget(image_calc_box)
69 | self.setLayout(main_layout)
70 |
71 | def _update_image1(self, selected_layer: str) -> None:
72 | """Update the layer that is set to be the 'source labels' layer for copying labels from."""
73 |
74 | if selected_layer == "":
75 | self.image1_layer = None
76 | else:
77 | self.image1_layer = self.viewer.layers[selected_layer]
78 | self.image1_dropdown.setCurrentText(selected_layer)
79 |
80 | def _update_image2(self, selected_layer: str) -> None:
81 | """Update the layer that is set to be the 'source labels' layer for copying labels from."""
82 |
83 | if selected_layer == "":
84 | self.image2_layer = None
85 | else:
86 | self.image2_layer = self.viewer.layers[selected_layer]
87 | self.image2_dropdown.setCurrentText(selected_layer)
88 |
89 | def _calculate_images(self):
90 | """Add label image 2 to label image 1"""
91 |
92 | if isinstance(self.image1_layer, da.core.Array) or isinstance(
93 | self.image2_layer, da.core.Array
94 | ):
95 | msg = QMessageBox()
96 | msg.setWindowTitle(
97 | "Cannot yet run image calculator on dask arrays"
98 | )
99 | msg.setText("Cannot yet run image calculator on dask arrays")
100 | msg.setIcon(QMessageBox.Information)
101 | msg.setStandardButtons(QMessageBox.Ok)
102 | msg.exec_()
103 | return False
104 | if self.image1_layer.data.shape != self.image2_layer.data.shape:
105 | msg = QMessageBox()
106 | msg.setWindowTitle("Images must have the same shape")
107 | msg.setText("Images must have the same shape")
108 | msg.setIcon(QMessageBox.Information)
109 | msg.setStandardButtons(QMessageBox.Ok)
110 | msg.exec_()
111 | return False
112 |
113 | if self.operation.currentText() == "Add":
114 | self.viewer.add_image(
115 | np.add(self.image1_layer.data, self.image2_layer.data),
116 | scale = self.image1_layer.scale,
117 | )
118 | if self.operation.currentText() == "Subtract":
119 | self.viewer.add_image(
120 | np.subtract(self.image1_layer.data, self.image2_layer.data),
121 | scale = self.image1_layer.scale,
122 | )
123 | if self.operation.currentText() == "Multiply":
124 | self.viewer.add_image(
125 | np.multiply(self.image1_layer.data, self.image2_layer.data),
126 | scale = self.image1_layer.scale,
127 | )
128 | if self.operation.currentText() == "Divide":
129 | self.viewer.add_image(
130 | np.divide(
131 | self.image1_layer.data,
132 | self.image2_layer.data,
133 | out=np.zeros_like(self.image1_layer.data, dtype=float),
134 | where=self.image2_layer.data != 0,
135 | ),
136 | scale = self.image1_layer.scale,
137 | )
138 | if self.operation.currentText() == "AND":
139 | self.viewer.add_labels(
140 | np.logical_and(
141 | self.image1_layer.data != 0, self.image2_layer.data != 0
142 | ).astype(int),
143 | scale = self.image1_layer.scale,
144 | )
145 | if self.operation.currentText() == "OR":
146 | self.viewer.add_labels(
147 | np.logical_or(
148 | self.image1_layer.data != 0, self.image2_layer.data != 0
149 | ).astype(int),
150 | scale = self.image1_layer.scale,
151 | )
152 |
--------------------------------------------------------------------------------
/src/napari_segmentation_correction/label_interpolator.py:
--------------------------------------------------------------------------------
1 | import os
2 | import shutil
3 |
4 | import dask.array as da
5 | import napari
6 | import numpy as np
7 | import tifffile
8 | from qtpy.QtWidgets import (
9 | QFileDialog,
10 | QGroupBox,
11 | QPushButton,
12 | QVBoxLayout,
13 | QWidget,
14 | )
15 | from scipy.ndimage import distance_transform_edt
16 | from skimage.io import imread
17 |
18 | from .layer_manager import LayerManager
19 |
20 |
21 | def signed_distance_transform(mask):
22 | mask = mask.astype(bool)
23 | dist_out = distance_transform_edt(~mask)
24 | dist_in = distance_transform_edt(mask)
25 | return dist_out - dist_in
26 |
27 |
28 | def interpolate_binary_mask(mask):
29 | """
30 | Interpolates a sparse binary mask array using SDTs along the first axis.
31 | Args:
32 | mask (ndarray): Binary array of shape (T, X, Y, Z) or (X, Y, Z), etc.,
33 | where some slices along 'axis' contain valid masks
34 | Returns:
35 | ndarray: Binary array of same shape with interpolated masks along 'axis'
36 | """
37 |
38 | output = np.zeros_like(mask, dtype=np.uint8)
39 |
40 | # Find slices along axis that have any nonzero values
41 | valid_idxs = [i for i in range(mask.shape[0]) if np.any(mask[i])]
42 |
43 | for i in range(len(valid_idxs) - 1):
44 | i_start, i_end = valid_idxs[i], valid_idxs[i + 1]
45 | sdt_start = signed_distance_transform(mask[i_start])
46 | sdt_end = signed_distance_transform(mask[i_end])
47 |
48 | for j in range(i_start, i_end + 1):
49 | alpha = (j - i_start) / (i_end - i_start)
50 | sdt_interp = (1 - alpha) * sdt_start + alpha * sdt_end
51 | output[j] = (sdt_interp < 0).astype(np.uint8)
52 |
53 | return output
54 |
55 |
56 | class InterpolationWidget(QWidget):
57 | """Widget to interpolate between nonzero pixels in a label layer using signed distance transforms."""
58 |
59 | def __init__(
60 | self, viewer: "napari.viewer.Viewer", label_manager: LayerManager
61 | ) -> None:
62 | super().__init__()
63 |
64 | self.viewer = viewer
65 | self.label_manager = label_manager
66 | self.outputdir = None
67 |
68 | interpolator_box = QGroupBox("Interpolate mask")
69 | interpolator_box_layout = QVBoxLayout()
70 |
71 | run_btn = QPushButton("Run interpolation along first axis")
72 | run_btn.clicked.connect(self._interpolate)
73 | run_btn.setEnabled(self.label_manager.selected_layer is not None)
74 | self.label_manager.layer_update.connect(
75 | lambda: run_btn.setEnabled(self.label_manager.selected_layer is not None)
76 | )
77 | interpolator_box_layout.addWidget(run_btn)
78 |
79 | interpolator_box.setLayout(interpolator_box_layout)
80 | main_layout = QVBoxLayout()
81 | main_layout.addWidget(interpolator_box)
82 | self.setLayout(main_layout)
83 |
84 | def _interpolate(self):
85 | """Interpolate between the nonzero pixels in the selected layer"""
86 |
87 | if isinstance(self.label_manager.selected_layer.data, da.core.Array):
88 | if self.outputdir is None:
89 | self.outputdir = QFileDialog.getExistingDirectory(
90 | self, "Select Output Folder"
91 | )
92 |
93 | outputdir = os.path.join(
94 | self.outputdir,
95 | (self.label_manager.selected_layer.name + "_interpolated"),
96 | )
97 | if os.path.exists(outputdir):
98 | shutil.rmtree(outputdir)
99 | os.mkdir(outputdir)
100 |
101 | in_memory_stack = []
102 | for i in range(
103 | self.label_manager.selected_layer.data.shape[0]
104 | ): # Loop over the first dimension
105 | current_stack = self.label_manager.selected_layer.data[
106 | i
107 | ].compute() # Compute the current stack
108 |
109 | in_memory_stack.append(current_stack)
110 |
111 | in_memory_stack = np.stack(in_memory_stack, axis=0)
112 | interpolated = interpolate_binary_mask(in_memory_stack)
113 |
114 | for i in range(interpolated.shape[0]):
115 | tifffile.imwrite(
116 | os.path.join(
117 | outputdir,
118 | (
119 | self.label_manager.selected_layer.name
120 | + "_interpolation_TP"
121 | + str(i).zfill(4)
122 | + ".tif"
123 | ),
124 | ),
125 | np.array(interpolated[i], dtype="uint8"),
126 | )
127 |
128 | file_list = [
129 | os.path.join(outputdir, fname)
130 | for fname in os.listdir(outputdir)
131 | if fname.endswith(".tif")
132 | ]
133 | self.label_manager.selected_layer = self.viewer.add_labels(
134 | da.stack([imread(fname) for fname in sorted(file_list)]),
135 | name=self.label_manager.selected_layer.name + "_interpolated",
136 | scale=self.label_manager.selected_layer.scale,
137 | )
138 | else:
139 | interpolated = interpolate_binary_mask(
140 | self.label_manager.selected_layer.data
141 | )
142 |
143 | self.label_manager.selected_layer = self.viewer.add_labels(
144 | interpolated,
145 | name=self.label_manager.selected_layer.name + "_interpolated",
146 | scale=self.label_manager.selected_layer.scale,
147 | )
148 |
--------------------------------------------------------------------------------
/src/napari_segmentation_correction/layer_controls.py:
--------------------------------------------------------------------------------
1 | import napari
2 | from napari_plane_sliders import PlaneSliderWidget
3 | from qtpy.QtWidgets import (
4 | QGroupBox,
5 | QVBoxLayout,
6 | QWidget,
7 | )
8 |
9 | from .copy_label_widget import CopyLabelWidget
10 | from .layer_manager import LayerManager
11 | from .save_labels_widget import SaveLabelsWidget
12 |
13 |
14 | class LayerControlsWidget(QWidget):
15 | """Widget showing region props as a table and plot widget"""
16 |
17 | def __init__(
18 | self, viewer: "napari.viewer.Viewer", label_manager: LayerManager
19 | ) -> None:
20 | super().__init__()
21 |
22 | self.viewer = viewer
23 | self.label_manager = label_manager
24 |
25 | layout = QVBoxLayout()
26 |
27 | ### create the dropdown for selecting label images
28 | layout.addWidget(self.label_manager)
29 |
30 | ### plane sliders
31 | plane_slider_box = QGroupBox("Plane Sliders")
32 | plane_slider_layout = QVBoxLayout()
33 | self.plane_sliders = PlaneSliderWidget(self.viewer)
34 | plane_slider_layout.addWidget(self.plane_sliders)
35 | plane_slider_box.setLayout(plane_slider_layout)
36 | layout.addWidget(plane_slider_box)
37 |
38 | ### Add widget for copy-pasting labels from one layer to another
39 | self.copy_label_widget = CopyLabelWidget(self.viewer)
40 | layout.addWidget(self.copy_label_widget)
41 |
42 | ### Add widget to save labels
43 | save_labels = SaveLabelsWidget(self.viewer, self.label_manager)
44 | layout.addWidget(save_labels)
45 |
46 | self.setLayout(layout)
47 |
--------------------------------------------------------------------------------
/src/napari_segmentation_correction/layer_dropdown.py:
--------------------------------------------------------------------------------
1 | import napari
2 | from napari.utils.events import Event
3 | from PyQt5.QtCore import pyqtSignal
4 | from qtpy.QtWidgets import QComboBox
5 |
6 |
7 | class LayerDropdown(QComboBox):
8 | """QComboBox widget with functions for updating the selected layer and to update the list of options when the list of layers is modified."""
9 |
10 | layer_changed = pyqtSignal(str) # signal to emit the selected layer name
11 |
12 | def __init__(
13 | self, viewer: napari.Viewer, layer_type: tuple, allow_none: bool = False
14 | ):
15 | super().__init__()
16 | self.viewer = viewer
17 | self.layer_type = layer_type
18 | self.allow_none = allow_none
19 | self.selected_layer = None
20 | self.viewer.layers.events.inserted.connect(self._on_insert)
21 | self.viewer.layers.events.changed.connect(self._update_dropdown)
22 | self.viewer.layers.events.removed.connect(self._update_dropdown)
23 | self.viewer.layers.selection.events.changed.connect(self._on_selection_changed)
24 | self.currentTextChanged.connect(self._emit_layer_changed)
25 | self._update_dropdown()
26 |
27 | def _on_insert(self, event) -> None:
28 | """Update dropdown and make new layer responsive to name changes"""
29 |
30 | layer = event.value
31 | if isinstance(layer, self.layer_type):
32 |
33 | @layer.events.name.connect
34 | def _on_rename(name_event):
35 | self._update_dropdown()
36 |
37 | self._update_dropdown()
38 |
39 | def _on_selection_changed(self) -> None:
40 | """Request signal emission if the user changes the layer selection."""
41 |
42 | if (
43 | len(self.viewer.layers.selection) == 1
44 | ): # Only consider single layer selection
45 | selected_layer = self.viewer.layers.selection.active
46 | if (
47 | isinstance(selected_layer, self.layer_type)
48 | and selected_layer != self.selected_layer
49 | ):
50 | self.setCurrentText(selected_layer.name)
51 | self._emit_layer_changed()
52 |
53 | def _update_dropdown(self, event: Event | None = None) -> None:
54 | """Update the list of options in the dropdown menu whenever the list of layers is changed"""
55 |
56 | if (
57 | event is None
58 | or not hasattr(event, "value")
59 | or isinstance(event.value, self.layer_type)
60 | ):
61 | selected_layer = self.currentText()
62 | self.clear()
63 | layers = [
64 | layer
65 | for layer in self.viewer.layers
66 | if isinstance(layer, self.layer_type)
67 | ]
68 | items = []
69 | if self.allow_none:
70 | self.addItem("No selection")
71 | items.append("No selection")
72 |
73 | for layer in layers:
74 | self.addItem(layer.name)
75 | items.append(layer.name)
76 | layer.events.name.connect(self._update_dropdown)
77 |
78 | # In case the currently selected layer is one of the available items, set it again to the current value of the dropdown.
79 | if selected_layer in items:
80 | self.setCurrentText(selected_layer)
81 |
82 | def _emit_layer_changed(self) -> None:
83 | """Emit a signal holding the currently selected layer"""
84 |
85 | selected_layer_name = self.currentText()
86 | if (
87 | selected_layer_name != "No selection"
88 | and selected_layer_name in self.viewer.layers
89 | ):
90 | self.selected_layer = self.viewer.layers[selected_layer_name]
91 | else:
92 | self.selected_layer = None
93 | selected_layer_name = ""
94 |
95 | self.layer_changed.emit(selected_layer_name)
96 |
97 | def deleteLater(self) -> None:
98 | """Ensure all connections are disconnected before deletion."""
99 | self.viewer.layers.events.inserted.disconnect(self._on_insert)
100 | self.viewer.layers.events.changed.disconnect(self._update_dropdown)
101 | self.viewer.layers.events.removed.disconnect(self._update_dropdown)
102 | self.viewer.layers.selection.events.changed.disconnect(
103 | self._on_selection_changed
104 | )
105 | self.currentTextChanged.disconnect(self._emit_layer_changed)
106 | super().deleteLater()
107 |
--------------------------------------------------------------------------------
/src/napari_segmentation_correction/layer_manager.py:
--------------------------------------------------------------------------------
1 | import dask.array as da
2 | import napari
3 | import numpy as np
4 | from psygnal import Signal
5 | from qtpy.QtWidgets import (
6 | QPushButton,
7 | QVBoxLayout,
8 | QWidget,
9 | )
10 |
11 | from .layer_dropdown import LayerDropdown
12 |
13 |
14 | class LayerManager(QWidget):
15 | """QComboBox widget with functions for updating the selected layer and to update the list of options when the list of layers is modified."""
16 |
17 | layer_update = Signal()
18 |
19 | def __init__(self, viewer: napari.Viewer):
20 | super().__init__()
21 |
22 | self.viewer = viewer
23 | self._selected_layer = None
24 | self.label_dropdown = LayerDropdown(self.viewer, (napari.layers.Labels))
25 | self.label_dropdown.layer_changed.connect(self.set_active_layer)
26 |
27 | ### Add option to convert dask array to in-memory array
28 | self.convert_to_array_btn = QPushButton("Convert to in-memory array")
29 | self.convert_to_array_btn.setEnabled(
30 | self.selected_layer is not None
31 | and isinstance(self.selected_layer.data, da.core.Array)
32 | )
33 | self.convert_to_array_btn.clicked.connect(self._convert_to_array)
34 |
35 | layout = QVBoxLayout()
36 | layout.addWidget(self.label_dropdown)
37 | layout.addWidget(self.convert_to_array_btn)
38 | self.setLayout(layout)
39 |
40 | # Send a signal on the layer dropdown to update the active layer
41 | self.label_dropdown._emit_layer_changed()
42 |
43 | @property
44 | def selected_layer(self):
45 | return self._selected_layer
46 |
47 | @selected_layer.setter
48 | def selected_layer(self, layer):
49 | if layer != self._selected_layer:
50 | self._selected_layer = layer
51 |
52 | def set_active_layer(self, selected_layer) -> None:
53 | """Update the layer that is set to be the 'labels' layer that is being edited."""
54 |
55 | if selected_layer == "":
56 | self._selected_layer = None
57 | else:
58 | self.selected_layer = self.viewer.layers[selected_layer]
59 | self.label_dropdown.setCurrentText(selected_layer)
60 | self.convert_to_array_btn.setEnabled(
61 | isinstance(self._selected_layer.data, da.core.Array)
62 | )
63 |
64 | self.layer_update.emit()
65 |
66 | def _convert_to_array(self) -> None:
67 | """Convert from dask array to in-memory array. This is necessary for manual editing using the label tools (brush, eraser, fill bucket)."""
68 |
69 | if isinstance(self._selected_layer.data, da.core.Array):
70 | stack = []
71 | for i in range(self._selected_layer.data.shape[0]):
72 | current_stack = self._selected_layer.data[i].compute()
73 | stack.append(current_stack)
74 | self._selected_layer.data = np.stack(stack, axis=0)
75 | self.convert_to_array_btn.setEnabled(False)
76 |
--------------------------------------------------------------------------------
/src/napari_segmentation_correction/napari.yaml:
--------------------------------------------------------------------------------
1 | name: napari-segmentation-correction
2 | display_name: Manual Segmentation Correction
3 | visibility: public
4 | categories: ["Annotation", "Segmentation"]
5 | contributions:
6 | commands:
7 | - id: napari-segmentation-correction.annotateND
8 | python_name: napari_segmentation_correction._widget:AnnotateLabelsND
9 | title: Manual Segmentation Correction
10 | widgets:
11 | - command: napari-segmentation-correction.annotateND
12 | display_name: Manual Segmentation Correction
13 |
--------------------------------------------------------------------------------
/src/napari_segmentation_correction/plot_widget.py:
--------------------------------------------------------------------------------
1 | import os
2 | from pathlib import Path
3 |
4 | import matplotlib.pyplot as plt
5 | from matplotlib.backends.backend_qt5agg import (
6 | FigureCanvas,
7 | NavigationToolbar2QT,
8 | )
9 | from matplotlib.colors import ListedColormap, to_rgb
10 | from qtpy.QtCore import Qt
11 | from qtpy.QtGui import QIcon
12 | from qtpy.QtWidgets import QComboBox, QHBoxLayout, QLabel, QVBoxLayout, QWidget
13 |
14 | from .layer_manager import LayerManager
15 |
16 | ICON_ROOT = Path(__file__).parent / "icons"
17 |
18 |
19 | class PlotWidget(QWidget):
20 | """Matplotlib widget that displays features of the selected labels layer"""
21 |
22 | def __init__(self, label_manager: LayerManager):
23 | super().__init__()
24 |
25 | self.label_manager = label_manager
26 | self.label_manager.layer_update.connect(self._layer_update)
27 |
28 | # Main plot.
29 | self.fig = plt.figure(constrained_layout=True)
30 | self.plot_canvas = FigureCanvas(self.fig)
31 | self.ax = self.plot_canvas.figure.subplots()
32 | self.toolbar = NavigationToolbar2QT(self.plot_canvas)
33 |
34 | # Specify plot customizations.
35 | self.fig.patch.set_facecolor("#262930")
36 | self.ax.tick_params(colors="white")
37 | self.ax.set_facecolor("#262930")
38 | self.ax.xaxis.label.set_color("white")
39 | self.ax.yaxis.label.set_color("white")
40 | self.ax.spines["bottom"].set_color("white")
41 | self.ax.spines["top"].set_color("white")
42 | self.ax.spines["right"].set_color("white")
43 | self.ax.spines["left"].set_color("white")
44 | for action_name in self.toolbar._actions:
45 | action = self.toolbar._actions[action_name]
46 | icon_path = os.path.join(ICON_ROOT, action_name + ".png")
47 | action.setIcon(QIcon(icon_path))
48 |
49 | # Create a dropdown window for selecting what to plot on the axes.
50 | x_axis_layout = QHBoxLayout()
51 | self.x_combo = QComboBox()
52 | x_axis_layout.addWidget(QLabel("x-axis"))
53 | x_axis_layout.addWidget(self.x_combo)
54 |
55 | y_axis_layout = QHBoxLayout()
56 | self.y_combo = QComboBox()
57 | y_axis_layout.addWidget(QLabel("y-axis"))
58 | y_axis_layout.addWidget(self.y_combo)
59 |
60 | color_group_layout = QHBoxLayout()
61 | self.group_combo = QComboBox()
62 | color_group_layout.addWidget(QLabel("Group color"))
63 | color_group_layout.addWidget(self.group_combo)
64 |
65 | self.x_combo.currentIndexChanged.connect(self._update_plot)
66 | self.y_combo.currentIndexChanged.connect(self._update_plot)
67 | self.group_combo.currentIndexChanged.connect(self._update_plot)
68 |
69 | dropdown_layout = QVBoxLayout()
70 | dropdown_layout.addLayout(x_axis_layout)
71 | dropdown_layout.addLayout(y_axis_layout)
72 | dropdown_layout.addLayout(color_group_layout)
73 | dropdown_widget = QWidget()
74 | dropdown_widget.setLayout(dropdown_layout)
75 | dropdown_layout.setAlignment(Qt.AlignTop)
76 |
77 | # Create and apply a horizontal layout for the dropdown widget, toolbar and canvas.
78 | plotting_layout = QVBoxLayout()
79 | plotting_layout.addWidget(dropdown_widget)
80 | plotting_layout.addWidget(self.toolbar)
81 | plotting_layout.addWidget(self.plot_canvas)
82 | plotting_layout.setAlignment(Qt.AlignTop)
83 |
84 | self.setLayout(plotting_layout)
85 | self.setMinimumHeight(500)
86 |
87 | def _layer_update(self) -> None:
88 | """Connect events to plot updates"""
89 |
90 | if self.label_manager.selected_layer is not None:
91 | self.label_manager.selected_layer.events.show_selected_label.connect(
92 | self._update_plot
93 | )
94 | self.label_manager.selected_layer.events.selected_label.connect(
95 | self._update_plot
96 | )
97 | self.label_manager.selected_layer.events.features.connect(
98 | self._update_dropdown
99 | )
100 |
101 | self._update_dropdown()
102 | self._update_plot()
103 |
104 | def _update_dropdown(self) -> None:
105 | """Update the dropdowns with the column headers"""
106 |
107 | if (
108 | self.label_manager.selected_layer is not None
109 | and len(self.label_manager.selected_layer.features) > 0
110 | ):
111 | # temporarily disconnect listening to updates in the comboboxes
112 | self.x_combo.currentIndexChanged.disconnect(self._update_plot)
113 | self.y_combo.currentIndexChanged.disconnect(self._update_plot)
114 | self.group_combo.currentIndexChanged.disconnect(self._update_plot)
115 |
116 | prev_index = self.x_combo.currentIndex() if self.x_combo.count() > 0 else 0
117 | self.x_combo.clear()
118 | self.x_combo.addItems(
119 | [
120 | item
121 | for item in self.label_manager.selected_layer.features.columns
122 | if item != "index"
123 | ]
124 | )
125 | self.x_combo.setCurrentIndex(prev_index)
126 |
127 | prev_index = self.y_combo.currentIndex() if self.y_combo.count() > 0 else 1
128 | self.y_combo.clear()
129 | self.y_combo.addItems(
130 | [
131 | item
132 | for item in self.label_manager.selected_layer.features.columns
133 | if item != "index"
134 | ]
135 | )
136 | self.y_combo.setCurrentIndex(prev_index)
137 |
138 | prev_index = (
139 | self.group_combo.currentIndex() if self.group_combo.count() > 0 else 0
140 | )
141 | self.group_combo.clear()
142 | self.group_combo.addItems(
143 | [
144 | item
145 | for item in self.label_manager.selected_layer.features.columns
146 | if item != "index"
147 | ]
148 | )
149 | self.group_combo.setCurrentIndex(prev_index)
150 |
151 | # reconnect to updates in the comboboxes
152 | self.x_combo.currentIndexChanged.connect(self._update_plot)
153 | self.y_combo.currentIndexChanged.connect(self._update_plot)
154 | self.group_combo.currentIndexChanged.connect(self._update_plot)
155 |
156 | self._update_plot()
157 |
158 | def _update_plot(self) -> None:
159 | """Update the plot by plotting the features selected by the user."""
160 |
161 | x_axis_property = self.x_combo.currentText()
162 | y_axis_property = self.y_combo.currentText()
163 | group = self.group_combo.currentText()
164 |
165 | # Clear data points, and reset the axis scaling and labels.
166 | for artist in self.ax.lines + self.ax.collections:
167 | artist.remove()
168 | self.ax.set_xlabel(x_axis_property)
169 | self.ax.set_ylabel(y_axis_property)
170 | self.ax.relim() # Recalculate limits for the current data
171 | self.ax.autoscale_view() # Update the view to include the new limits
172 |
173 | if (
174 | self.label_manager.selected_layer is not None
175 | and len(self.label_manager.selected_layer.features) > 0
176 | and (x_axis_property != "" and y_axis_property != "" and group != "")
177 | ):
178 | if group == "label":
179 | if self.label_manager.selected_layer.show_selected_label:
180 | label = self.label_manager.selected_layer.selected_label
181 | plotting_data = self.label_manager.selected_layer.features[
182 | self.label_manager.selected_layer.features["label"] == label
183 | ]
184 | unique_labels = [label]
185 | else:
186 | plotting_data = self.label_manager.selected_layer.features
187 | unique_labels = plotting_data["label"].unique()
188 |
189 | # Create consistent label-to-color mapping
190 | colormap = self.label_manager.selected_layer.colormap
191 | label_colors = {
192 | label: to_rgb(colormap.map(label)) for label in unique_labels
193 | }
194 |
195 | # Scatter plot
196 | self.ax.scatter(
197 | plotting_data[x_axis_property],
198 | plotting_data[y_axis_property],
199 | c=plotting_data["label"].map(label_colors),
200 | cmap=ListedColormap(
201 | list(label_colors.values())
202 | ), # Consistent colormap
203 | s=10,
204 | )
205 |
206 | # Line plot for time_point x-axis
207 | if x_axis_property == "time_point":
208 | for label, color in label_colors.items():
209 | label_data = plotting_data[
210 | plotting_data["label"] == label
211 | ].sort_values(by=x_axis_property)
212 | self.ax.plot(
213 | label_data[x_axis_property],
214 | label_data[y_axis_property],
215 | linestyle="-",
216 | color=color,
217 | linewidth=1,
218 | )
219 |
220 | else:
221 | # Continuous colormap for other grouping
222 | self.ax.scatter(
223 | self.label_manager.selected_layer.features[x_axis_property],
224 | self.label_manager.selected_layer.features[y_axis_property],
225 | c=self.label_manager.selected_layer.features[group],
226 | cmap="summer",
227 | s=10,
228 | )
229 |
230 | self.plot_canvas.draw()
231 |
--------------------------------------------------------------------------------
/src/napari_segmentation_correction/prop_filter_widget.py:
--------------------------------------------------------------------------------
1 | import functools
2 | import os
3 | import shutil
4 | from warnings import warn
5 |
6 | import dask.array as da
7 | import napari
8 | import numpy as np
9 | import pandas as pd
10 | import tifffile
11 | from qtpy.QtWidgets import (
12 | QComboBox,
13 | QDoubleSpinBox,
14 | QFileDialog,
15 | QGroupBox,
16 | QHBoxLayout,
17 | QPushButton,
18 | QVBoxLayout,
19 | QWidget,
20 | )
21 | from skimage.io import imread
22 |
23 | from .layer_manager import LayerManager
24 |
25 |
26 | class PropertyFilterWidget(QWidget):
27 | """Widget to filter objects by numerical property"""
28 |
29 | def __init__(
30 | self, viewer: "napari.viewer.Viewer", label_manager: LayerManager
31 | ) -> None:
32 | super().__init__()
33 |
34 | self.viewer = viewer
35 | self.label_manager = label_manager
36 | self.outputdir = None
37 |
38 | filterbox = QGroupBox("Filter objects by property")
39 | filter_layout = QHBoxLayout()
40 |
41 | self.property = QComboBox()
42 | self.property.currentIndexChanged.connect(self.update_min_max_value)
43 | self.label_manager.layer_update.connect(self.set_properties)
44 | self.operation = QComboBox()
45 | self.operation.addItems([">", "<", ">=", "<="])
46 | self.operation.setToolTip("Operation to apply for filtering")
47 | self.value = QDoubleSpinBox()
48 | self.value.setValue(0.0)
49 | self.value.setSingleStep(0.1)
50 | self.value.setToolTip("Threshold value for the selected property")
51 | self.keep_delete = QComboBox()
52 | self.keep_delete.addItems(["Keep", "Delete"])
53 | self.keep_delete.setToolTip(
54 | "Choose whether to keep or delete objects matching the criteria"
55 | )
56 | self.run_btn = QPushButton("Run")
57 | self.run_btn.clicked.connect(self.filter_by_property)
58 |
59 | filter_layout.addWidget(self.property)
60 | filter_layout.addWidget(self.value)
61 | filter_layout.addWidget(self.operation)
62 | filter_layout.addWidget(self.keep_delete)
63 |
64 | main_layout = QVBoxLayout()
65 | main_layout.addLayout(filter_layout)
66 | main_layout.addWidget(self.run_btn)
67 |
68 | filterbox.setLayout(main_layout)
69 |
70 | layout = QVBoxLayout()
71 | layout.addWidget(filterbox)
72 | self.setLayout(layout)
73 |
74 | def set_properties(self) -> None:
75 | """Set available properties for the selected layer"""
76 | current_prop = self.property.currentText()
77 | if self.label_manager.selected_layer is not None:
78 | props = list(self.label_manager.selected_layer.properties.keys())
79 | self.property.clear()
80 | self.property.addItems(
81 | [p for p in props if p not in ("label", "time_point")]
82 | )
83 | if current_prop in props:
84 | self.property.setCurrentText(current_prop)
85 | self.run_btn.setEnabled(True) if len(
86 | props
87 | ) > 0 else self.run_btn.setEnabled(False)
88 | else:
89 | self.run_btn.setEnabled(False)
90 |
91 | def update_min_max_value(self) -> None:
92 | """Update min and max values for the threshold spinbox based on selected property"""
93 | prop = self.property.currentText()
94 | if prop in self.label_manager.selected_layer.properties:
95 | values = self.label_manager.selected_layer.properties[prop]
96 | self.value.setMinimum(np.min(values))
97 | self.value.setMaximum(np.max(values))
98 | if self.value.value() < np.min(values) or self.value.value() > np.max(
99 | values
100 | ):
101 | self.value.setValue(np.min(values))
102 |
103 | def filter_by_property(self) -> None:
104 | """Filter objects by selected property and threshold value"""
105 |
106 | prop = self.property.currentText()
107 | value = self.value.value()
108 | operation = self.operation.currentText()
109 | keep_delete = self.keep_delete.currentText()
110 |
111 | if self.label_manager.selected_layer is not None:
112 | if isinstance(self.label_manager.selected_layer.data, da.core.Array):
113 | if self.outputdir is None:
114 | self.outputdir = QFileDialog.getExistingDirectory(
115 | self, "Select Output Folder"
116 | )
117 |
118 | outputdir = os.path.join(
119 | self.outputdir,
120 | (self.label_manager.selected_layer.name + "_filtered"),
121 | )
122 | if os.path.exists(outputdir):
123 | shutil.rmtree(outputdir)
124 | os.mkdir(outputdir)
125 |
126 | df = pd.DataFrame(self.label_manager.selected_layer.properties)
127 | if "time_point" in df.columns and prop in df.columns:
128 | filtered_label_imgs = []
129 | for time_point in range(
130 | self.label_manager.selected_layer.data.shape[0]
131 | ):
132 | df_subset = df.loc[df["time_point"] == time_point]
133 |
134 | if operation == ">":
135 | filtered_labels = df_subset.loc[
136 | df_subset[prop] > value, "label"
137 | ]
138 | elif operation == "<":
139 | filtered_labels = df_subset.loc[
140 | df_subset[prop] < value, "label"
141 | ]
142 | elif operation == "<=":
143 | filtered_labels = df_subset.loc[
144 | df_subset[prop] <= value, "label"
145 | ]
146 | elif operation == ">=":
147 | filtered_labels = df_subset.loc[
148 | df_subset[prop] >= value, "label"
149 | ]
150 |
151 | if isinstance(
152 | self.label_manager.selected_layer.data, da.core.Array
153 | ):
154 | labels = np.array(
155 | self.label_manager.selected_layer.data[time_point].compute()
156 | )
157 | else:
158 | labels = np.array(
159 | self.label_manager.selected_layer.data[time_point]
160 | )
161 |
162 | if len(filtered_labels) == 0:
163 | mask = np.zeros_like(labels, dtype=bool)
164 | else:
165 | mask = functools.reduce(
166 | np.logical_or, (labels == val for val in filtered_labels)
167 | )
168 | if keep_delete == "Delete":
169 | new_labels = np.where(~mask, labels, 0)
170 | else:
171 | new_labels = np.where(mask, labels, 0)
172 |
173 | if isinstance(
174 | self.label_manager.selected_layer.data, da.core.Array
175 | ):
176 | tifffile.imwrite(
177 | os.path.join(
178 | outputdir,
179 | (
180 | self.label_manager.selected_layer.name
181 | + "_filtered_TP"
182 | + str(time_point).zfill(4)
183 | + ".tif"
184 | ),
185 | ),
186 | np.array(new_labels, dtype="uint16"),
187 | )
188 | else:
189 | filtered_label_imgs.append(new_labels)
190 |
191 | if isinstance(self.label_manager.selected_layer.data, da.core.Array):
192 | file_list = [
193 | os.path.join(outputdir, fname)
194 | for fname in os.listdir(outputdir)
195 | if fname.endswith(".tif")
196 | ]
197 | self.label_manager.selected_layer = self.viewer.add_labels(
198 | da.stack([imread(fname) for fname in sorted(file_list)]),
199 | name=self.label_manager.selected_layer.name + "_filtered",
200 | scale=self.label_manager.selected_layer.scale,
201 | )
202 | else:
203 | result = np.stack(filtered_label_imgs)
204 | self.label_manager.selected_layer = self.viewer.add_labels(
205 | result,
206 | name=self.label_manager.selected_layer.name + "_filtered",
207 | scale=self.label_manager.selected_layer.scale,
208 | )
209 |
210 | elif prop in df.columns:
211 | if operation == ">":
212 | filtered_labels = df.loc[df[prop] > value, "label"]
213 | elif operation == "<":
214 | filtered_labels = df.loc[df[prop] < value, "label"]
215 | elif operation == "<=":
216 | filtered_labels = df.loc[df[prop] <= value, "label"]
217 | elif operation == ">=":
218 | filtered_labels = df.loc[df[prop] >= value, "label"]
219 |
220 | labels = np.array(self.label_manager.selected_layer.data)
221 | mask = functools.reduce(
222 | np.logical_or,
223 | (labels == val for val in filtered_labels),
224 | )
225 | if keep_delete == "Delete":
226 | result = np.where(~mask, labels, 0)
227 | else:
228 | result = np.where(mask, labels, 0)
229 |
230 | self.label_manager.selected_layer = self.viewer.add_labels(
231 | result,
232 | name=self.label_manager.selected_layer.name + "_filtered",
233 | scale=self.label_manager.selected_layer.scale,
234 | )
235 |
236 | else:
237 | warn(f"Property {prop} not found in layer properties", stacklevel=2)
238 | return
239 |
--------------------------------------------------------------------------------
/src/napari_segmentation_correction/regionprops_extended.py:
--------------------------------------------------------------------------------
1 | import math
2 |
3 | import numpy as np
4 | import pandas as pd
5 | from skimage.measure import marching_cubes, mesh_surface_area, regionprops
6 | from skimage.measure._regionprops import RegionProperties
7 |
8 |
9 | class ExtendedRegionProperties(RegionProperties):
10 | """Adding additional properties to skimage.measure._regionprops following the logic from the porespy package with some modifications to include the spacing information."""
11 |
12 | @property
13 | def ellipse_axes(self):
14 | return (self.axis_major_length, self.axis_minor_length)
15 |
16 | @property
17 | def ellipsoid_axes(self):
18 | """
19 | Calculate the three axes radii of the fitted ellipsoid.
20 |
21 | This method calculates the principal axes of inertia for the label region, which are used to determine the lengths of the axes radii of an ellipsoid that approximates the shape of the region.
22 |
23 | Returns:
24 | tuple: A tuple containing the lengths of the three principal axes radii (longr, midr, shortr).
25 |
26 | """
27 |
28 | # Extract the coordinates of the region's voxels
29 | cell = np.where(self.image)
30 | voxel_count = self.voxel_count
31 |
32 | z, y, x = cell
33 |
34 | # Center the coordinates and apply spacing
35 | z = (z - np.mean(z)) * self._spacing[0]
36 | y = (y - np.mean(y)) * self._spacing[1]
37 | x = (x - np.mean(x)) * self._spacing[2]
38 |
39 | # Calculate the elements of the inertia tensor
40 | i_xx = np.sum(y**2 + z**2)
41 | i_yy = np.sum(x**2 + z**2)
42 | i_zz = np.sum(x**2 + y**2)
43 | i_xy = np.sum(x * y)
44 | i_xz = np.sum(x * z)
45 | i_yz = np.sum(y * z)
46 | i = np.array([[i_xx, -i_xy, -i_xz], [-i_xy, i_yy, -i_yz], [-i_xz, -i_yz, i_zz]])
47 |
48 | # Compute the eigenvalues and eigenvectors of the inertia tensor. The eigenvalues of the inertia tensor represent the principal moments of inertia, and the eigenvectors represent the directions of the principal axes.
49 | eig = np.linalg.eig(i)
50 | eigval = eig[0]
51 |
52 | # Identify the principal axes
53 | longaxis = np.where(np.min(eigval) == eigval)[0][0]
54 | shortaxis = np.where(np.max(eigval) == eigval)[0][0]
55 | midaxis = (
56 | 0
57 | if shortaxis != 0 and longaxis != 0
58 | else 1
59 | if shortaxis != 1 and longaxis != 1
60 | else 2
61 | )
62 |
63 | # Calculate the lengths of the principal axes
64 | longr = math.sqrt(
65 | 5.0
66 | / 2.0
67 | * (eigval[midaxis] + eigval[shortaxis] - eigval[longaxis])
68 | / voxel_count
69 | )
70 | midr = math.sqrt(
71 | 5.0
72 | / 2.0
73 | * (eigval[shortaxis] + eigval[longaxis] - eigval[midaxis])
74 | / voxel_count
75 | )
76 | shortr = math.sqrt(
77 | 5.0
78 | / 2.0
79 | * (eigval[longaxis] + eigval[midaxis] - eigval[shortaxis])
80 | / voxel_count
81 | )
82 |
83 | return (longr, midr, shortr)
84 |
85 | @property
86 | def circularity(self):
87 | """
88 | Calculate the circularity of the region.
89 |
90 | Circularity is defined as 4π * (Area / Perimeter^2).
91 |
92 | Returns:
93 | float: The circularity of the region.
94 | """
95 | return 4 * math.pi * self.area / self.perimeter**2
96 |
97 | @property
98 | def pixel_count(self):
99 | """
100 | Get the number of pixels in the region.
101 |
102 | Returns:
103 | int: The number of pixels in the region.
104 | """
105 | return self.voxel_count
106 |
107 | @property
108 | def surface_area(self):
109 | """
110 | Calculate the surface area of the region using the marching cubes algorithm.
111 |
112 | The marching cubes algorithm is used to extract a 2D surface mesh from a 3D volume.
113 | The mesh surface area is calculated with skimage.measure.mesh_surface_area.
114 |
115 | Returns:
116 | float: The surface area of the region.
117 | """
118 | verts, faces, _, _ = marching_cubes(
119 | self.image, level=0.5, spacing=self._spacing
120 | )
121 | surface_area = mesh_surface_area(verts, faces)
122 | return surface_area
123 |
124 | @property
125 | def sphericity(self):
126 | """
127 | Calculate the sphericity of the region.
128 |
129 | Sphericity is defined as the ratio of the surface area of a sphere with the same volume as the region to the surface area of the region.
130 |
131 | Returns:
132 | float: The sphericity of the region.
133 | """
134 | vol = self.volume
135 | r = (3 / 4 / np.pi * vol) ** (1 / 3)
136 | a_equiv = 4 * np.pi * r**2
137 | a_region = self.surface_area
138 | return a_equiv / a_region
139 |
140 | @property
141 | def volume(self):
142 | """
143 | Calculate the volume of the region.
144 |
145 | The volume is calculated as the number of voxels in the region multiplied by the product of the spacing in each dimension.
146 |
147 | Returns:
148 | float: The volume of the region.
149 | """
150 | vol = np.sum(self.image) * np.prod(self._spacing)
151 | return vol
152 |
153 | @property
154 | def voxel_count(self):
155 | """
156 | Get the number of voxels in the region.
157 |
158 | Returns:
159 | int: The number of voxels in the region.
160 | """
161 | voxel_count = int(np.sum(self.image))
162 | return voxel_count
163 |
164 |
165 | def regionprops_extended(
166 | img: np.ndarray, spacing: tuple[float], intensity_image: np.ndarray | None = None
167 | ) -> list[ExtendedRegionProperties]:
168 | """
169 | Create instances of ExtendedRegionProperties that extend skimage.measure.RegionProperties.
170 |
171 | Args:
172 | img (np.ndarray): The labeled image.
173 | spacing (tuple[float]): The spacing between voxels in each dimension.
174 | intensity_image (np.ndarray, optional): The intensity image.
175 |
176 | Returns:
177 | list[ExtendedRegionProperties]: A list of ExtendedRegionProperties instances.
178 | """
179 | results = regionprops(img, intensity_image=intensity_image, spacing=spacing)
180 | for i, _ in enumerate(results):
181 | a = results[i]
182 | b = ExtendedRegionProperties(
183 | slice=a.slice,
184 | label=a.label,
185 | label_image=a._label_image,
186 | intensity_image=a._intensity_image,
187 | cache_active=a._cache_active,
188 | spacing=a._spacing,
189 | )
190 | results[i] = b
191 |
192 | return results
193 |
194 |
195 | def props_to_dataframe(regionprops, selected_properties=None) -> pd.DataFrame:
196 | """Convert ExtendedRegionProperties instance to pandas dataframe, following the logical from porespy.metrics._regionprops.props_to_dataframe"""
197 |
198 | if selected_properties is None:
199 | selected_properties = regionprops[0].__dict__()
200 |
201 | new_props = ["label"]
202 | # need to check if any of the props return multiple values
203 | for item in selected_properties:
204 | if isinstance(getattr(regionprops[0], item), tuple):
205 | for i, _ in enumerate(getattr(regionprops[0], item)):
206 | new_props.append(item + "-" + str(i + 1))
207 | else:
208 | new_props.append(item)
209 |
210 | # get the measurements for all properties of interest
211 | d = {}
212 | for k in new_props:
213 | if "-" in k:
214 | # If property is a tuple, extract the tuple element
215 | prop_name, idx = k.split("-")
216 | idx = int(idx) - 1
217 | d[k] = np.array([getattr(r, prop_name)[idx] for r in regionprops])
218 | else:
219 | d[k] = np.array([getattr(r, k) for r in regionprops])
220 |
221 | # Create pandas data frame an return
222 | df = pd.DataFrame(d)
223 |
224 | return df
225 |
226 |
227 | def calculate_extended_props(
228 | image: np.ndarray,
229 | properties: list[str],
230 | spacing: list[float],
231 | intensity_image: np.ndarray | None = None,
232 | ) -> pd.DataFrame:
233 | """Create regionproperties, and convert to pandas dataframe"""
234 |
235 | props = regionprops_extended(
236 | image, spacing=spacing, intensity_image=intensity_image
237 | )
238 | return props_to_dataframe(props, properties) if len(props) > 0 else pd.DataFrame()
239 |
--------------------------------------------------------------------------------
/src/napari_segmentation_correction/regionprops_widget.py:
--------------------------------------------------------------------------------
1 | import dask.array as da
2 | import napari
3 | import numpy as np
4 | import pandas as pd
5 | from qtpy.QtCore import Qt
6 | from qtpy.QtWidgets import (
7 | QCheckBox,
8 | QDoubleSpinBox,
9 | QGridLayout,
10 | QGroupBox,
11 | QHBoxLayout,
12 | QLabel,
13 | QMessageBox,
14 | QPushButton,
15 | QSizePolicy,
16 | QVBoxLayout,
17 | QWidget,
18 | )
19 | from tqdm import tqdm
20 |
21 | from .custom_table_widget import ColoredTableWidget
22 | from .layer_dropdown import LayerDropdown
23 | from .layer_manager import LayerManager
24 | from .prop_filter_widget import PropertyFilterWidget
25 | from .regionprops_extended import calculate_extended_props
26 |
27 |
28 | class RegionPropsWidget(QWidget):
29 | """Widget showing region props as a table and plot widget"""
30 |
31 | def __init__(
32 | self, viewer: "napari.viewer.Viewer", label_manager: LayerManager
33 | ) -> None:
34 | super().__init__()
35 |
36 | self.viewer = viewer
37 | self.label_manager = label_manager
38 | self.label_manager.layer_update.connect(self.update_dims)
39 | self.table = None
40 | self.ndims = 2
41 | self.feature_dims = 2
42 | self.axis_widgets = []
43 |
44 | dim_box = QGroupBox("Dimensions")
45 | grid = QGridLayout()
46 |
47 | # headers
48 | grid.addWidget(QLabel("Axis"), 0, 0)
49 | grid.addWidget(QLabel("Index"), 0, 1)
50 | grid.addWidget(QLabel("Pixel scaling"), 0, 2)
51 |
52 | # Z row
53 | self.z_label = QLabel("Z")
54 | self.z_label.setVisible(False)
55 | self.z_axis = QLabel("0")
56 | self.z_axis.setVisible(False)
57 | self.axis_widgets.append(self.z_axis)
58 | self.z_scale = QDoubleSpinBox()
59 | self.z_scale.setValue(1.0)
60 | self.z_scale.setSingleStep(0.1)
61 | self.z_scale.setMinimum(0.01)
62 | self.z_scale.setToolTip("Voxel size along Z axis")
63 | self.z_scale.setVisible(False)
64 | self.z_scale.setDecimals(3)
65 |
66 | grid.addWidget(self.z_label, 1, 0)
67 | grid.addWidget(self.z_axis, 1, 1)
68 | grid.addWidget(self.z_scale, 1, 2)
69 |
70 | # Y row
71 | self.y_label = QLabel("Y")
72 | self.y_axis = QLabel("1")
73 | self.axis_widgets.append(self.y_axis)
74 | self.y_scale = QDoubleSpinBox()
75 | self.y_scale.setValue(1.0)
76 | self.y_scale.setSingleStep(0.1)
77 | self.y_scale.setMinimum(0.01)
78 | self.y_scale.setToolTip("Voxel size along Y axis")
79 | self.y_scale.setDecimals(3)
80 |
81 | grid.addWidget(self.y_label, 2, 0)
82 | grid.addWidget(self.y_axis, 2, 1)
83 | grid.addWidget(self.y_scale, 2, 2)
84 |
85 | # X row
86 | self.x_label = QLabel("X")
87 | self.x_axis = QLabel("2")
88 | self.axis_widgets.append(self.x_axis)
89 | self.x_scale = QDoubleSpinBox()
90 | self.x_scale.setValue(1.0)
91 | self.x_scale.setSingleStep(0.1)
92 | self.x_scale.setMinimum(0.01)
93 | self.x_scale.setToolTip("Voxel size along X axis")
94 | self.x_scale.setDecimals(3)
95 |
96 | grid.addWidget(self.x_label, 3, 0)
97 | grid.addWidget(self.x_axis, 3, 1)
98 | grid.addWidget(self.x_scale, 3, 2)
99 |
100 | # add "use z" checkbox above grid
101 | main_layout = QVBoxLayout()
102 | self.use_z = QCheckBox("3D data (use Z axis)")
103 | self.use_z.setVisible(False)
104 | self.use_z.setEnabled(False)
105 | self.use_z.stateChanged.connect(self.update_use_z)
106 | main_layout.addWidget(self.use_z)
107 | main_layout.addLayout(grid)
108 |
109 | dim_box.setLayout(main_layout)
110 | dim_box.setMaximumHeight(200)
111 |
112 | # features widget
113 | self.feature_properties = [
114 | {
115 | "region_prop_name": "intensity_mean",
116 | "display_name": "Mean intensity",
117 | "enabled": False,
118 | "selected": False,
119 | "dims": [2, 3],
120 | },
121 | {
122 | "region_prop_name": "area",
123 | "display_name": "Area",
124 | "enabled": True,
125 | "selected": True,
126 | "dims": [2],
127 | },
128 | {
129 | "region_prop_name": "perimeter",
130 | "display_name": "Perimeter",
131 | "enabled": True,
132 | "selected": False,
133 | "dims": [2],
134 | },
135 | {
136 | "region_prop_name": "circularity",
137 | "display_name": "Circularity",
138 | "enabled": True,
139 | "selected": False,
140 | "dims": [2],
141 | },
142 | {
143 | "region_prop_name": "ellipse_axes",
144 | "display_name": "Ellipse axes",
145 | "enabled": True,
146 | "selected": False,
147 | "dims": [2],
148 | },
149 | {
150 | "region_prop_name": "volume",
151 | "display_name": "Volume",
152 | "enabled": True,
153 | "selected": True,
154 | "dims": [3],
155 | },
156 | {
157 | "region_prop_name": "surface_area",
158 | "display_name": "Surface area",
159 | "enabled": True,
160 | "selected": False,
161 | "dims": [3],
162 | },
163 | {
164 | "region_prop_name": "sphericity",
165 | "display_name": "Sphericity",
166 | "enabled": True,
167 | "selected": False,
168 | "dims": [3],
169 | },
170 | {
171 | "region_prop_name": "ellipsoid_axes",
172 | "display_name": "Ellipsoid axes",
173 | "enabled": True,
174 | "selected": False,
175 | "dims": [3],
176 | },
177 | ]
178 |
179 | feature_box = QGroupBox("Features to measure")
180 | feature_box.setMaximumHeight(250)
181 | self.checkbox_layout = QVBoxLayout()
182 | self.checkboxes = []
183 | self.intensity_image_dropdown = None
184 |
185 | feature_box.setLayout(self.checkbox_layout)
186 |
187 | # Push button to measure features
188 | self.measure_btn = QPushButton("Measure properties")
189 | self.measure_btn.clicked.connect(self.measure)
190 |
191 | ### Add widget for property filtering
192 | self.prop_filter_widget = PropertyFilterWidget(self.viewer, self.label_manager)
193 | self.prop_filter_widget.setVisible(False)
194 |
195 | # Add table layout
196 | self.regionprops_layout = QVBoxLayout()
197 |
198 | # Assemble layout
199 | main_box = QGroupBox("Region properties")
200 | main_layout = QVBoxLayout()
201 | main_layout.addWidget(dim_box)
202 | main_layout.addWidget(feature_box)
203 | main_layout.addWidget(self.measure_btn)
204 | main_layout.addWidget(self.prop_filter_widget)
205 | main_layout.addLayout(self.regionprops_layout)
206 | main_box.setLayout(main_layout)
207 |
208 | layout = QVBoxLayout()
209 | layout.addWidget(main_box)
210 |
211 | layout.setAlignment(Qt.AlignTop)
212 |
213 | self.setLayout(layout)
214 |
215 | # refresh dimensions based on current label layer, if any
216 | self.update_dims()
217 |
218 | def update_properties(self) -> None:
219 | if hasattr(self, "intensity_image_dropdown") and self.intensity_image_dropdown:
220 | self.intensity_image_dropdown.layer_changed.disconnect()
221 | self.intensity_image_dropdown.deleteLater()
222 | while self.checkbox_layout.count():
223 | item = self.checkbox_layout.takeAt(0)
224 | widget = item.widget()
225 | if widget is not None:
226 | widget.deleteLater()
227 | self.intensity_image_dropdown = None
228 | self.checkboxes = []
229 |
230 | # create checkbox for each feature
231 | self.properties = [
232 | f for f in self.feature_properties if self.feature_dims in f["dims"]
233 | ]
234 | self.checkbox_state = {
235 | prop["region_prop_name"]: prop["selected"] for prop in self.properties
236 | }
237 | for prop in self.properties:
238 | if self.feature_dims in prop["dims"]:
239 | checkbox = QCheckBox(prop["display_name"])
240 | checkbox.setEnabled(prop["enabled"])
241 | checkbox.setStyleSheet("QCheckBox:disabled { color: grey }")
242 | checkbox.setChecked(self.checkbox_state[prop["region_prop_name"]])
243 | checkbox.stateChanged.connect(
244 | lambda state, prop=prop: self.checkbox_state.update(
245 | {prop["region_prop_name"]: state == 2}
246 | )
247 | )
248 | self.checkboxes.append(
249 | {"region_prop_name": prop["region_prop_name"], "checkbox": checkbox}
250 | )
251 |
252 | if prop["region_prop_name"] == "intensity_mean":
253 | self.intensity_image_dropdown = LayerDropdown(
254 | self.viewer, napari.layers.Image
255 | )
256 | self.intensity_image_dropdown.layer_changed.connect(
257 | self._update_intensity_checkbox
258 | )
259 | if self.intensity_image_dropdown.selected_layer is not None:
260 | checkbox.setEnabled(True)
261 | int_layout = QHBoxLayout()
262 | int_layout.addWidget(checkbox)
263 | int_layout.addWidget(self.intensity_image_dropdown)
264 | int_layout.setContentsMargins(0, 0, 0, 0)
265 | int_layout.setAlignment(Qt.AlignLeft | Qt.AlignVCenter)
266 |
267 | int_widget = QWidget()
268 | int_widget.setLayout(int_layout)
269 |
270 | # Enforce same height behavior as a normal checkbox
271 | int_widget.setSizePolicy(checkbox.sizePolicy())
272 | self.intensity_image_dropdown.setSizePolicy(
273 | QSizePolicy.Expanding, QSizePolicy.Fixed
274 | )
275 |
276 | self.checkbox_layout.addWidget(int_widget)
277 | else:
278 | self.checkbox_layout.addWidget(checkbox)
279 |
280 | def _update_intensity_checkbox(self) -> None:
281 | """Enable or disable the intensity_mean checkbox based on the selected layer."""
282 | if self.intensity_image_dropdown is not None:
283 | checkbox = next(
284 | (
285 | cb["checkbox"]
286 | for cb in self.checkboxes
287 | if cb["region_prop_name"] == "intensity_mean"
288 | ),
289 | None,
290 | )
291 | if checkbox is not None:
292 | checkbox.setEnabled(
293 | isinstance(
294 | self.intensity_image_dropdown.selected_layer,
295 | napari.layers.Image,
296 | )
297 | )
298 |
299 | def update_use_z(self, state: int) -> None:
300 | self.z_label.setVisible(state == 2)
301 | self.z_axis.setVisible(state == 2)
302 | self.z_scale.setVisible(state == 2)
303 | self.z_axis.setEnabled(state == 2)
304 | self.z_scale.setEnabled(state == 2)
305 | self.feature_dims = 3 if state == 2 else 2
306 | self.update_dims()
307 |
308 | def update_dims(self) -> None:
309 | """Update the number of dimensions to measure based on the selected checkboxes"""
310 |
311 | if self.label_manager.selected_layer is not None:
312 | self.measure_btn.setEnabled(True)
313 | self.ndims = self.label_manager.selected_layer.ndim
314 | self.use_z.setVisible(self.ndims == 3)
315 | self.use_z.setEnabled(self.ndims == 3)
316 | if self.ndims == 4:
317 | self.feature_dims = 3
318 | self.z_axis.setVisible(True)
319 | self.z_label.setVisible(True)
320 | self.z_scale.setVisible(True)
321 | self.z_axis.setEnabled(True)
322 | self.z_scale.setEnabled(True)
323 |
324 | ax_names = [str(ax) for ax in range(self.ndims)]
325 | if len(ax_names) > 0:
326 | for i, widget in enumerate(self.axis_widgets):
327 | if self.ndims == 4:
328 | widget.setText(ax_names[i + 1])
329 | elif self.ndims == 2:
330 | widget.setText(ax_names[i - 1])
331 | else:
332 | widget.setText(ax_names[i])
333 |
334 | self.update_properties()
335 | self.update_table()
336 |
337 | self.z_scale.setValue(self.label_manager.selected_layer.scale[-3]) if len(
338 | self.label_manager.selected_layer.scale
339 | ) >= 3 else self.z_scale.setValue(1.0)
340 | self.y_scale.setValue(self.label_manager.selected_layer.scale[-2])
341 | self.x_scale.setValue(self.label_manager.selected_layer.scale[-1])
342 |
343 | else:
344 | self.measure_btn.setEnabled(False)
345 |
346 | def measure(self):
347 | if self.use_z.isChecked() or self.ndims == 4:
348 | spacing = (self.z_scale.value(), self.y_scale.value(), self.x_scale.value())
349 | else:
350 | spacing = (self.y_scale.value(), self.x_scale.value())
351 |
352 | # ensure spacing is applied to the layer and the viewer step is updated
353 | layer_scale = list(self.label_manager.selected_layer.scale)
354 | layer_scale[-1] = spacing[-1]
355 | layer_scale[-2] = spacing[-2]
356 | if len(layer_scale) > 3:
357 | layer_scale[-3] = spacing[-3]
358 | old_step = list(self.viewer.dims.current_step)
359 | step_size = [dim_range.step for dim_range in self.viewer.dims.range]
360 | new_step = [
361 | step * step_size
362 | for step, step_size in zip(old_step, step_size, strict=False)
363 | ]
364 | self.label_manager.selected_layer.scale = layer_scale
365 | self.viewer.reset_view()
366 | self.viewer.dims.current_step = new_step
367 |
368 | features = self.get_selected_features()
369 |
370 | if (
371 | isinstance(
372 | self.intensity_image_dropdown.selected_layer, napari.layers.Image
373 | )
374 | and "intensity_mean" in features
375 | ):
376 | intensity_image = self.intensity_image_dropdown.selected_layer.data
377 | else:
378 | intensity_image = None
379 |
380 | data = self.label_manager.selected_layer.data
381 | if intensity_image is not None and intensity_image.shape != data.shape:
382 | msg = QMessageBox()
383 | msg.setWindowTitle("Shape mismatch")
384 | msg.setText(
385 | f"Label layer and intensity image must have the same shape. Got {self.label_manager.selected_layer.data.shape} and {intensity_image.shape}."
386 | )
387 | msg.setIcon(QMessageBox.Critical)
388 | msg.setStandardButtons(QMessageBox.Ok)
389 | msg.exec_()
390 | return
391 | if (self.use_z.isChecked() and self.ndims == 3) or (
392 | not self.use_z.isChecked() and self.ndims == 2
393 | ):
394 | props = calculate_extended_props(
395 | data,
396 | intensity_image=intensity_image,
397 | properties=features,
398 | spacing=spacing,
399 | )
400 | else:
401 | for i in tqdm(range(data.shape[0])):
402 | d = data[i].compute() if isinstance(data, da.core.Array) else data[i]
403 | if isinstance(intensity_image, da.core.Array):
404 | int_img = intensity_image[i].compute()
405 | elif isinstance(intensity_image, np.ndarray):
406 | int_img = intensity_image[i]
407 | else:
408 | int_img = None
409 | props_slice = calculate_extended_props(
410 | d,
411 | intensity_image=int_img,
412 | properties=features,
413 | spacing=spacing,
414 | )
415 | props_slice["time_point"] = i
416 | if i == 0:
417 | props = props_slice
418 | else:
419 | props = pd.concat([props, props_slice], ignore_index=True)
420 |
421 | if hasattr(self.label_manager.selected_layer, "properties"):
422 | self.label_manager.selected_layer.properties = props
423 | self.prop_filter_widget.set_properties()
424 | self.update_table()
425 |
426 | def update_table(self) -> None:
427 | if self.table is not None:
428 | self.table.hide()
429 | self.prop_filter_widget.setVisible(False)
430 |
431 | if (
432 | self.viewer is not None
433 | and self.label_manager.selected_layer is not None
434 | and len(self.label_manager.selected_layer.properties) > 0
435 | ):
436 | self.table = ColoredTableWidget(
437 | self.label_manager.selected_layer, self.viewer
438 | )
439 | self.table.setMinimumWidth(500)
440 | self.regionprops_layout.addWidget(self.table)
441 | self.prop_filter_widget.setVisible(True)
442 |
443 | def get_selected_features(self) -> list[str]:
444 | """Return a list of the features that have been selected"""
445 |
446 | selected_features = [
447 | key for key in self.checkbox_state if self.checkbox_state[key]
448 | ]
449 | if "intensity_mean" in selected_features and not isinstance(
450 | self.intensity_image_dropdown.selected_layer, napari.layers.Image
451 | ):
452 | selected_features.remove("intensity_mean")
453 | selected_features.append("label") # always include label
454 | selected_features.append("centroid")
455 | return selected_features
456 |
457 | def set_selected_features(self, features: list[str]) -> None:
458 | """Set the selected features based on the input list"""
459 |
460 | for checkbox in self.checkboxes:
461 | checkbox["checkbox"].setChecked(checkbox["region_prop_name"] in features)
462 |
--------------------------------------------------------------------------------
/src/napari_segmentation_correction/save_labels_widget.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import dask.array as da
4 | import napari
5 | import numpy as np
6 | import tifffile
7 | from napari.layers import Labels
8 | from qtpy.QtWidgets import (
9 | QCheckBox,
10 | QComboBox,
11 | QFileDialog,
12 | QGroupBox,
13 | QHBoxLayout,
14 | QLineEdit,
15 | QPushButton,
16 | QVBoxLayout,
17 | QWidget,
18 | )
19 |
20 | from .layer_manager import LayerManager
21 |
22 |
23 | class SaveLabelsWidget(QWidget):
24 | """Widget for saving label data with options for datatype, compression, and whether to split time points."""
25 |
26 | def __init__(
27 | self, viewer: "napari.viewer.Viewer", label_manager: LayerManager
28 | ) -> None:
29 | super().__init__()
30 |
31 | self.viewer = viewer
32 | self.label_manager = label_manager
33 |
34 | # Select datatype
35 | self.select_dtype = QComboBox()
36 | self.select_dtype.addItem("np.uint8")
37 | self.select_dtype.addItem("np.uint16")
38 | self.select_dtype.addItem("np.uint32")
39 | self.select_dtype.addItem("np.uint64")
40 | self.select_dtype.setToolTip("File bit depth for saving.")
41 |
42 | # Split time points
43 | self.split_time_points = QCheckBox("Split time points")
44 | self.split_time_points.setEnabled(
45 | isinstance(self.label_manager.selected_layer, napari.layers.Labels)
46 | and self.label_manager.selected_layer.data.ndim >= 3
47 | )
48 | self.label_manager.layer_update.connect(
49 | lambda: self.split_time_points.setEnabled(
50 | isinstance(self.label_manager.selected_layer, napari.layers.Labels)
51 | and self.label_manager.selected_layer.data.ndim >= 3
52 | )
53 | )
54 | self.split_time_points.setToolTip(
55 | "Saves each time point to a separate file. Assumes that the time dimension is along the first axis."
56 | )
57 |
58 | # Use compression
59 | self.use_compression = QCheckBox("Use compression")
60 | self.use_compression.setChecked(True)
61 | self.use_compression.setToolTip(
62 | "Use zstd compression? This may take a bit longer to save."
63 | )
64 |
65 | # Filename
66 | self.filename = QLineEdit()
67 | self.filename.setPlaceholderText("File name")
68 | self.label_manager.layer_update.connect(self.update_filename)
69 | self.filename.setToolTip("Filename for saving labels")
70 |
71 | ## Add save button
72 | self.save_btn = QPushButton("Save labels")
73 | self.save_btn.clicked.connect(self._save_labels)
74 | self.save_btn.setEnabled(isinstance(self.label_manager.selected_layer, Labels))
75 | self.label_manager.layer_update.connect(
76 | lambda: self.save_btn.setEnabled(
77 | isinstance(self.label_manager.selected_layer, napari.layers.Labels)
78 | )
79 | )
80 |
81 | # Combine layouts
82 | save_box = QGroupBox("Save labels")
83 | layout = QVBoxLayout()
84 |
85 | settings_layout = QHBoxLayout()
86 | settings_layout.addWidget(self.select_dtype)
87 | settings_layout.addWidget(self.split_time_points)
88 | settings_layout.addWidget(self.use_compression)
89 | layout.addLayout(settings_layout)
90 | layout.addWidget(self.filename)
91 | layout.addWidget(self.save_btn)
92 | save_box.setLayout(layout)
93 |
94 | main_layout = QVBoxLayout()
95 | main_layout.addWidget(save_box)
96 | self.setLayout(main_layout)
97 |
98 | def update_filename(self) -> None:
99 | """Update the default filename"""
100 |
101 | if isinstance(self.label_manager.selected_layer, napari.layers.Labels):
102 | self.filename.setText(self.label_manager.selected_layer.name)
103 |
104 | def _save_labels(self) -> None:
105 | """Save the currently active labels layer. If it consists of multiple timepoints, they are written to multiple 3D stacks."""
106 |
107 | data = self.label_manager.selected_layer.data
108 | ndim = data.ndim
109 | split_time_points = ndim >= 3 and self.split_time_points.isChecked()
110 | dtype_map = {
111 | "np.uint8": np.uint8,
112 | "np.uint16": np.uint16,
113 | "np.uint32": np.uint32,
114 | "np.uint64": np.uint64,
115 | }
116 | dtype = dtype_map[self.select_dtype.currentText()]
117 | use_compression = self.use_compression.isChecked()
118 | filename = self.filename.text()
119 |
120 | destination = QFileDialog.getExistingDirectory(self, "Select Output Folder")
121 |
122 | if ndim >= 3 and split_time_points:
123 | for i in range(data.shape[0]):
124 | if isinstance(data, da.core.Array):
125 | current_stack = data[i].compute().astype(dtype)
126 | else:
127 | current_stack = data[i].astype(dtype)
128 |
129 | tifffile.imwrite(
130 | (
131 | os.path.join(
132 | destination,
133 | (
134 | filename.split(".tif")[0]
135 | + "_TP"
136 | + str(i).zfill(4)
137 | + ".tif"
138 | ),
139 | )
140 | ),
141 | current_stack,
142 | compression="zstd" if use_compression else None,
143 | )
144 |
145 | else:
146 | tifffile.imwrite(
147 | os.path.join(destination, (filename + ".tif")),
148 | data.astype(dtype),
149 | compression="zstd" if use_compression else None,
150 | )
151 |
--------------------------------------------------------------------------------
/src/napari_segmentation_correction/select_delete_widget.py:
--------------------------------------------------------------------------------
1 | import copy
2 | import functools
3 | import os
4 | import shutil
5 |
6 | import dask.array as da
7 | import napari
8 | import numpy as np
9 | import tifffile
10 | from napari.layers import Labels
11 | from qtpy.QtWidgets import (
12 | QCheckBox,
13 | QFileDialog,
14 | QGroupBox,
15 | QHBoxLayout,
16 | QLabel,
17 | QMessageBox,
18 | QPushButton,
19 | QVBoxLayout,
20 | QWidget,
21 | )
22 | from skimage.io import imread
23 |
24 | from .layer_dropdown import LayerDropdown
25 |
26 |
27 | class SelectDeleteMask(QWidget):
28 | """Widget to select labels to keep or to delete based on overlap with a mask."""
29 |
30 | def __init__(self, viewer: "napari.viewer.Viewer") -> None:
31 | super().__init__()
32 |
33 | self.viewer = viewer
34 | self.image1_layer = None
35 | self.mask_layer = None
36 | self.outputdir = None
37 |
38 | ### Add one image to another
39 | select_delete_box = QGroupBox("Select / Delete labels by mask")
40 | select_delete_box_layout = QVBoxLayout()
41 |
42 | image1_layout = QHBoxLayout()
43 | image1_layout.addWidget(QLabel("Labels"))
44 | self.image1_dropdown = LayerDropdown(self.viewer, (Labels))
45 | self.image1_dropdown.layer_changed.connect(self._update_image1)
46 | image1_layout.addWidget(self.image1_dropdown)
47 |
48 | image2_layout = QHBoxLayout()
49 | image2_layout.addWidget(QLabel("Mask"))
50 | self.mask_dropdown = LayerDropdown(self.viewer, (Labels))
51 | self.mask_dropdown.layer_changed.connect(self._update_image2)
52 | image2_layout.addWidget(self.mask_dropdown)
53 |
54 | select_delete_box_layout.addLayout(image1_layout)
55 | select_delete_box_layout.addLayout(image2_layout)
56 |
57 | self.stack_checkbox = QCheckBox("Apply 3D mask to all time points in 4D array")
58 | self.stack_checkbox.setEnabled(False)
59 | select_delete_box_layout.addWidget(self.stack_checkbox)
60 |
61 | self.select_btn = QPushButton("Select labels")
62 | self.select_btn.clicked.connect(self.select_labels)
63 | select_delete_box_layout.addWidget(self.select_btn)
64 |
65 | self.delete_btn = QPushButton("Delete labels")
66 | self.delete_btn.clicked.connect(self.delete_labels)
67 | select_delete_box_layout.addWidget(self.delete_btn)
68 |
69 | self.image1_dropdown.layer_changed.connect(self._update_buttons)
70 | self.mask_dropdown.layer_changed.connect(self._update_buttons)
71 | self._update_buttons()
72 |
73 | select_delete_box.setLayout(select_delete_box_layout)
74 | main_layout = QVBoxLayout()
75 | main_layout.addWidget(select_delete_box)
76 | self.setLayout(main_layout)
77 |
78 | def _update_buttons(self) -> None:
79 | """Update button state according to whether image layers are present"""
80 |
81 | active = (
82 | self.image1_dropdown.selected_layer is not None
83 | and self.mask_dropdown.selected_layer is not None
84 | )
85 | self.select_btn.setEnabled(active)
86 | self.delete_btn.setEnabled(active)
87 |
88 | def _update_image1(self, selected_layer: str) -> None:
89 | """Update the layer that is set to be the 'source labels' layer for copying labels from."""
90 |
91 | if selected_layer == "":
92 | self.image1_layer = None
93 | else:
94 | self.image1_layer = self.viewer.layers[selected_layer]
95 | self.image1_dropdown.setCurrentText(selected_layer)
96 |
97 | # update the checkbox and buttons as needed needed
98 | if self.mask_layer is not None and self.image1_layer is not None:
99 | self.select_btn.setEnabled(True)
100 | self.delete_btn.setEnabled(True)
101 | if len(self.image1_layer.data.shape) == len(self.mask_layer.data.shape) + 1:
102 | self.stack_checkbox.setEnabled(True)
103 | else:
104 | self.stack_checkbox.setEnabled(False)
105 | self.stack_checkbox.setCheckState(False)
106 | else:
107 | self.select_btn.setEnabled(False)
108 | self.delete_btn.setEnabled(False)
109 | self.stack_checkbox.setEnabled(False)
110 | self.stack_checkbox.setCheckState(False)
111 |
112 | def _update_image2(self, selected_layer: str) -> None:
113 | """Update the layer that is set to be the 'source labels' layer for copying labels from."""
114 |
115 | if selected_layer == "":
116 | self.mask_layer = None
117 | else:
118 | self.mask_layer = self.viewer.layers[selected_layer]
119 | self.mask_dropdown.setCurrentText(selected_layer)
120 |
121 | # update the checkbox and buttons as needed
122 | if self.mask_layer is not None and self.image1_layer is not None:
123 | self.select_btn.setEnabled(True)
124 | self.delete_btn.setEnabled(True)
125 | if len(self.image1_layer.data.shape) == len(self.mask_layer.data.shape) + 1:
126 | self.stack_checkbox.setEnabled(True)
127 | else:
128 | self.stack_checkbox.setEnabled(False)
129 | self.stack_checkbox.setCheckState(False)
130 | else:
131 | self.select_btn.setEnabled(False)
132 | self.delete_btn.setEnabled(False)
133 | self.stack_checkbox.setEnabled(False)
134 | self.stack_checkbox.setCheckState(False)
135 |
136 | def select_labels(self):
137 | # check data dimensions first
138 | image_shape = self.image1_layer.data.shape
139 | mask_shape = self.mask_layer.data.shape
140 |
141 | if len(image_shape) == len(mask_shape) + 1 and image_shape[1:] == mask_shape:
142 | # apply mask to single time point or to full stack depending on checkbox state
143 | if self.stack_checkbox.isChecked():
144 | # loop over all time points
145 | print("applying the mask to all time points")
146 | # check if the data is a dask array
147 | if isinstance(self.image1_layer.data, da.core.Array):
148 | if self.outputdir is None:
149 | self.outputdir = QFileDialog.getExistingDirectory(
150 | self, "Select Output Folder"
151 | )
152 |
153 | outputdir = os.path.join(
154 | self.outputdir,
155 | (self.image1_layer.name + "_filtered_labels"),
156 | )
157 | if os.path.exists(outputdir):
158 | shutil.rmtree(outputdir)
159 | os.mkdir(outputdir)
160 |
161 | for i in range(
162 | self.image1_layer.data.shape[0]
163 | ): # Loop over the first dimension
164 | current_stack = self.image1_layer.data[
165 | i
166 | ].compute() # Compute the current stack
167 |
168 | to_keep = np.unique(current_stack[self.mask_layer.data > 0])
169 | filtered_mask = functools.reduce(
170 | np.logical_or, (current_stack == val for val in to_keep)
171 | )
172 | filtered_data = np.where(filtered_mask, current_stack, 0)
173 |
174 | tifffile.imwrite(
175 | os.path.join(
176 | outputdir,
177 | (
178 | self.image1_layer.name
179 | + "_filtered_labels_TP"
180 | + str(i).zfill(4)
181 | + ".tif"
182 | ),
183 | ),
184 | np.array(filtered_data, dtype="uint16"),
185 | )
186 |
187 | file_list = [
188 | os.path.join(outputdir, fname)
189 | for fname in os.listdir(outputdir)
190 | if fname.endswith(".tif")
191 | ]
192 | self.image1_layer = self.viewer.add_labels(
193 | da.stack([imread(fname) for fname in sorted(file_list)]),
194 | name=self.image1_layer.name + "_filtered_labels",
195 | scale=self.image1_layer.scale,
196 | )
197 |
198 | else:
199 | for tp in range(self.image1_layer.data.shape[0]):
200 | to_keep = np.unique(
201 | self.image1_layer.data[tp][self.mask_layer.data > 0]
202 | )
203 | filtered_mask = functools.reduce(
204 | np.logical_or,
205 | (self.image1_layer.data[tp] == val for val in to_keep),
206 | )
207 | filtered_data_tp = np.where(
208 | filtered_mask, self.image1_layer.data[tp], 0
209 | )
210 | self.image1_layer.data[tp] = filtered_data_tp
211 |
212 | else:
213 | tp = self.viewer.dims.current_step[0]
214 | print("applying the mask to the current time point only", tp)
215 | if isinstance(self.image1_layer.data, da.core.Array):
216 | outputdir = QFileDialog.getExistingDirectory(
217 | self,
218 | "Please select the directory that holds the images. Data will be changed here. Selecting a new empty directory will create a copy of all data",
219 | )
220 |
221 | if len(os.listdir(outputdir)) == 0:
222 | for i in range(
223 | self.image1_layer.data.shape[0]
224 | ): # Loop over the first dimension
225 | current_stack = self.image1_layer.data[
226 | i
227 | ].compute() # Compute the current stack
228 |
229 | if i == tp:
230 | to_keep = np.unique(
231 | current_stack[self.mask_layer.data > 0]
232 | )
233 | filtered_mask = functools.reduce(
234 | np.logical_or,
235 | (current_stack == val for val in to_keep),
236 | )
237 | current_stack = np.where(
238 | filtered_mask, current_stack, 0
239 | )
240 | tifffile.imwrite(
241 | os.path.join(
242 | outputdir,
243 | (
244 | self.image1_layer.name
245 | + "_filtered_labels_TP"
246 | + str(i).zfill(4)
247 | + ".tif"
248 | ),
249 | ),
250 | np.array(current_stack, dtype="uint16"),
251 | )
252 |
253 | file_list = sorted(
254 | [
255 | os.path.join(outputdir, fname)
256 | for fname in os.listdir(outputdir)
257 | if fname.endswith(".tif")
258 | ]
259 | )
260 | else:
261 | current_stack = self.image1_layer.data[
262 | tp
263 | ].compute() # Compute the current stack
264 |
265 | to_keep = np.unique(current_stack[self.mask_layer.data > 0])
266 | filtered_mask = functools.reduce(
267 | np.logical_or, (current_stack == val for val in to_keep)
268 | )
269 | current_stack = np.where(filtered_mask, current_stack, 0)
270 |
271 | file_list = sorted(
272 | [
273 | os.path.join(outputdir, fname)
274 | for fname in os.listdir(outputdir)
275 | if fname.endswith(".tif")
276 | ]
277 | )
278 |
279 | tifffile.imwrite(
280 | file_list[tp],
281 | np.array(current_stack, dtype="uint16"),
282 | )
283 |
284 | self.image1_layer = self.viewer.add_labels(
285 | da.stack([imread(fname) for fname in file_list]),
286 | name=self.image1_layer.name + "_filtered_labels",
287 | scale=self.image1_layer.scale,
288 | )
289 |
290 | else:
291 | tp = self.viewer.dims.current_step[0]
292 | to_keep = np.unique(
293 | self.image1_layer.data[tp][self.mask_layer.data > 0]
294 | )
295 | filtered_mask = functools.reduce(
296 | np.logical_or,
297 | (self.image1_layer.data[tp] == val for val in to_keep),
298 | )
299 | filtered_data_tp = np.where(
300 | filtered_mask, self.image1_layer.data[tp], 0
301 | )
302 | self.image1_layer.data[tp] = filtered_data_tp
303 |
304 | elif image_shape == mask_shape:
305 | if isinstance(self.image1_layer.data, da.core.Array):
306 | if self.outputdir is None:
307 | self.outputdir = QFileDialog.getExistingDirectory(
308 | self, "Select Output Folder"
309 | )
310 |
311 | outputdir = os.path.join(
312 | self.outputdir,
313 | (self.image1_layer.name + "_filtered_labels"),
314 | )
315 | if os.path.exists(outputdir):
316 | shutil.rmtree(outputdir)
317 | os.mkdir(outputdir)
318 |
319 | for i in range(
320 | self.image1_layer.data.shape[0]
321 | ): # Loop over the first dimension
322 | current_stack = self.image1_layer.data[
323 | i
324 | ].compute() # Compute the current stack
325 |
326 | if isinstance(self.mask_layer.data, da.core.Array):
327 | to_keep = np.unique(
328 | current_stack[self.mask_layer.data[i].compute() > 0]
329 | )
330 | else:
331 | to_keep = np.unique(current_stack[self.mask_layer.data[i] > 0])
332 |
333 | filtered_mask = functools.reduce(
334 | np.logical_or, (current_stack == val for val in to_keep)
335 | )
336 | filtered_data_tp = np.where(filtered_mask, current_stack, 0)
337 |
338 | tifffile.imwrite(
339 | os.path.join(
340 | outputdir,
341 | (
342 | self.image1_layer.name
343 | + "_filtered_labels_TP"
344 | + str(i).zfill(4)
345 | + ".tif"
346 | ),
347 | ),
348 | np.array(filtered_data_tp, dtype="uint16"),
349 | )
350 |
351 | file_list = [
352 | os.path.join(outputdir, fname)
353 | for fname in os.listdir(outputdir)
354 | if fname.endswith(".tif")
355 | ]
356 | self.image1_layer = self.viewer.add_labels(
357 | da.stack([imread(fname) for fname in sorted(file_list)]),
358 | name=self.image1_layer.name + "_filtered_labels",
359 | scale=self.image1_layer.scale,
360 | )
361 | else:
362 | to_keep = np.unique(self.image1_layer.data[self.mask_layer.data > 0])
363 | filtered_mask = functools.reduce(
364 | np.logical_or, (self.image1_layer.data == val for val in to_keep)
365 | )
366 | self.viewer.add_labels(
367 | np.where(filtered_mask, self.image1_layer.data, 0),
368 | name="selected labels",
369 | scale=self.image1_layer.scale,
370 | )
371 |
372 | else:
373 | msg = QMessageBox()
374 | msg.setWindowTitle("Images do not have compatible shapes")
375 | msg.setText("Please provide images that have matching dimensions")
376 | msg.setIcon(QMessageBox.Information)
377 | msg.setStandardButtons(QMessageBox.Ok)
378 | msg.exec_()
379 | return False
380 |
381 | def delete_labels(self):
382 | """Delete labels that overlap with given mask. If the shape of the mask has 1 dimension less than the image, the mask will be applied to the current time point (index in the first dimension) of the image data."""
383 |
384 | if isinstance(self.mask_layer.data, da.core.Array):
385 | msg = QMessageBox()
386 | msg.setWindowTitle("Please provide a mask that is not a Dask array")
387 | msg.setText("Please provide a mask that is not a Dask array")
388 | msg.setIcon(QMessageBox.Information)
389 | msg.setStandardButtons(QMessageBox.Ok)
390 | msg.exec_()
391 | return False
392 |
393 | # check data dimensions first
394 | image_shape = self.image1_layer.data.shape
395 | mask_shape = self.mask_layer.data.shape
396 |
397 | if len(image_shape) == len(mask_shape) + 1 and image_shape[1:] == mask_shape:
398 | # apply mask to single time point or to full stack depending on checkbox state
399 | if self.stack_checkbox.isChecked():
400 | # loop over all time points
401 | print("applying the mask to all time points")
402 | # check if the data is a dask array
403 | if isinstance(self.image1_layer.data, da.core.Array):
404 | if self.outputdir is None:
405 | self.outputdir = QFileDialog.getExistingDirectory(
406 | self, "Select Output Folder"
407 | )
408 |
409 | outputdir = os.path.join(
410 | self.outputdir,
411 | (self.image1_layer.name + "_filtered_labels"),
412 | )
413 | if os.path.exists(outputdir):
414 | shutil.rmtree(outputdir)
415 | os.mkdir(outputdir)
416 |
417 | for i in range(
418 | self.image1_layer.data.shape[0]
419 | ): # Loop over the first dimension
420 | current_stack = self.image1_layer.data[
421 | i
422 | ].compute() # Compute the current stack
423 |
424 | to_delete = np.unique(current_stack[self.mask_layer.data > 0])
425 | for label in to_delete:
426 | current_stack[current_stack == label] = 0
427 | tifffile.imwrite(
428 | os.path.join(
429 | outputdir,
430 | (
431 | self.image1_layer.name
432 | + "_filtered_labels_TP"
433 | + str(i).zfill(4)
434 | + ".tif"
435 | ),
436 | ),
437 | np.array(current_stack, dtype="uint16"),
438 | )
439 |
440 | file_list = [
441 | os.path.join(outputdir, fname)
442 | for fname in os.listdir(outputdir)
443 | if fname.endswith(".tif")
444 | ]
445 | self.image1_layer = self.viewer.add_labels(
446 | da.stack([imread(fname) for fname in sorted(file_list)]),
447 | name=self.image1_layer.name + "_filtered_labels",
448 | scale=self.image1_layer.scale,
449 | )
450 |
451 | else:
452 | for tp in range(self.image1_layer.data.shape[0]):
453 | to_delete = np.unique(
454 | self.image1_layer.data[tp][self.mask_layer.data > 0]
455 | )
456 | for label in to_delete:
457 | self.image1_layer.data[tp][
458 | self.image1_layer.data[tp] == label
459 | ] = 0
460 |
461 | else:
462 | tp = self.viewer.dims.current_step[0]
463 | if isinstance(self.image1_layer.data, da.core.Array):
464 | outputdir = QFileDialog.getExistingDirectory(
465 | self,
466 | "Please select the directory that holds the images. Data will be changed here. Selecting a new empty directory will create a copy of all data",
467 | )
468 |
469 | if len(os.listdir(outputdir)) == 0:
470 | for i in range(
471 | self.image1_layer.data.shape[0]
472 | ): # Loop over the first dimension
473 | current_stack = self.image1_layer.data[
474 | i
475 | ].compute() # Compute the current stack
476 |
477 | if i == tp:
478 | to_delete = np.unique(
479 | current_stack[self.mask_layer.data > 0]
480 | )
481 | for label in to_delete:
482 | current_stack[current_stack == label] = 0
483 | tifffile.imwrite(
484 | os.path.join(
485 | outputdir,
486 | (
487 | self.image1_layer.name
488 | + "_filtered_labels_TP"
489 | + str(i).zfill(4)
490 | + ".tif"
491 | ),
492 | ),
493 | np.array(current_stack, dtype="uint16"),
494 | )
495 |
496 | file_list = sorted(
497 | [
498 | os.path.join(outputdir, fname)
499 | for fname in os.listdir(outputdir)
500 | if fname.endswith(".tif")
501 | ]
502 | )
503 | else:
504 | current_stack = self.image1_layer.data[
505 | tp
506 | ].compute() # Compute the current stack
507 | to_delete = np.unique(current_stack[self.mask_layer.data > 0])
508 | for label in to_delete:
509 | current_stack[current_stack == label] = 0
510 |
511 | file_list = sorted(
512 | [
513 | os.path.join(outputdir, fname)
514 | for fname in os.listdir(outputdir)
515 | if fname.endswith(".tif")
516 | ]
517 | )
518 |
519 | tifffile.imwrite(
520 | file_list[tp],
521 | np.array(current_stack, dtype="uint16"),
522 | )
523 |
524 | self.image1_layer = self.viewer.add_labels(
525 | da.stack([imread(fname) for fname in file_list]),
526 | name=self.image1_layer.name + "_filtered_labels",
527 | scale=self.image1_layer.scale,
528 | )
529 |
530 | else:
531 | tp = self.viewer.dims.current_step[0]
532 | to_delete = np.unique(
533 | self.image1_layer.data[tp][self.mask_layer.data > 0]
534 | )
535 | for label in to_delete:
536 | self.image1_layer.data[tp][
537 | self.image1_layer.data[tp] == label
538 | ] = 0
539 |
540 | elif image_shape == mask_shape:
541 | if isinstance(self.image1_layer.data, da.core.Array):
542 | if self.outputdir is None:
543 | self.outputdir = QFileDialog.getExistingDirectory(
544 | self, "Select Output Folder"
545 | )
546 |
547 | outputdir = os.path.join(
548 | self.outputdir,
549 | (self.image1_layer.name + "_filtered_labels"),
550 | )
551 | if os.path.exists(outputdir):
552 | shutil.rmtree(outputdir)
553 | os.mkdir(outputdir)
554 |
555 | for i in range(
556 | self.image1_layer.data.shape[0]
557 | ): # Loop over the first dimension
558 | current_stack = self.image1_layer.data[
559 | i
560 | ].compute() # Compute the current stack
561 |
562 | to_delete = np.unique(current_stack[self.mask_layer.data[tp] > 0])
563 | for label in to_delete:
564 | current_stack[current_stack == label] = 0
565 | tifffile.imwrite(
566 | os.path.join(
567 | outputdir,
568 | (
569 | self.image1_layer.name
570 | + "_filtered_labels_TP"
571 | + str(i).zfill(4)
572 | + ".tif"
573 | ),
574 | ),
575 | np.array(current_stack, dtype="uint16"),
576 | )
577 |
578 | file_list = [
579 | os.path.join(outputdir, fname)
580 | for fname in os.listdir(outputdir)
581 | if fname.endswith(".tif")
582 | ]
583 | self.image1_layer = self.viewer.add_labels(
584 | da.stack([imread(fname) for fname in sorted(file_list)]),
585 | name=self.image1_layer.name + "_filtered_labels",
586 | scale=self.image1_layer.scale,
587 | )
588 | else:
589 | to_delete = np.unique(self.image1_layer.data[self.mask_layer.data > 0])
590 | selected_labels = copy.deepcopy(self.image1_layer.data)
591 | for label in to_delete:
592 | selected_labels[selected_labels == label] = 0
593 | self.viewer.add_labels(
594 | selected_labels,
595 | name=self.image1_layer.name + "_filtered_labels",
596 | scale=self.image1_layer.scale,
597 | )
598 |
599 | else:
600 | msg = QMessageBox()
601 | msg.setWindowTitle("Images do not have compatible shapes")
602 | msg.setText("Please provide images that have matching dimensions")
603 | msg.setIcon(QMessageBox.Information)
604 | msg.setStandardButtons(QMessageBox.Ok)
605 | msg.exec_()
606 | return False
607 |
--------------------------------------------------------------------------------
/src/napari_segmentation_correction/smoothing_widget.py:
--------------------------------------------------------------------------------
1 | import os
2 | import shutil
3 |
4 | import dask.array as da
5 | import napari
6 | import numpy as np
7 | import tifffile
8 | from qtpy.QtWidgets import (
9 | QGroupBox,
10 | QHBoxLayout,
11 | QLabel,
12 | QMessageBox,
13 | QPushButton,
14 | QSpinBox,
15 | QVBoxLayout,
16 | QWidget,
17 | )
18 | from scipy import ndimage
19 | from skimage.io import imread
20 |
21 | from .layer_manager import LayerManager
22 |
23 |
24 | class SmoothingWidget(QWidget):
25 | """Widget that 'smooths' labels by applying a median filter"""
26 |
27 | def __init__(
28 | self, viewer: "napari.viewer.Viewer", label_manager: LayerManager
29 | ) -> None:
30 | super().__init__()
31 |
32 | self.viewer = viewer
33 | self.label_manager = label_manager
34 | self.outputdir = None
35 |
36 | smoothbox = QGroupBox("Smooth objects")
37 | smooth_boxlayout = QVBoxLayout()
38 |
39 | smooth_layout = QHBoxLayout()
40 | self.median_radius_field = QSpinBox()
41 | self.median_radius_field.setMaximum(100)
42 | self.smooth_btn = QPushButton("Smooth")
43 | smooth_layout.addWidget(self.median_radius_field)
44 | smooth_layout.addWidget(self.smooth_btn)
45 |
46 | smooth_boxlayout.addWidget(QLabel("Median filter radius"))
47 | smooth_boxlayout.addLayout(smooth_layout)
48 |
49 | self.smooth_btn.clicked.connect(self._smooth_objects)
50 | self.smooth_btn.setEnabled(
51 | isinstance(self.label_manager.selected_layer, napari.layers.Labels)
52 | )
53 | self.label_manager.layer_update.connect(
54 | lambda: self.smooth_btn.setEnabled(
55 | isinstance(self.label_manager.selected_layer, napari.layers.Labels)
56 | )
57 | )
58 |
59 | smoothbox.setLayout(smooth_boxlayout)
60 | layout = QVBoxLayout()
61 | layout.addWidget(smoothbox)
62 | self.setLayout(layout)
63 |
64 | def _smooth_objects(self) -> None:
65 | """Smooth objects by using a median filter."""
66 |
67 | if isinstance(self.label_manager.selected_layer.data, da.core.Array):
68 | if self.outputdir is None:
69 | msg = QMessageBox()
70 | msg.setWindowTitle("No output directory selected")
71 | msg.setText("Please specify an output directory first!")
72 | msg.setIcon(QMessageBox.Information)
73 | msg.setStandardButtons(QMessageBox.Ok)
74 | msg.exec_()
75 | return False
76 |
77 | else:
78 | outputdir = os.path.join(
79 | self.outputdir,
80 | (self.label_manager.selected_layer.name + "_smoothed"),
81 | )
82 | if os.path.exists(outputdir):
83 | shutil.rmtree(outputdir)
84 | os.mkdir(outputdir)
85 |
86 | for i in range(
87 | self.label_manager.selected_layer.data.shape[0]
88 | ): # Loop over the first dimension
89 | current_stack = self.label_manager.selected_layer.data[
90 | i
91 | ].compute() # Compute the current stack
92 | smoothed = ndimage.median_filter(
93 | current_stack, size=self.median_radius_field.value()
94 | )
95 | tifffile.imwrite(
96 | os.path.join(
97 | outputdir,
98 | (
99 | self.label_manager.selected_layer.name
100 | + "_smoothed_TP"
101 | + str(i).zfill(4)
102 | + ".tif"
103 | ),
104 | ),
105 | np.array(smoothed, dtype="uint16"),
106 | )
107 |
108 | file_list = [
109 | os.path.join(outputdir, fname)
110 | for fname in os.listdir(outputdir)
111 | if fname.endswith(".tif")
112 | ]
113 | self.label_manager.selected_layer = self.viewer.add_labels(
114 | da.stack([imread(fname) for fname in sorted(file_list)]),
115 | name=self.label_manager.selected_layer.name + "_smoothed",
116 | scale=self.label_manager.selected_layer.scale,
117 | )
118 |
119 | else:
120 | if len(self.label_manager.selected_layer.data.shape) == 4:
121 | stack = []
122 | for i in range(self.label_manager.selected_layer.data.shape[0]):
123 | smoothed = ndimage.median_filter(
124 | self.label_manager.selected_layer.data[i],
125 | size=self.median_radius_field.value(),
126 | )
127 | stack.append(smoothed)
128 | self.label_manager.selected_layer = self.viewer.add_labels(
129 | np.stack(stack, axis=0),
130 | name=self.label_manager.selected_layer.name + "_smoothed",
131 | scale=self.label_manager.selected_layer.scale,
132 | )
133 |
134 | elif self.label_manager.selected_layer.data.ndim in (2, 3):
135 | self.label_manager.selected_layer = self.viewer.add_labels(
136 | ndimage.median_filter(
137 | self.label_manager.selected_layer.data,
138 | size=self.median_radius_field.value(),
139 | ),
140 | name=self.label_manager.selected_layer.name + "_smoothed",
141 | scale=self.label_manager.selected_layer.scale,
142 | )
143 | else:
144 | print("input should be a 2D, 3D or 4D array")
145 |
--------------------------------------------------------------------------------
/src/napari_segmentation_correction/threshold_widget.py:
--------------------------------------------------------------------------------
1 | import os
2 | import shutil
3 |
4 | import dask.array as da
5 | import napari
6 | import numpy as np
7 | import tifffile
8 | from napari.layers import Image, Labels
9 | from qtpy.QtWidgets import (
10 | QFileDialog,
11 | QGroupBox,
12 | QHBoxLayout,
13 | QLabel,
14 | QPushButton,
15 | QSpinBox,
16 | QVBoxLayout,
17 | QWidget,
18 | )
19 | from skimage.io import imread
20 |
21 | from .layer_dropdown import LayerDropdown
22 |
23 |
24 | class ThresholdWidget(QWidget):
25 | """Widget that applies a threshold to an image or labels layer"""
26 |
27 | def __init__(self, viewer: "napari.viewer.Viewer") -> None:
28 | super().__init__()
29 |
30 | self.viewer = viewer
31 | self.outputdir = None
32 | self.threshold_layer = None
33 |
34 | threshold_box = QGroupBox("Threshold")
35 | threshold_box_layout = QVBoxLayout()
36 |
37 | self.threshold_layer_dropdown = LayerDropdown(
38 | self.viewer, (Image, Labels)
39 | )
40 | self.threshold_layer_dropdown.layer_changed.connect(
41 | self._update_threshold_layer
42 | )
43 | threshold_box_layout.addWidget(self.threshold_layer_dropdown)
44 |
45 | min_threshold_layout = QHBoxLayout()
46 | min_threshold_layout.addWidget(QLabel("Min value"))
47 | self.min_threshold = QSpinBox()
48 | self.min_threshold.setMaximum(65535)
49 | min_threshold_layout.addWidget(self.min_threshold)
50 |
51 | max_threshold_layout = QHBoxLayout()
52 | max_threshold_layout.addWidget(QLabel("Max value"))
53 | self.max_threshold = QSpinBox()
54 | self.max_threshold.setMaximum(65535)
55 | self.max_threshold.setValue(65535)
56 | max_threshold_layout.addWidget(self.max_threshold)
57 |
58 | threshold_box_layout.addLayout(min_threshold_layout)
59 | threshold_box_layout.addLayout(max_threshold_layout)
60 | threshold_btn = QPushButton("Run")
61 | threshold_btn.clicked.connect(self._threshold)
62 |
63 | threshold_btn.setEnabled(isinstance(self.threshold_layer_dropdown.selected_layer, napari.layers.Labels | napari.layers.Image))
64 | self.threshold_layer_dropdown.layer_changed.connect(
65 | lambda: threshold_btn.setEnabled(isinstance(self.threshold_layer_dropdown.selected_layer, napari.layers.Labels | napari.layers.Image))
66 | )
67 |
68 | threshold_box_layout.addWidget(threshold_btn)
69 | threshold_box.setLayout(threshold_box_layout)
70 |
71 | layout = QVBoxLayout()
72 | layout.addWidget(threshold_box)
73 | self.setLayout(layout)
74 |
75 | def _update_threshold_layer(self, selected_layer) -> None:
76 | """Update the layer that is set to be the 'source labels' layer for copying labels from."""
77 |
78 | if selected_layer == "":
79 | self.threshold_layer = None
80 | else:
81 | self.threshold_layer = self.viewer.layers[selected_layer]
82 | self.threshold_layer_dropdown.setCurrentText(selected_layer)
83 |
84 | def _threshold(self):
85 | """Threshold the selected label or intensity image"""
86 |
87 | if isinstance(self.threshold_layer.data, da.core.Array):
88 | if self.outputdir is None:
89 | self.outputdir = QFileDialog.getExistingDirectory(self, "Select Output Folder")
90 |
91 | outputdir = os.path.join(
92 | self.outputdir,
93 | (self.threshold_layer.name + "_threshold"),
94 | )
95 | if os.path.exists(outputdir):
96 | shutil.rmtree(outputdir)
97 | os.mkdir(outputdir)
98 |
99 | for i in range(
100 | self.threshold_layer.data.shape[0]
101 | ): # Loop over the first dimension
102 | data = self.threshold_layer.data[
103 | i
104 | ].compute() # Compute the current stack
105 |
106 | thresholded = (
107 | data >= int(self.min_threshold.value())
108 | ) & (data <= int(self.max_threshold.value()))
109 |
110 | tifffile.imwrite(
111 | os.path.join(
112 | outputdir,
113 | (
114 | self.threshold_layer.name
115 | + "_thresholded_TP"
116 | + str(i).zfill(4)
117 | + ".tif"
118 | ),
119 | ),
120 | np.array(thresholded, dtype="uint8"),
121 | )
122 |
123 | file_list = [
124 | os.path.join(outputdir, fname)
125 | for fname in os.listdir(outputdir)
126 | if fname.endswith(".tif")
127 | ]
128 | self.viewer.add_labels(
129 | da.stack([imread(fname) for fname in sorted(file_list)]),
130 | name=self.threshold_layer.name + "_thresholded",
131 | scale=self.threshold_layer.scale
132 | )
133 |
134 | else:
135 | thresholded = (
136 | self.threshold_layer.data >= int(self.min_threshold.value())
137 | ) & (self.threshold_layer.data <= int(self.max_threshold.value()))
138 | self.viewer.add_labels(
139 | thresholded, name=self.threshold_layer.name + "_thresholded",
140 | scale=self.threshold_layer.scale
141 | )
142 |
--------------------------------------------------------------------------------
/tox.ini:
--------------------------------------------------------------------------------
1 | # For more information about tox, see https://tox.readthedocs.io/en/latest/
2 | [tox]
3 | envlist = py{310, 311}-{linux,macos,windows}
4 | isolated_build=true
5 |
6 | [gh-actions]
7 | python =
8 | 3.10: py310
9 |
10 | [gh-actions:env]
11 | PLATFORM =
12 | ubuntu-latest: linux
13 | macos-latest: macos
14 | windows-latest: windows
15 |
16 | [testenv]
17 | platform =
18 | macos: darwin
19 | linux: linux
20 | windows: win32
21 | passenv =
22 | CI
23 | GITHUB_ACTIONS
24 | DISPLAY
25 | XAUTHORITY
26 | NUMPY_EXPERIMENTAL_ARRAY_FUNCTION
27 | PYVISTA_OFF_SCREEN
28 | extras =
29 | testing
30 | commands = pytest -v --color=yes --cov=napari_segmentation_correction --cov-report=xml
31 |
--------------------------------------------------------------------------------