├── .coveragerc ├── .github └── workflows │ ├── build_docs.yml │ ├── preview_metadata.yml │ └── test_and_deploy.yml ├── .gitignore ├── .isort.cfg ├── .napari ├── DESCRIPTION.md └── config.yml ├── .pre-commit-config.yaml ├── CODE_OF_CONDUCT.md ├── LICENSE ├── MANIFEST.in ├── README.md ├── conda └── napari_CellSeg3D_ARM64.yml ├── docs ├── TODO.md ├── _config.yml ├── _templates │ ├── custom-class-template.rst │ └── custom-module-template.rst ├── _toc.yml ├── references.bib ├── source │ ├── code │ │ └── api.rst │ ├── guides │ │ ├── cropping_module_guide.rst │ │ ├── custom_model_template.rst │ │ ├── detailed_walkthrough.rst │ │ ├── inference_module_guide.rst │ │ ├── installation_guide.rst │ │ ├── metrics_module_guide.rst │ │ ├── review_module_guide.rst │ │ ├── training_module_guide.rst │ │ ├── training_wnet.rst │ │ └── utils_module_guide.rst │ ├── images │ │ ├── Review_Parameters.png │ │ ├── WNet_architecture.svg │ │ ├── converted_labels.png │ │ ├── cropping_process_example.png │ │ ├── inference_plugin_layout.png │ │ ├── inference_results_example.png │ │ ├── init_image_labels.png │ │ ├── plot_example_metrics.png │ │ ├── plots_train.png │ │ ├── plugin_crop.png │ │ ├── plugin_inference.png │ │ ├── plugin_menu.png │ │ ├── plugin_review.png │ │ ├── plugin_train.png │ │ ├── plugin_welcome.png │ │ ├── review_process_example.png │ │ ├── stat_plots.png │ │ ├── training_tab_1.png │ │ ├── training_tab_2.png │ │ ├── training_tab_3.png │ │ └── training_tab_4.png │ └── logo │ │ └── logo_alpha.png └── welcome.rst ├── examples ├── README.md ├── c5image.tif └── test_very_small.tif ├── napari_cellseg3d ├── __init__.py ├── _tests │ ├── __init__.py │ ├── conftest.py │ ├── fixtures.py │ ├── pytest.ini │ ├── res │ │ ├── test.png │ │ ├── test.tif │ │ ├── test_labels.tif │ │ └── wnet_test │ │ │ ├── lab │ │ │ └── test.tif │ │ │ ├── test.tif │ │ │ └── vol │ │ │ └── test.tif │ ├── test_base_plugin.py │ ├── test_dock_widget.py │ ├── test_helper.py │ ├── test_inference.py │ ├── test_interface.py │ ├── test_labels_correction.py │ ├── test_model_framework.py │ ├── test_models.py │ ├── test_plugin_inference.py │ ├── test_plugin_training.py │ ├── test_plugin_utils.py │ ├── test_plugins.py │ ├── test_review.py │ ├── test_training.py │ ├── test_utils.py │ ├── test_weight_download.py │ └── test_wnet_training.py ├── code_models │ ├── __init__.py │ ├── crf.py │ ├── instance_segmentation.py │ ├── model_framework.py │ ├── models │ │ ├── TEMPLATE_model.py │ │ ├── __init__.py │ │ ├── model_SegResNet.py │ │ ├── model_SwinUNetR.py │ │ ├── model_TRAILMAP.py │ │ ├── model_TRAILMAP_MS.py │ │ ├── model_VNet.py │ │ ├── model_WNet.py │ │ ├── model_test.py │ │ ├── pretrained │ │ │ ├── __init__.py │ │ │ └── pretrained_model_urls.json │ │ ├── unet │ │ │ ├── __init__.py │ │ │ ├── buildingblocks.py │ │ │ └── model.py │ │ └── wnet │ │ │ ├── __init__.py │ │ │ ├── model.py │ │ │ └── soft_Ncuts.py │ ├── worker_inference.py │ ├── worker_training.py │ └── workers_utils.py ├── code_plugins │ ├── __init__.py │ ├── plugin_base.py │ ├── plugin_convert.py │ ├── plugin_crf.py │ ├── plugin_crop.py │ ├── plugin_helper.py │ ├── plugin_metrics.py │ ├── plugin_model_inference.py │ ├── plugin_model_training.py │ ├── plugin_review.py │ ├── plugin_review_dock.py │ └── plugin_utilities.py ├── config.py ├── dev_scripts │ ├── __init__.py │ ├── artefact_labeling.py │ ├── classifier_test.ipynb │ ├── colab_training.py │ ├── correct_labels.py │ ├── crop_data.py │ ├── evaluate_labels.py │ ├── remote_inference.py │ ├── remote_training.py │ ├── sliding_window_voronoi.py │ ├── test_new_evaluation.ipynb │ ├── thread_test.py │ └── whole_brain_utils.py ├── interface.py ├── napari.yaml ├── plugins.py ├── res │ ├── __init__.py │ └── logo_alpha.png ├── setup.py └── utils.py ├── notebooks ├── Colab_WNet3D_training.ipynb ├── Colab_inference_demo.ipynb ├── assess_instance.ipynb ├── label_stats_csv_plot.ipynb ├── labels_plot.ipynb ├── plots_data.csv ├── plots_data2.csv └── view_wnet.ipynb ├── pyproject.toml ├── requirements.txt ├── setup.cfg └── tox.ini /.coveragerc: -------------------------------------------------------------------------------- 1 | [report] 2 | exclude_lines = 3 | if __name__ == .__main__.: 4 | 5 | [run] 6 | omit = 7 | napari_cellseg3d/setup.py, napari_cellseg3d/code_models/models/wnet/train_wnet.py, napari_cellseg3d/code_models/models/wnet/model.py,napari_cellseg3d/code_models/models/TEMPLATE_model.py, napari_cellseg3d/code_models/models/unet/*, napari_cellseg3d/dev_scripts/*, napari_cellseg3d/code_plugins/plugin_metrics.py 8 | -------------------------------------------------------------------------------- /.github/workflows/build_docs.yml: -------------------------------------------------------------------------------- 1 | name: deploy 2 | 3 | on: 4 | push: 5 | branches: # branch to trigger deployment 6 | - main 7 | 8 | # This job installs dependencies, build the book, and pushes it to `gh-pages` 9 | jobs: 10 | build-and-deploy-book: 11 | runs-on: ${{ matrix.os }} 12 | strategy: 13 | matrix: 14 | os: [ubuntu-latest] 15 | python-version: [3.9] 16 | steps: 17 | - uses: actions/checkout@v2 18 | 19 | # Install dependencies 20 | - name: Set up Python ${{ matrix.python-version }} 21 | uses: actions/setup-python@v1 22 | with: 23 | python-version: ${{ matrix.python-version }} 24 | - name: Install dependencies 25 | run: | 26 | pip install jupyter-book 27 | pip install -e . 28 | 29 | # Build the book 30 | - name: Build the book 31 | run: | 32 | jupyter-book build docs/ 33 | 34 | # Deploy the book's HTML to gh-pages branch 35 | - name: GitHub Pages action 36 | uses: peaceiris/actions-gh-pages@v3.6.1 37 | with: 38 | github_token: ${{ secrets.GITHUB_TOKEN }} 39 | publish_dir: docs/_build/html 40 | -------------------------------------------------------------------------------- /.github/workflows/preview_metadata.yml: -------------------------------------------------------------------------------- 1 | name: napari hub Preview Page # we use this name to find your preview page artifact, so don't change it! 2 | # For more info on this action, see https://github.com/chanzuckerberg/napari-hub-preview-action/blob/main/action.yml 3 | 4 | on: 5 | pull_request: 6 | branches: 7 | - 'test' # '**' for all 8 | 9 | jobs: 10 | preview-page: 11 | name: Preview Page Deploy 12 | runs-on: ubuntu-latest 13 | 14 | steps: 15 | - name: Checkout repo 16 | uses: actions/checkout@v2 17 | 18 | - name: napari hub Preview Page Builder 19 | uses: chanzuckerberg/napari-hub-preview-action@v0.1.6 20 | with: 21 | hub-ref: main 22 | -------------------------------------------------------------------------------- /.github/workflows/test_and_deploy.yml: -------------------------------------------------------------------------------- 1 | # This workflows will upload a Python Package using Twine when a release is created 2 | # For more information see: https://help.github.com/en/actions/language-and-framework-guides/using-python-with-github-actions#publishing-to-package-registries 3 | 4 | name: test and deploy 5 | 6 | on: 7 | push: 8 | branches: 9 | - main 10 | tags: 11 | - "v*" # Push events to matching v*, i.e. v1.0, v20.15.10 12 | pull_request: 13 | branches: 14 | - main 15 | workflow_dispatch: 16 | 17 | jobs: 18 | test: 19 | name: ${{ matrix.platform }} py${{ matrix.python-version }} 20 | runs-on: ${{ matrix.platform }} 21 | strategy: 22 | matrix: 23 | # platform: [ubuntu-latest, windows-latest] # , macos-latest 24 | platform: [ubuntu-latest] 25 | python-version: ['3.8', '3.9'] #issues with monai and 3.10; pausing for now. users should use python 3.9 26 | 27 | steps: 28 | - uses: actions/checkout@v3 29 | 30 | - name: Set up Python ${{ matrix.python-version }} 31 | uses: actions/setup-python@v4 32 | with: 33 | python-version: ${{ matrix.python-version }} 34 | 35 | # these libraries enable testing on Qt on linux 36 | - uses: tlambert03/setup-qt-libs@v1 37 | 38 | # strategy borrowed from vispy for installing opengl libs on windows 39 | - name: Install Windows OpenGL 40 | if: runner.os == 'Windows' 41 | run: | 42 | git clone --depth 1 https://github.com/pyvista/gl-ci-helpers.git 43 | powershell gl-ci-helpers/appveyor/install_opengl.ps1 44 | if (Test-Path -Path "C:\Windows\system32\opengl32.dll" -PathType Leaf) {Exit 0} else {Exit 1} 45 | 46 | # note: if you need dependencies from conda, considering using 47 | # setup-miniconda: https://github.com/conda-incubator/setup-miniconda 48 | # and 49 | # tox-conda: https://github.com/tox-dev/tox-conda 50 | - name: Install dependencies 51 | run: | 52 | python -m pip install --upgrade pip 53 | python -m pip install setuptools tox tox-gh-actions 54 | # pip install git+https://github.com/lucasb-eyer/pydensecrf.git@master#egg=pydensecrf 55 | 56 | # this runs the platform-specific tests declared in tox.ini 57 | - name: Test with tox 58 | uses: GabrielBB/xvfb-action@v1 # aganders3/headless-gui@v1 59 | with: 60 | run: python -m tox 61 | env: 62 | PLATFORM: ${{ matrix.platform }} 63 | 64 | - name: Coverage 65 | uses: codecov/codecov-action@v2 66 | 67 | deploy: 68 | # this will run when you have tagged a commit, starting with "v*" 69 | # and requires that you have put your twine API key in your 70 | # github secrets (see readme for details) 71 | needs: [test] 72 | runs-on: ubuntu-latest 73 | if: contains(github.ref, 'tags') 74 | steps: 75 | - uses: actions/checkout@v2 76 | - name: Set up Python 77 | uses: actions/setup-python@v2 78 | with: 79 | python-version: "3.x" 80 | - name: Install dependencies 81 | run: | 82 | python -m pip install --upgrade pip 83 | pip install -U setuptools setuptools_scm wheel twine build 84 | - name: Build and publish 85 | env: 86 | TWINE_USERNAME: __token__ 87 | TWINE_PASSWORD: ${{ secrets.TWINE_API_KEY }} 88 | run: | 89 | git tag 90 | python -m build . 91 | twine upload dist/* 92 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # unwanted results files 10 | *.tif 11 | *.tiff 12 | napari_cellseg3d/_tests/res/*.csv 13 | *.pth 14 | *.pt 15 | *.onnx 16 | *.tar.gz 17 | *.db 18 | 19 | # Distribution / packaging 20 | .Python 21 | env/ 22 | build/ 23 | develop-eggs/ 24 | dist/ 25 | downloads/ 26 | eggs/ 27 | .eggs/ 28 | lib/ 29 | lib64/ 30 | parts/ 31 | sdist/ 32 | var/ 33 | *.egg-info/ 34 | .installed.cfg 35 | *.egg 36 | 37 | # PyInstaller 38 | # Usually these files are written by a python script from a template 39 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 40 | *.manifest 41 | *.spec 42 | 43 | # Installer logs 44 | pip-log.txt 45 | pip-delete-this-directory.txt 46 | 47 | # Unit test / coverage reports 48 | htmlcov/ 49 | .tox/ 50 | .coverage 51 | .coverage.* 52 | .cache 53 | nosetests.xml 54 | coverage.xml 55 | *,cover 56 | .hypothesis/ 57 | .napari_cache 58 | 59 | # Translations 60 | *.mo 61 | *.pot 62 | 63 | # Django stuff: 64 | *.log 65 | local_settings.py 66 | 67 | # Flask instance folder 68 | instance/ 69 | 70 | # Sphinx documentation 71 | docs/_build/ 72 | 73 | # MkDocs documentation 74 | /site/ 75 | 76 | # PyBuilder 77 | target/ 78 | 79 | # Pycharm and VSCode 80 | .idea/ 81 | venv/ 82 | .vscode/ 83 | 84 | # IPython Notebook 85 | .ipynb_checkpoints 86 | 87 | # pyenv 88 | .python-version 89 | 90 | # OS 91 | .DS_Store 92 | 93 | # written by setuptools_scm 94 | **/_version.py 95 | 96 | # WANDB 97 | /wandb/ 98 | 99 | 100 | ######## 101 | #project specific 102 | #dataset, weights, old logos, requirements 103 | /napari_cellseg3d/code_models/models/dataset/ 104 | /napari_cellseg3d/code_models/models/saved_weights/ 105 | /docs/source/logo/old_logo/ 106 | /docs/source/code/_autosummary/ 107 | /reqs/ 108 | /loss_plots/ 109 | /wandb/ 110 | notebooks/csv_cell_plot.html 111 | notebooks/full_plot.html 112 | *.csv 113 | *.png 114 | notebooks/instance_test.ipynb 115 | *.prof 116 | /docs/_build/ 117 | /docs/source/code/_autosummary/*.rst 118 | cov.syspath.txt 119 | napari_cellseg3d/dev_scripts/wandb 120 | wandb/ 121 | 122 | #include test data 123 | !napari_cellseg3d/_tests/res/test.tif 124 | !napari_cellseg3d/_tests/res/test.png 125 | !napari_cellseg3d/_tests/res/test_labels.tif 126 | !napari_cellseg3d/_tests/res/wnet_test/*.tif 127 | !napari_cellseg3d/_tests/res/wnet_test/lab/*.tif 128 | !napari_cellseg3d/_tests/res/wnet_test/vol/*.tif 129 | !examples/* 130 | 131 | #include docs images 132 | !docs/source/logo/* 133 | !docs/source/images/* 134 | lcov.info 135 | coverage.lcov 136 | -------------------------------------------------------------------------------- /.isort.cfg: -------------------------------------------------------------------------------- 1 | [settings] 2 | force_single_line = True 3 | force_sort_within_sections = False 4 | lexicographical = True 5 | single_line_exclusions = ('typing',) 6 | order_by_type = False 7 | group_by_package = True 8 | skip=__init__.py 9 | -------------------------------------------------------------------------------- /.napari/DESCRIPTION.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | 13 | ## Description 14 | 15 | A napari plugin for 3D cell segmentation: training, inference, and data review. In particular, this project was developed for analysis of mesoSPIM-acquired (cleared tissue + lightsheet) datasets. 16 | 17 | A detailed walk-through and description is available [on the documentation website](https://adaptivemotorcontrollab.github.io/cellseg3d-docs/res/welcome.html). 18 | 19 | 49 | ## Intended Audience & Supported Data 50 | 51 | This plugin requires basic knowledge in machine learning; 52 | all the concepts required for the parameters of the plugin are still covered and explained for their contextual use in the plugin. 53 | 54 | Currently, this plugin requires 3D volumes as .tif files, for review and cropping 2D stacks as .tif or .png are supported as well. 55 | Feel free to open an issue on Github if you'd like to discuss implementation of a specific file type ! 56 | 57 | 71 | ## Quickstart 72 | 73 | Install from pip with `pip install napari-cellseg3d` 74 | 75 | OR 76 | 77 | - Install napari from pip with `pip install "napari[all]"`, 78 | then from the “Plugins” menu within the napari application, select “Install/Uninstall Package(s)...” 79 | - Copy `napari-cellseg3d` and paste it where it says “Install by name/url…” 80 | - Click “Install” 81 | 94 | ## Additional Install Steps 95 | 96 | **Python >= 3.8 required** 97 | 98 | Requires manual installation of **pytorch** and **MONAI**. 99 | 100 | For Pytorch, please see [PyTorch's website for installation instructions](https://pytorch.org/get-started/locally/). 101 | 102 | A **CUDA-capable GPU** is not needed but very strongly recommended, especially for training. 103 | Simply follow the instructions on Pytorch's install page. 104 | 105 | If you get errors from MONAI regarding missing readers, please see [MONAI's optional dependencies](https://docs.monai.io/en/stable/installation.html#installing-the-recommended-dependencies) page for instructions on getting the readers required by your images. 106 | 107 | 117 | ## Getting Help 118 | 119 | If you would like to report an issue with the plugin, 120 | please open an [issue on Github](https://github.com/AdaptiveMotorControlLab/CellSeg3D/issues) 121 | 127 | 128 | ## How to Cite 129 | 130 | 131 | 138 | -------------------------------------------------------------------------------- /.napari/config.yml: -------------------------------------------------------------------------------- 1 | visibility: public 2 | # note : remember to add napari-hub preview app to repo 3 | 4 | project_urls: 5 | Bug Tracker: https://github.com/AdaptiveMotorControlLab/CellSeg3D/issues 6 | Documentation: https://adaptivemotorcontrollab.github.io/CellSeg3D/res/welcome.html 7 | Source Code: https://github.com/AdaptiveMotorControlLab/CellSeg3D 8 | Project Site: https://github.com/AdaptiveMotorControlLab/CellSeg3D 9 | User Support: https://github.com/AdaptiveMotorControlLab/CellSeg3D/issues 10 | 11 | # Add labels from the EDAM Bioimaging ontology 12 | labels: 13 | ontology: EDAM-BIOIMAGING:alpha06 14 | terms: 15 | - 3D image 16 | - Cell segmentation 17 | - Light-sheet microscopy 18 | - Image visualisation 19 | - Neuron image analysis 20 | - Convolutional neural network 21 | - Watershed segmentation 22 | - Object counting 23 | - Image crop 24 | - Plotting 25 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | - repo: https://github.com/pre-commit/pre-commit-hooks 3 | rev: v4.6.0 4 | hooks: 5 | # - id: check-docstring-first 6 | - id: end-of-file-fixer 7 | - id: trailing-whitespace 8 | - id: check-yaml 9 | - id: check-toml 10 | # - id: check-added-large-files 11 | # args: [--maxkb=5000] 12 | # - repo: https://github.com/pycqa/isort 13 | # rev: 5.12.0 14 | # hooks: 15 | # - id: isort 16 | # args: ["--profile", "black", --line-length=79] 17 | - repo: https://github.com/charliermarsh/ruff-pre-commit 18 | # Ruff version. 19 | rev: 'v0.5.6' 20 | hooks: 21 | - id: ruff 22 | args: [ --fix, --exit-non-zero-on-fix ] 23 | - repo: https://github.com/psf/black 24 | rev: 24.8.0 25 | hooks: 26 | - id: black 27 | args: [--line-length=79] 28 | - repo: https://github.com/tlambert03/napari-plugin-checks 29 | rev: v0.3.0 30 | hooks: 31 | - id: napari-plugin-checks 32 | # https://mypy.readthedocs.io/en/stable/introduction.html 33 | # you may wish to add this as well! 34 | # - repo: https://github.com/pre-commit/mirrors-mypy 35 | # rev: v0.910-1 36 | # hooks: 37 | # - id: mypy 38 | -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Contributor Covenant Code of Conduct 2 | 3 | ## Our Pledge 4 | 5 | In the interest of fostering an open and welcoming environment, we as 6 | contributors and maintainers pledge to making participation in our project and 7 | our community a harassment-free experience for everyone, regardless of age, body 8 | size, disability, ethnicity, sex characteristics, gender identity and expression, 9 | level of experience, education, socio-economic status, nationality, personal 10 | appearance, race, religion, or sexual identity and orientation. 11 | 12 | ## Our Standards 13 | 14 | Examples of behavior that contributes to creating a positive environment 15 | include: 16 | 17 | * Using welcoming and inclusive language 18 | * Being respectful of differing viewpoints and experiences 19 | * Gracefully accepting constructive criticism 20 | * Focusing on what is best for the community 21 | * Showing empathy towards other community members 22 | 23 | Examples of unacceptable behavior by participants include: 24 | 25 | * The use of sexualized language or imagery and unwelcome sexual attention or 26 | advances 27 | * Trolling, insulting/derogatory comments, and personal or political attacks 28 | * Public or private harassment 29 | * Publishing others' private information, such as a physical or electronic 30 | address, without explicit permission 31 | * Other conduct which could reasonably be considered inappropriate in a 32 | professional setting 33 | 34 | ## Our Responsibilities 35 | 36 | Project maintainers are responsible for clarifying the standards of acceptable 37 | behavior and are expected to take appropriate and fair corrective action in 38 | response to any instances of unacceptable behavior. 39 | 40 | Project maintainers have the right and responsibility to remove, edit, or 41 | reject comments, commits, code, wiki edits, issues, and other contributions 42 | that are not aligned to this Code of Conduct, or to ban temporarily or 43 | permanently any contributor for other behaviors that they deem inappropriate, 44 | threatening, offensive, or harmful. 45 | 46 | ## Scope 47 | 48 | This Code of Conduct applies both within project spaces and in public spaces 49 | when an individual is representing the project or its community. Examples of 50 | representing a project or community include using an official project e-mail 51 | address, posting via an official social media account, or acting as an appointed 52 | representative at an online or offline event. Representation of a project may be 53 | further defined and clarified by project maintainers. 54 | 55 | ## Enforcement 56 | 57 | Instances of abusive, harassing, or otherwise unacceptable behavior may be 58 | reported by contacting the project team at cyril.achard@epfl.ch. All 59 | complaints will be reviewed and investigated and will result in a response that 60 | is deemed necessary and appropriate to the circumstances. The project team is 61 | obligated to maintain confidentiality with regard to the reporter of an incident. 62 | Further details of specific enforcement policies may be posted separately. 63 | 64 | Project maintainers who do not follow or enforce the Code of Conduct in good 65 | faith may face temporary or permanent repercussions as determined by other 66 | members of the project's leadership. 67 | 68 | ## Attribution 69 | 70 | This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4, 71 | available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html 72 | 73 | [homepage]: https://www.contributor-covenant.org 74 | 75 | For answers to common questions about this code of conduct, see 76 | https://www.contributor-covenant.org/faq 77 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | 2 | The MIT License (MIT) 3 | 4 | Copyright (c) 2022 Cyril Achard, Maxime Vidal, Mackenzie Mathis 5 | 6 | Permission is hereby granted, free of charge, to any person obtaining a copy 7 | of this software and associated documentation files (the "Software"), to deal 8 | in the Software without restriction, including without limitation the rights 9 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 10 | copies of the Software, and to permit persons to whom the Software is 11 | furnished to do so, subject to the following conditions: 12 | 13 | The above copyright notice and this permission notice shall be included in 14 | all copies or substantial portions of the Software. 15 | 16 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 17 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 18 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 19 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 20 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 21 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 22 | THE SOFTWARE. 23 | -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | include LICENSE 2 | include README.md 3 | include requirements.txt 4 | include napari.yaml 5 | recursive-include res *.png 6 | recursive-include code_models *.json 7 | 8 | recursive-exclude * __pycache__ 9 | recursive-exclude * *.py[co] 10 | -------------------------------------------------------------------------------- /conda/napari_CellSeg3D_ARM64.yml: -------------------------------------------------------------------------------- 1 | name: napari_CellSeg3D_ARM64 2 | channels: 3 | - anaconda 4 | - conda-forge 5 | - defaults 6 | dependencies: 7 | - python=3.9 8 | - pip 9 | - pyqt 10 | - imagecodecs 11 | - pip: 12 | - numpy 13 | - napari>=0.4.14 14 | - scikit-image>=0.19.2 15 | - matplotlib>=3.4.1 16 | - tifffile>=2022.2.9 17 | - torch>=1.11 18 | - monai[nibabel, einops]>=0.9.0 19 | - tqdm 20 | - scikit-image 21 | - pyclesperanto-prototype 22 | - tqdm 23 | - matplotlib 24 | - napari_cellseg3d 25 | -------------------------------------------------------------------------------- /docs/TODO.md: -------------------------------------------------------------------------------- 1 | [//]: # ( 2 | TODO: 3 | - [ ] Add a way to get the current version of the library 4 | - [x] Update all modules 5 | - [x] Better WNet3D tutorial 6 | - [x] Setup GH Actions 7 | - [ ] Add a bibliography 8 | ) 9 | -------------------------------------------------------------------------------- /docs/_config.yml: -------------------------------------------------------------------------------- 1 | # Book settings 2 | # Learn more at https://jupyterbook.org/customize/config.html 3 | 4 | title: napari-cellseg3d Documentation 5 | author: Cyril Achard, Maxime Vidal, Timokleia Kousi, Mackenzie Mathis | Mathis Laboratory 6 | logo: source/logo/logo_alpha.png 7 | 8 | # Force re-execution of notebooks on each build. 9 | # See https://jupyterbook.org/content/execute.html 10 | execute: 11 | execute_notebooks: 'off' 12 | 13 | # Define the name of the latex output file for PDF builds 14 | latex: 15 | latex_documents: 16 | targetname: book.tex 17 | 18 | # Add a bibtex file so that we can create citations 19 | bibtex_bibfiles: 20 | - references.bib 21 | 22 | # Information about where the book exists on the web 23 | repository: 24 | url: https://github.com/AdaptiveMotorControlLab/CellSeg3D # Online location of your book 25 | path_to_book: docs # Optional path to your book, relative to the repository root 26 | branch: main # Which branch of the repository should be used when creating links (optional) 27 | 28 | # Add GitHub buttons to your book 29 | # See https://jupyterbook.org/customize/config.html#add-a-link-to-your-repository 30 | html: 31 | use_issues_button: true 32 | use_repository_button: true 33 | 34 | # Add auto-generated API docs 35 | sphinx: 36 | extra_extensions: 37 | - 'sphinx.ext.napoleon' 38 | - 'sphinx.ext.autodoc' 39 | - 'sphinx.ext.autosummary' 40 | - 'sphinx.ext.viewcode' 41 | - 'sphinx.ext.autosectionlabel' 42 | config: 43 | add_module_names: False 44 | autosectionlabel_prefix_document: True 45 | autosummary_generate: True 46 | autoclass_content: "both" 47 | # templates_path: ['_templates'] 48 | exclude_patterns: 49 | - '_build' 50 | - '_templates' 51 | # - 'napari_cellseg3d/__pycache__' 52 | -------------------------------------------------------------------------------- /docs/_templates/custom-class-template.rst: -------------------------------------------------------------------------------- 1 | {{ fullname | escape | underline}} 2 | 3 | .. currentmodule:: {{ module }} 4 | 5 | .. autoclass:: {{ objname }} 6 | :members: 7 | :show-inheritance: 8 | 9 | {% block methods %} 10 | {% if methods %} 11 | .. rubric:: {{ _('Methods') }} 12 | 13 | .. autosummary:: 14 | :nosignatures: 15 | {% for item in methods %} 16 | {%- if not item.startswith('_') %} 17 | ~{{ name }}.{{ item }} 18 | {%- endif -%} 19 | {%- endfor %} 20 | {% endif %} 21 | {% endblock %} 22 | 23 | {% block attributes %} 24 | {% if attributes %} 25 | .. rubric:: {{ _('Attributes') }} 26 | 27 | .. autosummary:: 28 | {% for item in attributes %} 29 | ~{{ name }}.{{ item }} 30 | {%- endfor %} 31 | {% endif %} 32 | {% endblock %} 33 | -------------------------------------------------------------------------------- /docs/_templates/custom-module-template.rst: -------------------------------------------------------------------------------- 1 | {{ fullname | escape | underline}} 2 | 3 | .. automodule:: {{ fullname }} 4 | 5 | {% block attributes %} 6 | {% if attributes %} 7 | .. rubric:: Module attributes 8 | 9 | .. autosummary:: 10 | :toctree: 11 | {% for item in attributes %} 12 | {{ item }} 13 | {%- endfor %} 14 | {% endif %} 15 | {% endblock %} 16 | 17 | {% block functions %} 18 | {% if functions %} 19 | .. rubric:: {{ _('Functions') }} 20 | 21 | .. autosummary:: 22 | :toctree: 23 | :nosignatures: 24 | {% for item in functions %} 25 | {{ item }} 26 | {%- endfor %} 27 | {% endif %} 28 | {% endblock %} 29 | 30 | {% block classes %} 31 | {% if classes %} 32 | .. rubric:: {{ _('Classes') }} 33 | 34 | .. autosummary:: 35 | :toctree: 36 | :template: custom-class-template.rst 37 | :nosignatures: 38 | {% for item in classes %} 39 | {{ item }} 40 | {%- endfor %} 41 | {% endif %} 42 | {% endblock %} 43 | 44 | {% block exceptions %} 45 | {% if exceptions %} 46 | .. rubric:: {{ _('Exceptions') }} 47 | 48 | .. autosummary:: 49 | :toctree: 50 | {% for item in exceptions %} 51 | {{ item }} 52 | {%- endfor %} 53 | {% endif %} 54 | {% endblock %} 55 | 56 | {% block modules %} 57 | {% if modules %} 58 | .. autosummary:: 59 | :toctree: 60 | :template: custom-module-template.rst 61 | :recursive: 62 | {% for item in modules %} 63 | {{ item }} 64 | {%- endfor %} 65 | {% endif %} 66 | {% endblock %} 67 | -------------------------------------------------------------------------------- /docs/_toc.yml: -------------------------------------------------------------------------------- 1 | # Table of contents 2 | 3 | format: jb-book 4 | root: welcome.rst 5 | parts: 6 | - caption : User guides 7 | chapters: 8 | - file: source/guides/installation_guide.rst 9 | - file: source/guides/review_module_guide.rst 10 | - file: source/guides/training_module_guide.rst 11 | - file: source/guides/inference_module_guide.rst 12 | - file: source/guides/cropping_module_guide.rst 13 | - file: source/guides/utils_module_guide.rst 14 | - caption : Walkthroughs 15 | chapters: 16 | - file: source/guides/detailed_walkthrough.rst 17 | - file: source/guides/training_wnet.rst 18 | - caption : Advanced guides 19 | chapters: 20 | - file: source/guides/custom_model_template.rst 21 | - caption : Code 22 | chapters: 23 | - file: source/code/api.rst 24 | -------------------------------------------------------------------------------- /docs/references.bib: -------------------------------------------------------------------------------- 1 | --- 2 | --- 3 | 4 | @inproceedings{holdgraf_evidence_2014, 5 | address = {Brisbane, Australia, Australia}, 6 | title = {Evidence for {Predictive} {Coding} in {Human} {Auditory} {Cortex}}, 7 | booktitle = {International {Conference} on {Cognitive} {Neuroscience}}, 8 | publisher = {Frontiers in Neuroscience}, 9 | author = {Holdgraf, Christopher Ramsay and de Heer, Wendy and Pasley, Brian N. and Knight, Robert T.}, 10 | year = {2014} 11 | } 12 | 13 | @article{holdgraf_rapid_2016, 14 | title = {Rapid tuning shifts in human auditory cortex enhance speech intelligibility}, 15 | volume = {7}, 16 | issn = {2041-1723}, 17 | url = {http://www.nature.com/doifinder/10.1038/ncomms13654}, 18 | doi = {10.1038/ncomms13654}, 19 | number = {May}, 20 | journal = {Nature Communications}, 21 | author = {Holdgraf, Christopher Ramsay and de Heer, Wendy and Pasley, Brian N. and Rieger, Jochem W. and Crone, Nathan and Lin, Jack J. and Knight, Robert T. and Theunissen, Frédéric E.}, 22 | year = {2016}, 23 | pages = {13654}, 24 | file = {Holdgraf et al. - 2016 - Rapid tuning shifts in human auditory cortex enhance speech intelligibility.pdf:C\:\\Users\\chold\\Zotero\\storage\\MDQP3JWE\\Holdgraf et al. - 2016 - Rapid tuning shifts in human auditory cortex enhance speech intelligibility.pdf:application/pdf} 25 | } 26 | 27 | @inproceedings{holdgraf_portable_2017, 28 | title = {Portable learning environments for hands-on computational instruction using container-and cloud-based technology to teach data science}, 29 | volume = {Part F1287}, 30 | isbn = {978-1-4503-5272-7}, 31 | doi = {10.1145/3093338.3093370}, 32 | abstract = {© 2017 ACM. There is an increasing interest in learning outside of the traditional classroom setting. This is especially true for topics covering computational tools and data science, as both are challenging to incorporate in the standard curriculum. These atypical learning environments offer new opportunities for teaching, particularly when it comes to combining conceptual knowledge with hands-on experience/expertise with methods and skills. Advances in cloud computing and containerized environments provide an attractive opportunity to improve the effciency and ease with which students can learn. This manuscript details recent advances towards using commonly-Available cloud computing services and advanced cyberinfrastructure support for improving the learning experience in bootcamp-style events. We cover the benets (and challenges) of using a server hosted remotely instead of relying on student laptops, discuss the technology that was used in order to make this possible, and give suggestions for how others could implement and improve upon this model for pedagogy and reproducibility.}, 33 | booktitle = {{ACM} {International} {Conference} {Proceeding} {Series}}, 34 | author = {Holdgraf, Christopher Ramsay and Culich, A. and Rokem, A. and Deniz, F. and Alegro, M. and Ushizima, D.}, 35 | year = {2017}, 36 | keywords = {Teaching, Bootcamps, Cloud computing, Data science, Docker, Pedagogy} 37 | } 38 | 39 | @article{holdgraf_encoding_2017, 40 | title = {Encoding and decoding models in cognitive electrophysiology}, 41 | volume = {11}, 42 | issn = {16625137}, 43 | doi = {10.3389/fnsys.2017.00061}, 44 | abstract = {© 2017 Holdgraf, Rieger, Micheli, Martin, Knight and Theunissen. Cognitive neuroscience has seen rapid growth in the size and complexity of data recorded from the human brain as well as in the computational tools available to analyze this data. This data explosion has resulted in an increased use of multivariate, model-based methods for asking neuroscience questions, allowing scientists to investigate multiple hypotheses with a single dataset, to use complex, time-varying stimuli, and to study the human brain under more naturalistic conditions. These tools come in the form of “Encoding” models, in which stimulus features are used to model brain activity, and “Decoding” models, in which neural features are used to generated a stimulus output. Here we review the current state of encoding and decoding models in cognitive electrophysiology and provide a practical guide toward conducting experiments and analyses in this emerging field. Our examples focus on using linear models in the study of human language and audition. We show how to calculate auditory receptive fields from natural sounds as well as how to decode neural recordings to predict speech. The paper aims to be a useful tutorial to these approaches, and a practical introduction to using machine learning and applied statistics to build models of neural activity. The data analytic approaches we discuss may also be applied to other sensory modalities, motor systems, and cognitive systems, and we cover some examples in these areas. In addition, a collection of Jupyter notebooks is publicly available as a complement to the material covered in this paper, providing code examples and tutorials for predictive modeling in python. The aimis to provide a practical understanding of predictivemodeling of human brain data and to propose best-practices in conducting these analyses.}, 45 | journal = {Frontiers in Systems Neuroscience}, 46 | author = {Holdgraf, Christopher Ramsay and Rieger, J.W. and Micheli, C. and Martin, S. and Knight, R.T. and Theunissen, F.E.}, 47 | year = {2017}, 48 | keywords = {Decoding models, Encoding models, Electrocorticography (ECoG), Electrophysiology/evoked potentials, Machine learning applied to neuroscience, Natural stimuli, Predictive modeling, Tutorials} 49 | } 50 | 51 | @book{ruby, 52 | title = {The Ruby Programming Language}, 53 | author = {Flanagan, David and Matsumoto, Yukihiro}, 54 | year = {2008}, 55 | publisher = {O'Reilly Media} 56 | } 57 | -------------------------------------------------------------------------------- /docs/source/code/api.rst: -------------------------------------------------------------------------------- 1 | API reference 2 | ================= 3 | 4 | .. autosummary:: 5 | :toctree: _autosummary 6 | :recursive: 7 | 8 | napari_cellseg3d.interface 9 | napari_cellseg3d.code_models 10 | napari_cellseg3d.code_plugins 11 | napari_cellseg3d.utils 12 | -------------------------------------------------------------------------------- /docs/source/guides/cropping_module_guide.rst: -------------------------------------------------------------------------------- 1 | .. _cropping_module_guide: 2 | 3 | Cropping✂️ 4 | ========== 5 | 6 | .. figure:: ../images/plugin_crop.png 7 | :align: center 8 | 9 | Layout of the cropping module 10 | 11 | **Cropping** allows you to crop your volumes and labels dynamically, 12 | by selecting a fixed size volume and moving it around the image. 13 | 14 | To access it: 15 | - Navigate to **`Plugins -> Utilities`**. 16 | - Choose **`Crop`** from the bottom menu. 17 | 18 | Once cropped, you have multiple options to save the volumes and labels: 19 | - Use the **`Quicksave`** button in Napari. 20 | - Select the layer and then go to **`File` -> `Save selected layers`**. 21 | - With the correct layer highlighted, simply press **`CTRL + S`**. 22 | 23 | .. Note:: 24 | For more on utility tools, see :doc:`utils_module_guide`. 25 | 26 | Launching the cropping process 27 | ------------------------------ 28 | 1. From the layer selection dropdown menu, select your image. If you want to crop a second image with the same dimensions simultaneously, 29 | check the **`Crop another image simultaneously`** option and then select the relevant layer. 30 | 31 | 2. Define your desired cropped volume size. This size will remain fixed for the duration of the session. 32 | To update the size, you will need to restart the process. 33 | 34 | 3. You can also correct the anisotropy, if you work with anisotropic data: simply set your microscope's resolution in microns. 35 | 36 | .. important:: 37 | This will scale the image in the viewer, but saved images will **still be anisotropic.** To resize your image, see :doc:`utils_module_guide`. 38 | 39 | 4. Press **`Start`** to start the cropping process. 40 | If you'd like to modify the volume size, change the parameters as described and hit **`Start`** again. 41 | 42 | Creating new layers 43 | ------------------- 44 | To "zoom in" on a specific portion of your volume: 45 | 46 | - Use the `Create new layers` checkbox next time you hit `Start`. This option lets you make an additional cropping layer instead of replacing the current one. 47 | 48 | - This way, you can first select your region of interest by using the tool as described above, then enable the option, select the cropped region produced before as the input layer, and define a smaller crop size in order to further crop within your region of interest. 49 | 50 | Interface & functionalities 51 | --------------------------- 52 | 53 | .. figure:: ../images/cropping_process_example.png 54 | :align: center 55 | 56 | Example of the cropping process interface. 57 | 58 | Once you have launched the review process, you will gain control over three sliders, which will let 59 | you to **adjust the position** of the cropped volumes and labels in the x,y and z positions. 60 | 61 | .. note:: 62 | * If your **cropped volume isn't visible**, consider changing the **colormap** of the image and the cropped 63 | volume to improve their visibility. 64 | * You may want to adjust the **opacity** and **contrast thresholds** depending on your image. 65 | * If the image appears empty: 66 | - Right-click on the contrast limits sliders. 67 | - Select **`Full Range`** and then **`Reset`**. 68 | 69 | Saving your cropped volume 70 | -------------------------- 71 | - When you are done, you can save the cropped volume and labels directly with the **`Quicksave`** button located at the bottom left. Your work will be saved in the same folder as the image you choose. 72 | 73 | - If you want more options (name, format) when saving: 74 | - Select the desired layer. 75 | - Navigate in the napari menu to **`File -> Save selected layer`**. 76 | - Press **`CTRL+S`** once you have selected the correct layer. 77 | 78 | 79 | Source code 80 | ------------------------------------------------- 81 | 82 | * :doc:`../code/_autosummary/napari_cellseg3d.code_plugins.plugin_crop` 83 | * :doc:`../code/_autosummary/napari_cellseg3d.code_plugins.plugin_base` 84 | -------------------------------------------------------------------------------- /docs/source/guides/custom_model_template.rst: -------------------------------------------------------------------------------- 1 | .. _custom_model_guide: 2 | 3 | Advanced : Custom models 4 | ============================================= 5 | 6 | .. warning:: 7 | **WIP** : Adding new models is still a work in progress and will likely not work out of the box, leading to errors. 8 | 9 | Please `file an issue`_ if you would like to add a custom model and we will help you get it working. 10 | 11 | To add a custom model, you will need a **.py** file with the following structure to be placed in the *napari_cellseg3d/models* folder:: 12 | 13 | class ModelTemplate_(ABC): # replace ABC with your PyTorch model class name 14 | weights_file = ( 15 | "model_template.pth" # specify the file name of the weights file only 16 | ) # download URL goes in pretrained_models.json 17 | 18 | @abstractmethod 19 | def __init__( 20 | self, input_image_size, in_channels=1, out_channels=1, **kwargs 21 | ): 22 | """Reimplement this as needed; only include input_image_size if necessary. For now only in/out channels = 1 is supported.""" 23 | pass 24 | 25 | @abstractmethod 26 | def forward(self, x): 27 | """Reimplement this as needed. Ensure that output is a torch tensor with dims (batch, channels, z, y, x).""" 28 | pass 29 | 30 | 31 | .. note:: 32 | **WIP** : Currently you must modify :doc:`model_framework.py <../code/_autosummary/napari_cellseg3d.code_models.model_framework>` as well : import your model class and add it to the ``model_dict`` attribute 33 | 34 | .. _file an issue: https://github.com/AdaptiveMotorControlLab/CellSeg3D/issues 35 | -------------------------------------------------------------------------------- /docs/source/guides/installation_guide.rst: -------------------------------------------------------------------------------- 1 | Installation guide ⚙ 2 | ====================== 3 | This guide outlines the steps for installing CellSeg3D and its dependencies. The plugin is compatible with Windows, Linux, and MacOS. 4 | 5 | **Note for ARM64 Mac Users:** 6 | Please refer to the :ref:`section below ` for specific instructions. 7 | 8 | .. warning:: 9 | If you encounter any issues during installation, feel free to open an issue on our `GitHub repository`_. 10 | 11 | .. _GitHub repository: https://github.com/AdaptiveMotorControlLab/CellSeg3D/issues 12 | 13 | 14 | Installing pre-requisites 15 | --------------------------- 16 | 17 | PyQt5 or PySide2 18 | _____________________ 19 | 20 | CellSeg3D requires either **PyQt5** or **PySide2** as a Qt backend for napari. If you don't have a Qt backend installed: 21 | 22 | .. code-block:: 23 | 24 | pip install napari[all] 25 | 26 | This command installs PyQt5 by default. 27 | 28 | PyTorch 29 | _____________________ 30 | 31 | For PyTorch installation, refer to `PyTorch's website`_ , with or without CUDA according to your hardware. 32 | Select the installation criteria that match your OS and hardware (GPU or CPU). 33 | 34 | .. note:: 35 | While a **CUDA-capable GPU** is not mandatory, it is highly recommended for both training and inference. 36 | 37 | 38 | * Running into MONAI-related errors? Consult MONAI’s optional dependencies for solutions. Please see `MONAI's optional dependencies`_ page for instructions on getting the readers required by your images. 39 | 40 | .. _MONAI's optional dependencies: https://docs.monai.io/en/stable/installation.html#installing-the-recommended-dependencies 41 | .. _PyTorch's website: https://pytorch.org/get-started/locally/ 42 | 43 | 44 | 45 | Installing CellSeg3D 46 | -------------------------------------------- 47 | 48 | .. warning:: 49 | For ARM64 Mac users, please see the :ref:`section below ` 50 | 51 | **Via pip**: 52 | 53 | .. code-block:: 54 | 55 | pip install napari-cellseg3d 56 | 57 | **Directly in napari**: 58 | 59 | - Navigate to **Plugins > Install/Uninstall Packages** 60 | - Search for ``napari-cellseg3d`` 61 | 62 | **For local installation** (after cloning from GitHub) 63 | Navigate to the cloned CellSeg3D folder and run: 64 | 65 | .. code-block:: 66 | 67 | pip install -e . 68 | 69 | Successful installation will add the napari-cellseg3D plugin to napari’s Plugins section. 70 | 71 | 72 | ARM64 Mac installation 73 | -------------------------------------------- 74 | .. _ARM64_Mac_installation: 75 | 76 | For ARM64 Macs, we recommend using our custom CONDA environment. This is particularly important for ARM64 (Silicon chips) MacBooks. 77 | 78 | Start by installing `miniconda3`_. 79 | 80 | Creating the environment 81 | ______________________________ 82 | 83 | .. _miniconda3: https://docs.conda.io/projects/conda/en/latest/user-guide/install/macos.html 84 | 85 | 1. **Clone the repository** (`link `_): 86 | 87 | .. code-block:: 88 | 89 | git clone https://github.com/AdaptiveMotorControlLab/CellSeg3D.git 90 | 91 | 2. **Create the Conda Environment** : 92 | In the terminal, navigate to the CellSeg3D folder: 93 | 94 | .. code-block:: 95 | 96 | cd CellSeg3D 97 | conda env create -f conda/napari_cellseg3d_ARM64.yml 98 | 99 | This will also install the necessary dependencies as well as the plugin. 100 | 101 | 3. **Activate the environment** : 102 | 103 | .. code-block:: 104 | 105 | conda activate napari_cellseg3d_ARM64 106 | 107 | 4. **Install a Qt backend** : 108 | Important : you only need to install one of the following backends. 109 | PyQt5: 110 | 111 | .. code-block:: 112 | 113 | pip install PyQt5 114 | 115 | OR 116 | PySide2: 117 | 118 | .. code-block:: 119 | 120 | pip install PySide2 121 | 122 | 5. **Install PyTorch** : 123 | Refer to `PyTorch's website`_ for installation instructions. 124 | 125 | 6. **Launch napari** : 126 | You should now see the CellSeg3D plugin in the Plugins section of napari. 127 | See `Usage section `_ for a guide on how to use the plugin. 128 | 129 | Updating the environment 130 | ______________________________ 131 | 132 | In order to update the environment, navigate to the CellSeg3D folder and run: 133 | 134 | .. code-block:: 135 | 136 | conda deactivate 137 | conda env update -f conda/napari_cellseg3d_ARM64.yml 138 | 139 | 140 | Troubleshoting 141 | ------------------------------ 142 | 143 | pyClesperanto 144 | _____________________ 145 | 146 | If you encounter the following error : *clGetPlatformIDs failed: PLATFORM_NOT_FOUND_KHR* : 147 | 148 | Please install `clinfo `_ and check if your OpenCL platform is available. 149 | 150 | If not, please install the OpenCL driver for your hardware. 151 | 152 | `[Source] `_ 153 | 154 | 155 | --- 156 | 157 | 158 | **Please help us make this section better by reporting any issues you encounter during installation.** 159 | 160 | Optional requirements 161 | ------------------------------ 162 | 163 | Additional functionalities 164 | ______________________________ 165 | 166 | Several additional functionalities are available optionally. To install them, use the following commands: 167 | 168 | - CRF post-processing: 169 | 170 | .. code-block:: 171 | 172 | pip install pydensecrf@git+https://github.com/lucasb-eyer/pydensecrf.git#egg=master 173 | 174 | - Weights & Biases integration: 175 | 176 | .. code-block:: 177 | 178 | pip install napari-cellseg3D[wandb] 179 | 180 | 181 | - ONNX model support (EXPERIMENTAL): 182 | Depending on your hardware, you can install the CPU or GPU version of ONNX. 183 | 184 | .. code-block:: 185 | 186 | pip install napari-cellseg3D[onnx-cpu] 187 | pip install napari-cellseg3D[onnx-gpu] 188 | 189 | Development requirements 190 | ______________________________ 191 | 192 | - Building the documentation: 193 | 194 | .. code-block:: 195 | 196 | pip install napari-cellseg3D[docs] 197 | 198 | - Running tests locally: 199 | 200 | .. code-block:: 201 | 202 | pip install pydensecrf@git+https://github.com/lucasb-eyer/pydensecrf.git#egg=master 203 | pip install napari-cellseg3D[test] 204 | 205 | - Dev utilities: 206 | 207 | .. code-block:: 208 | 209 | pip install napari-cellseg3D[dev] 210 | -------------------------------------------------------------------------------- /docs/source/guides/metrics_module_guide.rst: -------------------------------------------------------------------------------- 1 | .. _metrics_module_guide: 2 | 3 | Metrics utility guide 📈 4 | ======================== 5 | 6 | .. figure:: ../images/plot_example_metrics.png 7 | :scale: 35 % 8 | :align: right 9 | 10 | Dice metric plot result 11 | 12 | This tool computes the Dice coefficient, a similarity measure, between two sets of label folders. 13 | Ranges from 0 (no similarity) to 1 (perfect similarity). 14 | 15 | The Dice coefficient is defined as : 16 | 17 | .. math:: \frac {2|X \cap Y|} {|X|+|Y|} 18 | 19 | Required parameters: 20 | -------------------- 21 | 22 | * Ground Truth Labels folder 23 | * Prediction Labels folder 24 | * Threshold for sufficient score. Pairs below this score are highlighted in the viewer and marked in red on the plot. 25 | * Whether to automatically determine the best orientation for the computation by rotating and flipping; 26 | useful if your images have varied orientation. 27 | 28 | .. note:: 29 | - The tool might rotate and flip images randomly to find the best Dice coefficient. If you have small images with a large number of labels, this might lead to metric inaccuracies. Low score images might be in the wrong orientation when displayed for comparison. 30 | - This tool assumes that **predictions are padded to a power of two.** Ground truth labels can be smaller, as they will be padded to match the prediction size. 31 | - Your files should have names that can be sorted numerically; please ensure that each ground truth label has a matching prediction label. 32 | 33 | To begin, press the **`Compute Dice`** button. This will plot the Dice score for each ground truth-prediction labels pair. 34 | Pairs below the threshold will be displayed on the viewer for verification, ground truth appears in **blue**, and low score predictions in **red**. 35 | 36 | Source code 37 | ------------------------------------------------- 38 | 39 | * :doc:`../code/plugin_base` 40 | * :doc:`../code/plugin_metrics` 41 | -------------------------------------------------------------------------------- /docs/source/guides/review_module_guide.rst: -------------------------------------------------------------------------------- 1 | .. _review_module_guide: 2 | 3 | Labeling🔍 4 | ================================= 5 | 6 | .. figure:: ../images/plugin_review.png 7 | :align: center 8 | 9 | Layout of the review module 10 | 11 | **Labeling** allows you to inspect your labels, which may be manually created or predicted by a model, and make necessary corrections. 12 | The system will save the updated status of each file in a csv file. 13 | Additionally, the time taken per slice review is logged, enabling efficient monitoring. 14 | 15 | See `Usage section `_ for instructions on launching the plugin. 16 | 17 | Launching the review process 18 | --------------------------------- 19 | .. figure:: ../images/Review_Parameters.png 20 | :align: right 21 | :width: 300px 22 | 23 | 24 | 1. **Data paths:** 25 | - *Starting a new review:* Choose the **`New review`** option, and select the corresponding layers within Napari. 26 | - *Continuing an existing review:* Select the **`Existing review`** option, and choose the folder that contains the image, labels, and CSV file. 27 | 28 | .. note:: 29 | Cellseg3D supports 3D **`.tif`** files at the moment. 30 | If you have a stack, open it as a folder in Napari, then save it as a single **`.tif`** file. 31 | 32 | 2. **Managing anisotropic data:** 33 | Check this option to scale your images to visually remove the anisotropy, so as to make review easier. 34 | 35 | .. note:: 36 | The results will be saved as anisotropic images. If you want to resize them, check the :doc:`utils_module_guide` 37 | 38 | 3. **CSV file naming:** 39 | - Select a name for your review, which will be used for the CSV file that logs the status of each slice. 40 | - If an identical CSV file already exists, it will be used. If not, a new one will be generated. 41 | - If you choose to create a new dataset, a new CSV will always be created. If multiple copies already exist, a sequential number will be appended to the new file's name. 42 | 43 | 4. **Beginning the labeling:** 44 | Press **`Start reviewing`** once you are ready to start the review process. 45 | 46 | .. warning:: 47 | Starting a review session opens a new window and closes the current one. 48 | Make sure you have saved your work before starting the review session. 49 | 50 | Interface & functionalities 51 | --------------------------- 52 | 53 | .. figure:: ../images/review_process_example.png 54 | :align: center 55 | 56 | Interface for the labeling process. 57 | 58 | Once you have launched the labeling process, you will have access to the following functionalities: 59 | 60 | .. hlist:: 61 | :columns: 1 62 | 63 | * A dialog to choose where to save the verified and/or corrected annotations, and a button to save the labels. They will be using the provided file format. 64 | * A button to update the status of the slice in the csv file (in this case : checked/not checked) 65 | * A graph with projections in the x-y, y-z and x-z planes, to allow the reviewer to better understand the context of the volume and decide whether the image should be labeled or not. Use **shift-click** anywhere on the image or label layer to update the plot to the location being reviewed. 66 | 67 | To recap, you can check your labels, correct them, save them and keep track of which slices have been checked or not. 68 | 69 | .. note:: 70 | You can find the csv file containing the annotation status **in the same folder as the labels**. 71 | It will also keep track of the time taken to review each slice, which can be useful to monitor the progress of the review. 72 | 73 | Source code 74 | ------------------------------------------------- 75 | 76 | * :doc:`../code/_autosummary/napari_cellseg3d.code_plugins.plugin_review` 77 | * :doc:`../code/_autosummary/napari_cellseg3d.code_plugins.plugin_review_dock` 78 | * :doc:`../code/_autosummary/napari_cellseg3d.code_plugins.plugin_base` 79 | -------------------------------------------------------------------------------- /docs/source/guides/training_wnet.rst: -------------------------------------------------------------------------------- 1 | .. _training_wnet: 2 | 3 | Walkthrough - WNet3D training 4 | =============================== 5 | 6 | This plugin provides a reimplemented, custom version of the WNet3D model from `WNet, A Deep Model for Fully Unsupervised Image Segmentation`_. 7 | 8 | For training your model, you can choose among: 9 | 10 | * Directly within the plugin 11 | * The provided Jupyter notebook (locally) 12 | * Our Colab notebook (inspired by https://github.com/HenriquesLab/ZeroCostDL4Mic) 13 | 14 | Selecting training data 15 | ------------------------- 16 | 17 | The WNet3D **does not require a large amount of data to train**, but **choosing the right data** to train this unsupervised model **is crucial**. 18 | 19 | You may find below some guidelines, based on our own data and testing. 20 | 21 | The WNet3D is a self-supervised learning approach for 3D cell segmentation, and relies on the assumption that structural and morphological features of cells can be inferred directly from unlabeled data. This involves leveraging inherent properties such as spatial coherence and local contrast in imaging volumes to distinguish cellular structures. This approach assumes that meaningful representations of cellular boundaries and nuclei can emerge solely from raw 3D volumes. Thus, we strongly recommend that you use WNet3D on stacks that have clear foreground/background segregation and limited noise. Even if your final samples have noise, it is best to train on data that is as clean as you can. 22 | 23 | 24 | .. important:: 25 | For optimal performance, the following should be avoided for training: 26 | 27 | - Images with over-exposed pixels/artifacts you do not want to be learned! 28 | - Almost-empty and/or fully empty images, especially if noise is present (it will learn to segment very small objects!). 29 | 30 | However, the model may accomodate: 31 | 32 | - Uneven brightness distribution in your image! 33 | - Varied object shapes and radius! 34 | - Noisy images (as long as resolution is sufficient and boundaries are clear)! 35 | - Uneven illumination across the image! 36 | 37 | For optimal results, during inference, images should be similar to those the model was trained on; however this is not a strict requirement. 38 | 39 | You may also retrain from our pretrained model to your image dataset to help quickly reach good performance if, simply check "Use pre-trained weights" in the training module, and lower the learning rate. 40 | 41 | .. note:: 42 | - The WNet3D relies on brightness to distinguish objects from the background. For better results, use image regions with minimal artifacts. If you notice many artifacts, consider trying one of our supervised models (for lightsheet microscopy). 43 | - The model has two losses, the **`SoftNCut loss`**, which clusters pixels according to brightness, and a reconstruction loss, either **`Mean Square Error (MSE)`** or **`Binary Cross Entropy (BCE)`**. 44 | - For good performance, wait for the SoftNCut to reach a plateau; the reconstruction loss should also be decreasing overall, but this is generally less critical for segmentation performance. 45 | 46 | Parameters 47 | ------------- 48 | 49 | .. figure:: ../images/training_tab_4.png 50 | :scale: 100 % 51 | :align: right 52 | 53 | Advanced tab 54 | 55 | _`When using the WNet3D training module`, the **Advanced** tab contains a set of additional options: 56 | 57 | - **Number of classes** : Dictates the segmentation classes (default is 2). Increasing the number of classes will result in a more progressive segmentation according to brightness; can be useful if you have "halos" around your objects, or to approximate boundary labels. 58 | - **Reconstruction loss** : Choose between MSE or BCE (default is MSE). MSE is more precise but also sensitive to outliers; BCE is more robust against outliers at the cost of precision. 59 | 60 | - NCuts parameters: 61 | - **Intensity sigma** : Standard deviation of the feature similarity term, focusing on brightness (default is 1). 62 | - **Spatial sigma** : Standard deviation for the spatial proximity term (default is 4). 63 | - **Radius** : Pixel radius for the loss computation (default is 2). 64 | 65 | .. note:: 66 | - The **Intensity Sigma** depends on image pixel values. The default of 1 is optimised for images being mapped between 0 and 100, which is done automatically by the plugin. 67 | - Raising the **Radius** might improve performance in certain cases, but will also greatly increase computation time. 68 | 69 | - Weights for the sum of losses : 70 | - **NCuts weight** : Sets the weight of the NCuts loss (default is 0.5). 71 | - **Reconstruction weight** : Sets the weight for the reconstruction loss (default is 5*1e-3). 72 | 73 | .. important:: 74 | The weight of the reconstruction loss should be adjusted to ensure that both losses are balanced. 75 | 76 | This balance can be assessed using the live view of training outputs : 77 | if the NCuts loss is "taking over", causing the segmentation to only label very large, brighter versus dimmer regions, the reconstruction loss should be increased. 78 | 79 | This will help the model to focus on the details of the objects, rather than just the overall brightness of the volume. 80 | 81 | Common issues troubleshooting 82 | ------------------------------ 83 | 84 | .. important:: 85 | If you do not find a satisfactory answer here, please do not hesitate to `open an issue`_ on GitHub. 86 | 87 | 88 | - **The NCuts loss "explodes" upward after a few epochs** : Lower the learning rate, for example start with a factor of two, then ten. 89 | 90 | - **Reconstruction (decoder) performance is poor** : First, try increasing the weight of the reconstruction loss. If this is ineffective, switch to BCE loss and set the scaling factor of the reconstruction loss to 0.5, OR adjust the weight of the MSE loss. 91 | 92 | - **Segmentation only separates the brighter versus dimmer regions** : Increase the weight of the reconstruction loss. 93 | 94 | 95 | .. _WNet, A Deep Model for Fully Unsupervised Image Segmentation: https://arxiv.org/abs/1711.08506 96 | .. _open an issue: https://github.com/AdaptiveMotorControlLab/CellSeg3D/issues 97 | -------------------------------------------------------------------------------- /docs/source/guides/utils_module_guide.rst: -------------------------------------------------------------------------------- 1 | .. _utils_module_guide: 2 | 3 | Utilities 🛠 4 | ============ 5 | 6 | Here you will find a range of tools for image processing and analysis. 7 | See `Usage section `_ for instructions on launching the plugin. 8 | 9 | .. note:: 10 | The utility selection menu is found at the bottom of the plugin window. 11 | 12 | You may specify the results directory for saving; afterwards you can run each action on a folder or on the currently selected layer. 13 | 14 | Default Paths for Saving Results 15 | ________________________________ 16 | 17 | Each utility saves results to a default directory under the user's home directory. The default paths are as follows: 18 | 19 | * Artifact Removal: ``~/cellseg3d/artifact_removed`` 20 | * Fragmentation: ``~/cellseg3d/fragmented`` 21 | * Anisotropy Correction: ``~/cellseg3d/anisotropy`` 22 | * Small Object Removal: ``~/cellseg3d/small_removed`` 23 | * Semantic Label Conversion: ``~/cellseg3d/semantic_labels`` 24 | * Instance Label Conversion: ``~/cellseg3d/instance_labels`` 25 | * Thresholding: ``~/cellseg3d/threshold`` 26 | * Statistics: ``~/cellseg3d/stats`` 27 | * Threshold Grid Search: ``~/cellseg3d/threshold_grid_search`` 28 | 29 | Available actions 30 | __________________ 31 | 32 | 1. Crop 3D volumes 33 | ------------------ 34 | Please refer to :ref:`cropping_module_guide` for a guide on using the cropping utility. 35 | 36 | 2. Convert to instance labels 37 | ----------------------------- 38 | This will convert semantic (binary) labels to instance labels (with a unique ID for each object). 39 | The available methods for this are: 40 | 41 | * `Connected Components`_ : simple method that will assign a unique ID to each connected component. Does not work well for touching objects (objects will often be fused). 42 | * `Watershed`_ : method based on topographic maps. Works well for clumped objects and anisotropic volumes depending on the quality of topography; clumed objects may be fused if this is not true. 43 | * `Voronoi-Otsu`_ : method based on Voronoi diagrams and Otsu thresholding. Works well for clumped objects but only for "round" objects. 44 | 45 | 3. Convert to semantic labels 46 | ----------------------------- 47 | Transforms instance labels into 0/1 semantic labels, useful for training purposes. 48 | 49 | 4. Remove small objects 50 | ----------------------- 51 | Input a size threshold (in pixels) to eliminate objects below this size. 52 | 53 | 5. Resize anisotropic images 54 | ---------------------------- 55 | Input your microscope's resolution to remove anisotropy in images. 56 | 57 | 6. Threshold images 58 | ------------------- 59 | Removes values beneath a certain threshold. 60 | 61 | 7. Fragment image 62 | ----------------- 63 | Break down large images into smaller cubes, optimal for training. 64 | 65 | 8. Conditional Random Field (CRF) 66 | --------------------------------- 67 | 68 | .. note:: 69 | This utility is only available if you have installed the `pydensecrf` package. 70 | You may install it by using the command ``pip install pydensecrf@git+https://github.com/lucasb-eyer/pydensecrf.git#egg=master``. 71 | 72 | | Refines semantic predictions by pairing them with the original image. 73 | | For a list of parameters, see the :doc:`CRF API page<../code/_autosummary/napari_cellseg3d.code_models.crf>`. 74 | 75 | 9. Labels statistics 76 | ------------------------------------------------ 77 | | Computes statistics for each object in the image. 78 | | Enter the name of the csv file to save the results, then select your layer or folder of labels to compute the statistics. 79 | 80 | .. note:: 81 | Images that are not only integer labels will be ignored. 82 | 83 | The available statistics are: 84 | 85 | For each object : 86 | 87 | * Object volume (pixels) 88 | * :math:`X,Y,Z` coordinates of the centroid 89 | * Sphericity 90 | 91 | Global metrics : 92 | 93 | * Image size 94 | * Total image volume (pixels) 95 | * Total object (labeled) volume (pixels) 96 | * Filling ratio (fraction of the volume that is labeled) 97 | * The number of labeled objects 98 | 99 | .. hint:: 100 | Check the ``notebooks`` folder for examples of plots using the statistics CSV file. 101 | 102 | 10. Clear large labels 103 | ---------------------- 104 | | Clears labels that are larger than a given threshold. 105 | | This is useful for removing artifacts that are larger than the objects of interest. 106 | 107 | 11. Find the best threshold 108 | ----------------------- 109 | | Finds the best threshold for separating objects from the background. 110 | | Requires a prediction from a model and GT labels as input. 111 | 112 | .. caution:: 113 | If the input prediction is not from the plugin, it will be remapped to the 0-1 range. 114 | 115 | | The threshold is found by maximizing the Dice coefficient between the thresolded prediction and the binarized GT labels. 116 | 117 | | The value for the best threshold will be displayed, and the prediction will be thresholded and saved with this value. 118 | 119 | Source code 120 | ___________ 121 | 122 | * :doc:`../code/_autosummary/napari_cellseg3d.code_plugins.plugin_convert` 123 | * :doc:`../code/_autosummary/napari_cellseg3d.code_plugins.plugin_crf` 124 | 125 | 126 | .. links 127 | 128 | .. _Watershed: https://scikit-image.org/docs/dev/auto_examples/segmentation/plot_watershed.html 129 | .. _Connected Components: https://scikit-image.org/docs/dev/api/skimage.measure.html#skimage.measure.label 130 | .. _Voronoi-Otsu: https://haesleinhuepf.github.io/BioImageAnalysisNotebooks/20_image_segmentation/11_voronoi_otsu_labeling.html 131 | -------------------------------------------------------------------------------- /docs/source/images/Review_Parameters.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AdaptiveMotorControlLab/CellSeg3D/6de4b86a671ffcd4b5535277a53082ac5ecc00a1/docs/source/images/Review_Parameters.png -------------------------------------------------------------------------------- /docs/source/images/converted_labels.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AdaptiveMotorControlLab/CellSeg3D/6de4b86a671ffcd4b5535277a53082ac5ecc00a1/docs/source/images/converted_labels.png -------------------------------------------------------------------------------- /docs/source/images/cropping_process_example.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AdaptiveMotorControlLab/CellSeg3D/6de4b86a671ffcd4b5535277a53082ac5ecc00a1/docs/source/images/cropping_process_example.png -------------------------------------------------------------------------------- /docs/source/images/inference_plugin_layout.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AdaptiveMotorControlLab/CellSeg3D/6de4b86a671ffcd4b5535277a53082ac5ecc00a1/docs/source/images/inference_plugin_layout.png -------------------------------------------------------------------------------- /docs/source/images/inference_results_example.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AdaptiveMotorControlLab/CellSeg3D/6de4b86a671ffcd4b5535277a53082ac5ecc00a1/docs/source/images/inference_results_example.png -------------------------------------------------------------------------------- /docs/source/images/init_image_labels.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AdaptiveMotorControlLab/CellSeg3D/6de4b86a671ffcd4b5535277a53082ac5ecc00a1/docs/source/images/init_image_labels.png -------------------------------------------------------------------------------- /docs/source/images/plot_example_metrics.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AdaptiveMotorControlLab/CellSeg3D/6de4b86a671ffcd4b5535277a53082ac5ecc00a1/docs/source/images/plot_example_metrics.png -------------------------------------------------------------------------------- /docs/source/images/plots_train.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AdaptiveMotorControlLab/CellSeg3D/6de4b86a671ffcd4b5535277a53082ac5ecc00a1/docs/source/images/plots_train.png -------------------------------------------------------------------------------- /docs/source/images/plugin_crop.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AdaptiveMotorControlLab/CellSeg3D/6de4b86a671ffcd4b5535277a53082ac5ecc00a1/docs/source/images/plugin_crop.png -------------------------------------------------------------------------------- /docs/source/images/plugin_inference.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AdaptiveMotorControlLab/CellSeg3D/6de4b86a671ffcd4b5535277a53082ac5ecc00a1/docs/source/images/plugin_inference.png -------------------------------------------------------------------------------- /docs/source/images/plugin_menu.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AdaptiveMotorControlLab/CellSeg3D/6de4b86a671ffcd4b5535277a53082ac5ecc00a1/docs/source/images/plugin_menu.png -------------------------------------------------------------------------------- /docs/source/images/plugin_review.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AdaptiveMotorControlLab/CellSeg3D/6de4b86a671ffcd4b5535277a53082ac5ecc00a1/docs/source/images/plugin_review.png -------------------------------------------------------------------------------- /docs/source/images/plugin_train.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AdaptiveMotorControlLab/CellSeg3D/6de4b86a671ffcd4b5535277a53082ac5ecc00a1/docs/source/images/plugin_train.png -------------------------------------------------------------------------------- /docs/source/images/plugin_welcome.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AdaptiveMotorControlLab/CellSeg3D/6de4b86a671ffcd4b5535277a53082ac5ecc00a1/docs/source/images/plugin_welcome.png -------------------------------------------------------------------------------- /docs/source/images/review_process_example.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AdaptiveMotorControlLab/CellSeg3D/6de4b86a671ffcd4b5535277a53082ac5ecc00a1/docs/source/images/review_process_example.png -------------------------------------------------------------------------------- /docs/source/images/stat_plots.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AdaptiveMotorControlLab/CellSeg3D/6de4b86a671ffcd4b5535277a53082ac5ecc00a1/docs/source/images/stat_plots.png -------------------------------------------------------------------------------- /docs/source/images/training_tab_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AdaptiveMotorControlLab/CellSeg3D/6de4b86a671ffcd4b5535277a53082ac5ecc00a1/docs/source/images/training_tab_1.png -------------------------------------------------------------------------------- /docs/source/images/training_tab_2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AdaptiveMotorControlLab/CellSeg3D/6de4b86a671ffcd4b5535277a53082ac5ecc00a1/docs/source/images/training_tab_2.png -------------------------------------------------------------------------------- /docs/source/images/training_tab_3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AdaptiveMotorControlLab/CellSeg3D/6de4b86a671ffcd4b5535277a53082ac5ecc00a1/docs/source/images/training_tab_3.png -------------------------------------------------------------------------------- /docs/source/images/training_tab_4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AdaptiveMotorControlLab/CellSeg3D/6de4b86a671ffcd4b5535277a53082ac5ecc00a1/docs/source/images/training_tab_4.png -------------------------------------------------------------------------------- /docs/source/logo/logo_alpha.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AdaptiveMotorControlLab/CellSeg3D/6de4b86a671ffcd4b5535277a53082ac5ecc00a1/docs/source/logo/logo_alpha.png -------------------------------------------------------------------------------- /docs/welcome.rst: -------------------------------------------------------------------------------- 1 | Welcome to CellSeg3D! 2 | ===================== 3 | 4 | 5 | .. figure:: ./source/images/plugin_welcome.png 6 | :align: center 7 | 8 | **CellSeg3D** is a toolbox for 3D segmentation of cells in light-sheet microscopy images, using napari. 9 | Use CellSeg3D to: 10 | 11 | * Review labeled cell volumes from whole-brain samples of mice imaged by mesoSPIM microscopy [1]_ 12 | * Train and use segmentation models from the MONAI project [2]_ 13 | * Train and use our WNet3D unsupervised model 14 | * Or implement your own custom 3D segmentation models using PyTorch! 15 | 16 | 17 | .. figure:: https://images.squarespace-cdn.com/content/v1/57f6d51c9f74566f55ecf271/0d16a71b-3ff2-477a-9d83-18d96cb1ce28/full_demo.gif?format=500w 18 | :alt: CellSeg3D demo 19 | :width: 800 20 | :align: center 21 | 22 | Demo of the plugin 23 | 24 | 25 | Requirements 26 | -------------------------------------------- 27 | 28 | .. important:: 29 | This package requires **PyQt5** or **PySide2** to be installed first for napari to run. 30 | If you do not have a Qt backend installed you can use : 31 | ``pip install napari[all]`` 32 | to install PyQt5 by default. 33 | 34 | This package depends on PyTorch and certain optional dependencies of MONAI. These come as requirements, but if 35 | you need further assistance, please see below. 36 | 37 | .. note:: 38 | A **CUDA-capable GPU** is not needed but **very strongly recommended**, especially for training and to a lesser degree inference. 39 | 40 | * For help with PyTorch, please see `PyTorch's website`_ for installation instructions, with or without CUDA according to your hardware. 41 | **Depending on your setup, you might wish to install torch first.** 42 | 43 | * If you get errors from MONAI regarding missing readers, please see `MONAI's optional dependencies`_ page for instructions on getting the readers required by your images. 44 | 45 | .. _MONAI's optional dependencies: https://docs.monai.io/en/stable/installation.html#installing-the-recommended-dependencies 46 | .. _PyTorch's website: https://pytorch.org/get-started/locally/ 47 | 48 | 49 | 50 | Installation 51 | -------------------------------------------- 52 | CellSeg3D can be run on Windows, Linux, or MacOS. 53 | 54 | For detailed installation instructions, including installing pre-requisites, 55 | please see :ref:`source/guides/installation_guide:Installation guide ⚙` 56 | 57 | .. warning:: 58 | **ARM64 MacOS users**, please refer to the :ref:`dedicated section ` 59 | 60 | You can install ``napari-cellseg3d`` via pip: 61 | 62 | .. code-block:: 63 | 64 | pip install napari-cellseg3d 65 | 66 | For local installation after cloning from GitHub, please run the following in the CellSeg3D folder: 67 | 68 | .. code-block:: 69 | 70 | pip install -e . 71 | 72 | If the installation was successful, you will find the napari-cellseg3D plugin in the Plugins section of napari. 73 | 74 | 75 | Usage 76 | -------------------------------------------- 77 | 78 | 79 | To use the plugin, please run: 80 | 81 | .. code-block:: 82 | 83 | napari 84 | 85 | Then go into **Plugins > CellSeg3D** 86 | 87 | .. figure:: ./source/images/plugin_menu.png 88 | :align: center 89 | 90 | 91 | and choose the correct tool to use: 92 | 93 | - :ref:`review_module_guide`: Examine and refine your labels, whether manually annotated or predicted by a pre-trained model. 94 | - :ref:`training_module_guide`: Train segmentation algorithms on your own data. 95 | - :ref:`inference_module_guide`: Use pre-trained segmentation algorithms on volumes to automate cell labelling. 96 | - :ref:`utils_module_guide`: Leverage various utilities, including cropping your volumes and labels, converting semantic to instance labels, and more. 97 | - **Help/About...** : Quick access to version info, Github pages and documentation. 98 | 99 | .. hint:: 100 | Many buttons have tooltips to help you understand what they do. 101 | Simply hover over them to see the tooltip. 102 | 103 | 104 | Documentation contents 105 | -------------------------------------------- 106 | _`From this page you can access the guides on the several modules available for your tasks`, such as : 107 | 108 | 109 | * Main modules : 110 | * :ref:`review_module_guide` 111 | * :ref:`training_module_guide` 112 | * :ref:`inference_module_guide` 113 | * Utilities : 114 | * :ref:`cropping_module_guide` 115 | * :ref:`utils_module_guide` 116 | 117 | .. 118 | * Convert labels : :ref:`utils_module_guide` 119 | .. 120 | * Compute scores : :ref:`metrics_module_guide` 121 | 122 | * Advanced : 123 | * :ref:`training_wnet` 124 | * :ref:`custom_model_guide` **(WIP)** 125 | 126 | Other useful napari plugins 127 | --------------------------------------------- 128 | 129 | .. important:: 130 | | Please note that these plugins are not developed by us, and we cannot guarantee their compatibility, functionality or support. 131 | | Installing napari plugins in separated environments is recommended. 132 | 133 | * `brainreg-napari`_ : Whole-brain registration in napari 134 | * `napari-brightness-contrast`_ : Adjust brightness and contrast of your images, visualize histograms and more 135 | * `napari-pyclesperanto-assistant`_ : Image processing workflows using pyclEsperanto 136 | * `napari-skimage-regionprops`_ : Compute region properties on your labels 137 | 138 | .. _napari-pyclesperanto-assistant: https://www.napari-hub.org/plugins/napari-pyclesperanto-assistant 139 | .. _napari-brightness-contrast: https://www.napari-hub.org/plugins/napari-brightness-contrast 140 | .. _brainreg-napari: https://github.com/brainglobe/brainreg?tab=readme-ov-file#installation 141 | .. _napari-skimage-regionprops: https://www.napari-hub.org/plugins/napari-skimage-regionprops 142 | 143 | Acknowledgments & References 144 | --------------------------------------------- 145 | If you find our code or ideas useful, please cite: 146 | 147 | Achard Cyril, Kousi Timokleia, Frey Markus, Vidal Maxime, Paychère Yves, Hofmann Colin, Iqbal Asim, Hausmann Sebastien B, Pagès Stéphane, Mathis Mackenzie Weygandt (2024) 148 | CellSeg3D: self-supervised 3D cell segmentation for microscopy eLife https://doi.org/10.7554/eLife.99848.1 149 | 150 | 151 | 152 | This plugin additionally uses the following libraries and software: 153 | 154 | * `napari`_ 155 | 156 | * `PyTorch`_ 157 | 158 | * `MONAI project`_ (various models used here are credited `on their website`_) 159 | 160 | * `pyclEsperanto`_ (for the Voronoi Otsu labeling) by Robert Haase 161 | 162 | 163 | 164 | .. _Mathis Laboratory of Adaptive Intelligence: http://www.mackenziemathislab.org/ 165 | .. _Wyss Center: https://wysscenter.ch/ 166 | .. _TRAILMAP project on GitHub: https://github.com/AlbertPun/TRAILMAP 167 | .. _napari: https://napari.org/ 168 | .. _PyTorch: https://pytorch.org/ 169 | .. _MONAI project: https://monai.io/ 170 | .. _on their website: https://docs.monai.io/en/stable/networks.html#nets 171 | .. _pyclEsperanto: https://github.com/clEsperanto/pyclesperanto_prototype 172 | .. _WNet: https://arxiv.org/abs/1711.08506 173 | 174 | .. rubric:: References 175 | 176 | .. [1] The mesoSPIM initiative: open-source light-sheet microscopes for imaging cleared tissue, Voigt et al., 2019 ( https://doi.org/10.1038/s41592-019-0554-0 ) 177 | .. [2] MONAI Project website ( https://monai.io/ ) 178 | -------------------------------------------------------------------------------- /examples/README.md: -------------------------------------------------------------------------------- 1 | # Testing CellSeg3D on demo data 2 | 3 | Here is a very small volume (from [IDR project 853](https://idr.openmicroscopy.org/webclient/?show=project-853)) to test on. 4 | All credits to the original authors of the data. 5 | You can install, launch `napari`, activate the CellSeg3D plugin app, and drag & drop this volume into the canvas. 6 | Then, for example, run the `Inference` module with one of our models. 7 | 8 | See [CellSeg3D documentation](https://adaptivemotorcontrollab.github.io/CellSeg3D/welcome.html) for more details. -------------------------------------------------------------------------------- /examples/c5image.tif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AdaptiveMotorControlLab/CellSeg3D/6de4b86a671ffcd4b5535277a53082ac5ecc00a1/examples/c5image.tif -------------------------------------------------------------------------------- /examples/test_very_small.tif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AdaptiveMotorControlLab/CellSeg3D/6de4b86a671ffcd4b5535277a53082ac5ecc00a1/examples/test_very_small.tif -------------------------------------------------------------------------------- /napari_cellseg3d/__init__.py: -------------------------------------------------------------------------------- 1 | """napari-cellseg3d - napari plugin for cell segmentation.""" 2 | 3 | __version__ = "0.2.2" 4 | -------------------------------------------------------------------------------- /napari_cellseg3d/_tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AdaptiveMotorControlLab/CellSeg3D/6de4b86a671ffcd4b5535277a53082ac5ecc00a1/napari_cellseg3d/_tests/__init__.py -------------------------------------------------------------------------------- /napari_cellseg3d/_tests/conftest.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import pytest 4 | 5 | 6 | @pytest.fixture(scope="session", autouse=True) 7 | def env_config(): 8 | """ 9 | Configure environment variables needed for the test session 10 | """ 11 | 12 | # This makes QT render everything offscreen and thus prevents 13 | # any Modals / Dialogs or other Widgets being rendered on the screen while running unit tests 14 | os.environ["QT_QPA_PLATFORM"] = "offscreen" 15 | 16 | yield 17 | 18 | os.environ.pop("QT_QPA_PLATFORM") 19 | -------------------------------------------------------------------------------- /napari_cellseg3d/_tests/fixtures.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from qtpy.QtWidgets import QTextEdit 3 | 4 | from napari_cellseg3d.utils import LOGGER as logger 5 | 6 | 7 | class LogFixture(QTextEdit): 8 | """Fixture for testing, replaces napari_cellseg3d.interface.Log in model_workers during testing.""" 9 | 10 | def __init__(self): 11 | super(LogFixture, self).__init__() 12 | 13 | def print_and_log(self, text, printing=None): 14 | print(text) 15 | 16 | def warn(self, warning): 17 | logger.warning(warning) 18 | 19 | def error(self, e): 20 | raise (e) 21 | 22 | 23 | class WNetFixture(torch.nn.Module): 24 | """Fixture for testing, replaces napari_cellseg3d.models.WNet during testing.""" 25 | 26 | def __init__(self): 27 | super().__init__() 28 | self.mock_conv = torch.nn.Conv3d(1, 1, 1) 29 | self.mock_conv.requires_grad_(False) 30 | 31 | def forward_encoder(self, x): 32 | """Forward pass through encoder.""" 33 | return x 34 | 35 | def forward_decoder(self, x): 36 | """Forward pass through decoder.""" 37 | return x 38 | 39 | def forward(self, x): 40 | """Forward pass through WNet.""" 41 | return self.forward_encoder(x), self.forward_decoder(x) 42 | 43 | 44 | class ModelFixture(torch.nn.Module): 45 | """Fixture for testing, replaces napari_cellseg3d models during testing.""" 46 | 47 | def __init__(self): 48 | """Fixture for testing, replaces models during testing.""" 49 | super().__init__() 50 | self.mock_conv = torch.nn.Conv3d(1, 1, 1) 51 | self.mock_conv.requires_grad_(False) 52 | 53 | def forward(self, x): 54 | """Forward pass through model.""" 55 | return x 56 | 57 | 58 | class OptimizerFixture: 59 | """Fixture for testing, replaces optimizers during testing.""" 60 | 61 | def __init__(self): 62 | self.param_groups = [] 63 | self.param_groups.append({"lr": 0}) 64 | 65 | def zero_grad(self): 66 | """Dummy function for zero_grad.""" 67 | pass 68 | 69 | def step(self, *args): 70 | """Dummy function for step.""" 71 | pass 72 | 73 | 74 | class SchedulerFixture: 75 | """Fixture for testing, replaces schedulers during testing.""" 76 | 77 | def step(self, *args): 78 | """Dummy function for step.""" 79 | pass 80 | 81 | 82 | class LossFixture: 83 | """Fixture for testing, replaces losses during testing.""" 84 | 85 | def __call__(self, *args): 86 | """Dummy function for __call__.""" 87 | return self 88 | 89 | def backward(self, *args): 90 | """Dummy function for backward.""" 91 | pass 92 | 93 | def item(self): 94 | """Dummy function for item.""" 95 | return 0 96 | 97 | def detach(self): 98 | """Dummy function for detach.""" 99 | return self 100 | -------------------------------------------------------------------------------- /napari_cellseg3d/_tests/pytest.ini: -------------------------------------------------------------------------------- 1 | [pytest] 2 | qt_api=pyqt5 3 | -------------------------------------------------------------------------------- /napari_cellseg3d/_tests/res/test.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AdaptiveMotorControlLab/CellSeg3D/6de4b86a671ffcd4b5535277a53082ac5ecc00a1/napari_cellseg3d/_tests/res/test.png -------------------------------------------------------------------------------- /napari_cellseg3d/_tests/res/test.tif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AdaptiveMotorControlLab/CellSeg3D/6de4b86a671ffcd4b5535277a53082ac5ecc00a1/napari_cellseg3d/_tests/res/test.tif -------------------------------------------------------------------------------- /napari_cellseg3d/_tests/res/test_labels.tif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AdaptiveMotorControlLab/CellSeg3D/6de4b86a671ffcd4b5535277a53082ac5ecc00a1/napari_cellseg3d/_tests/res/test_labels.tif -------------------------------------------------------------------------------- /napari_cellseg3d/_tests/res/wnet_test/lab/test.tif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AdaptiveMotorControlLab/CellSeg3D/6de4b86a671ffcd4b5535277a53082ac5ecc00a1/napari_cellseg3d/_tests/res/wnet_test/lab/test.tif -------------------------------------------------------------------------------- /napari_cellseg3d/_tests/res/wnet_test/test.tif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AdaptiveMotorControlLab/CellSeg3D/6de4b86a671ffcd4b5535277a53082ac5ecc00a1/napari_cellseg3d/_tests/res/wnet_test/test.tif -------------------------------------------------------------------------------- /napari_cellseg3d/_tests/res/wnet_test/vol/test.tif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AdaptiveMotorControlLab/CellSeg3D/6de4b86a671ffcd4b5535277a53082ac5ecc00a1/napari_cellseg3d/_tests/res/wnet_test/vol/test.tif -------------------------------------------------------------------------------- /napari_cellseg3d/_tests/test_base_plugin.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | from napari_cellseg3d.code_plugins.plugin_base import ( 4 | BasePluginSingleImage, 5 | ) 6 | 7 | 8 | def test_base_single_image(make_napari_viewer_proxy): 9 | viewer = make_napari_viewer_proxy() 10 | plugin = BasePluginSingleImage(viewer) 11 | 12 | test_folder = Path(__file__).parent.resolve() 13 | test_image = str(test_folder / "res/test.tif") 14 | 15 | assert plugin._check_results_path(str(test_folder)) 16 | plugin.image_path = test_image 17 | assert plugin._default_path[0] != test_image 18 | plugin._update_default_paths() 19 | assert plugin._default_path[0] == test_image 20 | -------------------------------------------------------------------------------- /napari_cellseg3d/_tests/test_dock_widget.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | from tifffile import imread 4 | 5 | from napari_cellseg3d.code_plugins.plugin_review_dock import Datamanager 6 | 7 | 8 | def test_prepare(make_napari_viewer_proxy): 9 | path_image = str(Path(__file__).resolve().parent / "res/test.tif") 10 | image = imread(str(path_image)) 11 | viewer = make_napari_viewer_proxy() 12 | viewer.add_image(image) 13 | widget = Datamanager(viewer) 14 | viewer.window.add_dock_widget(widget) 15 | 16 | widget.prepare(path_image, ".tif", "", False) 17 | 18 | assert widget.filetype == ".tif" 19 | assert widget.as_folder is False 20 | assert Path(widget.csv_path) == ( 21 | Path(__file__).resolve().parent / "res/_train0.csv" 22 | ) 23 | -------------------------------------------------------------------------------- /napari_cellseg3d/_tests/test_helper.py: -------------------------------------------------------------------------------- 1 | from napari_cellseg3d.code_plugins.plugin_helper import Helper 2 | 3 | 4 | def test_helper(make_napari_viewer_proxy): 5 | viewer = make_napari_viewer_proxy() 6 | widget = Helper(viewer) 7 | 8 | dock = viewer.window.add_dock_widget(widget) 9 | children = len(viewer.window._dock_widgets) 10 | 11 | assert dock is not None 12 | 13 | widget.btnc.click() 14 | 15 | assert len(viewer.window._dock_widgets) == children - 1 16 | -------------------------------------------------------------------------------- /napari_cellseg3d/_tests/test_inference.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | import napari 4 | import numpy as np 5 | import pytest 6 | import torch 7 | from monai.data import DataLoader 8 | 9 | from napari_cellseg3d.code_models.worker_inference import InferenceWorker 10 | from napari_cellseg3d.code_models.workers_utils import ( 11 | PRETRAINED_WEIGHTS_DIR, 12 | InferenceResult, 13 | ONNXModelWrapper, 14 | WeightsDownloader, 15 | ) 16 | from napari_cellseg3d.config import InferenceWorkerConfig 17 | from napari_cellseg3d.utils import rand_gen 18 | 19 | 20 | def test_onnx_inference(make_napari_viewer_proxy): 21 | downloader = WeightsDownloader() 22 | downloader.download_weights("WNet_ONNX", "wnet.onnx") 23 | path = str(Path(PRETRAINED_WEIGHTS_DIR).resolve() / "wnet.onnx") 24 | assert Path(path).is_file() 25 | dims = 64 26 | batch = 1 27 | x = torch.randn(size=(batch, 1, dims, dims, dims)) 28 | worker = ONNXModelWrapper(file_location=path) 29 | assert worker.eval() is None 30 | assert worker.to(device="cpu") is None 31 | assert worker.forward(x).shape == (batch, 2, dims, dims, dims) 32 | 33 | 34 | def test_load_folder(): 35 | config = InferenceWorkerConfig() 36 | worker = InferenceWorker(worker_config=config) 37 | 38 | images_path = Path(__file__).resolve().parent / "res/test.tif" 39 | worker.config.images_filepaths = [str(images_path)] 40 | dataloader = worker.load_folder() 41 | assert isinstance(dataloader, DataLoader) 42 | assert len(dataloader) == 1 43 | worker.config.sliding_window_config.window_size = [64, 64, 64] 44 | dataloader = worker.load_folder() 45 | assert isinstance(dataloader, DataLoader) 46 | assert len(dataloader) == 1 47 | 48 | mock_layer = napari.layers.Image(data=rand_gen.random((64, 64, 64))) 49 | worker.config.layer = mock_layer 50 | input_image = worker.load_layer() 51 | assert input_image.shape == (1, 1, 64, 64, 64) 52 | 53 | mock_layer = napari.layers.Image(data=rand_gen.random((5, 2, 64, 64, 64))) 54 | worker.config.layer = mock_layer 55 | assert len(mock_layer.data.shape) == 5 56 | with pytest.raises( 57 | ValueError, 58 | match="Data array is not 3-dimensional but 5-dimensional, please check for extra channel/batch dimensions", 59 | ): 60 | worker.load_layer() 61 | 62 | 63 | def test_inference_on_folder(): 64 | config = InferenceWorkerConfig() 65 | config.filetype = ".tif" 66 | config.images_filepaths = [ 67 | str(Path(__file__).resolve().parent / "res/test.tif") 68 | ] 69 | 70 | config.sliding_window_config.window_size = 8 71 | 72 | class mock_work: 73 | @staticmethod 74 | def eval(): 75 | return True 76 | 77 | def __call__(self, x): 78 | return torch.Tensor(x) 79 | 80 | worker = InferenceWorker(worker_config=config) 81 | worker.aniso_transform = mock_work() 82 | 83 | image = torch.Tensor(rand_gen.random(size=(1, 1, 8, 8, 8))) 84 | assert image.shape == (1, 1, 8, 8, 8) 85 | assert image.dtype == torch.float32 86 | res = worker.inference_on_folder( 87 | {"image": image}, 88 | 0, 89 | model=mock_work(), 90 | post_process_transforms=mock_work(), 91 | ) 92 | assert isinstance(res, InferenceResult) 93 | assert res.semantic_segmentation is not None 94 | 95 | 96 | def test_post_processing(): 97 | image = rand_gen.random((1, 1, 64, 64, 64)) 98 | labels = rand_gen.random((1, 2, 64, 64, 64)) 99 | mock_layer = napari.layers.Image(data=image) 100 | mock_layer.name = "test" 101 | 102 | config = InferenceWorkerConfig() 103 | config.layer = mock_layer 104 | worker = InferenceWorker(worker_config=config) 105 | 106 | results = worker.run_crf(image, labels, lambda x: x) 107 | assert results.shape == (2, 64, 64, 64) 108 | 109 | worker.stats_csv(np.squeeze(labels)) 110 | -------------------------------------------------------------------------------- /napari_cellseg3d/_tests/test_interface.py: -------------------------------------------------------------------------------- 1 | from napari_cellseg3d.interface import AnisotropyWidgets, Log 2 | 3 | 4 | def test_log(qtbot): 5 | log = Log() 6 | log.print_and_log("test") 7 | 8 | assert log.toPlainText() == "\ntest" 9 | 10 | log.replace_last_line("test2") 11 | 12 | assert log.toPlainText() == "\ntest2" 13 | 14 | qtbot.add_widget(log) 15 | 16 | 17 | def test_zoom_factor(): 18 | resolution = [5.0, 10.0, 5.0] 19 | zoom = AnisotropyWidgets.anisotropy_zoom_factor(resolution) 20 | assert zoom == [1, 0.5, 1] 21 | -------------------------------------------------------------------------------- /napari_cellseg3d/_tests/test_labels_correction.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | import numpy as np 4 | from tifffile import imread 5 | 6 | from napari_cellseg3d.dev_scripts import artefact_labeling as al 7 | from napari_cellseg3d.dev_scripts import correct_labels as cl 8 | from napari_cellseg3d.dev_scripts import evaluate_labels as el 9 | 10 | res_folder = Path(__file__).resolve().parent / "res" 11 | image_path = res_folder / "test.tif" 12 | image = imread(str(image_path)) 13 | 14 | labels_path = res_folder / "test_labels.tif" 15 | labels = imread(str(labels_path)) # .astype(np.int32) 16 | 17 | 18 | def test_artefact_labeling(): 19 | output_path = str(res_folder / "test_artifacts.tif") 20 | al.create_artefact_labels(image, labels, output_path=output_path) 21 | assert Path(output_path).is_file() 22 | 23 | 24 | def test_artefact_labeling_utils(): 25 | crop_test = al.crop_image(image) 26 | assert isinstance(crop_test, np.ndarray) 27 | output_path = str(res_folder / "test_cropped.tif") 28 | al.crop_image_path(image, path_image_out=output_path) 29 | assert Path(output_path).is_file() 30 | 31 | 32 | def test_correct_labels(): 33 | output_path = res_folder / "test_correct" 34 | output_path.mkdir(exist_ok=True, parents=True) 35 | cl.relabel_non_unique_i( 36 | labels, str(output_path / "corrected.tif"), go_fast=True 37 | ) 38 | 39 | 40 | def test_relabel(): 41 | cl.relabel( 42 | str(image_path), 43 | str(labels_path), 44 | go_fast=True, 45 | test=True, 46 | ) 47 | 48 | 49 | def test_evaluate_model_performance(): 50 | el.evaluate_model_performance( 51 | labels, labels, print_details=True, visualize=False 52 | ) 53 | -------------------------------------------------------------------------------- /napari_cellseg3d/_tests/test_model_framework.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | from napari_cellseg3d.code_models import model_framework 4 | from napari_cellseg3d.config import MODEL_LIST 5 | 6 | 7 | def pth(path): 8 | return str(Path(path)) 9 | 10 | 11 | def test_update_default(make_napari_viewer_proxy): 12 | view = make_napari_viewer_proxy() 13 | widget = model_framework.ModelFramework(view) 14 | 15 | widget.images_filepaths = [] 16 | widget.results_path = None 17 | 18 | widget._update_default_paths() 19 | 20 | assert widget._default_path == [None, None, None, None] 21 | 22 | widget.images_filepaths = [ 23 | pth("C:/test/test/images.tif"), 24 | pth("C:/images/test/data.png"), 25 | ] 26 | widget.labels_filepaths = [ 27 | pth("C:/dataset/labels/lab1.tif"), 28 | pth("C:/data/labels/lab2.tif"), 29 | ] 30 | widget.results_path = pth("D:/dataset/res") 31 | # widget.model_path = None 32 | 33 | widget._update_default_paths() 34 | 35 | assert widget._default_path == [ 36 | pth("C:/test/test"), 37 | pth("C:/dataset/labels"), 38 | None, 39 | pth("D:/dataset/res"), 40 | ] 41 | 42 | 43 | def test_create_train_dataset_dict(make_napari_viewer_proxy): 44 | view = make_napari_viewer_proxy() 45 | widget = model_framework.ModelFramework(view) 46 | 47 | widget.images_filepaths = [str(f"{i}.tif") for i in range(3)] 48 | widget.labels_filepaths = [str(f"lab_{i}.tif") for i in range(3)] 49 | 50 | expect = [ 51 | {"image": "0.tif", "label": "lab_0.tif"}, 52 | {"image": "1.tif", "label": "lab_1.tif"}, 53 | {"image": "2.tif", "label": "lab_2.tif"}, 54 | ] 55 | 56 | assert widget.create_train_dataset_dict() == expect 57 | 58 | 59 | def test_log(make_napari_viewer_proxy): 60 | mock_test = "test" 61 | framework = model_framework.ModelFramework( 62 | viewer=make_napari_viewer_proxy() 63 | ) 64 | framework.log.print_and_log(mock_test) 65 | assert len(framework.log.toPlainText()) != 0 66 | assert framework.log.toPlainText() == "\n" + mock_test 67 | 68 | framework.results_path = str(Path(__file__).resolve().parent / "res") 69 | framework.save_log(do_timestamp=False) 70 | log_path = Path(__file__).resolve().parent / "res/Log_report.txt" 71 | assert log_path.is_file() 72 | with Path.open(log_path.resolve(), "r") as f: 73 | assert f.read() == "\n" + mock_test 74 | 75 | # remove log file 76 | log_path.unlink(missing_ok=False) 77 | log_path = Path(__file__).resolve().parent / "res/Log_report.txt" 78 | framework.save_log_to_path(str(log_path.parent), do_timestamp=False) 79 | assert log_path.is_file() 80 | with Path.open(log_path.resolve(), "r") as f: 81 | assert f.read() == "\n" + mock_test 82 | log_path.unlink(missing_ok=False) 83 | 84 | 85 | def test_display_elements(make_napari_viewer_proxy): 86 | framework = model_framework.ModelFramework( 87 | viewer=make_napari_viewer_proxy() 88 | ) 89 | 90 | framework.display_status_report() 91 | framework.display_status_report() 92 | 93 | framework.custom_weights_choice.setChecked(False) 94 | framework._toggle_weights_path() 95 | assert not framework.weights_filewidget.isVisible() 96 | 97 | 98 | def test_available_models_retrieval(make_napari_viewer_proxy): 99 | framework = model_framework.ModelFramework( 100 | viewer=make_napari_viewer_proxy() 101 | ) 102 | assert framework.get_available_models() == MODEL_LIST 103 | 104 | 105 | def test_update_weights_path(make_napari_viewer_proxy): 106 | framework = model_framework.ModelFramework( 107 | viewer=make_napari_viewer_proxy() 108 | ) 109 | assert ( 110 | framework._update_weights_path(framework._default_weights_folder) 111 | is None 112 | ) 113 | name = str(Path.home() / "test/weight.pth") 114 | framework._update_weights_path([name]) 115 | assert framework.weights_config.path == name 116 | assert framework.weights_filewidget.text_field.text() == name 117 | assert framework._default_weights_folder == str(Path.home() / "test") 118 | -------------------------------------------------------------------------------- /napari_cellseg3d/_tests/test_models.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | import numpy as np 4 | import pytest 5 | import torch 6 | 7 | from napari_cellseg3d.code_models.crf import ( 8 | CRFWorker, 9 | correct_shape_for_crf, 10 | crf_batch, 11 | crf_with_config, 12 | ) 13 | from napari_cellseg3d.code_models.models.model_TRAILMAP_MS import TRAILMAP_MS_ 14 | from napari_cellseg3d.code_models.models.wnet.soft_Ncuts import SoftNCutsLoss 15 | from napari_cellseg3d.config import MODEL_LIST, CRFConfig 16 | from napari_cellseg3d.utils import rand_gen 17 | 18 | 19 | def test_correct_shape_for_crf(): 20 | test = rand_gen.random(size=(1, 1, 8, 8, 8)) 21 | assert correct_shape_for_crf(test).shape == (1, 8, 8, 8) 22 | test = rand_gen.random(size=(8, 8, 8)) 23 | assert correct_shape_for_crf(test).shape == (1, 8, 8, 8) 24 | 25 | 26 | def test_model_list(): 27 | for model_name in MODEL_LIST: 28 | # if model_name=="test": 29 | # continue 30 | dims = 64 31 | test = MODEL_LIST[model_name]( 32 | input_img_size=[dims, dims, dims], 33 | in_channels=1, 34 | out_channels=1, 35 | dropout_prob=0.3, 36 | ) 37 | assert isinstance(test, MODEL_LIST[model_name]) 38 | 39 | 40 | def test_soft_ncuts_loss(): 41 | dims = 8 42 | labels = torch.rand([1, 1, dims, dims, dims]) 43 | 44 | loss = SoftNCutsLoss( 45 | data_shape=[dims, dims, dims], 46 | device="cpu", 47 | intensity_sigma=4, 48 | spatial_sigma=4, 49 | radius=2, 50 | ) 51 | 52 | res = loss.forward(labels, labels) 53 | assert isinstance(res, torch.Tensor) 54 | assert 0 <= res <= 1 # ASSUMES NUMBER OF CLASS IS 2, NOT CORRECT IF K>2 55 | 56 | loss = SoftNCutsLoss( 57 | data_shape=[dims, dims, dims], 58 | device="cpu", 59 | intensity_sigma=4, 60 | spatial_sigma=4, 61 | radius=None, # test radius=None init 62 | ) 63 | assert loss.radius == 5 64 | 65 | 66 | def test_crf_batch(): 67 | dims = 8 68 | mock_image = rand_gen.random(size=(1, dims, dims, dims)) 69 | mock_label = rand_gen.random(size=(2, dims, dims, dims)) 70 | config = CRFConfig() 71 | 72 | result = crf_batch( 73 | np.array([mock_image, mock_image, mock_image]), 74 | np.array([mock_label, mock_label, mock_label]), 75 | sa=config.sa, 76 | sb=config.sb, 77 | sg=config.sg, 78 | w1=config.w1, 79 | w2=config.w2, 80 | ) 81 | 82 | assert result.shape == (3, 2, dims, dims, dims) 83 | 84 | 85 | def test_crf_config(): 86 | dims = 8 87 | mock_image = rand_gen.random(size=(1, dims, dims, dims)) 88 | mock_label = rand_gen.random(size=(2, dims, dims, dims)) 89 | config = CRFConfig() 90 | 91 | result = crf_with_config(mock_image, mock_label, config) 92 | assert result.shape == mock_label.shape 93 | 94 | 95 | def test_crf_worker(qtbot): 96 | dims = 8 97 | mock_image = rand_gen.random(size=(1, dims, dims, dims)) 98 | mock_label = rand_gen.random(size=(2, dims, dims, dims)) 99 | assert len(mock_label.shape) == 4 100 | crf = CRFWorker([mock_image], [mock_label]) 101 | 102 | def on_yield(result): 103 | assert len(result.shape) == 4 104 | assert len(mock_label.shape) == 4 105 | assert result.shape[-3:] == mock_label.shape[-3:] 106 | 107 | result = next(crf._run_crf_job()) 108 | on_yield(result) 109 | 110 | 111 | def test_pretrained_weights_compatibility(): 112 | from napari_cellseg3d.code_models.workers_utils import WeightsDownloader 113 | from napari_cellseg3d.config import MODEL_LIST, PRETRAINED_WEIGHTS_DIR 114 | 115 | for model_name in MODEL_LIST: 116 | file_name = MODEL_LIST[model_name].weights_file 117 | WeightsDownloader().download_weights(model_name, file_name) 118 | model = MODEL_LIST[model_name](input_img_size=[64, 64, 64]) 119 | try: 120 | model.load_state_dict( 121 | torch.load( 122 | str(Path(PRETRAINED_WEIGHTS_DIR) / file_name), 123 | map_location="cpu", 124 | ), 125 | strict=True, 126 | ) 127 | except RuntimeError: 128 | pytest.fail(f"Failed to load weights for {model_name}") 129 | 130 | 131 | def test_trailmap_init(): 132 | test = TRAILMAP_MS_( 133 | input_img_size=[128, 128, 128], 134 | in_channels=1, 135 | out_channels=1, 136 | dropout_prob=0.3, 137 | ) 138 | assert isinstance(test, TRAILMAP_MS_) 139 | -------------------------------------------------------------------------------- /napari_cellseg3d/_tests/test_plugin_inference.py: -------------------------------------------------------------------------------- 1 | from napari_cellseg3d._tests.fixtures import LogFixture 2 | from napari_cellseg3d.code_models.instance_segmentation import ( 3 | INSTANCE_SEGMENTATION_METHOD_LIST, 4 | volume_stats, 5 | ) 6 | from napari_cellseg3d.code_models.models.model_test import TestModel 7 | from napari_cellseg3d.code_models.workers_utils import InferenceResult 8 | from napari_cellseg3d.code_plugins.plugin_model_inference import ( 9 | Inferer, 10 | ) 11 | from napari_cellseg3d.config import MODEL_LIST 12 | from napari_cellseg3d.utils import rand_gen 13 | 14 | 15 | def test_inference(make_napari_viewer_proxy, qtbot): 16 | dims = 6 17 | image = rand_gen.random(size=(dims, dims, dims)) 18 | # assert image.shape == (dims, dims, dims) 19 | 20 | viewer = make_napari_viewer_proxy() 21 | widget = Inferer(viewer) 22 | widget.log = LogFixture() 23 | viewer.window.add_dock_widget(widget) 24 | viewer.add_image(image) 25 | 26 | assert len(viewer.layers) == 1 27 | 28 | widget.use_window_choice.setChecked(True) 29 | widget.window_overlap_slider.setValue(0) 30 | widget.keep_data_on_cpu_box.setChecked(True) 31 | 32 | assert widget.check_ready() 33 | 34 | widget.model_choice.setCurrentText("WNet3D") 35 | widget._restrict_window_size_for_model() 36 | assert widget.use_window_choice.isChecked() 37 | assert widget.window_size_choice.currentText() == "64" 38 | 39 | test_model_name = "test" 40 | MODEL_LIST[test_model_name] = TestModel 41 | widget.model_choice.addItem(test_model_name) 42 | widget.model_choice.setCurrentText(test_model_name) 43 | 44 | widget.use_window_choice.setChecked(False) 45 | widget.worker_config = widget._set_worker_config() 46 | assert widget.worker_config is not None 47 | assert widget.model_info is not None 48 | 49 | worker = widget._create_worker_from_config(widget.worker_config) 50 | assert worker.config is not None 51 | assert worker.config.model_info is not None 52 | assert worker.config.sliding_window_config.is_enabled() is False 53 | worker.config.layer = viewer.layers[0] 54 | worker.config.post_process_config.instance.enabled = True 55 | worker.config.post_process_config.instance.method = ( 56 | INSTANCE_SEGMENTATION_METHOD_LIST["Watershed"]() 57 | ) 58 | 59 | assert worker.config.layer is not None 60 | worker.log_parameters() 61 | 62 | res = next(worker.inference()) 63 | assert isinstance(res, InferenceResult) 64 | assert res.semantic_segmentation.shape == (8, 8, 8) 65 | assert res.instance_labels.shape == (8, 8, 8) 66 | widget.on_yield(res) 67 | 68 | mock_image = rand_gen.random(size=(10, 10, 10)) 69 | mock_labels = rand_gen.integers(0, 10, (10, 10, 10)) 70 | mock_results = InferenceResult( 71 | image_id=0, 72 | original=mock_image, 73 | instance_labels=mock_labels, 74 | crf_results=mock_image, 75 | stats=[volume_stats(mock_labels)], 76 | semantic_segmentation=mock_image, 77 | model_name="test", 78 | ) 79 | num_layers = len(viewer.layers) 80 | widget.worker_config.post_process_config.instance.enabled = True 81 | widget._display_results(mock_results) 82 | assert len(viewer.layers) == num_layers + 4 83 | 84 | # assert widget.check_ready() 85 | # widget._setup_worker() 86 | # # widget.config.show_results = True 87 | # with qtbot.waitSignal(widget.worker.yielded, timeout=10000) as blocker: 88 | # blocker.connect( 89 | # widget.worker.errored 90 | # ) # Can add other signals to blocker 91 | # widget.worker.start() 92 | 93 | assert widget.on_finish() 94 | -------------------------------------------------------------------------------- /napari_cellseg3d/_tests/test_plugin_training.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | from monai.utils import set_determinism 4 | 5 | from napari_cellseg3d import config 6 | from napari_cellseg3d.code_plugins.plugin_model_training import ( 7 | Trainer, 8 | ) 9 | from napari_cellseg3d.config import MODEL_LIST 10 | 11 | im_path = Path(__file__).resolve().parent / "res/test.tif" 12 | im_path_str = str(im_path) 13 | 14 | 15 | def test_worker_configs(make_napari_viewer_proxy): 16 | set_determinism(seed=0) 17 | viewer = make_napari_viewer_proxy() 18 | widget = Trainer(viewer=viewer) 19 | # test supervised config and worker 20 | widget.device_choice.setCurrentIndex(0) 21 | widget.model_choice.setCurrentIndex(0) 22 | widget._toggle_unsupervised_mode(enabled=False) 23 | assert widget.model_choice.currentText() == list(MODEL_LIST.keys())[0] 24 | worker = widget._create_worker(additional_results_description="test") 25 | default_config = config.SupervisedTrainingWorkerConfig() 26 | excluded = [ 27 | "results_path_folder", 28 | "loss_function", 29 | "model_info", 30 | "sample_size", 31 | "weights_info", 32 | ] 33 | for attr in dir(default_config): 34 | if not attr.startswith("__") and attr not in excluded: 35 | assert getattr(default_config, attr) == getattr( 36 | worker.config, attr 37 | ) 38 | # test unsupervised config and worker 39 | widget.model_choice.setCurrentText("WNet3D") 40 | widget._toggle_unsupervised_mode(enabled=True) 41 | default_config = config.WNetTrainingWorkerConfig() 42 | worker = widget._create_worker(additional_results_description="TEST_1") 43 | excluded = ["results_path_folder", "sample_size", "weights_info"] 44 | for attr in dir(default_config): 45 | if not attr.startswith("__") and attr not in excluded: 46 | assert getattr(default_config, attr) == getattr( 47 | worker.config, attr 48 | ) 49 | widget.unsupervised_images_filewidget.text_field.setText( 50 | str((im_path.parent / "wnet_test").resolve()) 51 | ) 52 | widget.data = widget.create_dataset_dict_no_labs() 53 | worker = widget._create_worker(additional_results_description="TEST_2") 54 | dataloader, eval_dataloader, data_shape = worker._get_data() 55 | assert eval_dataloader is None 56 | assert data_shape == (6, 6, 6) 57 | 58 | widget.images_filepaths = [str(im_path)] 59 | widget.labels_filepaths = [str(im_path)] 60 | # widget.unsupervised_eval_data = widget.create_train_dataset_dict() 61 | worker = widget._create_worker(additional_results_description="TEST_3") 62 | dataloader, eval_dataloader, data_shape = worker._get_data() 63 | assert widget.unsupervised_eval_data is not None 64 | assert eval_dataloader is not None 65 | assert widget.unsupervised_eval_data[0]["image"] is not None 66 | assert widget.unsupervised_eval_data[0]["label"] is not None 67 | assert data_shape == (6, 6, 6) 68 | 69 | 70 | def test_update_loss_plot(make_napari_viewer_proxy): 71 | view = make_napari_viewer_proxy() 72 | widget = Trainer(view) 73 | 74 | widget.worker_config = config.SupervisedTrainingWorkerConfig() 75 | assert widget._is_current_job_supervised() is True 76 | widget.worker_config.validation_interval = 1 77 | widget.worker_config.results_path_folder = "." 78 | 79 | epoch_loss_values = {"loss": [1]} 80 | metric_values = [] 81 | widget.update_loss_plot(epoch_loss_values, metric_values) 82 | assert widget.plot_2 is None 83 | assert widget.plot_1 is None 84 | 85 | widget.worker_config.validation_interval = 2 86 | 87 | epoch_loss_values = {"loss": [0, 1]} 88 | metric_values = [0.2] 89 | widget.update_loss_plot(epoch_loss_values, metric_values) 90 | assert widget.plot_2 is None 91 | assert widget.plot_1 is None 92 | 93 | epoch_loss_values = {"loss": [0, 1, 0.5, 0.7]} 94 | metric_values = [0.1, 0.2] 95 | widget.update_loss_plot(epoch_loss_values, metric_values) 96 | assert widget.plot_2 is not None 97 | assert widget.plot_1 is not None 98 | 99 | epoch_loss_values = {"loss": [0, 1, 0.5, 0.7, 0.5, 0.7]} 100 | metric_values = [0.2, 0.3, 0.5, 0.7] 101 | widget.update_loss_plot(epoch_loss_values, metric_values) 102 | assert widget.plot_2 is not None 103 | assert widget.plot_1 is not None 104 | 105 | 106 | def test_check_matching_losses(): 107 | plugin = Trainer(None) 108 | config = plugin._set_worker_config() 109 | worker = plugin._create_supervised_worker_from_config(config) 110 | 111 | assert plugin.loss_list == list(worker.loss_dict.keys()) 112 | -------------------------------------------------------------------------------- /napari_cellseg3d/_tests/test_plugin_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from napari_cellseg3d.code_plugins.plugin_convert import StatsUtils 4 | from napari_cellseg3d.code_plugins.plugin_crop import Cropping 5 | from napari_cellseg3d.code_plugins.plugin_utilities import ( 6 | UTILITIES_WIDGETS, 7 | Utilities, 8 | ) 9 | from napari_cellseg3d.utils import rand_gen 10 | 11 | 12 | def test_utils_plugin(make_napari_viewer_proxy): 13 | view = make_napari_viewer_proxy() 14 | widget = Utilities(view) 15 | 16 | image = rand_gen.random((10, 10, 10)) # .astype(np.uint8) 17 | image_layer = view.add_image(image, name="image") 18 | label_layer = view.add_labels(image.astype(np.uint8), name="labels") 19 | 20 | view.window.add_dock_widget(widget) 21 | view.dims.ndisplay = 3 22 | for i, utils_name in enumerate(UTILITIES_WIDGETS.keys()): 23 | widget.utils_choice.setCurrentIndex(i) 24 | assert isinstance( 25 | widget.utils_widgets[i], UTILITIES_WIDGETS[utils_name] 26 | ) 27 | if utils_name == "Convert to instance labels": 28 | # to avoid issues with Voronoi-Otsu missing runtime 29 | menu = widget.utils_widgets[i].instance_widgets.method_choice 30 | menu.setCurrentIndex(menu.currentIndex() + 1) 31 | 32 | assert len(image_layer.data.shape) == 3 33 | assert len(label_layer.data.shape) == 3 34 | widget.utils_widgets[i]._start() 35 | 36 | 37 | def test_crop_widget(make_napari_viewer_proxy): 38 | view = make_napari_viewer_proxy() 39 | widget = Cropping(view) 40 | 41 | image = rand_gen.random((10, 10, 10)) 42 | image_layer_1 = view.add_image(image, name="image") 43 | image_layer_2 = view.add_labels(image.astype(np.uint16), name="image2") 44 | 45 | view.window.add_dock_widget(widget) 46 | view.dims.ndisplay = 3 47 | assert len(image_layer_1.data.shape) == 3 48 | assert len(image_layer_2.data.shape) == 3 49 | widget.crop_second_image_choice.setChecked(True) 50 | widget.aniso_widgets.checkbox.setChecked(True) 51 | 52 | widget._start() 53 | widget.create_new_layer.setChecked(True) 54 | widget.quicksave() 55 | 56 | widget.sliders[0].setValue(2) 57 | widget.sliders[1].setValue(2) 58 | widget.sliders[2].setValue(2) 59 | 60 | widget._start() 61 | 62 | 63 | def test_stats_plugin(make_napari_viewer_proxy): 64 | view = make_napari_viewer_proxy() 65 | widget = StatsUtils(view) 66 | 67 | labels = rand_gen.random((10, 10, 10)).astype(np.uint8) 68 | view.add_labels(labels, name="labels") 69 | 70 | view.window.add_dock_widget(widget) 71 | widget.csv_name.setText("test.csv") 72 | widget._start() 73 | -------------------------------------------------------------------------------- /napari_cellseg3d/_tests/test_plugins.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | from napari_cellseg3d import plugins 4 | from napari_cellseg3d.code_plugins import plugin_metrics as m 5 | 6 | 7 | def test_all_plugins_import(make_napari_viewer_proxy): 8 | plugins.napari_experimental_provide_dock_widget() 9 | 10 | 11 | def test_plugin_metrics(make_napari_viewer_proxy): 12 | viewer = make_napari_viewer_proxy() 13 | w = m.MetricsUtils(viewer=viewer, parent=None) 14 | viewer.window.add_dock_widget(w) 15 | 16 | im_path = str(Path(__file__).resolve().parent / "res/test.tif") 17 | labels_path = im_path 18 | 19 | w.image_filewidget.text_field = im_path 20 | w.labels_filewidget.text_field = labels_path 21 | w.compute_dice() 22 | -------------------------------------------------------------------------------- /napari_cellseg3d/_tests/test_review.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | from napari_cellseg3d.code_plugins import plugin_review as rev 4 | 5 | 6 | def test_launch_review(make_napari_viewer_proxy): 7 | view = make_napari_viewer_proxy() 8 | widget = rev.Reviewer(view) 9 | 10 | # widget.filetype_choice.setCurrentIndex(0) 11 | 12 | im_path = str(Path(__file__).resolve().parent / "res/test.tif") 13 | lab_path = str(Path(__file__).resolve().parent / "res/test_labels.tif") 14 | 15 | widget.folder_choice.setChecked(True) 16 | widget.image_filewidget.text_field = im_path 17 | widget.labels_filewidget.text_field = lab_path 18 | widget.results_filewidget.text_field = str( 19 | Path(__file__).resolve().parent / "res" 20 | ) 21 | 22 | widget.run_review() 23 | widget._viewer.close() 24 | 25 | assert widget._viewer is not None 26 | -------------------------------------------------------------------------------- /napari_cellseg3d/_tests/test_training.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | import pytest 4 | 5 | from napari_cellseg3d._tests.fixtures import ( 6 | LogFixture, 7 | LossFixture, 8 | ModelFixture, 9 | OptimizerFixture, 10 | SchedulerFixture, 11 | WNetFixture, 12 | ) 13 | from napari_cellseg3d.code_models.models.model_test import TestModel 14 | from napari_cellseg3d.code_models.workers_utils import TrainingReport 15 | from napari_cellseg3d.code_plugins.plugin_model_training import ( 16 | Trainer, 17 | ) 18 | from napari_cellseg3d.config import MODEL_LIST 19 | 20 | WANDB_MODE = "disabled" 21 | 22 | im_path = Path(__file__).resolve().parent / "res/test.tif" 23 | im_path_str = str(im_path) 24 | lab_path = Path(__file__).resolve().parent / "res/test_labels.tif" 25 | lab_path_str = str(lab_path) 26 | 27 | 28 | def test_supervised_training(make_napari_viewer_proxy): 29 | viewer = make_napari_viewer_proxy() 30 | widget = Trainer(viewer) 31 | widget.log = LogFixture() 32 | widget.model_choice.setCurrentIndex(0) 33 | 34 | widget.images_filepath = [] 35 | widget.labels_filepaths = [] 36 | 37 | assert not widget.unsupervised_mode 38 | assert not widget.check_ready() 39 | 40 | widget.images_filepaths = [im_path_str] 41 | widget.labels_filepaths = [lab_path_str] 42 | widget.epoch_choice.setValue(1) 43 | widget.val_interval_choice.setValue(1) 44 | widget.device_choice.setCurrentIndex(0) 45 | 46 | assert widget.check_ready() 47 | 48 | MODEL_LIST["test"] = TestModel 49 | widget.model_choice.addItem("test") 50 | widget.model_choice.setCurrentText("test") 51 | widget.unsupervised_mode = False 52 | worker_config = widget._set_worker_config() 53 | assert worker_config.model_info.name == "test" 54 | worker = widget._create_supervised_worker_from_config(worker_config) 55 | worker.config.train_data_dict = [ 56 | {"image": im_path_str, "label": im_path_str} 57 | ] 58 | worker.config.val_data_dict = [ 59 | {"image": im_path_str, "label": im_path_str} 60 | ] 61 | worker.config.max_epochs = 2 62 | worker.config.validation_interval = 2 63 | 64 | worker.log_parameters() 65 | for res_i in worker.train( 66 | provided_model=ModelFixture(), 67 | provided_optimizer=OptimizerFixture(), 68 | provided_loss=LossFixture(), 69 | provided_scheduler=SchedulerFixture(), 70 | ): 71 | assert isinstance(res_i, TrainingReport) 72 | res = res_i 73 | assert res.epoch == 1 74 | 75 | widget.worker = worker 76 | res.show_plot = True 77 | res.loss_1_values = {"loss": [1, 1, 1, 1]} 78 | res.loss_2_values = [1, 1, 1, 1] 79 | widget.on_yield(res) 80 | assert widget.loss_1_values["loss"] == [1, 1, 1, 1] 81 | assert widget.loss_2_values == [1, 1, 1, 1] 82 | 83 | 84 | def test_unsupervised_training(make_napari_viewer_proxy): 85 | viewer = make_napari_viewer_proxy() 86 | widget = Trainer(viewer) 87 | widget.log = LogFixture() 88 | widget.worker = None 89 | widget.model_choice.setCurrentText("WNet3D") 90 | widget._toggle_unsupervised_mode(enabled=True) 91 | 92 | widget.patch_choice.setChecked(True) 93 | [w.setValue(4) for w in widget.patch_size_widgets] 94 | 95 | widget.unsupervised_images_filewidget.text_field.setText( 96 | str((im_path.parent / "wnet_test").resolve()) 97 | ) 98 | # widget.start() 99 | widget.data = widget.create_dataset_dict_no_labs() 100 | widget.worker = widget._create_worker( 101 | additional_results_description="wnet_test" 102 | ) 103 | assert widget.worker.config.train_data_dict is not None 104 | widget.worker.config.max_epochs = 1 105 | for res_i in widget.worker.train( 106 | provided_model=WNetFixture(), 107 | provided_optimizer=OptimizerFixture(), 108 | provided_loss=LossFixture(), 109 | ): 110 | assert isinstance(res_i, TrainingReport) 111 | res = res_i 112 | assert res.epoch == 0 113 | widget.worker._abort_requested = True 114 | res = next( 115 | widget.worker.train( 116 | provided_model=WNetFixture(), 117 | provided_optimizer=OptimizerFixture(), 118 | provided_loss=LossFixture(), 119 | ) 120 | ) 121 | assert isinstance(res, TrainingReport) 122 | assert not res.show_plot 123 | with pytest.raises( 124 | AttributeError, 125 | match="'WNetTrainingWorker' object has no attribute 'model'", 126 | ): 127 | assert widget.worker.model is None 128 | 129 | widget.worker.config.eval_volume_dict = [ 130 | {"image": im_path_str, "label": im_path_str} 131 | ] 132 | widget.worker._get_data() 133 | eval_res = widget.worker.eval( 134 | model=WNetFixture(), 135 | epoch=-10, 136 | ) 137 | assert isinstance(eval_res, TrainingReport) 138 | assert eval_res.show_plot 139 | assert eval_res.epoch == -10 140 | -------------------------------------------------------------------------------- /napari_cellseg3d/_tests/test_utils.py: -------------------------------------------------------------------------------- 1 | import random 2 | from functools import partial 3 | from pathlib import Path 4 | 5 | import numpy as np 6 | import pytest 7 | import torch 8 | 9 | from napari_cellseg3d import utils 10 | from napari_cellseg3d.dev_scripts import thread_test 11 | 12 | rand_gen = utils.rand_gen 13 | 14 | 15 | def test_singleton_class(): 16 | class TestSingleton(metaclass=utils.Singleton): 17 | def __init__(self, value): 18 | self.value = value 19 | 20 | a = TestSingleton(1) 21 | b = TestSingleton(2) 22 | 23 | assert a.value == b.value 24 | 25 | 26 | def test_save_folder(): 27 | test_path = Path(__file__).resolve().parent / "res" 28 | folder_name = "test_folder" 29 | images = [rand_gen.random((5, 5, 5)).astype(np.float32) for _ in range(10)] 30 | images_paths = [f"{i}.tif" for i in range(10)] 31 | 32 | utils.save_folder( 33 | test_path, folder_name, images, images_paths, exist_ok=True 34 | ) 35 | assert (test_path / folder_name).is_dir() 36 | for i in range(10): 37 | assert (test_path / folder_name / images_paths[i]).is_file() 38 | 39 | 40 | def test_normalize_y(): 41 | test_array = np.array([0, 255, 127.5]) 42 | results = utils.normalize_y(test_array) 43 | expected = test_array / 255 44 | assert np.all(results == expected) 45 | assert np.all(test_array == utils.denormalize_y(results)) 46 | 47 | 48 | def test_sphericities(): 49 | for _i in range(100): 50 | mock_volume = random.randint(1, 10) 51 | mock_surface = random.randint( 52 | 100, 1000 53 | ) # assuming surface is always larger than volume 54 | sphericity_vol = utils.sphericity_volume_area( 55 | mock_volume, mock_surface 56 | ) 57 | assert 0 <= sphericity_vol <= 1 58 | 59 | semi_major = random.randint(10, 100) 60 | semi_minor = random.randint(10, 100) 61 | try: 62 | sphericity_axes = utils.sphericity_axis(semi_major, semi_minor) 63 | except ZeroDivisionError: 64 | sphericity_axes = 0 65 | except ValueError: 66 | sphericity_axes = 0 67 | if sphericity_axes is None: 68 | sphericity_axes = ( 69 | 0 # errors already handled in function, returns None 70 | ) 71 | assert 0 <= sphericity_axes <= 1 72 | 73 | 74 | def test_normalize_max(): 75 | test_array = np.array([0, 255, 127.5]) 76 | expected = np.array([0, 1, 0.5]) 77 | assert np.all(utils.normalize_max(test_array) == expected) 78 | 79 | 80 | def test_dice_coeff(): 81 | test_array = rand_gen.integers(0, 2, (64, 64, 64)) 82 | test_array_2 = rand_gen.integers(0, 2, (64, 64, 64)) 83 | assert utils.dice_coeff(test_array, test_array) == 1 84 | assert utils.dice_coeff(test_array, test_array_2) <= 1 85 | 86 | 87 | def test_fill_list_in_between(): 88 | test_list = [1, 2, 3, 4, 5, 6] 89 | res = [ 90 | 1, 91 | "", 92 | "", 93 | 2, 94 | "", 95 | "", 96 | 3, 97 | "", 98 | "", 99 | 4, 100 | "", 101 | "", 102 | 5, 103 | "", 104 | "", 105 | 6, 106 | "", 107 | "", 108 | ] 109 | 110 | assert utils.fill_list_in_between(test_list, 2, "") == res 111 | 112 | fill = partial(utils.fill_list_in_between, n=2, fill_value="") 113 | 114 | assert fill(test_list) == res 115 | 116 | 117 | def test_align_array_sizes(): 118 | im = np.zeros((128, 512, 256)) 119 | print(im.shape) 120 | 121 | dim_1 = (64, 64, 512) 122 | ground = np.array((512, 64, 64)) 123 | pred = np.array(dim_1) 124 | 125 | ori, targ = utils.align_array_sizes(ground, pred) 126 | 127 | im_1 = np.moveaxis(im, ori, targ) 128 | print(im_1.shape) 129 | assert im_1.shape == (512, 256, 128) 130 | 131 | dim_2 = (512, 256, 128) 132 | ground = np.array((128, 512, 256)) 133 | pred = np.array(dim_2) 134 | 135 | ori, targ = utils.align_array_sizes(ground, pred) 136 | 137 | im_2 = np.moveaxis(im, ori, targ) 138 | print(im_2.shape) 139 | assert im_2.shape == dim_2 140 | 141 | dim_3 = (128, 128, 128) 142 | ground = np.array(dim_3) 143 | pred = np.array(dim_3) 144 | 145 | ori, targ = utils.align_array_sizes(ground, pred) 146 | im_3 = np.moveaxis(im, ori, targ) 147 | print(im_3.shape) 148 | assert im_3.shape == im.shape 149 | 150 | 151 | def test_get_padding_dim(): 152 | tensor = torch.randn(100, 30, 40) 153 | size = tensor.size() 154 | 155 | pad = utils.get_padding_dim(size) 156 | 157 | assert pad == [128, 32, 64] 158 | 159 | tensor = torch.randn(2000, 30, 40) 160 | size = tensor.size() 161 | 162 | # warn = logger.warning( 163 | # "Warning : a very large dimension for automatic padding has been computed.\n" 164 | # "Ensure your images are of an appropriate size and/or that you have enough memory." 165 | # "The padding value is currently 2048." 166 | # ) 167 | # 168 | pad = utils.get_padding_dim(size) 169 | # 170 | # pytest.warns(warn, (lambda: utils.get_padding_dim(size))) 171 | 172 | assert pad == [2048, 32, 64] 173 | 174 | tensor = torch.randn(65, 70, 80) 175 | size = tensor.size() 176 | 177 | pad = utils.get_padding_dim(size) 178 | 179 | assert pad == [128, 128, 128] 180 | 181 | tensor_wrong = torch.randn(65, 70, 80, 90) 182 | with pytest.raises( 183 | ValueError, 184 | match="Please check the dimensions of the input, only 2 or 3-dimensional data is supported currently", 185 | ): 186 | utils.get_padding_dim(tensor_wrong.size()) 187 | 188 | 189 | def test_normalize_x(): 190 | test_array = utils.normalize_x(np.array([0, 255, 127.5])) 191 | expected = np.array([-1, 1, 0]) 192 | assert np.all(test_array == expected) 193 | 194 | 195 | def test_load_images(): 196 | path = Path(__file__).resolve().parent / "res" 197 | # with pytest.raises( 198 | # ValueError, match="If loading as a folder, filetype must be specified" 199 | # ): 200 | # images = utils.load_images(str(path), as_folder=True) 201 | # with pytest.raises( 202 | # NotImplementedError, 203 | # match="Loading as folder not implemented yet. Use napari to load as folder", 204 | # ): 205 | # images = utils.load_images(str(path), as_folder=True, filetype=".tif") 206 | # # assert len(images) == 1 207 | path = path / "test.tif" 208 | images = utils.load_images(str(path)) 209 | assert images.shape == (6, 6, 6) 210 | 211 | 212 | def test_parse_default_path(): 213 | user_path = Path.home() 214 | assert utils.parse_default_path([None]) == str(user_path) 215 | 216 | test_path = (Path.home() / "test" / "test" / "test" / "test").as_posix() 217 | path = [test_path, None, None] 218 | assert utils.parse_default_path(path, check_existence=False) == str( 219 | test_path 220 | ) 221 | 222 | test_path = (Path.home() / "test" / "does" / "not" / "exist").as_posix() 223 | path = [test_path, None, None] 224 | assert utils.parse_default_path(path, check_existence=True) == str( 225 | Path.home() 226 | ) 227 | 228 | long_path = Path.home() 229 | long_path = ( 230 | long_path 231 | / "very" 232 | / "long" 233 | / "path" 234 | / "what" 235 | / "a" 236 | / "bore" 237 | / "ifonlytherewassomething" 238 | / "tohelpmenotsearchit" 239 | / "allthetime" 240 | ) 241 | path = [test_path, None, None, long_path, ""] 242 | assert utils.parse_default_path(path, check_existence=False) == str( 243 | long_path.as_posix() 244 | ) 245 | 246 | 247 | def test_thread_test(make_napari_viewer_proxy): 248 | viewer = make_napari_viewer_proxy() 249 | w = thread_test.create_connected_widget(viewer) 250 | viewer.window.add_dock_widget(w) 251 | 252 | 253 | def test_quantile_norm(): 254 | array = rand_gen.random(size=(100, 100, 100)) 255 | low_quantile = np.quantile(array, 0.01) 256 | high_quantile = np.quantile(array, 0.99) 257 | array_norm = utils.quantile_normalization(array) 258 | assert array_norm.min() >= low_quantile 259 | assert array_norm.max() <= high_quantile 260 | 261 | 262 | def test_get_all_matching_files(): 263 | test_image_path = Path(__file__).resolve().parent / "res/wnet_test" 264 | paths = utils.get_all_matching_files(test_image_path) 265 | 266 | assert len(paths) == 1 267 | assert [Path(p).is_file() for p in paths] 268 | assert [Path(p).suffix == ".tif" for p in paths] 269 | -------------------------------------------------------------------------------- /napari_cellseg3d/_tests/test_weight_download.py: -------------------------------------------------------------------------------- 1 | from napari_cellseg3d.code_models.workers_utils import ( 2 | PRETRAINED_WEIGHTS_DIR, 3 | WeightsDownloader, 4 | ) 5 | 6 | 7 | # DISABLED, causes GitHub actions to freeze 8 | def test_weight_download(): 9 | downloader = WeightsDownloader() 10 | downloader.download_weights("test", "test.pth") 11 | result_path = PRETRAINED_WEIGHTS_DIR / "test.pth" 12 | 13 | assert result_path.is_file() 14 | -------------------------------------------------------------------------------- /napari_cellseg3d/_tests/test_wnet_training.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AdaptiveMotorControlLab/CellSeg3D/6de4b86a671ffcd4b5535277a53082ac5ecc00a1/napari_cellseg3d/_tests/test_wnet_training.py -------------------------------------------------------------------------------- /napari_cellseg3d/code_models/__init__.py: -------------------------------------------------------------------------------- 1 | """This folder contains the code used by models in the plugin. 2 | 3 | * ``models`` folder: contains the model classes, which are wrappers for the actual models. The wrappers are used to ensure that the models are compatible with the plugin. 4 | * model_framework.py: contains the code for the model framework, used by training and inference plugins 5 | * worker_inference.py: contains the code for the inference worker 6 | * worker_training.py: contains the code for the training worker 7 | * instance_segmentation.py: contains the code for instance segmentation 8 | * crf.py: contains the code for the CRF postprocessing 9 | * worker_utils.py: contains functions used by the workers 10 | 11 | """ 12 | -------------------------------------------------------------------------------- /napari_cellseg3d/code_models/models/TEMPLATE_model.py: -------------------------------------------------------------------------------- 1 | """This is a template for a model class. It is not used in the plugin, but is here to show how to implement a model class. 2 | 3 | Please note that custom model implementations are not fully supported out of the box yet, but might be in the future. 4 | """ 5 | 6 | from abc import ABC, abstractmethod 7 | 8 | 9 | class ModelTemplate_(ABC): 10 | """Template for a model class. This is not used in the plugin, but is here to show how to implement a model class.""" 11 | 12 | weights_file = ( 13 | "model_template.pth" # specify the file name of the weights file only 14 | ) 15 | default_threshold = 0.5 # specify the default threshold for the model 16 | 17 | @abstractmethod 18 | def __init__( 19 | self, input_image_size, in_channels=1, out_channels=1, **kwargs 20 | ): 21 | """Reimplement this as needed; only include input_image_size if necessary. For now only in/out channels = 1 is supported.""" 22 | pass 23 | 24 | @abstractmethod 25 | def forward(self, x): 26 | """Reimplement this as needed. Ensure that output is a torch tensor with dims (batch, channels, z, y, x).""" 27 | pass 28 | -------------------------------------------------------------------------------- /napari_cellseg3d/code_models/models/__init__.py: -------------------------------------------------------------------------------- 1 | """Contains model code and wrappers for the models, as classes.""" 2 | -------------------------------------------------------------------------------- /napari_cellseg3d/code_models/models/model_SegResNet.py: -------------------------------------------------------------------------------- 1 | """SegResNet wrapper for napari_cellseg3d.""" 2 | 3 | from monai.networks.nets import SegResNetVAE 4 | 5 | 6 | class SegResNet_(SegResNetVAE): 7 | """SegResNet_ wrapper for napari_cellseg3d.""" 8 | 9 | weights_file = "SegResNet_latest.pth" 10 | default_threshold = 0.3 11 | 12 | def __init__( 13 | self, input_img_size, out_channels=1, dropout_prob=0.3, **kwargs 14 | ): 15 | """Create a SegResNet model. 16 | 17 | Args: 18 | input_img_size (tuple): input image size 19 | out_channels (int): number of output channels 20 | dropout_prob (float): dropout probability. 21 | **kwargs: additional arguments to SegResNetVAE. 22 | """ 23 | super().__init__( 24 | input_img_size, 25 | out_channels=out_channels, 26 | dropout_prob=dropout_prob, 27 | ) 28 | 29 | def forward(self, x): 30 | """Forward pass of the SegResNet model.""" 31 | res = SegResNetVAE.forward(self, x) 32 | # logger.debug(f"SegResNetVAE.forward: {res[0].shape}") 33 | return res[0] 34 | 35 | # def get_model_test(self, size): 36 | # return SegResNetVAE( 37 | # size, in_channels=1, out_channels=1, dropout_prob=0.3 38 | # ) 39 | 40 | # def get_output(model, input): 41 | # out = model(input)[0] 42 | # return out 43 | 44 | # def get_validation(model, val_inputs): 45 | # val_outputs = model(val_inputs) 46 | # return val_outputs[0] 47 | -------------------------------------------------------------------------------- /napari_cellseg3d/code_models/models/model_SwinUNetR.py: -------------------------------------------------------------------------------- 1 | """SwinUNetR wrapper for napari_cellseg3d.""" 2 | 3 | from monai.networks.nets import SwinUNETR 4 | 5 | from napari_cellseg3d.utils import LOGGER 6 | 7 | logger = LOGGER 8 | 9 | 10 | class SwinUNETR_(SwinUNETR): 11 | """SwinUNETR wrapper for napari_cellseg3d.""" 12 | 13 | weights_file = "SwinUNetR_latest.pth" 14 | default_threshold = 0.4 15 | 16 | def __init__( 17 | self, 18 | in_channels=1, 19 | out_channels=1, 20 | input_img_size=(64, 64, 64), 21 | use_checkpoint=True, 22 | **kwargs, 23 | ): 24 | """Create a SwinUNetR model. 25 | 26 | Args: 27 | in_channels (int): number of input channels 28 | out_channels (int): number of output channels 29 | input_img_size (tuple): input image size 30 | use_checkpoint (bool): whether to use checkpointing during training. 31 | **kwargs: additional arguments to SwinUNETR. 32 | """ 33 | try: 34 | super().__init__( 35 | input_img_size, 36 | in_channels=in_channels, 37 | out_channels=out_channels, 38 | feature_size=48, 39 | use_checkpoint=use_checkpoint, 40 | drop_rate=0.5, 41 | attn_drop_rate=0.5, 42 | use_v2=True, 43 | **kwargs, 44 | ) 45 | except TypeError as e: 46 | logger.warning(f"Caught TypeError: {e}") 47 | super().__init__( 48 | input_img_size, 49 | in_channels=1, 50 | out_channels=1, 51 | feature_size=48, 52 | use_checkpoint=use_checkpoint, 53 | drop_rate=0.5, 54 | attn_drop_rate=0.5, 55 | use_v2=True, 56 | ) 57 | 58 | # def forward(self, x_in): 59 | # y = super().forward(x_in) 60 | # return softmax(y, dim=1) 61 | # return sigmoid(y) 62 | 63 | # def get_output(self, input): 64 | # out = self(input) 65 | # return torch.sigmoid(out) 66 | 67 | # def get_validation(self, val_inputs): 68 | # return self(val_inputs) 69 | -------------------------------------------------------------------------------- /napari_cellseg3d/code_models/models/model_TRAILMAP.py: -------------------------------------------------------------------------------- 1 | """Legacy version of adapted TRAILMAP model, not used in the current version of the plugin.""" 2 | # import torch 3 | # from torch import nn 4 | # 5 | # 6 | # class TRAILMAP(nn.Module): 7 | # def __init__(self, in_ch, out_ch, *args, **kwargs): 8 | # super().__init__() 9 | # self.conv0 = self.encoderBlock(in_ch, 32, 3) # input 10 | # self.conv1 = self.encoderBlock(32, 64, 3) # l1 11 | # self.conv2 = self.encoderBlock(64, 128, 3) # l2 12 | # self.conv3 = self.encoderBlock(128, 256, 3) # l3 13 | # 14 | # self.bridge = self.bridgeBlock(256, 512, 3) 15 | # 16 | # self.up5 = self.decoderBlock(256 + 512, 256, 2) 17 | # 18 | # self.up6 = self.decoderBlock(128 + 256, 128, 2) 19 | # self.up7 = self.decoderBlock(128 + 64, 64, 2) # l2 20 | # self.up8 = self.decoderBlock(64 + 32, 32, 2) # l1 21 | # self.out = self.outBlock(32, out_ch, 1) 22 | # 23 | # def forward(self, x): 24 | # conv0 = self.conv0(x) # l0 25 | # conv1 = self.conv1(conv0) # l1 26 | # conv2 = self.conv2(conv1) # l2 27 | # conv3 = self.conv3(conv2) # l3 28 | # 29 | # bridge = self.bridge(conv3) # bridge 30 | # # print("bridge :") 31 | # # print(bridge.shape) 32 | # 33 | # up5 = self.up5(torch.cat([conv3, bridge], 1)) # l3 34 | # # print("up") 35 | # # print(up5.shape) 36 | # up6 = self.up6(torch.cat([up5, conv2], 1)) # l2 37 | # # print(up6.shape) 38 | # up7 = self.up7(torch.cat([up6, conv1], 1)) # l1 39 | # # print(up7.shape) 40 | # 41 | # up8 = self.up8(torch.cat([up7, conv0], 1)) # l1 42 | # # print(up8.shape) 43 | # return self.out(up8) 44 | # # print("out:") 45 | # # print(out.shape) 46 | # 47 | # def encoderBlock(self, in_ch, out_ch, kernel_size, padding="same"): 48 | # return nn.Sequential( 49 | # nn.Conv3d(in_ch, out_ch, kernel_size=kernel_size, padding=padding), 50 | # nn.BatchNorm3d(out_ch), 51 | # nn.ReLU(), 52 | # nn.Conv3d( 53 | # out_ch, out_ch, kernel_size=kernel_size, padding=padding 54 | # ), 55 | # nn.BatchNorm3d(out_ch), 56 | # nn.ReLU(), 57 | # nn.MaxPool3d(2), 58 | # ) 59 | # 60 | # def bridgeBlock(self, in_ch, out_ch, kernel_size, padding="same"): 61 | # return nn.Sequential( 62 | # nn.Conv3d(in_ch, out_ch, kernel_size=kernel_size, padding=padding), 63 | # nn.BatchNorm3d(out_ch), 64 | # nn.ReLU(), 65 | # nn.Conv3d( 66 | # out_ch, out_ch, kernel_size=kernel_size, padding=padding 67 | # ), 68 | # nn.BatchNorm3d(out_ch), 69 | # nn.ReLU(), 70 | # ) 71 | # 72 | # def decoderBlock(self, in_ch, out_ch, kernel_size, padding="same"): 73 | # return nn.Sequential( 74 | # nn.Conv3d(in_ch, out_ch, kernel_size=kernel_size, padding=padding), 75 | # nn.BatchNorm3d(out_ch), 76 | # nn.ReLU(), 77 | # nn.Conv3d( 78 | # out_ch, out_ch, kernel_size=kernel_size, padding=padding 79 | # ), 80 | # nn.BatchNorm3d(out_ch), 81 | # nn.ReLU(), 82 | # nn.ConvTranspose3d( 83 | # out_ch, out_ch, kernel_size=kernel_size, stride=(2, 2, 2) 84 | # ), 85 | # ) 86 | # 87 | # def outBlock(self, in_ch, out_ch, kernel_size, padding="same"): 88 | # return nn.Sequential( 89 | # nn.Conv3d(in_ch, out_ch, kernel_size=kernel_size, padding=padding), 90 | # ) 91 | # 92 | # 93 | # class TRAILMAP_(TRAILMAP): 94 | # weights_file = "TRAILMAP_PyTorch.pth" # model additionally trained on Mathis/Wyss mesoSPIM data 95 | # # FIXME currently incorrect, find good weights from TRAILMAP_test and upload them 96 | # 97 | # def __init__(self, in_channels=1, out_channels=1, **kwargs): 98 | # super().__init__(in_channels, out_channels, **kwargs) 99 | # 100 | # # def get_output(model, input): 101 | # # out = model(input) 102 | # # 103 | # # return out 104 | # 105 | # # def get_validation(model, val_inputs): 106 | # # return model(val_inputs) 107 | -------------------------------------------------------------------------------- /napari_cellseg3d/code_models/models/model_TRAILMAP_MS.py: -------------------------------------------------------------------------------- 1 | """TRAILMAP model, reimplemented in PyTorch.""" 2 | 3 | from napari_cellseg3d.code_models.models.unet.model import UNet3D 4 | from napari_cellseg3d.utils import LOGGER as logger 5 | 6 | 7 | class TRAILMAP_MS_(UNet3D): 8 | """TRAILMAP_MS wrapper for napari_cellseg3d.""" 9 | 10 | weights_file = "TRAILMAP_MS_best_metric.pth" 11 | default_threshold = 0.15 12 | 13 | # original model from Liqun Luo lab, transferred to pytorch and trained on mesoSPIM-acquired data (mostly TPH2 as of July 2022) 14 | 15 | def __init__(self, in_channels=1, out_channels=1, **kwargs): 16 | """Create a TRAILMAP_MS model. 17 | 18 | Args: 19 | in_channels (int): number of input channels 20 | out_channels (int): number of output channels. 21 | **kwargs: additional arguments to UNet3D. 22 | """ 23 | try: 24 | super().__init__( 25 | in_channels=in_channels, out_channels=out_channels, **kwargs 26 | ) 27 | except TypeError as e: 28 | logger.warning(f"Caught TypeError: {e}") 29 | super().__init__( 30 | in_channels=in_channels, out_channels=out_channels 31 | ) 32 | 33 | # def get_output(self, input): 34 | # out = self(input) 35 | 36 | # return out 37 | # 38 | # def get_validation(self, val_inputs): 39 | # return self(val_inputs) 40 | -------------------------------------------------------------------------------- /napari_cellseg3d/code_models/models/model_VNet.py: -------------------------------------------------------------------------------- 1 | """VNet wrapper for napari_cellseg3d.""" 2 | 3 | from monai.networks.nets import VNet 4 | 5 | 6 | class VNet_(VNet): 7 | """VNet wrapper for napari_cellseg3d.""" 8 | 9 | weights_file = "VNet_latest.pth" 10 | default_threshold = 0.15 11 | 12 | def __init__(self, in_channels=1, out_channels=1, **kwargs): 13 | """Create a VNet model. 14 | 15 | Args: 16 | in_channels (int): number of input channels 17 | out_channels (int): number of output channels. 18 | **kwargs: additional arguments to VNet. 19 | """ 20 | try: 21 | super().__init__( 22 | in_channels=in_channels, 23 | out_channels=out_channels, 24 | bias=True, 25 | **kwargs, 26 | ) 27 | except TypeError: 28 | super().__init__( 29 | in_channels=in_channels, out_channels=out_channels, bias=True 30 | ) 31 | -------------------------------------------------------------------------------- /napari_cellseg3d/code_models/models/model_WNet.py: -------------------------------------------------------------------------------- 1 | """Wrapper for the W-Net model, with the decoder weights removed. 2 | 3 | .. important:: Used for inference only. For training the base class is used. 4 | """ 5 | 6 | # local 7 | from napari_cellseg3d.code_models.models.wnet.model import WNet_encoder 8 | from napari_cellseg3d.utils import remap_image 9 | 10 | 11 | class WNet_(WNet_encoder): 12 | """W-Net wrapper for napari_cellseg3d. 13 | 14 | ..important:: Used for inference only, therefore only the encoder is used. For training the base class is used. 15 | """ 16 | 17 | weights_file = "wnet_latest.pth" 18 | default_threshold = 0.6 19 | 20 | def __init__( 21 | self, 22 | in_channels=1, 23 | out_channels=2, 24 | # num_classes=2, 25 | **kwargs, 26 | ): 27 | """Create a W-Net model. 28 | 29 | Args: 30 | in_channels (int): number of input channels 31 | out_channels (int): number of output channels. 32 | **kwargs: additional arguments to WNet_encoder. 33 | """ 34 | super().__init__( 35 | in_channels=in_channels, 36 | out_channels=out_channels, 37 | # num_classes=num_classes, 38 | softmax=False, 39 | ) 40 | 41 | # def train(self: T, mode: bool = True) -> T: 42 | # raise NotImplementedError("Training not implemented for WNet") 43 | 44 | def forward(self, x): 45 | """Forward pass of the W-Net model.""" 46 | norm_x = remap_image(x) 47 | return super().forward(norm_x) 48 | 49 | def load_state_dict(self, state_dict, strict=True): 50 | """Load the model state dict for inference, without the decoder weights.""" 51 | encoder_checkpoint = state_dict.copy() 52 | for k in state_dict: 53 | if k.startswith("decoder"): 54 | encoder_checkpoint.pop(k) 55 | # print(encoder_checkpoint.keys()) 56 | super().load_state_dict(encoder_checkpoint, strict=strict) 57 | -------------------------------------------------------------------------------- /napari_cellseg3d/code_models/models/model_test.py: -------------------------------------------------------------------------------- 1 | """Model for testing purposes.""" 2 | 3 | import torch 4 | from torch import nn 5 | 6 | 7 | class TestModel(nn.Module): 8 | """For tests only.""" 9 | 10 | weights_file = "test.pth" 11 | default_threshold = 0.5 12 | 13 | def __init__(self, **kwargs): 14 | """Create a TestModel model.""" 15 | super().__init__() 16 | self.linear = nn.Linear(8, 8) 17 | 18 | def forward(self, x): 19 | """Forward pass of the TestModel model.""" 20 | return self.linear(torch.tensor(x, requires_grad=True)) 21 | 22 | # def get_output(self, _, input): 23 | # return input 24 | 25 | # def get_validation(self, val_inputs): 26 | # return val_inputs 27 | 28 | 29 | if __name__ == "__main__": 30 | model = TestModel() 31 | model.train() 32 | model.zero_grad() 33 | from napari_cellseg3d.config import PRETRAINED_WEIGHTS_DIR 34 | 35 | torch.save( 36 | model.state_dict(), 37 | PRETRAINED_WEIGHTS_DIR + f"/{TestModel.weights_file}", 38 | ) 39 | -------------------------------------------------------------------------------- /napari_cellseg3d/code_models/models/pretrained/__init__.py: -------------------------------------------------------------------------------- 1 | """Hosts the downloaded pretrained model weights. 2 | 3 | Please feel free to delete weights if you do not need them. 4 | They will be downloaded again automatically if needed. 5 | """ 6 | -------------------------------------------------------------------------------- /napari_cellseg3d/code_models/models/pretrained/pretrained_model_urls.json: -------------------------------------------------------------------------------- 1 | { 2 | "TRAILMAP_MS": "https://huggingface.co/C-Achard/cellseg3d/resolve/main/TRAILMAP_latest.tar.gz", 3 | "SegResNet": "https://huggingface.co/C-Achard/cellseg3d/resolve/main/SegResNet_latest.tar.gz", 4 | "VNet": "https://huggingface.co/C-Achard/cellseg3d/resolve/main/VNet_latest.tar.gz", 5 | "SwinUNetR": "https://huggingface.co/C-Achard/cellseg3d/resolve/main/SwinUNetR_latest.tar.gz", 6 | "WNet": "https://huggingface.co/C-Achard/cellseg3d/resolve/main/wnet_latest.tar.gz", 7 | "WNet3D": "https://huggingface.co/C-Achard/cellseg3d/resolve/main/wnet_latest.tar.gz", 8 | "WNet_ONNX": "https://huggingface.co/C-Achard/cellseg3d/resolve/main/wnet_onnx.tar.gz", 9 | "test": "https://huggingface.co/C-Achard/cellseg3d/resolve/main/test.tar.gz" 10 | } 11 | -------------------------------------------------------------------------------- /napari_cellseg3d/code_models/models/unet/__init__.py: -------------------------------------------------------------------------------- 1 | """Building block of a UNet model. 2 | 3 | Used mostly by the TRAILMAP model. 4 | """ 5 | -------------------------------------------------------------------------------- /napari_cellseg3d/code_models/models/unet/model.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | from napari_cellseg3d.code_models.models.unet.buildingblocks import ( 4 | DoubleConv, 5 | create_decoders, 6 | create_encoders, 7 | ) 8 | 9 | 10 | def number_of_features_per_level(init_channel_number, num_levels): 11 | return [init_channel_number * 2**k for k in range(num_levels)] 12 | 13 | 14 | class Abstract3DUNet(nn.Module): 15 | """ 16 | Base class for standard and residual UNet. 17 | 18 | Args: 19 | in_channels (int): number of input channels 20 | out_channels (int): number of output segmentation masks; 21 | Note that that the of out_channels might correspond to either 22 | different semantic classes or to different binary segmentation mask. 23 | It's up to the user of the class to interpret the out_channels and 24 | use the proper loss criterion during training (i.e. CrossEntropyLoss (multi-class) 25 | or BCEWithLogitsLoss (two-class) respectively) 26 | f_maps (int, tuple): number of feature maps at each level of the encoder; if it's an integer the number 27 | of feature maps is given by the geometric progression: f_maps ^ k, k=1,2,3,4 28 | final_sigmoid (bool): if True apply element-wise nn.Sigmoid after the 29 | final 1x1 convolution, otherwise apply nn.Softmax. MUST be True if nn.BCELoss (two-class) is used 30 | to train the model. MUST be False if nn.CrossEntropyLoss (multi-class) is used to train the model. 31 | basic_module: basic model for the encoder/decoder (DoubleConv, ExtResNetBlock, ....) 32 | layer_order (string): determines the order of layers 33 | in `SingleConv` module. e.g. 'crg' stands for Conv3d+ReLU+GroupNorm3d. 34 | See `SingleConv` for more info 35 | num_groups (int): number of groups for the GroupNorm 36 | num_levels (int): number of levels in the encoder/decoder path (applied only if f_maps is an int) 37 | is_segmentation (bool): if True (semantic segmentation problem) Sigmoid/Softmax normalization is applied 38 | after the final convolution; if False (regression problem) the normalization layer is skipped at the end 39 | conv_kernel_size (int or tuple): size of the convolving kernel in the basic_module 40 | pool_kernel_size (int or tuple): the size of the window 41 | conv_padding (int or tuple): add zero-padding added to all three sides of the input 42 | """ 43 | 44 | def __init__( 45 | self, 46 | in_channels, 47 | out_channels, 48 | final_sigmoid, 49 | basic_module, 50 | f_maps=64, 51 | layer_order="gcr", 52 | num_groups=8, 53 | num_levels=4, 54 | is_segmentation=True, 55 | conv_kernel_size=3, 56 | pool_kernel_size=2, 57 | conv_padding=1, 58 | **kwargs, 59 | ): 60 | super(Abstract3DUNet, self).__init__() 61 | 62 | if isinstance(f_maps, int): 63 | f_maps = number_of_features_per_level( 64 | f_maps, num_levels=num_levels 65 | ) 66 | 67 | assert isinstance(f_maps, (list, tuple)) 68 | assert len(f_maps) > 1, "Required at least 2 levels in the U-Net" 69 | 70 | # create encoder path 71 | self.encoders = create_encoders( 72 | in_channels, 73 | f_maps, 74 | basic_module, 75 | conv_kernel_size, 76 | conv_padding, 77 | layer_order, 78 | num_groups, 79 | pool_kernel_size, 80 | ) 81 | 82 | # create decoder path 83 | self.decoders = create_decoders( 84 | f_maps, 85 | basic_module, 86 | conv_kernel_size, 87 | conv_padding, 88 | layer_order, 89 | num_groups, 90 | upsample=True, 91 | ) 92 | 93 | # in the last layer a 1×1 convolution reduces the number of output 94 | # channels to the number of labels 95 | self.final_conv = nn.Conv3d(f_maps[0], out_channels, 1) 96 | 97 | if is_segmentation: 98 | # semantic segmentation problem 99 | if final_sigmoid: 100 | self.final_activation = nn.Sigmoid() 101 | else: 102 | self.final_activation = nn.Softmax(dim=1) 103 | else: 104 | # regression problem 105 | self.final_activation = None 106 | 107 | def forward(self, x): 108 | # encoder part 109 | encoders_features = [] 110 | for encoder in self.encoders: 111 | x = encoder(x) 112 | # reverse the encoder outputs to be aligned with the decoder 113 | encoders_features.insert(0, x) 114 | 115 | # remove the last encoder's output from the list 116 | # !!remember: it's the 1st in the list 117 | encoders_features = encoders_features[1:] 118 | 119 | # decoder part 120 | for decoder, encoder_features in zip(self.decoders, encoders_features): 121 | # pass the output from the corresponding encoder and the output 122 | # of the previous decoder 123 | x = decoder(encoder_features, x) 124 | 125 | x = self.final_conv(x) 126 | 127 | # apply final_activation (i.e. Sigmoid or Softmax) only during prediction. During training the network outputs logits 128 | if not self.training and self.final_activation is not None: 129 | x = self.final_activation(x) 130 | 131 | return x 132 | 133 | 134 | class UNet3D(Abstract3DUNet): 135 | """ 136 | 3DUnet model from 137 | `"3D U-Net: Learning Dense Volumetric Segmentation from Sparse Annotation" 138 | `. 139 | 140 | Uses `DoubleConv` as a basic_module and nearest neighbor upsampling in the decoder 141 | """ 142 | 143 | def __init__( 144 | self, 145 | in_channels, 146 | out_channels, 147 | final_sigmoid=True, 148 | f_maps=64, 149 | layer_order="crb", # gcr for groupnorm 150 | num_groups=8, 151 | num_levels=4, 152 | is_segmentation=True, 153 | conv_padding=1, 154 | **kwargs, 155 | ): 156 | super(UNet3D, self).__init__( 157 | in_channels=in_channels, 158 | out_channels=out_channels, 159 | final_sigmoid=final_sigmoid, 160 | basic_module=DoubleConv, 161 | f_maps=f_maps, 162 | layer_order=layer_order, 163 | num_groups=num_groups, 164 | num_levels=num_levels, 165 | is_segmentation=is_segmentation, 166 | conv_padding=conv_padding, 167 | **kwargs, 168 | ) 169 | -------------------------------------------------------------------------------- /napari_cellseg3d/code_models/models/wnet/__init__.py: -------------------------------------------------------------------------------- 1 | """Building blocks for WNet model.""" 2 | -------------------------------------------------------------------------------- /napari_cellseg3d/code_models/models/wnet/soft_Ncuts.py: -------------------------------------------------------------------------------- 1 | """Implementation of a 3D Soft N-Cuts loss based on https://arxiv.org/abs/1711.08506 and https://ieeexplore.ieee.org/document/868688. 2 | 3 | The implementation was adapted and approximated to reduce computational and memory cost. 4 | This faster version was proposed on https://github.com/fkodom/wnet-unsupervised-image-segmentation. 5 | """ 6 | 7 | import math 8 | 9 | import numpy as np 10 | import torch 11 | import torch.nn as nn 12 | import torch.nn.functional as F 13 | from scipy.stats import norm 14 | 15 | from napari_cellseg3d.utils import LOGGER as logger 16 | 17 | __author__ = "Yves Paychère, Colin Hofmann, Cyril Achard" 18 | __credits__ = [ 19 | "Yves Paychère", 20 | "Colin Hofmann", 21 | "Cyril Achard", 22 | "Xide Xia", 23 | "Brian Kulis", 24 | "Jianbo Shi", 25 | "Jitendra Malik", 26 | "Frank Odom", 27 | ] 28 | 29 | 30 | class SoftNCutsLoss(nn.Module): 31 | """Implementation of a 3D Soft N-Cuts loss based on https://arxiv.org/abs/1711.08506 and https://ieeexplore.ieee.org/document/868688. 32 | 33 | Args: 34 | data_shape (H, W, D): shape of the images as a tuple. 35 | intensity_sigma (scalar): scale of the gaussian kernel of pixels brightness. 36 | spatial_sigma (scalar): scale of the gaussian kernel of pixels spacial distance. 37 | radius (scalar): radius of pixels for which we compute the weights 38 | """ 39 | 40 | def __init__( 41 | self, data_shape, device, intensity_sigma, spatial_sigma, radius=None 42 | ): 43 | """Initialize the Soft N-Cuts loss. 44 | 45 | Args: 46 | data_shape (H, W, D): shape of the images as a tuple. 47 | device (torch.device): device on which the loss is computed. 48 | intensity_sigma (scalar): scale of the gaussian kernel of pixels brightness. 49 | spatial_sigma (scalar): scale of the gaussian kernel of pixels spacial distance. 50 | radius (scalar): radius of pixels for which we compute the weights 51 | """ 52 | super(SoftNCutsLoss, self).__init__() 53 | self.intensity_sigma = intensity_sigma 54 | self.spatial_sigma = spatial_sigma 55 | self.radius = radius 56 | self.H = data_shape[0] 57 | self.W = data_shape[1] 58 | self.D = data_shape[2] 59 | self.device = device 60 | 61 | if self.radius is None: 62 | self.radius = min( 63 | max(5, math.ceil(min(self.H, self.W, self.D) / 20)), 64 | self.H, 65 | self.W, 66 | self.D, 67 | ) 68 | logger.info(f"Radius set to {self.radius}") 69 | 70 | def forward(self, labels, inputs): 71 | """Forward pass of the Soft N-Cuts loss. 72 | 73 | Args: 74 | labels (torch.Tensor): Tensor of shape (N, K, H, W, D) containing the predicted class probabilities for each pixel. 75 | inputs (torch.Tensor): Tensor of shape (N, C, H, W, D) containing the input images. 76 | 77 | Returns: 78 | The Soft N-Cuts loss of shape (N,). 79 | """ 80 | # inputs.shape[0] 81 | # inputs.shape[1] 82 | K = labels.shape[1] 83 | 84 | labels.to(self.device) 85 | inputs.to(self.device) 86 | 87 | loss = 0 88 | 89 | kernel = self.gaussian_kernel(self.radius, self.spatial_sigma).to( 90 | self.device 91 | ) 92 | 93 | for k in range(K): 94 | # Compute the average pixel value for this class, and the difference from each pixel 95 | class_probs = labels[:, k].unsqueeze(1) 96 | class_mean = torch.mean( 97 | inputs * class_probs, dim=(2, 3, 4), keepdim=True 98 | ) / torch.add( 99 | torch.mean(class_probs, dim=(2, 3, 4), keepdim=True), 1e-5 100 | ) 101 | diff = (inputs - class_mean).pow(2).sum(dim=1).unsqueeze(1) 102 | 103 | # Weight the loss by the difference from the class average. 104 | weights = torch.exp( 105 | diff.pow(2).mul(-1 / self.intensity_sigma**2) 106 | ) 107 | 108 | numerator = torch.sum( 109 | class_probs 110 | * F.conv3d(class_probs * weights, kernel, padding=self.radius), 111 | dim=(1, 2, 3, 4), 112 | ) 113 | denominator = torch.sum( 114 | class_probs * F.conv3d(weights, kernel, padding=self.radius), 115 | dim=(1, 2, 3, 4), 116 | ) 117 | loss += nn.L1Loss()( 118 | numerator / torch.add(denominator, 1e-6), 119 | torch.zeros_like(numerator), 120 | ) 121 | 122 | return K - loss 123 | 124 | def gaussian_kernel(self, radius, sigma): 125 | """Computes the Gaussian kernel. 126 | 127 | Args: 128 | radius (int): The radius of the kernel. 129 | sigma (float): The standard deviation of the Gaussian distribution. 130 | 131 | Returns: 132 | The Gaussian kernel of shape (1, 1, 2*radius+1, 2*radius+1, 2*radius+1). 133 | """ 134 | x_2 = np.linspace(-radius, radius, 2 * radius + 1) ** 2 135 | dist = ( 136 | np.sqrt( 137 | x_2.reshape(-1, 1, 1) 138 | + x_2.reshape(1, -1, 1) 139 | + x_2.reshape(1, 1, -1) 140 | ) 141 | / sigma 142 | ) 143 | kernel = norm.pdf(dist) / norm.pdf(0) 144 | kernel = torch.from_numpy(kernel.astype(np.float32)) 145 | return kernel.view( 146 | (1, 1, kernel.shape[0], kernel.shape[1], kernel.shape[2]) 147 | ) 148 | -------------------------------------------------------------------------------- /napari_cellseg3d/code_plugins/__init__.py: -------------------------------------------------------------------------------- 1 | """This folder contains all plugin-related code.""" 2 | -------------------------------------------------------------------------------- /napari_cellseg3d/code_plugins/plugin_helper.py: -------------------------------------------------------------------------------- 1 | """Tiny plugin showing link to documentation and about page.""" 2 | 3 | import pathlib 4 | from typing import TYPE_CHECKING 5 | 6 | if TYPE_CHECKING: 7 | import napari 8 | 9 | # Qt 10 | from qtpy.QtCore import QSize 11 | from qtpy.QtGui import QIcon, QPixmap 12 | from qtpy.QtWidgets import QVBoxLayout, QWidget 13 | 14 | # local 15 | from napari_cellseg3d import interface as ui 16 | 17 | 18 | class Helper(QWidget, metaclass=ui.QWidgetSingleton): 19 | """Tiny plugin showing link to documentation and about page.""" 20 | 21 | def __init__(self, viewer: "napari.viewer.Viewer"): 22 | """Creates a widget with links to documentation and about page.""" 23 | super().__init__() 24 | 25 | self.help_url = "https://adaptivemotorcontrollab.github.io/CellSeg3D/" 26 | 27 | self.about_url = "https://wysscenter.ch/advances/3d-computer-vision-for-brain-analysis" 28 | self.repo_url = "https://github.com/AdaptiveMotorControlLab/CellSeg3D" 29 | self._viewer = viewer 30 | 31 | logo_path = str( 32 | pathlib.Path(__file__).parent.resolve() / "../res/logo_alpha.png" 33 | ) 34 | print(logo_path) 35 | image = QPixmap(logo_path) 36 | 37 | self.logo_label = ui.Button(func=lambda: ui.open_url(self.repo_url)) 38 | self.logo_label.setIcon(QIcon(image)) 39 | self.logo_label.setMinimumSize(200, 200) 40 | self.logo_label.setIconSize(QSize(200, 200)) 41 | self.logo_label.setStyleSheet( 42 | "QPushButton { background-color: transparent }" 43 | ) 44 | self.logo_label.setToolTip("Open Github page") 45 | 46 | self.info_label = ui.make_label( 47 | f"You are using napari-cellseg3d v.{'0.2.2'}\n\n" 48 | f"Plugin for cell segmentation developed\n" 49 | f"by the Mathis Lab of Adaptive Motor Control\n\n" 50 | f"Code by :\nCyril Achard\nMaxime Vidal\nJessy Lauer\nMackenzie Mathis\n" 51 | f"\nReleased under the MIT license", 52 | self, 53 | ) 54 | 55 | self.btn1 = ui.Button("Help...", lambda: ui.open_url(self.help_url)) 56 | self.btn1.setToolTip("Go to documentation") 57 | 58 | self.btn2 = ui.Button("About...", lambda: ui.open_url(self.about_url)) 59 | 60 | self.btnc = ui.Button("Close", self.remove_from_viewer) 61 | 62 | self.build() 63 | 64 | def build(self): 65 | """Build the widget."".""" 66 | vbox = QVBoxLayout() 67 | 68 | widgets = [ 69 | self.logo_label, 70 | self.info_label, 71 | self.btn1, 72 | self.btn2, 73 | self.btnc, 74 | ] 75 | ui.add_widgets(vbox, widgets) 76 | self.setLayout(vbox) 77 | 78 | def remove_from_viewer(self): 79 | """Remove the widget from the viewer.""" 80 | self._viewer.window.remove_dock_widget(self) 81 | -------------------------------------------------------------------------------- /napari_cellseg3d/dev_scripts/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AdaptiveMotorControlLab/CellSeg3D/6de4b86a671ffcd4b5535277a53082ac5ecc00a1/napari_cellseg3d/dev_scripts/__init__.py -------------------------------------------------------------------------------- /napari_cellseg3d/dev_scripts/colab_training.py: -------------------------------------------------------------------------------- 1 | """Script to run WNet training in Google Colab.""" 2 | 3 | import time 4 | from pathlib import Path 5 | from typing import TYPE_CHECKING 6 | 7 | from monai.data import CacheDataset 8 | 9 | # MONAI 10 | from monai.metrics import DiceMetric 11 | from monai.transforms import ( 12 | Compose, 13 | EnsureChannelFirstd, 14 | EnsureTyped, 15 | LoadImaged, 16 | Orientationd, 17 | ) 18 | 19 | # local 20 | from napari_cellseg3d import config, utils 21 | from napari_cellseg3d.code_models.worker_training import WNetTrainingWorker 22 | from napari_cellseg3d.code_models.workers_utils import ( 23 | PRETRAINED_WEIGHTS_DIR, 24 | ) 25 | 26 | if TYPE_CHECKING: 27 | from monai.data import DataLoader 28 | 29 | logger = utils.LOGGER 30 | VERBOSE_SCHEDULER = True 31 | logger.debug(f"PRETRAINED WEIGHT DIR LOCATION : {PRETRAINED_WEIGHTS_DIR}") 32 | 33 | 34 | class LogFixture: 35 | """Fixture for napari-less logging, replaces napari_cellseg3d.interface.Log in model_workers. 36 | 37 | This allows to redirect the output of the workers to stdout instead of a specialized widget. 38 | """ 39 | 40 | def __init__(self): 41 | """Creates a LogFixture object.""" 42 | super(LogFixture, self).__init__() 43 | 44 | def print_and_log(self, text, printing=None): 45 | """Prints and logs text.""" 46 | print(text) 47 | 48 | def warn(self, warning): 49 | """Logs warning.""" 50 | logger.warning(warning) 51 | 52 | def error(self, e): 53 | """Logs error.""" 54 | raise (e) 55 | 56 | 57 | class WNetTrainingWorkerColab(WNetTrainingWorker): 58 | """A custom worker to run WNet (unsupervised) training jobs in. 59 | 60 | Inherits from :py:class:`napari.qt.threading.GeneratorWorker` via :py:class:`TrainingWorkerBase`. 61 | """ 62 | 63 | def __init__( 64 | self, 65 | worker_config: config.WNetTrainingWorkerConfig, 66 | wandb_config: config.WandBConfig = None, 67 | ): 68 | """Create a WNet training worker for Google Colab. 69 | 70 | Args: 71 | worker_config: worker configuration 72 | wandb_config: optional wandb configuration 73 | """ 74 | super().__init__(worker_config) 75 | super().__init__(worker_config) 76 | self.wandb_config = ( 77 | wandb_config if wandb_config is not None else config.WandBConfig() 78 | ) 79 | 80 | self.dice_metric = DiceMetric( 81 | include_background=False, reduction="mean", get_not_nans=False 82 | ) 83 | self.normalize_function = utils.remap_image 84 | self.start_time = time.time() 85 | self.ncuts_losses = [] 86 | self.rec_losses = [] 87 | self.total_losses = [] 88 | self.best_dice = -1 89 | self.dice_values = [] 90 | 91 | self.dataloader: DataLoader = None 92 | self.eval_dataloader: DataLoader = None 93 | self.data_shape = None 94 | 95 | def get_dataset(self, train_transforms): 96 | """Creates a Dataset applying some transforms/augmentation on the data using the MONAI library. 97 | 98 | Args: 99 | train_transforms (monai.transforms.Compose): The transforms to apply to the data 100 | 101 | Returns: 102 | (tuple): A tuple containing the shape of the data and the dataset 103 | """ 104 | train_files = self.config.train_data_dict 105 | 106 | first_volume = LoadImaged(keys=["image"])(train_files[0]) 107 | first_volume_shape = first_volume["image"].shape 108 | 109 | if len(first_volume_shape) != 3: 110 | raise ValueError( 111 | f"Expected 3D volumes, got {len(first_volume_shape)} dimensions" 112 | ) 113 | 114 | # Transforms to be applied to each volume 115 | load_single_images = Compose( 116 | [ 117 | LoadImaged(keys=["image"]), 118 | EnsureChannelFirstd( 119 | keys=["image"], 120 | channel_dim="no_channel", 121 | strict_check=False, 122 | ), 123 | Orientationd(keys=["image"], axcodes="PLI"), 124 | # SpatialPadd( 125 | # keys=["image"], 126 | # spatial_size=(utils.get_padding_dim(first_volume_shape)), 127 | # ), 128 | EnsureTyped(keys=["image"]), 129 | # RemapTensord(keys=["image"], new_min=0.0, new_max=100.0), 130 | ] 131 | ) 132 | 133 | # Create the dataset 134 | dataset = CacheDataset( 135 | data=train_files, 136 | transform=Compose([load_single_images, train_transforms]), 137 | ) 138 | 139 | return first_volume_shape, dataset 140 | 141 | 142 | def get_colab_worker( 143 | worker_config: config.WNetTrainingWorkerConfig, 144 | wandb_config: config.WandBConfig, 145 | ): 146 | """Train a WNet model in Google Colab. 147 | 148 | Args: 149 | worker_config (config.WNetTrainingWorkerConfig): config for the training worker 150 | wandb_config (config.WandBConfig): config for wandb 151 | """ 152 | log = LogFixture() 153 | worker = WNetTrainingWorkerColab(worker_config, wandb_config) 154 | 155 | worker.log_signal.connect(log.print_and_log) 156 | worker.warn_signal.connect(log.warn) 157 | worker.error_signal.connect(log.error) 158 | 159 | return worker 160 | 161 | 162 | def create_dataset_dict_no_labs(volume_directory): 163 | """Creates unsupervised data dictionary for MONAI transforms and training.""" 164 | if not volume_directory.exists(): 165 | raise ValueError(f"Data folder {volume_directory} does not exist") 166 | images_filepaths = utils.get_all_matching_files(volume_directory) 167 | if len(images_filepaths) == 0: 168 | raise ValueError(f"Data folder {volume_directory} is empty") 169 | 170 | logger.info("Images :") 171 | for file in images_filepaths: 172 | logger.info(Path(file).stem) 173 | logger.info("*" * 10) 174 | return [{"image": str(image_name)} for image_name in images_filepaths] 175 | 176 | 177 | def create_eval_dataset_dict(image_directory, label_directory): 178 | """Creates data dictionary for MONAI transforms and training. 179 | 180 | Returns: 181 | A dict with the following keys 182 | 183 | * "image": image 184 | * "label" : corresponding label 185 | """ 186 | images_filepaths = utils.get_all_matching_files(image_directory) 187 | labels_filepaths = utils.get_all_matching_files(label_directory) 188 | 189 | if len(images_filepaths) == 0 or len(labels_filepaths) == 0: 190 | raise ValueError("Data folders are empty") 191 | 192 | if not Path(images_filepaths[0]).parent.exists(): 193 | raise ValueError("Images folder does not exist") 194 | if not Path(labels_filepaths[0]).parent.exists(): 195 | raise ValueError("Labels folder does not exist") 196 | 197 | logger.info("Images :\n") 198 | for file in images_filepaths: 199 | logger.info(Path(file).name) 200 | logger.info("*" * 10) 201 | logger.info("Labels :\n") 202 | for file in labels_filepaths: 203 | logger.info(Path(file).name) 204 | return [ 205 | {"image": image_name, "label": label_name} 206 | for image_name, label_name in zip(images_filepaths, labels_filepaths) 207 | ] 208 | -------------------------------------------------------------------------------- /napari_cellseg3d/dev_scripts/crop_data.py: -------------------------------------------------------------------------------- 1 | """Simple script to fragment a 3d image into smaller 3d images of size roi_size.""" 2 | from pathlib import Path 3 | 4 | import numpy as np 5 | from tifffile import imread, imwrite 6 | 7 | from napari_cellseg3d.utils import get_all_matching_files 8 | 9 | 10 | def crop_3d_image(image, roi_size): 11 | """Crops a 3d image by extracting all regions of size roi_size. 12 | 13 | If the edge of the array is reached, the cropped region is overlapped with the previous cropped region. 14 | """ 15 | image_size = image.shape 16 | cropped_images = [] 17 | for i in range(0, image_size[0], roi_size[0]): 18 | for j in range(0, image_size[1], roi_size[1]): 19 | for k in range(0, image_size[2], roi_size[2]): 20 | if i + roi_size[0] >= image_size[0]: 21 | crop_location_i = image_size[0] - roi_size[0] 22 | else: 23 | crop_location_i = i 24 | if j + roi_size[1] >= image_size[1]: 25 | crop_location_j = image_size[1] - roi_size[1] 26 | else: 27 | crop_location_j = j 28 | if k + roi_size[2] >= image_size[2]: 29 | crop_location_k = image_size[2] - roi_size[2] 30 | else: 31 | crop_location_k = k 32 | cropped_images.append( 33 | image[ 34 | crop_location_i : crop_location_i + roi_size[0], 35 | crop_location_j : crop_location_j + roi_size[1], 36 | crop_location_k : crop_location_k + roi_size[2], 37 | ] 38 | ) 39 | return cropped_images 40 | 41 | 42 | if __name__ == "__main__": 43 | image_path = ( 44 | Path().home() 45 | # / "Desktop/Code/CELLSEG_BENCHMARK/TPH2_DATA/somatomotor_iso" 46 | # / "Desktop/Code/CELLSEG_BENCHMARK/TPH2_DATA/somatomotor_iso/labels/semantic" 47 | / "Desktop/Code/CELLSEG_BENCHMARK/TPH2_mesospim/visual_iso/labels/semantic" 48 | ) 49 | if not image_path.exists() or not image_path.is_dir(): 50 | raise ValueError(f"Image path {image_path} does not exist") 51 | image_list = get_all_matching_files(image_path) 52 | for j in image_list: 53 | print(j) 54 | image = imread(str(j)) 55 | # crops = crop_3d_image(image, (64, 64, 64)) 56 | crops = [image] 57 | # viewer = napari.Viewer() 58 | if not (image_path / "cropped").exists(): 59 | (image_path / "cropped").mkdir(exist_ok=False) 60 | for i, im in enumerate(crops): 61 | print(im.shape) 62 | # viewer.add_image(im) 63 | imwrite( 64 | str(image_path / f"cropped/{j.stem}_{i}_crop.tif"), 65 | im.astype(np.uint16), 66 | dtype="uint16", 67 | ) 68 | # napari.run() 69 | -------------------------------------------------------------------------------- /napari_cellseg3d/dev_scripts/remote_inference.py: -------------------------------------------------------------------------------- 1 | """Script to perform inference on a single image and run post-processing on the results, withot napari.""" 2 | import logging 3 | from dataclasses import dataclass 4 | from pathlib import Path 5 | from typing import List 6 | 7 | import numpy as np 8 | import torch 9 | 10 | from napari_cellseg3d.code_models.instance_segmentation import ( 11 | clear_large_objects, 12 | clear_small_objects, 13 | threshold, 14 | volume_stats, 15 | voronoi_otsu, 16 | ) 17 | from napari_cellseg3d.code_models.worker_inference import InferenceWorker 18 | from napari_cellseg3d.config import ( 19 | InferenceWorkerConfig, 20 | InstanceSegConfig, 21 | ModelInfo, 22 | SlidingWindowConfig, 23 | ) 24 | from napari_cellseg3d.utils import resize 25 | 26 | logger = logging.getLogger(__name__) 27 | logging.basicConfig(level=logging.INFO) 28 | 29 | 30 | class LogFixture: 31 | """Fixture for napari-less logging, replaces napari_cellseg3d.interface.Log in model_workers. 32 | 33 | This allows to redirect the output of the workers to stdout instead of a specialized widget. 34 | """ 35 | 36 | def __init__(self): 37 | """Creates a LogFixture object.""" 38 | super(LogFixture, self).__init__() 39 | 40 | def print_and_log(self, text, printing=None): 41 | """Prints and logs text.""" 42 | print(text) 43 | 44 | def warn(self, warning): 45 | """Logs warning.""" 46 | logger.warning(warning) 47 | 48 | def error(self, e): 49 | """Logs error.""" 50 | raise (e) 51 | 52 | 53 | WINDOW_SIZE = 64 54 | 55 | MODEL_INFO = ModelInfo( 56 | name="SwinUNetR", 57 | model_input_size=64, 58 | ) 59 | 60 | CONFIG = InferenceWorkerConfig( 61 | device="cuda" if torch.cuda.is_available() else "cpu", 62 | model_info=MODEL_INFO, 63 | results_path=str(Path("./results").absolute()), 64 | compute_stats=False, 65 | sliding_window_config=SlidingWindowConfig(WINDOW_SIZE, 0.25), 66 | ) 67 | 68 | 69 | @dataclass 70 | class PostProcessConfig: 71 | """Config for post-processing.""" 72 | 73 | threshold: float = 0.4 74 | spot_sigma: float = 0.55 75 | outline_sigma: float = 0.55 76 | isotropic_spot_sigma: float = 0.2 77 | isotropic_outline_sigma: float = 0.2 78 | anisotropy_correction: List[ 79 | float 80 | ] = None # TODO change to actual values, should be a ratio like [1,1/5,1] 81 | clear_small_size: int = 5 82 | clear_large_objects: int = 500 83 | 84 | 85 | def inference_on_images( 86 | image: np.array, config: InferenceWorkerConfig = CONFIG 87 | ): 88 | """This function provides inference on an image with minimal config. 89 | 90 | Args: 91 | image (np.array): Image to perform inference on. 92 | config (InferenceWorkerConfig, optional): Config for InferenceWorker. Defaults to CONFIG, see above. 93 | """ 94 | # instance_method = InstanceSegmentationWrapper(voronoi_otsu, {"spot_sigma": 0.7, "outline_sigma": 0.7}) 95 | 96 | config.post_process_config.zoom.enabled = False 97 | config.post_process_config.thresholding.enabled = ( 98 | False # will need to be enabled and set to 0.5 for the test images 99 | ) 100 | config.post_process_config.instance = InstanceSegConfig( 101 | enabled=False, 102 | ) 103 | 104 | config.layer = image 105 | 106 | log = LogFixture() 107 | worker = InferenceWorker(config) 108 | logger.debug(f"Worker config: {worker.config}") 109 | 110 | worker.log_signal.connect(log.print_and_log) 111 | worker.warn_signal.connect(log.warn) 112 | worker.error_signal.connect(log.error) 113 | 114 | worker.log_parameters() 115 | 116 | results = [] 117 | # append the InferenceResult when yielded by worker to results 118 | for result in worker.inference(): 119 | results.append(result) 120 | 121 | return results 122 | 123 | 124 | def post_processing(semantic_segmentation, config: PostProcessConfig = None): 125 | """Run post-processing on inference results.""" 126 | config = PostProcessConfig() if config is None else config 127 | # if config.anisotropy_correction is None: 128 | # config.anisotropy_correction = [1, 1, 1 / 5] 129 | if config.anisotropy_correction is None: 130 | config.anisotropy_correction = [1, 1, 1] 131 | 132 | image = semantic_segmentation 133 | # apply threshold to semantic segmentation 134 | logger.info(f"Thresholding with {config.threshold}") 135 | image = threshold(image, config.threshold) 136 | logger.debug(f"Thresholded image shape: {image.shape}") 137 | # remove artifacts by clearing large objects 138 | logger.info(f"Clearing large objects with {config.clear_large_objects}") 139 | image = clear_large_objects(image, config.clear_large_objects) 140 | # run instance segmentation 141 | logger.info( 142 | f"Running instance segmentation with {config.spot_sigma} and {config.outline_sigma}" 143 | ) 144 | labels = voronoi_otsu( 145 | image, 146 | spot_sigma=config.spot_sigma, 147 | outline_sigma=config.outline_sigma, 148 | ) 149 | # clear small objects 150 | logger.info(f"Clearing small objects with {config.clear_small_size}") 151 | labels = clear_small_objects(labels, config.clear_small_size).astype( 152 | np.uint16 153 | ) 154 | logger.debug(f"Labels shape: {labels.shape}") 155 | # get volume stats WITH ANISOTROPY 156 | logger.debug(f"NUMBER OF OBJECTS: {np.max(np.unique(labels))-1}") 157 | stats_not_resized = volume_stats(labels) 158 | ######## RUN WITH ANISOTROPY ######## 159 | result_dict = {} 160 | result_dict["Not resized"] = { 161 | "labels": labels, 162 | "stats": stats_not_resized, 163 | } 164 | 165 | if config.anisotropy_correction != [1, 1, 1]: 166 | logger.info("Resizing image to correct anisotropy") 167 | image = resize(image, config.anisotropy_correction) 168 | logger.debug(f"Resized image shape: {image.shape}") 169 | logger.info("Running labels without anisotropy") 170 | labels_resized = voronoi_otsu( 171 | image, 172 | spot_sigma=config.isotropic_spot_sigma, 173 | outline_sigma=config.isotropic_outline_sigma, 174 | ) 175 | logger.info( 176 | f"Clearing small objects with {config.clear_large_objects}" 177 | ) 178 | labels_resized = clear_small_objects( 179 | labels_resized, config.clear_small_size 180 | ).astype(np.uint16) 181 | logger.debug( 182 | f"NUMBER OF OBJECTS: {np.max(np.unique(labels_resized))-1}" 183 | ) 184 | logger.info("Getting volume stats without anisotropy") 185 | stats_resized = volume_stats(labels_resized) 186 | return labels_resized, stats_resized 187 | 188 | return labels, stats_not_resized 189 | -------------------------------------------------------------------------------- /napari_cellseg3d/dev_scripts/remote_training.py: -------------------------------------------------------------------------------- 1 | """Showcases how to train a model without napari.""" 2 | 3 | from pathlib import Path 4 | 5 | from napari_cellseg3d import config as cfg 6 | from napari_cellseg3d.code_models.worker_training import ( 7 | SupervisedTrainingWorker, 8 | ) 9 | from napari_cellseg3d.utils import LOGGER as logger 10 | from napari_cellseg3d.utils import get_all_matching_files 11 | 12 | TRAINING_SPLIT = 0.2 # 0.4, 0.8 13 | MODEL_NAME = "SegResNet" # "SwinUNetR" 14 | BATCH_SIZE = 10 if MODEL_NAME == "SegResNet" else 5 15 | # BATCH_SIZE = 1 16 | 17 | SPLIT_FOLDER = "1_c15" # "2_c1_c4_visual" "3_c1245_visual" 18 | RESULTS_PATH = ( 19 | Path("/data/cyril") 20 | / "CELLSEG_BENCHMARK/cellseg3d_train" 21 | / f"{MODEL_NAME}_{SPLIT_FOLDER}_{int(TRAINING_SPLIT*100)}" 22 | ) 23 | 24 | IMAGES = ( 25 | Path("/data/cyril") 26 | / f"CELLSEG_BENCHMARK/TPH2_mesospim/SPLITS/{SPLIT_FOLDER}" 27 | ) 28 | LABELS = ( 29 | Path("/data/cyril") 30 | / f"CELLSEG_BENCHMARK/TPH2_mesospim/SPLITS/{SPLIT_FOLDER}/labels/semantic" 31 | ) 32 | 33 | 34 | class LogFixture: 35 | """Fixture for napari-less logging, replaces napari_cellseg3d.interface.Log in model_workers. 36 | 37 | This allows to redirect the output of the workers to stdout instead of a specialized widget. 38 | """ 39 | 40 | def __init__(self): 41 | """Creates a LogFixture object.""" 42 | super(LogFixture, self).__init__() 43 | 44 | def print_and_log(self, text, printing=None): 45 | """Prints and logs text.""" 46 | print(text) 47 | 48 | def warn(self, warning): 49 | """Logs warning.""" 50 | logger.warning(warning) 51 | 52 | def error(self, e): 53 | """Logs error.""" 54 | raise (e) 55 | 56 | 57 | def prepare_data(images_path, labels_path): 58 | """Prepares data for training.""" 59 | assert images_path.exists(), f"Images path does not exist: {images_path}" 60 | assert labels_path.exists(), f"Labels path does not exist: {labels_path}" 61 | if not RESULTS_PATH.exists(): 62 | RESULTS_PATH.mkdir(parents=True, exist_ok=True) 63 | 64 | images = get_all_matching_files(images_path) 65 | labels = get_all_matching_files(labels_path) 66 | 67 | print(f"Images paths: {images}") 68 | print(f"Labels paths: {labels}") 69 | 70 | logger.info("Images :\n") 71 | for file in images: 72 | logger.info(Path(file).name) 73 | logger.info("*" * 10) 74 | logger.info("Labels :\n") 75 | for file in images: 76 | logger.info(Path(file).name) 77 | 78 | assert len(images) == len( 79 | labels 80 | ), "Number of images and labels must be the same" 81 | 82 | return [ 83 | {"image": str(image_path), "label": str(label_path)} 84 | for image_path, label_path in zip(images, labels) 85 | ] 86 | 87 | 88 | def remote_training(): 89 | """Function to train a model without napari.""" 90 | # print(f"Results path: {RESULTS_PATH.resolve()}") 91 | 92 | wandb_config = cfg.WandBConfig( 93 | mode="online", 94 | save_model_artifact=True, 95 | ) 96 | 97 | deterministic_config = cfg.DeterministicConfig( 98 | seed=34936339, 99 | ) 100 | 101 | worker_config = cfg.SupervisedTrainingWorkerConfig( 102 | device="cuda:0", 103 | max_epochs=50, 104 | learning_rate=0.001, # 1e-3 105 | validation_interval=2, 106 | batch_size=BATCH_SIZE, # 10 for SegResNet 107 | deterministic_config=deterministic_config, 108 | scheduler_factor=0.5, 109 | scheduler_patience=10, # use default scheduler 110 | weights_info=cfg.WeightsInfo(), # no pretrained weights 111 | results_path_folder=str(RESULTS_PATH), 112 | sampling=False, 113 | do_augmentation=True, 114 | train_data_dict=prepare_data(IMAGES, LABELS), 115 | # supervised specific 116 | model_info=cfg.ModelInfo( 117 | name=MODEL_NAME, 118 | model_input_size=(64, 64, 64), 119 | ), 120 | loss_function="Generalized Dice", 121 | training_percent=TRAINING_SPLIT, 122 | ) 123 | 124 | worker = SupervisedTrainingWorker(worker_config) 125 | worker.wandb_config = wandb_config 126 | ######### SET LOG 127 | log = LogFixture() 128 | worker.log_signal.connect(log.print_and_log) 129 | worker.warn_signal.connect(log.warn) 130 | worker.error_signal.connect(log.error) 131 | 132 | results = [] 133 | for result in worker.train(): 134 | results.append(result) 135 | print("Training finished") 136 | 137 | 138 | if __name__ == "__main__": 139 | results = remote_training() 140 | -------------------------------------------------------------------------------- /napari_cellseg3d/dev_scripts/sliding_window_voronoi.py: -------------------------------------------------------------------------------- 1 | """Test script for sliding window Voronoi-Otsu segmentation."".""" 2 | import numpy as np 3 | import pyclesperanto_prototype as cle 4 | from tqdm import tqdm 5 | 6 | 7 | def sliding_window_voronoi_otsu(volume, spot_sigma, outline_sigma, patch_size): 8 | """Given a volume of dimensions HxWxD, a spot_sigma and an outline_sigma, perform Voronoi-Otsu segmentation on the volume using a sliding window of size patch_size. 9 | 10 | If the edge has been reached, the patch size is reduced 11 | to fit the remaining space. The result is a segmentation of the same size 12 | as the input volume. 13 | 14 | Args: 15 | volume (np.array): The volume to segment. 16 | spot_sigma (float): The sigma for the spot detection. 17 | outline_sigma (float): The sigma for the outline detection. 18 | patch_size (int): The size of the sliding window. 19 | """ 20 | result = np.zeros(volume.shape, dtype=np.uint32) 21 | max_label_id = 0 22 | x, y, z = volume.shape[-3:] 23 | for i in tqdm(range(0, x, patch_size)): 24 | for j in range(0, y, patch_size): 25 | for k in range(0, z, patch_size): 26 | patch = volume[ 27 | i : min(i + patch_size, x), 28 | j : min(j + patch_size, y), 29 | k : min(k + patch_size, z), 30 | ] 31 | patch_result = cle.voronoi_otsu_labeling( 32 | patch, spot_sigma=spot_sigma, outline_sigma=outline_sigma 33 | ) 34 | patch_result = np.array(patch_result) 35 | # make sure labels are unique, only where result is not 0 36 | patch_result[patch_result > 0] += max_label_id 37 | result[ 38 | i : min(i + patch_size, x), 39 | j : min(j + patch_size, y), 40 | k : min(k + patch_size, z), 41 | ] = patch_result 42 | max_label_id = np.max(patch_result) 43 | return result 44 | 45 | 46 | # if __name__ == "__main__": 47 | # import napari 48 | # 49 | # rand_array = np.random.random((525, 621, 400)) 50 | # rand_array = rand_array > 0.999 51 | # 52 | # result = sliding_window_voronoi_otsu(rand_array, 0.1, 0.1, 128) 53 | # 54 | # viewer = napari.Viewer() 55 | # viewer.add_image(rand_array) 56 | # viewer.add_labels(result) 57 | # napari.run() 58 | -------------------------------------------------------------------------------- /napari_cellseg3d/dev_scripts/thread_test.py: -------------------------------------------------------------------------------- 1 | import time 2 | 3 | import napari 4 | from napari.qt.threading import thread_worker 5 | from qtpy.QtWidgets import ( 6 | QGridLayout, 7 | QLabel, 8 | QProgressBar, 9 | QPushButton, 10 | QTextEdit, 11 | QVBoxLayout, 12 | QWidget, 13 | ) 14 | 15 | from napari_cellseg3d.utils import rand_gen 16 | 17 | #################################### 18 | # Tutorial code from napari forums # 19 | #################################### 20 | # not covered by tests 21 | 22 | 23 | @thread_worker 24 | def two_way_communication_with_args(start, end): 25 | """Both sends and receives values to & from the main thread. 26 | 27 | Accepts arguments, puts them on the worker object. 28 | Receives values from main thread with ``incoming = yield`` 29 | Optionally returns a value at the end. 30 | """ 31 | # do computationally intensive work here 32 | i = start 33 | while i < end: 34 | i += 1 35 | time.sleep(0.1) 36 | # incoming receives values from the main thread 37 | # while yielding sends values back to the main thread 38 | incoming = yield i 39 | i = incoming if incoming is not None else i 40 | 41 | # do optional teardown here 42 | return "done" 43 | 44 | 45 | class Controller(QWidget): 46 | """Widget that controls a function running in another thread.""" 47 | 48 | def __init__(self, viewer): 49 | """Build the widget.""" 50 | super().__init__() 51 | 52 | self.viewer = viewer 53 | layout = QGridLayout() 54 | self.setLayout(layout) 55 | self.status = QLabel('Click "Start"', self) 56 | self.play_btn = QPushButton("Start", self) 57 | self.abort_btn = QPushButton("Abort!", self) 58 | self.reset_btn = QPushButton("Reset", self) 59 | self.progress_bar = QProgressBar() 60 | 61 | layout.addWidget(self.play_btn, 0, 0) 62 | layout.addWidget(self.reset_btn, 0, 1) 63 | layout.addWidget(self.abort_btn, 0, 2) 64 | layout.addWidget(self.status, 0, 3) 65 | layout.setColumnStretch(3, 1) 66 | layout.addWidget(self.progress_bar, 1, 0, 1, 4) 67 | 68 | self.btn = QPushButton("Oui") 69 | self.log = QTextEdit() 70 | self.prog = QProgressBar() 71 | self.build() 72 | 73 | def build(self): 74 | container = QWidget() 75 | layout = QVBoxLayout() 76 | layout.addWidget(self.prog) 77 | layout.addWidget(self.log) 78 | layout.addWidget(self.btn) 79 | container.setLayout(layout) 80 | self.viewer.window.add_dock_widget(container, area="left") 81 | 82 | 83 | def create_connected_widget(viewer): 84 | """Builds a widget that can control a function in another thread.""" 85 | w = Controller(viewer) 86 | steps = 40 87 | 88 | # the decorated function now returns a GeneratorWorker object, and the 89 | # Qthread in which it's running. 90 | # (optionally pass start=False to prevent immediate running) 91 | worker = two_way_communication_with_args(0, steps) 92 | 93 | w.play_btn.clicked.connect(worker.start) 94 | 95 | # it provides signals like {started, yielded, returned, errored, finished} 96 | worker.returned.connect(lambda x: w.status.setText(f"worker returned {x}")) 97 | worker.errored.connect(lambda x: w.status.setText(f"worker errored {x}")) 98 | worker.started.connect(lambda: w.status.setText("worker started...")) 99 | worker.aborted.connect(lambda: w.status.setText("worker aborted")) 100 | 101 | # send values into the function (like generator.send) using worker.send 102 | # abort thread with worker.abort() 103 | w.abort_btn.clicked.connect(lambda: worker.quit()) 104 | 105 | def on_reset_button_pressed(): 106 | # we want to avoid sending into a unstarted worker 107 | if worker.is_running: 108 | worker.send(0) 109 | 110 | def on_yield(x, test): 111 | # Receive events and update widget progress 112 | w.progress_bar.setValue(100 * x // steps) 113 | w.log.insertPlainText(str(x) + "\n") 114 | w.log.verticalScrollBar().setValue(w.log.verticalScrollBar().maximum()) 115 | w.status.setText(f"worker yielded {x}") 116 | print(test) 117 | 118 | def on_start(): 119 | def handle_pause(): 120 | worker.toggle_pause() 121 | w.play_btn.setText("Pause" if worker.is_paused else "Continue") 122 | 123 | w.play_btn.clicked.disconnect(worker.start) 124 | w.play_btn.setText("Pause") 125 | w.play_btn.clicked.connect(handle_pause) 126 | 127 | def on_finish(): 128 | w.play_btn.setDisabled(True) 129 | w.reset_btn.setDisabled(True) 130 | w.abort_btn.setDisabled(True) 131 | w.play_btn.setText("Done") 132 | 133 | w.reset_btn.clicked.connect(on_reset_button_pressed) 134 | worker.yielded.connect(lambda x: on_yield(x, test="oui")) 135 | worker.started.connect(on_start) 136 | worker.finished.connect(on_finish) 137 | return w 138 | 139 | 140 | if __name__ == "__main__": 141 | viewer = napari.view_image(rand_gen.random((512, 512))) 142 | w = create_connected_widget(viewer) 143 | viewer.window.add_dock_widget(w) 144 | 145 | napari.run() 146 | -------------------------------------------------------------------------------- /napari_cellseg3d/dev_scripts/whole_brain_utils.py: -------------------------------------------------------------------------------- 1 | """Utilities to improve whole-brain regions segmentation.""" 2 | import numpy as np 3 | from skimage.measure import label 4 | from skimage.segmentation import find_boundaries 5 | 6 | 7 | def extract_continuous_region(image): 8 | """Extract continuous region from image.""" 9 | image = np.where(image > 0, 1, 0) 10 | return label(image) 11 | 12 | 13 | def get_boundaries(image_regions, num_iters=1): 14 | """Obtain boundaries from image regions.""" 15 | boundaries = np.zeros_like(image_regions) 16 | label_values = np.unique(image_regions) 17 | iter_n = 0 18 | new_labels = image_regions 19 | while iter_n < num_iters: 20 | for i in label_values: 21 | if i == 0: 22 | continue 23 | boundary = find_boundaries(new_labels == i) 24 | boundaries += np.where(boundary > 0, i, 0) 25 | new_labels = np.where(boundary > 0, 0, new_labels) 26 | iter_n += 1 27 | return boundaries 28 | 29 | 30 | def remove_boundaries_from_segmentation( 31 | image_segmentation, image_labels=None, image=None, thickness_num_iters=1 32 | ): 33 | """Remove boundaries from segmentation. 34 | 35 | Args: 36 | image_segmentation (np.ndarray): 3D image segmentation. 37 | image_labels (np.ndarray): 3D integer labels of image segmentation. Use output from extract_continuous_region. 38 | image (np.ndarray): Additional 3D image used to extract continuous region. 39 | thickness_num_iters (int): Number of iterations to remove boundaries. A greater number will remove more boundary pixels. 40 | """ 41 | if image_labels is None: 42 | image_regions = extract_continuous_region(image_segmentation) 43 | elif image is not None: 44 | image_regions = extract_continuous_region(image) 45 | else: 46 | image_regions = image_labels 47 | boundaries = get_boundaries(image_regions, num_iters=thickness_num_iters) 48 | 49 | seg_in = np.where(image_regions > 0, image_segmentation, 0) 50 | return np.where(boundaries > 0, 0, seg_in) 51 | -------------------------------------------------------------------------------- /napari_cellseg3d/napari.yaml: -------------------------------------------------------------------------------- 1 | name: napari_cellseg3d 2 | display_name: CellSeg3D 3 | schema_version: 0.0.4 4 | 5 | contributions: 6 | commands: 7 | - id: napari_cellseg3d.load 8 | title: Create reviewer 9 | python_name: napari_cellseg3d.plugins:Reviewer 10 | 11 | - id: napari_cellseg3d.help 12 | title: Create Help 13 | python_name: napari_cellseg3d.plugins:Helper 14 | 15 | - id: napari_cellseg3d.utils 16 | title: Create utilities 17 | python_name: napari_cellseg3d.plugins:Utilities 18 | 19 | - id: napari_cellseg3d.infer 20 | title: Create Inference widget 21 | python_name: napari_cellseg3d.plugins:Inferer 22 | 23 | - id: napari_cellseg3d.train 24 | title: Create Trainer widget 25 | python_name: napari_cellseg3d.plugins:Trainer 26 | 27 | 28 | widgets: 29 | - command: napari_cellseg3d.load 30 | display_name: Labeling 31 | 32 | - command: napari_cellseg3d.infer 33 | display_name: Inference 34 | 35 | - command: napari_cellseg3d.train 36 | display_name: Training 37 | 38 | - command: napari_cellseg3d.utils 39 | display_name: Utilities 40 | 41 | - command: napari_cellseg3d.help 42 | display_name: Help/About... 43 | -------------------------------------------------------------------------------- /napari_cellseg3d/plugins.py: -------------------------------------------------------------------------------- 1 | """napari-cellseg3d: napari plugin for 3D cell segmentation. 2 | 3 | Main plugins menu for napari-cellseg3d. 4 | """ 5 | from napari_cellseg3d.code_plugins.plugin_helper import Helper 6 | from napari_cellseg3d.code_plugins.plugin_model_inference import Inferer 7 | from napari_cellseg3d.code_plugins.plugin_model_training import Trainer 8 | from napari_cellseg3d.code_plugins.plugin_review import Reviewer 9 | from napari_cellseg3d.code_plugins.plugin_utilities import Utilities 10 | 11 | 12 | def napari_experimental_provide_dock_widget(): 13 | return [ 14 | (Reviewer, {"name": "Review loader"}), 15 | (Helper, {"name": "Help/About..."}), 16 | (Inferer, {"name": "Inference loader"}), 17 | (Trainer, {"name": "Training loader"}), 18 | (Utilities, {"name": "Utilities"}), 19 | ] 20 | -------------------------------------------------------------------------------- /napari_cellseg3d/res/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AdaptiveMotorControlLab/CellSeg3D/6de4b86a671ffcd4b5535277a53082ac5ecc00a1/napari_cellseg3d/res/__init__.py -------------------------------------------------------------------------------- /napari_cellseg3d/res/logo_alpha.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AdaptiveMotorControlLab/CellSeg3D/6de4b86a671ffcd4b5535277a53082ac5ecc00a1/napari_cellseg3d/res/logo_alpha.png -------------------------------------------------------------------------------- /napari_cellseg3d/setup.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | import setuptools 3 | 4 | if __name__ == "__main__": 5 | setuptools.setup() 6 | -------------------------------------------------------------------------------- /notebooks/labels_plot.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Labels plot\n", 8 | "\n", 9 | "This simple notebook shows how you can plot your labels in jupyter using matplotlib.\n", 10 | "Viewing in napari is recommended however, as it allows you to interact with thel labels." 11 | ] 12 | }, 13 | { 14 | "cell_type": "code", 15 | "execution_count": null, 16 | "id": "d1c53d64-3b95-454c-9183-3fd95f3154ee", 17 | "metadata": {}, 18 | "outputs": [], 19 | "source": [ 20 | "import matplotlib.pyplot as plt\n", 21 | "import os\n", 22 | "import numpy as np\n", 23 | "from tifffile import imread" 24 | ] 25 | }, 26 | { 27 | "cell_type": "code", 28 | "execution_count": null, 29 | "metadata": { 30 | "collapsed": false 31 | }, 32 | "outputs": [], 33 | "source": [ 34 | "import sys\n", 35 | "!{sys.executable} -m pip install plotly" 36 | ] 37 | }, 38 | { 39 | "cell_type": "markdown", 40 | "id": "cb6ce5bf-0a09-46a4-988c-2183d09a8211", 41 | "metadata": {}, 42 | "source": [ 43 | "## Plot of the labels" 44 | ] 45 | }, 46 | { 47 | "cell_type": "markdown", 48 | "id": "7aa3b205-2795-4f24-951a-90211a1a96fa", 49 | "metadata": {}, 50 | "source": [ 51 | "**Enter your image folder below, make sure the tif images you want to see plotted are present.**" 52 | ] 53 | }, 54 | { 55 | "cell_type": "code", 56 | "execution_count": null, 57 | "id": "0e3915c7-ad7c-4303-bbf9-4ef8be3c7a0a", 58 | "metadata": {}, 59 | "outputs": [], 60 | "source": [ 61 | "pred_path = \"C:/Users/Cyril/Desktop/test/pred/large/\"\n", 62 | "\n", 63 | "pred_images = []\n", 64 | "for filename in sorted(os.listdir(pred_path)):\n", 65 | " img = imread(os.path.join(pred_path, filename))\n", 66 | " pred_images.append(np.array(img.compute()))" 67 | ] 68 | }, 69 | { 70 | "cell_type": "code", 71 | "execution_count": null, 72 | "id": "85aa1e5f-eb1c-4d7f-b7f3-af15bbc479e0", 73 | "metadata": {}, 74 | "outputs": [], 75 | "source": [ 76 | "for y_pred in pred_images:\n", 77 | " y_pred[y_pred > 0.9] = 1\n", 78 | " y_pred[y_pred <= 0.9] = 0\n", 79 | " pred3d = y_pred\n", 80 | " z, x, y = pred3d.nonzero()\n", 81 | " fig = plt.figure(figsize=(10, 10))\n", 82 | " ax = plt.axes(projection=\"3d\")\n", 83 | " ax.scatter3D(x, y, z, c=z, alpha=1)" 84 | ] 85 | }, 86 | { 87 | "cell_type": "code", 88 | "execution_count": null, 89 | "id": "48db4c1f-683a-4d1a-9f06-6e681091b038", 90 | "metadata": {}, 91 | "outputs": [], 92 | "source": [ 93 | "import plotly.graph_objects as go\n", 94 | "from plotly.offline import iplot, init_notebook_mode" 95 | ] 96 | }, 97 | { 98 | "cell_type": "code", 99 | "execution_count": null, 100 | "id": "9f138504-2cb3-4008-83da-7e494fecb903", 101 | "metadata": {}, 102 | "outputs": [], 103 | "source": [ 104 | "for pred3d in pred_images:\n", 105 | " z, x, y = pred3d.nonzero()\n", 106 | " fig = go.Figure(\n", 107 | " data=go.Scatter3d(\n", 108 | " x=x,\n", 109 | " y=y,\n", 110 | " z=z,\n", 111 | " mode=\"markers\",\n", 112 | " marker=dict(\n", 113 | " size=4,\n", 114 | " color=z, # set color to an array/list of desired values\n", 115 | " colorscale=\"Viridis\", # choose a colorscale\n", 116 | " opacity=0.8,\n", 117 | " ),\n", 118 | " )\n", 119 | " )\n", 120 | "\n", 121 | " fig.update_layout(\n", 122 | " height=600,\n", 123 | " width=600,\n", 124 | " )\n", 125 | "\n", 126 | " fig.show()" 127 | ] 128 | }, 129 | { 130 | "cell_type": "markdown", 131 | "id": "e9b9bb2a-1b86-4b8f-a380-cd21c2681860", 132 | "metadata": { 133 | "tags": [] 134 | }, 135 | "source": [ 136 | "Save as html in case plotly plots do not render correctly :" 137 | ] 138 | }, 139 | { 140 | "cell_type": "code", 141 | "execution_count": null, 142 | "id": "c7b50635-2797-4c16-912f-239982c0ded8", 143 | "metadata": {}, 144 | "outputs": [], 145 | "source": [ 146 | "import os\n", 147 | "\n", 148 | "os.system(\"jupyter nbconvert --to html full_plot.ipynb\")" 149 | ] 150 | }, 151 | { 152 | "cell_type": "code", 153 | "execution_count": null, 154 | "id": "fca42ff2-6d27-461f-aa6f-2f01b2631b0f", 155 | "metadata": {}, 156 | "outputs": [], 157 | "source": [] 158 | } 159 | ], 160 | "metadata": { 161 | "kernelspec": { 162 | "display_name": "Python 3 (ipykernel)", 163 | "language": "python", 164 | "name": "python3" 165 | }, 166 | "language_info": { 167 | "codemirror_mode": { 168 | "name": "ipython", 169 | "version": 3 170 | }, 171 | "file_extension": ".py", 172 | "mimetype": "text/x-python", 173 | "name": "python", 174 | "nbconvert_exporter": "python", 175 | "pygments_lexer": "ipython3", 176 | "version": "3.8.11" 177 | } 178 | }, 179 | "nbformat": 4, 180 | "nbformat_minor": 5 181 | } 182 | -------------------------------------------------------------------------------- /notebooks/plots_data2.csv: -------------------------------------------------------------------------------- 1 | Volume,Centroid x,Centroid y,Centroid z,Sphericity (axes),Image size,Total image volume,Total object volume (pixels),Filling ratio,Number objects 2 | 31,0.4838709677419355,15.0,115.16129032258064,0.6450106202713449,"(64, 128, 153)",1253376,11865,0.009466433057598039,102 3 | 93,3.2688172043010755,31.537634408602152,116.02150537634408,0.6658406191165697,,,,, 4 | 136,1.7941176470588236,42.80882352941177,117.375,0.7093448289217301,,,,, 5 | 47,1.4893617021276595,38.851063829787236,109.72340425531915,0.7965232988603199,,,,, 6 | 39,1.641025641025641,49.282051282051285,113.02564102564102,0.8481975118577634,,,,, 7 | 231,1.5454545454545454,52.27272727272727,43.34632034632035,0.6616320516251043,,,,, 8 | 68,1.7058823529411764,57.5,118.8970588235294,0.7626068213630461,,,,, 9 | 101,2.118811881188119,61.48514851485149,44.504950495049506,0.8781711034689408,,,,, 10 | 100,1.27,81.28,48.6,0.8727176428899344,,,,, 11 | 122,2.360655737704918,85.76229508196721,117.8688524590164,0.8412965543243708,,,,, 12 | 103,1.5048543689320388,89.54368932038835,68.50485436893204,0.8165368443026244,,,,, 13 | 81,4.283950617283951,15.17283950617284,35.333333333333336,0.9005042399298055,,,,, 14 | 109,4.614678899082569,78.5137614678899,36.055045871559635,0.8733098387087967,,,,, 15 | 93,5.010752688172043,68.04301075268818,50.24731182795699,0.7590625341479772,,,,, 16 | 99,5.090909090909091,107.4040404040404,53.292929292929294,0.859130593365946,,,,, 17 | 41,7.146341463414634,108.4390243902439,115.07317073170732,0.9184875572032806,,,,, 18 | 119,8.445378151260504,41.05882352941177,45.38655462184874,0.8297995432915728,,,,, 19 | 123,9.105691056910569,54.0650406504065,47.520325203252035,0.8921347978266936,,,,, 20 | 35,9.485714285714286,55.4,116.31428571428572,0.7971467037126605,,,,, 21 | 76,10.236842105263158,56.23684210526316,73.21052631578948,0.8852614598258646,,,,, 22 | 80,10.15,70.7,112.4375,0.9233100446504423,,,,, 23 | 267,11.104868913857677,96.01123595505618,51.10112359550562,0.5976002177170818,,,,, 24 | 139,11.33093525179856,24.489208633093526,69.38848920863309,0.8089813798530895,,,,, 25 | 152,11.039473684210526,70.51973684210526,49.9671052631579,0.6243663969275468,,,,, 26 | 117,11.495726495726496,29.45299145299145,41.27350427350427,0.8369347777397995,,,,, 27 | 118,11.584745762711865,61.46610169491525,63.432203389830505,0.8504587700091666,,,,, 28 | 39,10.692307692307692,79.46153846153847,42.256410256410255,0.9001134522912314,,,,, 29 | 138,12.601449275362318,63.34057971014493,49.93478260869565,0.8185376701026623,,,,, 30 | 65,14.753846153846155,16.43076923076923,118.18461538461538,0.927190969047071,,,,, 31 | 225,16.27111111111111,51.626666666666665,51.48888888888889,0.687863052578344,,,,, 32 | 36,15.277777777777779,59.166666666666664,115.5,0.8548343194855755,,,,, 33 | 84,16.845238095238095,41.20238095238095,38.214285714285715,0.8163542691116327,,,,, 34 | 63,16.41269841269841,70.06349206349206,92.74603174603175,0.8896698904491191,,,,, 35 | 65,16.16923076923077,75.21538461538462,75.6923076923077,0.869797622352677,,,,, 36 | 75,17.6,14.56,39.373333333333335,0.8548205422178841,,,,, 37 | 100,17.37,63.15,60.2,0.8947628894427309,,,,, 38 | 102,18.735294117647058,96.36274509803921,90.00980392156863,0.8764523221569422,,,,, 39 | 107,18.588785046728972,99.60747663551402,48.44859813084112,0.8203180976193716,,,,, 40 | 100,19.71,31.85,73.98,0.8381112248589747,,,,, 41 | 126,20.714285714285715,23.73015873015873,60.94444444444444,0.8582497946561416,,,,, 42 | 258,23.37984496124031,25.007751937984494,78.62015503875969,0.8020745456084516,,,,, 43 | 272,22.30514705882353,51.62132352941177,42.35294117647059,0.5605069436020935,,,,, 44 | 118,22.228813559322035,88.82203389830508,66.38135593220339,0.856271263684927,,,,, 45 | 90,21.944444444444443,101.45555555555555,38.24444444444445,0.8901324964553012,,,,, 46 | 112,23.357142857142858,93.9375,49.339285714285715,0.8609449976320952,,,,, 47 | 105,23.514285714285716,32.40952380952381,40.114285714285714,0.8382834679422576,,,,, 48 | 107,23.49532710280374,68.59813084112149,40.97196261682243,0.8505970858679607,,,,, 49 | 97,23.68041237113402,71.79381443298969,61.02061855670103,0.8674537376209845,,,,, 50 | 90,27.044444444444444,36.644444444444446,34.522222222222226,0.9115901791172402,,,,, 51 | 106,26.462264150943398,76.91509433962264,61.405660377358494,0.8521232208462145,,,,, 52 | 104,27.567307692307693,81.47115384615384,67.91346153846153,0.8856333469575824,,,,, 53 | 31,28.387096774193548,101.06451612903226,118.54838709677419,0.6733136737854146,,,,, 54 | 68,27.514705882352942,101.67647058823529,101.25,0.9229020737673537,,,,, 55 | 95,28.652631578947368,23.273684210526316,55.21052631578947,0.7884335121029805,,,,, 56 | 58,28.155172413793103,58.08620689655172,88.24137931034483,0.8787860420711744,,,,, 57 | 107,28.64485981308411,80.16822429906541,48.308411214953274,0.8219017552933839,,,,, 58 | 186,29.956989247311828,95.09677419354838,64.91397849462365,0.7281758776689308,,,,, 59 | 41,29.878048780487806,45.09756097560975,116.04878048780488,0.88591764616251,,,,, 60 | 105,30.685714285714287,110.87619047619047,35.67619047619048,0.8775272958630516,,,,, 61 | 100,32.21,48.53,47.45,0.8562048181553484,,,,, 62 | 198,33.36363636363637,37.56060606060606,39.18181818181818,0.7217935004908436,,,,, 63 | 92,32.358695652173914,67.41304347826087,84.96739130434783,0.8533929372998513,,,,, 64 | 101,34.71287128712871,61.51485148514851,45.148514851485146,0.8091304152391979,,,,, 65 | 141,35.851063829787236,23.673758865248228,54.354609929078016,0.7944676345817062,,,,, 66 | 112,35.410714285714285,62.625,62.0,0.8400571169779933,,,,, 67 | 84,35.785714285714285,92.83333333333333,100.21428571428571,0.8744192159361612,,,,, 68 | 101,36.17821782178218,91.39603960396039,41.87128712871287,0.8306307397615024,,,,, 69 | 124,37.66129032258065,95.13709677419355,53.975806451612904,0.8703042352704017,,,,, 70 | 210,40.94761904761905,78.75238095238095,40.25238095238095,0.7771494996946016,,,,, 71 | 140,39.40714285714286,47.55714285714286,43.76428571428571,0.8012734847195678,,,,, 72 | 127,39.346456692913385,101.46456692913385,55.574803149606296,0.8271814187663593,,,,, 73 | 111,40.369369369369366,39.810810810810814,49.52252252252252,0.760605901476274,,,,, 74 | 105,40.63809523809524,98.93333333333334,43.628571428571426,0.8271553589005392,,,,, 75 | 182,42.005494505494504,18.087912087912088,37.637362637362635,0.7986773634576607,,,,, 76 | 72,41.47222222222222,37.47222222222222,74.20833333333333,0.8835057428298462,,,,, 77 | 98,41.38775510204081,75.78571428571429,86.22448979591837,0.8290078770517749,,,,, 78 | 161,45.80124223602485,106.41614906832298,45.4472049689441,0.7198341410144687,,,,, 79 | 114,45.35964912280702,100.85087719298245,56.728070175438596,0.8227801400218587,,,,, 80 | 50,48.64,109.58,118.0,0.867465852994249,,,,, 81 | 111,49.648648648648646,24.35135135135135,47.612612612612615,0.8057757698817215,,,,, 82 | 32,49.28125,29.8125,54.3125,0.8540165407362008,,,,, 83 | 145,50.675862068965515,63.83448275862069,70.12413793103448,0.7820923492690073,,,,, 84 | 115,52.25217391304348,44.40869565217391,48.29565217391304,0.8053783508669673,,,,, 85 | 217,52.30875576036866,76.23502304147465,52.857142857142854,0.7704630666972506,,,,, 86 | 134,52.71641791044776,84.19402985074628,66.00746268656717,0.7472184461928553,,,,, 87 | 164,53.28048780487805,101.98780487804878,86.6951219512195,0.7648888539609627,,,,, 88 | 124,53.903225806451616,31.60483870967742,35.83870967741935,0.836658496088752,,,,, 89 | 126,53.69047619047619,57.023809523809526,66.12698412698413,0.77450154403673,,,,, 90 | 372,54.623655913978496,68.63172043010752,46.336021505376344,0.512400378837635,,,,, 91 | 118,54.720338983050844,76.44915254237289,38.55084745762712,0.7820080547887326,,,,, 92 | 36,54.80555555555556,112.55555555555556,54.083333333333336,0.7359388178605194,,,,, 93 | 241,57.7551867219917,27.904564315352697,51.6058091286307,0.5188750708215295,,,,, 94 | 125,58.608,44.584,59.488,0.7867203535081382,,,,, 95 | 210,60.61904761904762,17.080952380952382,43.352380952380955,0.8115493205588132,,,,, 96 | 133,60.278195488721806,99.84962406015038,56.849624060150376,0.7910733249800463,,,,, 97 | 112,61.651785714285715,52.544642857142854,54.419642857142854,0.8500703659928895,,,,, 98 | 158,61.79746835443038,81.42405063291139,68.63291139240506,0.7856704194864025,,,,, 99 | 218,61.76146788990825,87.31192660550458,51.80275229357798,0.7158664669220091,,,,, 100 | 183,62.26775956284153,37.85792349726776,43.557377049180324,0.5736646092986522,,,,, 101 | 77,62.37662337662338,70.77922077922078,50.324675324675326,0.7390700452945853,,,,, 102 | 33,62.21212121212121,88.18181818181819,80.48484848484848,0.8861848417735211,,,,, 103 | 63,62.6031746031746,102.06349206349206,39.20634920634921,0.6780265897476226,,,,, 104 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [project] 2 | name = "napari_cellseg3d" 3 | authors = [ 4 | {name = "Cyril Achard", email = "cyril.achard@epfl.ch"}, 5 | {name = "Maxime Vidal", email = "maxime.vidal@epfl.ch"}, 6 | {name = "Mackenzie Mathis", email = "mackenzie@post.harvard.edu"}, 7 | ] 8 | readme = "README.md" 9 | description = "Plugin for cell segmentation in 3D" 10 | classifiers = [ 11 | "Development Status :: 4 - Beta", 12 | "Intended Audience :: Science/Research", 13 | "Framework :: napari", 14 | "Topic :: Software Development :: Testing", 15 | "Programming Language :: Python", 16 | "Programming Language :: Python :: 3", 17 | "Programming Language :: Python :: 3.8", 18 | "Programming Language :: Python :: 3.9", 19 | "Programming Language :: Python :: 3.10", 20 | "Operating System :: OS Independent", 21 | "License :: OSI Approved :: MIT License", 22 | "Topic :: Scientific/Engineering :: Artificial Intelligence", 23 | "Topic :: Scientific/Engineering :: Image Processing", 24 | "Topic :: Scientific/Engineering :: Visualization", 25 | ] 26 | license = {text = "MIT"} 27 | requires-python = ">=3.8" 28 | dependencies = [ 29 | "numpy", 30 | "napari[all]>=0.4.14", 31 | "QtPy", 32 | # "opencv-python>=4.5.5", 33 | # "dask-image>=0.6.0", 34 | "scikit-image>=0.19.2", 35 | "matplotlib>=3.4.1", 36 | "tifffile>=2022.2.9", 37 | # "imageio-ffmpeg>=0.4.5", 38 | "imagecodecs>=2023.3.16", 39 | "torch>=1.11", 40 | "monai[nibabel,einops]>=0.9.0", 41 | "itk", 42 | "tqdm", 43 | # "nibabel", 44 | # "pillow", 45 | "pyclesperanto-prototype", 46 | "tqdm", 47 | "matplotlib", 48 | "pydensecrf2", 49 | ] 50 | dynamic = ["version", "entry-points"] 51 | 52 | [project.urls] 53 | Homepage = "https://github.com/AdaptiveMotorControlLab/CellSeg3D" 54 | Documentation = "https://adaptivemotorcontrollab.github.io/cellseg3d-docs/res/welcome.html" 55 | Issues = "https://github.com/AdaptiveMotorControlLab/CellSeg3D/issues" 56 | 57 | [build-system] 58 | requires = ["setuptools", "wheel"] 59 | build-backend = "setuptools.build_meta" 60 | 61 | [tool.setuptools] 62 | include-package-data = true 63 | 64 | [tool.setuptools.packages.find] 65 | where = ["."] 66 | 67 | [tool.setuptools.package-data] 68 | "*" = ["res/*.png", "code_models/models/pretrained/*.json", "*.yaml"] 69 | 70 | [tool.ruff] 71 | select = [ 72 | "E", "F", "W", 73 | "A", 74 | "B", 75 | "D", 76 | "G", 77 | "I", 78 | "PT", 79 | "PTH", 80 | "RET", 81 | "SIM", 82 | "TCH", 83 | "NPY", 84 | ] 85 | # Never enforce `E501` (line length violations) and 'E741' (ambiguous variable names) 86 | # and 'G004' (do not use f-strings in logging) 87 | # and 'A003' (Shadowing python builtins) 88 | # and 'F401' (imported but unused) 89 | ignore = ["E501", "E741", "G004", "A003", "F401"] 90 | exclude = [ 91 | ".bzr", 92 | ".direnv", 93 | ".eggs", 94 | ".git", 95 | ".git-rewrite", 96 | ".hg", 97 | ".mypy_cache", 98 | ".nox", 99 | ".pants.d", 100 | ".pytype", 101 | ".ruff_cache", 102 | ".svn", 103 | ".tox", 104 | ".venv", 105 | "__pypackages__", 106 | "_build", 107 | "buck-out", 108 | "build", 109 | "dist", 110 | "node_modules", 111 | "venv", 112 | "docs/conf.py", 113 | "napari_cellseg3d/_tests/conftest.py", 114 | ] 115 | 116 | [tool.ruff.pydocstyle] 117 | convention = "google" 118 | 119 | [tool.black] 120 | line-length = 79 121 | 122 | [tool.isort] 123 | profile = "black" 124 | line_length = 79 125 | 126 | [project.optional-dependencies] 127 | pyqt5 = [ 128 | "pyqt5", 129 | ] 130 | pyside2 = [ 131 | "pyside2", 132 | ] 133 | pyside6 = [ 134 | "pyside6", 135 | ] 136 | onnx-cpu = [ 137 | "onnx", 138 | "onnxruntime" 139 | ] 140 | onnx-gpu = [ 141 | "onnx", 142 | "onnxruntime-gpu" 143 | ] 144 | wandb = [ 145 | "wandb" 146 | ] 147 | dev = [ 148 | "isort", 149 | "black", 150 | "ruff", 151 | "pre-commit", 152 | "tuna", 153 | "twine", 154 | ] 155 | docs = [ 156 | "jupyter-book", 157 | ] 158 | test = [ 159 | "pytest", 160 | "pytest_qt", 161 | "pytest-cov", 162 | "coverage", 163 | "tox", 164 | "twine", 165 | "onnx", 166 | "onnxruntime", 167 | ] 168 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | black 2 | coverage 3 | imageio-ffmpeg>=0.4.5 4 | isort 5 | itk 6 | jupyter-book 7 | pytest 8 | pytest-qt 9 | tox 10 | twine 11 | numpy 12 | napari[all]>=0.4.14 13 | QtPy 14 | opencv-python>=4.5.5 15 | pre-commit 16 | pyclesperanto-prototype>=0.22.0 17 | matplotlib>=3.4.1 18 | ruff 19 | tifffile>=2022.2.9 20 | torch>=1.11 21 | monai[nibabel,einops,tifffile]>=1.0.1 22 | pillow 23 | scikit-image>=0.19.2 24 | vispy>=0.9.6 25 | -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [metadata] 2 | name = napari_cellseg3d 3 | version = 0.2.2 4 | 5 | [options] 6 | packages = find: 7 | include_package_data = True 8 | python_requires = >=3.8 9 | package_dir = 10 | =. 11 | 12 | # add your package requirements here 13 | install_requires = 14 | numpy 15 | napari[all]>=0.4.14 16 | QtPy 17 | opencv-python>=4.5.5 18 | scikit-image>=0.19.2 19 | matplotlib>=3.4.1 20 | tifffile>=2022.2.9 21 | imageio-ffmpeg>=0.4.5 22 | torch>=1.11 23 | monai[nibabel,einops,tifffile]>=1.0.1 24 | itk 25 | tqdm 26 | nibabel 27 | pyclesperanto-prototype 28 | scikit-image 29 | pillow 30 | tqdm 31 | matplotlib 32 | vispy>=0.9.6 33 | 34 | [options.packages.find] 35 | where = . 36 | 37 | [options.package_data] 38 | napari_cellseg3d = 39 | res/*.png 40 | code_models/models/pretrained/*.json 41 | napari.yaml 42 | 43 | [options.entry_points] 44 | napari.manifest = 45 | napari_cellseg3d = napari_cellseg3d:napari.yaml 46 | -------------------------------------------------------------------------------- /tox.ini: -------------------------------------------------------------------------------- 1 | # For more information about tox, see https://tox.readthedocs.io/en/latest/ 2 | [tox] 3 | envlist = py{38,39,310}-{linux} 4 | ; envlist = py{38,39,310}-{linux,macos,windows} 5 | isolated_build=true 6 | 7 | [gh-actions] 8 | python = 9 | 3.8: py38 10 | 3.9: py39 11 | 3.10: py310 12 | 13 | [gh-actions:env] 14 | PLATFORM = 15 | ubuntu-latest: linux 16 | ; windows-latest: windows 17 | ; macos-latest: macos 18 | 19 | [testenv] 20 | platform = 21 | linux: linux 22 | ; windows: win32 23 | ; macos: darwin 24 | passenv = 25 | CI 26 | PYTHONPATH 27 | GITHUB_ACTIONS 28 | DISPLAY 29 | XAUTHORITY 30 | NUMPY_EXPERIMENTAL_ARRAY_FUNCTION 31 | PYVISTA_OFF_SCREEN 32 | deps = 33 | pytest # https://docs.pytest.org/en/latest/contents.html 34 | pytest-cov # https://pytest-cov.readthedocs.io/en/latest/ 35 | napari 36 | PyQt5 37 | magicgui 38 | pytest-qt 39 | qtpy 40 | git+https://github.com/lucasb-eyer/pydensecrf.git@master#egg=pydensecrf 41 | onnx 42 | onnxruntime 43 | ; pyopencl[pocl] 44 | ; opencv-python 45 | extras = crf 46 | usedevelop = true 47 | commands = pytest -v --color=yes --cov=napari_cellseg3d --cov-report=xml 48 | --------------------------------------------------------------------------------