├── .github ├── ISSUE_TEMPLATE │ ├── bug_report.md │ └── feature_request.md ├── PULL_REQUEST_TEMPLATE.md └── workflows │ ├── cibuildwheel.yml │ ├── docker-publish.yml │ └── docs.yaml ├── .gitignore ├── .pre-commit-config.yaml ├── CITATION.cff ├── Dockerfile ├── LICENSE ├── MANIFEST.in ├── README.md ├── artwork ├── overview.png ├── spotiflow_logo.png └── spotiflow_transp_small.png ├── docker-env-config.yml ├── docs ├── Makefile ├── _static │ └── spotiflow_transp_small.png ├── make.bat └── source │ ├── _static │ ├── spotiflow_napari_gui.png │ ├── spotiflow_napari_preds.png │ └── spotiflow_transp_small.png │ ├── api.rst │ ├── cli.rst │ ├── conf.py │ ├── finetune.rst │ ├── index.rst │ ├── installation.rst │ ├── napari.rst │ ├── pretrained.rst │ └── train.rst ├── examples ├── 1_train.ipynb ├── 2_inference.ipynb ├── 3_finetune.ipynb └── 4_train_3d.ipynb ├── extra ├── analyze_spot_clusters.py ├── annotation_ui.py └── run_starfish_spotiflow.py ├── pyproject.toml ├── scripts ├── predict_3d.py ├── predict_zarr_multigpu.py ├── train_simple.py └── train_simple_3d.py ├── setup.cfg ├── setup.py ├── spotiflow ├── __init__.py ├── augmentations │ ├── __init__.py │ ├── pipeline │ │ ├── __init__.py │ │ └── pipeline.py │ ├── test │ │ ├── test_pipeline.py │ │ └── transforms │ │ │ ├── test_crop.py │ │ │ ├── test_fliprot.py │ │ │ ├── test_intensity_shift.py │ │ │ ├── test_noise.py │ │ │ ├── test_rotation.py │ │ │ ├── test_scale.py │ │ │ └── test_translation.py │ ├── transforms │ │ ├── __init__.py │ │ ├── base.py │ │ ├── crop.py │ │ ├── fliprot.py │ │ ├── intensity_shift.py │ │ ├── noise.py │ │ ├── rotation.py │ │ ├── scale.py │ │ ├── translation.py │ │ └── utils.py │ └── transforms3d │ │ ├── __init__.py │ │ ├── crop.py │ │ ├── fliprot.py │ │ ├── intensity_shift.py │ │ ├── noise.py │ │ ├── rotation.py │ │ └── translation.py ├── cli │ ├── predict.py │ └── train.py ├── data │ ├── __init__.py │ ├── spots.py │ └── spots3d.py ├── lib │ ├── external │ │ └── nanoflann │ │ │ ├── LICENSE.txt │ │ │ └── nanoflann.hpp │ ├── filters.cpp │ ├── filters3d.cpp │ ├── point_nms.cpp │ ├── point_nms3d.cpp │ ├── spotflow2d.cpp │ └── spotflow3d.cpp ├── model │ ├── __init__.py │ ├── backbones │ │ ├── __init__.py │ │ ├── resnet.py │ │ └── unet.py │ ├── bg_remover.py │ ├── config.py │ ├── losses │ │ ├── __init__.py │ │ └── adaptive_wing.py │ ├── post.py │ ├── pretrained.py │ ├── spotiflow.py │ └── trainer.py ├── sample_data │ ├── __init__.py │ ├── datasets.py │ └── images │ │ ├── img_hybiss_2d.tif │ │ ├── img_synth_3d.tif │ │ ├── img_terra_2d.tif │ │ └── timelapse_telomeres_2d.tif ├── starfish │ ├── __init__.py │ └── spotiflow_wrapper.py ├── test │ ├── test_model_saveload.py │ └── test_prediction.py └── utils │ ├── __init__.py │ ├── fitting.py │ ├── get_file.py │ ├── matching.py │ ├── parallel.py │ ├── peaks.py │ └── utils.py ├── tests ├── test_data.py ├── test_fit.py ├── test_model.py ├── test_peaks.py ├── test_training.py ├── test_training_simple.py └── utils.py └── tox.ini /.github/ISSUE_TEMPLATE/bug_report.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Bug report 3 | about: Create a report to help us improve Spotiflow 4 | title: '' 5 | labels: '' 6 | assignees: '' 7 | 8 | --- 9 | 10 | **Describe the bug** 11 | A clear and concise description of what the bug is. 12 | 13 | **To Reproduce** 14 | Steps to reproduce the behavior: 15 | 16 | **Images/screenshots** 17 | If applicable, add images or screenshots to help reproduce/explain your problem. 18 | 19 | **Environment (please complete the following information):** 20 | - Spotiflow version: [e.g. 0.2.0] 21 | - Conda/mamba version (if applicable): [e.g. 4.13.0] 22 | - OS: [e.g. MacOS] 23 | - OS version: [e.g. 14.3] 24 | - GPU memory (if applicable): [e.g. 8GB] 25 | 26 | 27 | **Additional context** 28 | Add any other context about the problem here. 29 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/feature_request.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Feature request 3 | about: Suggest a new feature for Spotiflow 4 | title: '' 5 | labels: '' 6 | assignees: '' 7 | 8 | --- 9 | 10 | **Is your feature request related to a problem? Please describe.** 11 | A clear and concise description of what the problem is. Ex. I'm always frustrated when [...] 12 | 13 | **Describe the solution you'd like** 14 | A clear and concise description of what you want to happen. 15 | 16 | **Describe alternatives you've considered** 17 | A clear and concise description of any alternative solutions or features you've considered. 18 | 19 | **Additional context** 20 | Add any other context or screenshots about the feature request here. 21 | -------------------------------------------------------------------------------- /.github/PULL_REQUEST_TEMPLATE.md: -------------------------------------------------------------------------------- 1 | ## What type of PR is this? (check all applicable) 2 | 3 | - [ ] Bug Fix 4 | - [ ] Feature 5 | - [ ] Optimization 6 | - [ ] Documentation Update 7 | 8 | ## Description 9 | 10 | ## Related issues (if applicable) 11 | 12 | - Related Issue # 13 | - Closes # 14 | -------------------------------------------------------------------------------- /.github/workflows/cibuildwheel.yml: -------------------------------------------------------------------------------- 1 | # adapted from: 2 | # - https://github.com/matplotlib/matplotlib/blob/master/.github/workflows/cibuildwheel.yml 3 | # - https://github.com/scikit-image/scikit-image/blob/master/.github/workflows/cibuildwheel.yml 4 | # - https://github.com/pypa/cibuildwheel/blob/main/examples/github-deploy.yml 5 | # - https://github.com/stardist/stardist/blob/master/.github/workflows/cibuildwheel.yml 6 | 7 | name: tests 8 | 9 | on: 10 | push: 11 | branches: 12 | - main 13 | - wheels 14 | release: 15 | types: 16 | - published 17 | pull_request: 18 | branches: [ "main" ] 19 | 20 | jobs: 21 | test: 22 | name: Test on ${{ matrix.os }} with Python ${{ matrix.py }} 23 | runs-on: ${{ matrix.os }} 24 | if: github.event_name == 'push' && contains(github.ref, 'main') 25 | strategy: 26 | fail-fast: false 27 | matrix: 28 | os: [ubuntu-22.04] 29 | py: [cp311, cp312] 30 | steps: 31 | - uses: actions/checkout@v4 32 | name: Checkout repository 33 | 34 | - uses: actions/setup-python@v4 35 | name: Install Python 36 | with: 37 | python-version: '3.x' 38 | 39 | - name: Install cibuildwheel 40 | run: python -m pip install cibuildwheel 41 | - name: Build wheels for CPython (Linux) 42 | run: | 43 | python -m cibuildwheel --output-dir dist 44 | env: 45 | # only build for specific platforms 46 | CIBW_BUILD: "${{ matrix.py }}-*{x86_64,win_amd64}" 47 | CIBW_SKIP: "*musllinux*" 48 | 49 | CIBW_BUILD_VERBOSITY: 1 50 | CIBW_TEST_REQUIRES: pytest pytest-cov 51 | CIBW_TEST_COMMAND: pytest -v --cov=spotiflow {project} 52 | 53 | build_wheels: 54 | name: Build ${{ matrix.py }} wheels on ${{ matrix.os }} 55 | runs-on: ${{ matrix.os }} 56 | if: (github.event_name == 'release' && github.event.action == 'published') || (github.event_name == 'push' && contains(github.ref, 'wheels')) || github.event_name == 'pull_request' 57 | strategy: 58 | fail-fast: false 59 | matrix: 60 | os: [ubuntu-22.04, windows-2019, macos-13] 61 | py: [cp39, cp310, cp311, cp312, cp313] 62 | 63 | steps: 64 | - uses: actions/checkout@v4 65 | name: Checkout repository 66 | 67 | - uses: actions/setup-python@v4 68 | name: Install Python 69 | with: 70 | python-version: '3.x' 71 | 72 | - name: Install cibuildwheel 73 | run: python -m pip install cibuildwheel 74 | 75 | # https://scikit-learn.org/stable/developers/advanced_installation.html#macos 76 | - name: Setup OpenMP (macOS) 77 | if: startsWith(matrix.os, 'macos') 78 | shell: bash 79 | run: | 80 | brew config 81 | brew install libomp 82 | eval `brew shellenv` 83 | tee -a $GITHUB_ENV << END 84 | CC=/usr/bin/clang 85 | CXX=/usr/bin/clang++ 86 | CFLAGS=${CFLAGS} -I${HOMEBREW_PREFIX}/opt/libomp/include 87 | CXXFLAGS=${CXXFLAGS} -I${HOMEBREW_PREFIX}/opt/libomp/include 88 | LDFLAGS=${LDFLAGS} -Wl,-rpath,${HOMEBREW_PREFIX}/opt/libomp/lib -L${HOMEBREW_PREFIX}/opt/libomp/lib -lomp 89 | END 90 | 91 | - name: Build wheels for CPython (macOS) 92 | if: startsWith(matrix.os, 'macos') 93 | run: | 94 | python -m cibuildwheel --output-dir dist 95 | env: 96 | CIBW_BUILD: "${{ matrix.py }}-*" 97 | CIBW_ARCHS_MACOS: arm64 98 | CIBW_BUILD_VERBOSITY: 1 99 | CIBW_TEST_REQUIRES: pytest pytest-cov 100 | CIBW_TEST_COMMAND: pytest -v --cov=spotiflow {project} 101 | 102 | - name: Build wheels for CPython (Linux and Windows) 103 | if: startsWith(matrix.os, 'macos') == false 104 | run: | 105 | python -m cibuildwheel --output-dir dist 106 | env: 107 | # only build for specific platforms 108 | CIBW_BUILD: "${{ matrix.py }}-*{x86_64,win_amd64}" 109 | CIBW_SKIP: "*musllinux*" 110 | 111 | CIBW_BUILD_VERBOSITY: 1 112 | CIBW_TEST_REQUIRES: pytest pytest-cov 113 | CIBW_TEST_COMMAND: pytest -v --cov=spotiflow {project} 114 | 115 | 116 | - uses: actions/upload-artifact@v4 117 | name: Upload wheels 118 | with: 119 | name: dist-${{matrix.os}}-${{matrix.py}} 120 | path: ./dist/*.whl 121 | 122 | build_sdist: 123 | name: Build source distribution 124 | runs-on: ubuntu-latest 125 | if: (github.event_name == 'release' && github.event.action == 'published') || (github.event_name == 'push' && contains(github.ref, 'wheels')) 126 | steps: 127 | - uses: actions/checkout@v4 128 | name: Checkout repository 129 | 130 | - name: Build sdist 131 | run: pipx run build --sdist 132 | 133 | - uses: actions/upload-artifact@v4 134 | name: Upload sdist 135 | with: 136 | name: dist-${{matrix.os}}-${{matrix.py}} 137 | path: dist/*.tar.gz 138 | 139 | 140 | upload_pypi: 141 | name: Upload to PyPI 142 | needs: [build_wheels, build_sdist] 143 | runs-on: ubuntu-latest 144 | if: (github.event_name == 'release' && github.event.action == 'published') || (github.event_name == 'push' && contains(github.ref, 'wheels')) 145 | steps: 146 | - uses: actions/download-artifact@v4 147 | name: Download wheels and sdist 148 | with: 149 | path: dist 150 | pattern: dist-* 151 | merge-multiple: true 152 | 153 | - uses: pypa/gh-action-pypi-publish@release/v1 154 | name: Publish to PyPI 155 | with: 156 | user: __token__ 157 | password: ${{ secrets.PYPI_API_TOKEN }} 158 | verbose: true 159 | -------------------------------------------------------------------------------- /.github/workflows/docker-publish.yml: -------------------------------------------------------------------------------- 1 | name: Docker 2 | 3 | # This workflow uses actions that are not certified by GitHub. 4 | # They are provided by a third-party and are governed by 5 | # separate terms of service, privacy policy, and support 6 | # documentation. 7 | 8 | on: 9 | workflow_dispatch: 10 | release: 11 | types: 12 | - published 13 | 14 | env: 15 | # Use docker.io for Docker Hub if empty 16 | REGISTRY: ghcr.io 17 | # github.repository as / 18 | IMAGE_NAME: ${{ github.repository }} 19 | 20 | 21 | jobs: 22 | build: 23 | 24 | runs-on: ubuntu-latest 25 | permissions: 26 | contents: read 27 | packages: write 28 | # This is used to complete the identity challenge 29 | # with sigstore/fulcio when running outside of PRs. 30 | id-token: write 31 | 32 | steps: 33 | - name: Checkout repository 34 | uses: actions/checkout@v4 35 | 36 | # Install the cosign tool except on PR 37 | # https://github.com/sigstore/cosign-installer 38 | - name: Install cosign 39 | if: github.event_name != 'pull_request' 40 | uses: sigstore/cosign-installer@59acb6260d9c0ba8f4a2f9d9b48431a222b68e20 #v3.5.0 41 | with: 42 | cosign-release: 'v2.2.4' 43 | 44 | # Set up BuildKit Docker container builder to be able to build 45 | # multi-platform images and export cache 46 | # https://github.com/docker/setup-buildx-action 47 | - name: Set up Docker Buildx 48 | uses: docker/setup-buildx-action@f95db51fddba0c2d1ec667646a06c2ce06100226 # v3.0.0 49 | 50 | # Login against a Docker registry except on PR 51 | # https://github.com/docker/login-action 52 | - name: Log into registry ${{ env.REGISTRY }} 53 | if: github.event_name != 'pull_request' 54 | uses: docker/login-action@343f7c4344506bcbf9b4de18042ae17996df046d # v3.0.0 55 | with: 56 | registry: ${{ env.REGISTRY }} 57 | username: ${{ github.actor }} 58 | password: ${{ secrets.GITHUB_TOKEN }} 59 | 60 | # Extract metadata (tags, labels) for Docker 61 | # https://github.com/docker/metadata-action 62 | - name: Extract Docker metadata 63 | id: meta 64 | uses: docker/metadata-action@96383f45573cb7f253c731d3b3ab81c87ef81934 # v5.0.0 65 | with: 66 | images: ${{ env.REGISTRY }}/${{ env.IMAGE_NAME }} 67 | 68 | # Build and push Docker image with Buildx (don't push on PR) 69 | # https://github.com/docker/build-push-action 70 | - name: Build and push Docker image 71 | id: build-and-push 72 | uses: docker/build-push-action@0565240e2d4ab88bba5387d719585280857ece09 # v5.0.0 73 | with: 74 | context: . 75 | platforms: linux/amd64,linux/arm64 76 | push: ${{ github.event_name != 'pull_request' }} 77 | tags: ${{ steps.meta.outputs.tags }} 78 | labels: ${{ steps.meta.outputs.labels }} 79 | cache-from: type=gha 80 | cache-to: type=gha,mode=max 81 | 82 | # Sign the resulting Docker image digest except on PRs. 83 | # This will only write to the public Rekor transparency log when the Docker 84 | # repository is public to avoid leaking data. If you would like to publish 85 | # transparency data even for private images, pass --force to cosign below. 86 | # https://github.com/sigstore/cosign 87 | - name: Sign the published Docker image 88 | if: ${{ github.event_name != 'pull_request' }} 89 | env: 90 | # https://docs.github.com/en/actions/security-guides/security-hardening-for-github-actions#using-an-intermediate-environment-variable 91 | TAGS: ${{ steps.meta.outputs.tags }} 92 | DIGEST: ${{ steps.build-and-push.outputs.digest }} 93 | # This step uses the identity token to provision an ephemeral certificate 94 | # against the sigstore community Fulcio instance. 95 | run: echo "${TAGS}" | xargs -I {} cosign sign --yes {}@${DIGEST} 96 | -------------------------------------------------------------------------------- /.github/workflows/docs.yaml: -------------------------------------------------------------------------------- 1 | name: documentation 2 | 3 | on: 4 | workflow_run: 5 | workflows: [tests] 6 | types: 7 | - completed 8 | workflow_dispatch: 9 | 10 | permissions: 11 | contents: write 12 | 13 | jobs: 14 | docs: 15 | if: ${{ (github.event.workflow_run.conclusion == 'success' || github.event_name == 'workflow_dispatch') && !(github.event.workflow_run.event == 'push')}} 16 | runs-on: ubuntu-latest 17 | steps: 18 | - uses: actions/checkout@v3 19 | - uses: actions/setup-python@v5 20 | with: 21 | python-version: '3.12' 22 | - name: Install dependencies 23 | run: | 24 | pip install spotiflow[docs] 25 | - name: Sphinx build 26 | run: | 27 | sphinx-build docs/source docs/build 28 | - name: Deploy to GitHub Pages 29 | uses: peaceiris/actions-gh-pages@v3 30 | with: 31 | publish_branch: gh-pages 32 | github_token: ${{ secrets.GITHUB_TOKEN }} 33 | publish_dir: docs/build/ 34 | force_orphan: true 35 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | led / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib64/ 18 | parts/ 19 | sdist/ 20 | var/ 21 | wheels/ 22 | pip-wheel-metadata/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | 53 | # Translations 54 | *.mo 55 | *.pot 56 | 57 | # Django stuff: 58 | *.log 59 | local_settings.py 60 | db.sqlite3 61 | db.sqlite3-journal 62 | 63 | # Flask stuff: 64 | instance/ 65 | .webassets-cache 66 | 67 | # Scrapy stuff: 68 | .scrapy 69 | 70 | # Sphinx documentation 71 | docs/_build/ 72 | 73 | # PyBuilder 74 | target/ 75 | 76 | # Jupyter Notebook 77 | .ipynb_checkpoints 78 | 79 | # IPython 80 | profile_default/ 81 | ipython_config.py 82 | 83 | # pyenv 84 | .python-version 85 | 86 | # pipenv 87 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 88 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 89 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 90 | # install all needed dependencies. 91 | #Pipfile.lock 92 | 93 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 94 | __pypackages__/ 95 | 96 | # Celery stuff 97 | celerybeat-schedule 98 | celerybeat.pid 99 | 100 | # SageMath parsed files 101 | *.sage.py 102 | 103 | # Environments 104 | .env 105 | .venv 106 | env/ 107 | venv/ 108 | ENV/ 109 | env.bak/ 110 | venv.bak/ 111 | 112 | # Spyder project settings 113 | .spyderproject 114 | .spyproject 115 | 116 | # Rope project settings 117 | .ropeproject 118 | 119 | # mkdocs documentation 120 | /site 121 | 122 | # mypy 123 | .mypy_cache/ 124 | .dmypy.json 125 | dmypy.json 126 | 127 | # Pyre type checker 128 | .pyre/ 129 | 130 | # Version 131 | cbaidt/_version.py 132 | 133 | # Static files 134 | stat/ 135 | 136 | # wandb 137 | **/wandb/** 138 | 139 | 140 | # Logs 141 | **/logs/** 142 | **/lightning_logs/** 143 | 144 | 145 | # DS_Store 146 | **/.DS_Store 147 | 148 | # Version file 149 | spotiflow/_version.py 150 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | - repo: https://github.com/astral-sh/ruff-pre-commit 3 | rev: v0.0.254 4 | hooks: 5 | - id: ruff 6 | args: [--fix, --fix-only, --exit-non-zero-on-fix] 7 | - repo: https://github.com/pre-commit/pre-commit-hooks 8 | rev: v4.3.0 9 | hooks: 10 | - id: check-yaml 11 | - id: end-of-file-fixer 12 | - id: trailing-whitespace 13 | - repo: https://github.com/psf/black 14 | rev: 22.6.0 15 | hooks: 16 | - id: black 17 | - repo: https://github.com/mwouts/jupytext 18 | rev: v1.14.0 19 | hooks: 20 | - id: jupytext 21 | args: [--from, ipynb, --to, "py:percent"] 22 | -------------------------------------------------------------------------------- /CITATION.cff: -------------------------------------------------------------------------------- 1 | cff-version: 1.2.0 2 | message: If you use this software, please cite both the article from preferred-citation and the software itself. 3 | authors: 4 | - family-names: Dominguez Mantes 5 | given-names: Albert 6 | - family-names: Herrera 7 | given-names: Antonio 8 | - family-names: Khven 9 | given-names: Irina 10 | - family-names: Schlaeppi 11 | given-names: Anjalie 12 | - family-names: Kyriacou 13 | given-names: Eftychia 14 | - family-names: Tsissios 15 | given-names: Georgios 16 | - family-names: Skoufa 17 | given-names: Evangelia 18 | - family-names: Santangeli 19 | given-names: Luca 20 | - family-names: Buglakova 21 | given-names: Elena 22 | - family-names: Durmus 23 | given-names: Emine Berna 24 | - family-names: Manley 25 | given-names: Suliana 26 | - family-names: Kreshuk 27 | given-names: Anna 28 | - family-names: Arendt 29 | given-names: Detlev 30 | - family-names: Aztekin 31 | given-names: Can 32 | - family-names: Lingner 33 | given-names: Joachim 34 | - family-names: La Manno 35 | given-names: Gioele 36 | - family-names: Weigert 37 | given-names: Martin 38 | title: 'Spotiflow: accurate and efficient spot detection for fluorescence microscopy with deep stereographic flow regression' 39 | version: 1.0.0 40 | url: https://doi.org/10.1038/s41592-025-02662-x 41 | doi: 10.1038/s41592-025-02662-x 42 | date-released: 2025-06-06 43 | preferred-citation: 44 | authors: 45 | - family-names: Dominguez Mantes 46 | given-names: Albert 47 | - family-names: Herrera 48 | given-names: Antonio 49 | - family-names: Khven 50 | given-names: Irina 51 | - family-names: Schlaeppi 52 | given-names: Anjalie 53 | - family-names: Kyriacou 54 | given-names: Eftychia 55 | - family-names: Tsissios 56 | given-names: Georgios 57 | - family-names: Skoufa 58 | given-names: Evangelia 59 | - family-names: Santangeli 60 | given-names: Luca 61 | - family-names: Buglakova 62 | given-names: Elena 63 | - family-names: Durmus 64 | given-names: Emine Berna 65 | - family-names: Manley 66 | given-names: Suliana 67 | - family-names: Kreshuk 68 | given-names: Anna 69 | - family-names: Arendt 70 | given-names: Detlev 71 | - family-names: Aztekin 72 | given-names: Can 73 | - family-names: Lingner 74 | given-names: Joachim 75 | - family-names: La Manno 76 | given-names: Gioele 77 | - family-names: Weigert 78 | given-names: Martin 79 | title: 'Spotiflow: accurate and efficient spot detection for fluorescence microscopy with deep stereographic flow regression' 80 | doi: 10.1038/s41592-025-02662-x 81 | url: https://doi.org/10.1038/s41592-025-02662-x 82 | journal: Nature Methods 83 | year: 2025 84 | -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | FROM mambaorg/micromamba:noble 2 | LABEL authors="Albert Dominguez, Miguel Ibarra" 3 | 4 | # Set the base layer for micromamba 5 | USER root 6 | COPY docker-env-config.yml . 7 | 8 | RUN apt-get update -qq && apt-get install -y \ 9 | build-essential \ 10 | ffmpeg \ 11 | libsm6 \ 12 | libxext6 \ 13 | procps \ 14 | git 15 | 16 | # Set the environment variable for the root prefix 17 | ARG MAMBA_ROOT_PREFIX=/opt/conda 18 | 19 | # Add /opt/conda/bin to the PATH 20 | ENV PATH=$MAMBA_ROOT_PREFIX/bin:$PATH 21 | 22 | # Install stuff with micromamba 23 | RUN micromamba env create -f docker-env-config.yml --yes && \ 24 | micromamba clean --all --yes 25 | 26 | # Add environment to PATH 27 | ENV PATH="/opt/conda/envs/spotiflow/bin:$PATH" 28 | 29 | # Set the working directory 30 | WORKDIR /spotiflow 31 | 32 | # Copy contents of the folder to the working directory 33 | COPY . . -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | 2 | Copyright (c) 2023, Albert Dominguez Mantes, Martin Weigert 3 | All rights reserved. 4 | 5 | Redistribution and use in source and binary forms, with or without 6 | modification, are permitted provided that the following conditions are met: 7 | 8 | * Redistributions of source code must retain the above copyright notice, this 9 | list of conditions and the following disclaimer. 10 | 11 | * Redistributions in binary form must reproduce the above copyright notice, 12 | this list of conditions and the following disclaimer in the documentation 13 | and/or other materials provided with the distribution. 14 | 15 | * Neither the name of spotiflow nor the names of its 16 | contributors may be used to endorse or promote products derived from 17 | this software without specific prior written permission. 18 | 19 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 20 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 21 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 22 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 23 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 24 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 25 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 26 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 27 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 28 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 29 | -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | include spotiflow/lib/*cpp 2 | include spotiflow/lib/external/nanoflann/* 3 | include spotiflow/sample_data/images/* -------------------------------------------------------------------------------- /artwork/overview.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/weigertlab/spotiflow/8545981b0189c4faac0bfa6278376493fe6462b7/artwork/overview.png -------------------------------------------------------------------------------- /artwork/spotiflow_logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/weigertlab/spotiflow/8545981b0189c4faac0bfa6278376493fe6462b7/artwork/spotiflow_logo.png -------------------------------------------------------------------------------- /artwork/spotiflow_transp_small.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/weigertlab/spotiflow/8545981b0189c4faac0bfa6278376493fe6462b7/artwork/spotiflow_transp_small.png -------------------------------------------------------------------------------- /docker-env-config.yml: -------------------------------------------------------------------------------- 1 | name: spotiflow 2 | channels: 3 | - pytorch 4 | - conda-forge 5 | dependencies: 6 | - "python=3.11" 7 | - "pytorch" 8 | - "torchvision" 9 | - "cpuonly" 10 | - "zarr" 11 | - pip: 12 | - "spotiflow" 13 | -------------------------------------------------------------------------------- /docs/Makefile: -------------------------------------------------------------------------------- 1 | # Minimal makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line, and also 5 | # from the environment for the first two. 6 | SPHINXOPTS ?= 7 | SPHINXBUILD ?= sphinx-build 8 | SOURCEDIR = source 9 | BUILDDIR = build 10 | 11 | # Put it first so that "make" without argument is like "make help". 12 | help: 13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 14 | 15 | .PHONY: help Makefile 16 | 17 | # Catch-all target: route all unknown targets to Sphinx using the new 18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 19 | %: Makefile 20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 21 | -------------------------------------------------------------------------------- /docs/_static/spotiflow_transp_small.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/weigertlab/spotiflow/8545981b0189c4faac0bfa6278376493fe6462b7/docs/_static/spotiflow_transp_small.png -------------------------------------------------------------------------------- /docs/make.bat: -------------------------------------------------------------------------------- 1 | @ECHO OFF 2 | 3 | pushd %~dp0 4 | 5 | REM Command file for Sphinx documentation 6 | 7 | if "%SPHINXBUILD%" == "" ( 8 | set SPHINXBUILD=sphinx-build 9 | ) 10 | set SOURCEDIR=source 11 | set BUILDDIR=build 12 | 13 | %SPHINXBUILD% >NUL 2>NUL 14 | if errorlevel 9009 ( 15 | echo. 16 | echo.The 'sphinx-build' command was not found. Make sure you have Sphinx 17 | echo.installed, then set the SPHINXBUILD environment variable to point 18 | echo.to the full path of the 'sphinx-build' executable. Alternatively you 19 | echo.may add the Sphinx directory to PATH. 20 | echo. 21 | echo.If you don't have Sphinx installed, grab it from 22 | echo.https://www.sphinx-doc.org/ 23 | exit /b 1 24 | ) 25 | 26 | if "%1" == "" goto help 27 | 28 | %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 29 | goto end 30 | 31 | :help 32 | %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 33 | 34 | :end 35 | popd 36 | -------------------------------------------------------------------------------- /docs/source/_static/spotiflow_napari_gui.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/weigertlab/spotiflow/8545981b0189c4faac0bfa6278376493fe6462b7/docs/source/_static/spotiflow_napari_gui.png -------------------------------------------------------------------------------- /docs/source/_static/spotiflow_napari_preds.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/weigertlab/spotiflow/8545981b0189c4faac0bfa6278376493fe6462b7/docs/source/_static/spotiflow_napari_preds.png -------------------------------------------------------------------------------- /docs/source/_static/spotiflow_transp_small.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/weigertlab/spotiflow/8545981b0189c4faac0bfa6278376493fe6462b7/docs/source/_static/spotiflow_transp_small.png -------------------------------------------------------------------------------- /docs/source/api.rst: -------------------------------------------------------------------------------- 1 | API Reference 2 | ------------- 3 | 4 | .. autoclass:: spotiflow.model.spotiflow.Spotiflow 5 | :members: from_pretrained, from_folder, predict, fit, save, optimize_threshold 6 | 7 | .. autoclass:: spotiflow.model.config.SpotiflowModelConfig 8 | :members: 9 | 10 | .. autoclass:: spotiflow.model.config.SpotiflowTrainingConfig 11 | :members: 12 | 13 | .. autoclass:: spotiflow.data.spots.SpotsDataset 14 | :members: 15 | 16 | .. automethod:: __init__ 17 | 18 | .. autoclass:: spotiflow.data.spots3d.Spots3DDataset 19 | :members: 20 | 21 | .. automethod:: __init__ 22 | 23 | .. automodule:: spotiflow.utils 24 | :members: get_data, read_coords_csv, write_coords_csv, normalize 25 | 26 | .. automodule:: spotiflow.sample_data 27 | :members: 28 | 29 | .. autoclass:: spotiflow.starfish.SpotiflowDetector 30 | :members: run 31 | -------------------------------------------------------------------------------- /docs/source/cli.rst: -------------------------------------------------------------------------------- 1 | Inference via CLI 2 | ----------------- 3 | 4 | Command Line Interface (CLI) 5 | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^ 6 | You can use the CLI to run inference on an image or folder containing several images. To do that, you can use the following command 7 | 8 | .. code-block:: console 9 | 10 | $ spotiflow-predict PATH 11 | 12 | where ``PATH`` can be either point to an image file or to a folder containing different files. By default, the command will use the ``general`` pretrained model. You can specify a different model by using the ``--pretrained-model`` flag for the pre-trained models we offer or ``--model-dir`` if you want to use a custom model. After running, the detected spots are by default saved to a subfolder ``spotiflow_results`` created inside the input folder (this can be changed with the ``--out-dir`` flag). For more information, please refer to the help message of the CLI (simply run ``spotiflow-predict -h`` on your command line). 13 | -------------------------------------------------------------------------------- /docs/source/conf.py: -------------------------------------------------------------------------------- 1 | # Configuration file for the Sphinx documentation builder. 2 | # 3 | # For the full list of built-in configuration values, see the documentation: 4 | # https://www.sphinx-doc.org/en/master/usage/configuration.html 5 | 6 | # -- Project information ----------------------------------------------------- 7 | # https://www.sphinx-doc.org/en/master/usage/configuration.html#project-information 8 | from spotiflow import __version__ as __spotiflow_version__ 9 | 10 | project = 'Spotiflow' 11 | copyright = '2024, Albert Dominguez Mantes' 12 | author = 'Albert Dominguez Mantes' 13 | release = __spotiflow_version__ 14 | 15 | # -- General configuration --------------------------------------------------- 16 | # https://www.sphinx-doc.org/en/master/usage/configuration.html#general-configuration 17 | 18 | extensions = [ 19 | 'sphinx.ext.duration', 20 | "sphinx_immaterial", 21 | 'sphinx.ext.napoleon', 22 | 'sphinx.ext.autodoc', 23 | 'sphinx.ext.autosummary', 24 | 'sphinx.ext.autosectionlabel', 25 | 'sphinx_immaterial.apidoc.python.apigen', 26 | ] 27 | templates_path = ['_templates'] 28 | exclude_patterns = [] 29 | 30 | autosectionlabel_prefix_document = True 31 | 32 | python_apigen_modules = { 33 | "spotiflow": "api", 34 | } 35 | python_apigen_order_tiebreaker = "definition_order" 36 | 37 | 38 | # -- Options for HTML output ------------------------------------------------- 39 | # https://www.sphinx-doc.org/en/master/usage/configuration.html#options-for-html-output 40 | 41 | html_theme = 'sphinx_immaterial' 42 | html_static_path = ['_static'] 43 | html_logo = "_static/spotiflow_transp_small.png" 44 | 45 | 46 | html_theme_options = { 47 | "icon": { 48 | "repo": "fontawesome/brands/github", 49 | "edit": "material/file-edit-outline", 50 | }, 51 | # "site_url": "https://jbms.github.io/sphinx-immaterial/", 52 | "repo_url": "https://github.com/weigertlab/spotiflow", 53 | "repo_name": "Spotiflow", 54 | "edit_uri": "blob/main/docs", 55 | "globaltoc_collapse": True, 56 | "features": [ 57 | "navigation.expand", 58 | "navigation.sections", 59 | "navigation.top", 60 | "search.share", 61 | "toc.follow", 62 | "toc.sticky", 63 | "content.tabs.link", 64 | "announce.dismiss", 65 | ], 66 | "palette": [ 67 | { 68 | "media": "(prefers-color-scheme: light)", 69 | "scheme": "default", 70 | "primary": "light-green", 71 | "accent": "light-blue", 72 | "toggle": { 73 | "icon": "material/lightbulb-outline", 74 | "name": "Switch to dark mode", 75 | }, 76 | }, 77 | { 78 | "media": "(prefers-color-scheme: dark)", 79 | "scheme": "slate", 80 | "primary": "deep-orange", 81 | "accent": "lime", 82 | "toggle": { 83 | "icon": "material/lightbulb", 84 | "name": "Switch to light mode", 85 | }, 86 | }, 87 | ], 88 | "toc_title_is_page_title": True, 89 | } 90 | -------------------------------------------------------------------------------- /docs/source/finetune.rst: -------------------------------------------------------------------------------- 1 | Fine-tuning a Spotiflow model on a custom dataset 2 | ------------------------------------------------- 3 | 4 | Data format 5 | ^^^^^^^^^^^ 6 | 7 | See :ref:`train:Data format`. 8 | 9 | Fine-tuning (CLI) 10 | ^^^^^^^^^^^^^^^^^ 11 | 12 | You can easily fine-tune from an existing model by simply adding an argument to the CLI call. See :ref:`train:Basic training (CLI)` for more information. 13 | 14 | .. code-block:: console 15 | 16 | spotiflow-train /path/to/spots_data -o /path/to/my_finetuned_model --finetune-from general 17 | 18 | where `/path/to/my_finetuned_model` is the path to the directory containing the model you want to fine-tune. You can also pass other parameters to the training, such as the number of epochs, the learning rate, etc. For more information on the arguments allowed, please refer to the help message of the CLI (simply run ``spotiflow-train -h`` in your command line). 19 | 20 | 21 | Fine-tuning (API) 22 | ^^^^^^^^^^^^^^^^^ 23 | 24 | In order to fine-tune a pre-trained model on a custom dataset using the API, you can simply load the model very similarly to what you would normally do to predict on new images (you only need to add one extra parameter!): 25 | 26 | .. code-block:: python 27 | 28 | from spotiflow.model import Spotiflow 29 | from spotiflow.utils import get_data 30 | 31 | # Get the data 32 | train_imgs, train_spots, val_imgs, val_spots = get_data("/path/to/spots_data") 33 | 34 | # Initialize the model 35 | model = Spotiflow.from_pretrained( 36 | "general", 37 | inference_mode=False, 38 | ) 39 | 40 | # Train and save the model 41 | model.fit( 42 | train_imgs, 43 | train_spots, 44 | val_imgs, 45 | val_spots, 46 | save_dir="/path/to/my_finetuned_model", 47 | ) 48 | 49 | Of course, you can also fine-tune from a model you have trained before. In that case, use the ``from_folder()`` method instead of ``from_pretrained()`` (see :ref:`index:Predicting spots in an image`). 50 | All the information about training customization from :ref:`train:Customizing the training` applies here as well. However, note that you cannot change the model architecture when fine-tuning! 51 | -------------------------------------------------------------------------------- /docs/source/index.rst: -------------------------------------------------------------------------------- 1 | :hero: Spotiflow: accurate and robust spot detection for fluorescence microscopy 2 | 3 | ========= 4 | Spotiflow 5 | ========= 6 | 7 | Spotiflow is a learning-based subpixel-accurate spot detection method for 2D and 3D fluorescence microscopy. It is primarily developed for spatial transcriptomics workflows that require transcript detection in large, multiplexed FISH-images, although it can also be used to detect spot-like structures in general fluorescence microscopy images and volumes. For more information, please refer to our `paper `__. 8 | 9 | Getting Started 10 | --------------- 11 | 12 | Installation (pip) 13 | ~~~~~~~~~~~~~~~~~~ 14 | 15 | 16 | First, create and activate a fresh ``conda`` environment (we currently support Python 3.9 to 3.13). If you don't have ``conda`` installed, we recommend using `miniforge `__. 17 | 18 | .. code-block:: console 19 | 20 | $ conda create -n spotiflow python=3.12 21 | $ conda activate spotiflow 22 | 23 | Then, install PyTorch using ``pip``: 24 | 25 | .. code-block:: console 26 | 27 | $ pip install torch 28 | 29 | **Note (for Linux/Windows users with a CUDA-capable GPU)**: one might need to change the torch installation command depending on the CUDA version. Please refer to the `PyTorch website ` for more information. 30 | 31 | **Note (for Windows users):** if using Windows, if using Windows, please install the latest `Build Tools for Visual Studio `__ (make sure to select the C++ build tools during installation) before proceeding to install Spotiflow. 32 | 33 | Finally, install ``spotiflow`` using ``pip``: 34 | 35 | .. code-block:: console 36 | 37 | $ pip install spotiflow 38 | 39 | 40 | Installation (conda) 41 | ~~~~~~~~~~~~~~~~~~~~ 42 | 43 | For Linux/MacOS users, you can also install Spotiflow using ``conda`` through the ``conda-forge``. For directly creating a fresh environment with Spotiflow named ``spotiflow``, you can use the following command: 44 | 45 | .. code-block:: console 46 | 47 | $ conda create -n spotiflow -c conda-forge spotiflow python=3.12 48 | 49 | Predicting spots in an image 50 | ~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 51 | 52 | Python API 53 | ^^^^^^^^^^ 54 | 55 | The snippet below shows how to retrieve the spots from an image using one of the pretrained models: 56 | 57 | .. code-block:: python 58 | 59 | from skimage.io import imread 60 | from spotiflow.model import Spotiflow 61 | from spotiflow.utils import write_coords_csv 62 | 63 | 64 | # Load the desired image 65 | img = imread("/path/to/your/image") 66 | 67 | # Load a pretrained model 68 | model = Spotiflow.from_pretrained("general") 69 | 70 | # Predict spots 71 | spots, details = model.predict(img) # predict expects a numpy array 72 | 73 | # spots is a numpy array with shape (n_spots, 2) 74 | # details contains additional information about the prediction, like the predicted heatmap, the probability per spot, the flow field, etc. 75 | 76 | # Save the results to a CSV file 77 | write_coords_csv(spots, "/path/to/save/spots.csv") 78 | 79 | If a custom model needs to be used, simply change the model loading step to: 80 | 81 | .. code-block:: python 82 | 83 | # Load a custom model 84 | model = Spotiflow.from_folder("/path/to/model") 85 | 86 | Command Line Interface (CLI) 87 | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^ 88 | You can use the CLI to run inference on an image or folder containing several images. To do that, you can use the following command 89 | 90 | .. code-block:: console 91 | 92 | $ spotiflow-predict PATH 93 | 94 | where ``PATH`` can be either point to an image file or to a folder containing different files. By default, the command will use the ``general`` pretrained model. You can specify a different model by using the ``--pretrained-model`` flag for the pre-trained models we offer or ``--model-dir`` if you want to use a custom model. After running, the detected spots are by default saved to a subfolder ``spotiflow_results`` created inside the input folder (this can be changed with the ``--out-dir`` flag). For more information, please refer to the help message of the CLI (simply run ``spotiflow-predict -h`` on your command line). 95 | 96 | Napari plugin 97 | ^^^^^^^^^^^^^ 98 | Spotiflow also can be run easily in a graphical user interface as a `napari `__ plugin. See :ref:`napari:Predicting spots using the napari plugin` for more information. 99 | 100 | Contents 101 | -------- 102 | 103 | .. toctree:: 104 | :maxdepth: 2 105 | 106 | installation 107 | napari 108 | cli 109 | pretrained 110 | train 111 | finetune 112 | api 113 | -------------------------------------------------------------------------------- /docs/source/installation.rst: -------------------------------------------------------------------------------- 1 | Installation instructions 2 | ------------------------- 3 | 4 | Installation (pip) 5 | ~~~~~~~~~~~~~~~~~~ 6 | 7 | 8 | First, create and activate a fresh ``conda`` environment (we currently support Python 3.9 to 3.13). If you don't have ``conda`` installed, we recommend using `miniforge `__. 9 | 10 | .. code-block:: console 11 | 12 | $ conda create -n spotiflow python=3.12 13 | $ conda activate spotiflow 14 | 15 | Then, install PyTorch using ``pip``: 16 | 17 | .. code-block:: console 18 | 19 | $ pip install torch 20 | 21 | **Note (for Linux/Windows users with a CUDA-capable GPU)**: one might need to change the torch installation command depending on the CUDA version. Please refer to the `PyTorch website ` for more information. 22 | 23 | **Note (for Windows users):** if using Windows, if using Windows, please install the latest `Build Tools for Visual Studio `__ (make sure to select the C++ build tools during installation) before proceeding to install Spotiflow. 24 | 25 | Finally, install ``spotiflow`` using ``pip``: 26 | 27 | .. code-block:: console 28 | 29 | $ pip install spotiflow 30 | 31 | 32 | Installation (conda) 33 | ~~~~~~~~~~~~~~~~~~~~ 34 | 35 | For Linux/MacOS users, you can also install Spotiflow using ``conda`` through the ``conda-forge``. For directly creating a fresh environment with Spotiflow named ``spotiflow``, you can use the following command: 36 | 37 | .. code-block:: console 38 | 39 | $ conda create -n spotiflow -c conda-forge spotiflow python=3.12 -------------------------------------------------------------------------------- /docs/source/napari.rst: -------------------------------------------------------------------------------- 1 | Predicting spots using the napari plugin 2 | ---------------------------------------- 3 | 4 | The napari plugin can be used to predict spots in a napari viewer. First, you must install it in the environment containing Spotiflow: 5 | 6 | .. code-block:: console 7 | 8 | (spotiflow) $ pip install napari-spotiflow 9 | 10 | The plugin will then be available in the napari GUI under the name ``Spotiflow widget``. This is how the GUI looks like: 11 | 12 | .. image:: ./_static/spotiflow_napari_gui.png 13 | :width: 700 14 | :align: center 15 | 16 | The plugin allows running on two modes: for images (``2D``) and volumes (``3D``), which can be toggled using the corresponding buttons in the GUI. Spotiflow can also be run on time-lapses by setting the appropriate axis order (should be leading with a `T`). 17 | 18 | Upon pressing the button ``Detect spots``, The plugin will create a ``Points`` layer containing the predicted spots: 19 | 20 | .. image:: ./_static/spotiflow_napari_preds.png 21 | :width: 700 22 | :align: center 23 | 24 | If the option ``Show CNN output`` is checked, the plugin will also create two ``Image`` layers containing the heatmap output of the CNN as well as the stereographic flow. 25 | 26 | Finally, the plugin includes two sample 2D images (``HybISS`` and ``Terra``), a synthetic 3D volume (``Synthetic (3D)``) and a time-lapse (``Telomeres (2D+t)``). These samples can be loaded from the ``File`` menu (``File -> Open sample -> napari-spotiflow``). You can try using the plugin with these samples to get a better idea of how it works! -------------------------------------------------------------------------------- /docs/source/pretrained.rst: -------------------------------------------------------------------------------- 1 | Available pre-trained models 2 | ---------------------------- 3 | 4 | The following pre-trained models are available (for a more detailed description, please refer to the *Methods* section of the paper as well as *Supplementary Table 2*): 5 | 6 | - ``general``: trained on a diverse dataset of spots of different modalities acquired in different microscopes with different settings. This model is the default one used in the CLI (pixel sizes: 0.04 µm, 0.1 µm, 0.11 µm, 0.15 µm, 0.32 µm, 0.34 µm). 7 | - ``hybiss``: trained on HybISS data acquired in 3 different microscopes (pixel sizes: 0.15 µm, 0.32 µm, 0.34 µm). 8 | - ``synth_complex``: trained on synthetic data, which includes simulations of aberrated spots and fluorescence background (pixel size: 0.1 µm). 9 | - ``synth_3d``: trained on synthetic 3D data, which includes simulations of aberrated spots and Z-related artifacts (voxel size: 0.2 µm). 10 | - ``smfish_3d``: fine-tuned from the ``synth_3d`` model on smFISH 3D data of *Platynereis dumerilii* (voxel size: 0.13 µm (YX), 0.48 µm (Z)). 11 | 12 | You can use these models to predict spots in images or to fine-tune them on a few annotations of your own data. The models can be loaded via the API as follows: 13 | 14 | .. code-block:: python 15 | 16 | from spotiflow.model import Spotiflow 17 | 18 | pretrained_model_name = "general" 19 | model = Spotiflow.from_pretrained(pretrained_model_name) 20 | 21 | You can also load them from the napari plugin or from the CLI by specifying the name of the model. See :ref:`napari:Predicting spots using the napari plugin` and :ref:`cli:Inference via CLI` for more information respectively. 22 | -------------------------------------------------------------------------------- /docs/source/train.rst: -------------------------------------------------------------------------------- 1 | Training a Spotiflow model on a custom dataset 2 | ---------------------------------------------- 3 | 4 | Data format 5 | ^^^^^^^^^^^ 6 | 7 | First of all, make sure that the data is organized in the following format: 8 | 9 | :: 10 | 11 | spots_data 12 | ├── train 13 | │ ├── img_001.csv 14 | │ ├── img_001.tif 15 | | ... 16 | │ ├── img_XYZ.csv 17 | | └── img_XYZ.tif 18 | └── val 19 | ├── val_img_001.csv 20 | ├── val_img_001.tif 21 | ... 22 | ├── val_img_XYZ.csv 23 | └── val_img_XYZ.tif 24 | 25 | The actual naming of the files is not important, but the ``.csv`` and ``.tif`` files corresponding to the same image **must** have the same name! The ``.csv`` files must contain the spot coordinates in the following format: 26 | 27 | .. code-block:: 28 | 29 | y,x 30 | 42.3,24.24 31 | 252.99, 307.97 32 | ... 33 | 34 | The column names can also be `axis-0` (instead of `y`) and `axis-1` instead of `x`. For the 3D case, the format is similar but with an additional column corresponding to the `z` coordinate: 35 | 36 | .. code-block:: 37 | 38 | z,y,x 39 | 12.4,42.3,24.24 40 | 61.2,252.99, 307.97 41 | ... 42 | 43 | In this case, you can also use `axis-0`, `axis-1`, and `axis-2` instead of `z`, `y`, and `x`, respectively. 44 | 45 | 46 | Basic training (CLI) 47 | ^^^^^^^^^^^^^^^^^^^^ 48 | 49 | You can train a model using the CLI as follows: 50 | 51 | .. code-block:: console 52 | 53 | spotiflow-train /path/to/spots_data -o /path/to/my_trained_model 54 | 55 | where `/path/to/spots_data` is the path to the directory containing the data in the format described above and `/path/to/my_trained_model` is the directory where the trained model will be saved. You can also pass other parameters to the training, such as the number of epochs, the learning rate, etc. For more information on the arguments allowed, see the documentation of the CLI command: 56 | 57 | .. code-block:: console 58 | 59 | spotiflow-train --help 60 | 61 | To illustrate with an example, to train a Spotiflow model on 2-channel 3D data for 100 epochs, you can run: 62 | 63 | .. code-block:: console 64 | 65 | spotiflow-train /path/to/spots_data -o /my/trained/model --is-3d True --num-epochs 100 --in-channels 2 66 | 67 | Basic training (API) 68 | ^^^^^^^^^^^^^^^^^^^^ 69 | 70 | You can easily train a model using the default settings as follows and save it to the directory `/path/to/my_trained_model`: 71 | 72 | .. code-block:: python 73 | 74 | from spotiflow.model import Spotiflow 75 | from spotiflow.utils import get_data 76 | 77 | # Get the data 78 | train_imgs, train_spots, val_imgs, val_spots = get_data("/path/to/spots_data") 79 | 80 | # Initialize the model 81 | model = Spotiflow() 82 | 83 | # Train and save the model 84 | model.fit( 85 | train_imgs, 86 | train_spots, 87 | val_imgs, 88 | val_spots, 89 | save_dir="/path/to/my_trained_model", 90 | ) 91 | 92 | You can then load it by simply calling: 93 | 94 | .. code-block:: python 95 | 96 | model = Spotiflow.from_folder("/my/trained/model") 97 | 98 | Or using the CLI command: 99 | .. code-block:: console 100 | 101 | spotiflow-predict PATH --model-dir /path/to/my_trained_model ... 102 | 103 | 104 | Note that in order to train a 3D model using the API, you should initialize a :py:mod:`spotiflow.model.config.SpotiflowModelConfig` object and pass it to the `Spotiflow` constructor with the appropriate parameter set (see other options for the configuration at the end of the section): 105 | 106 | .. code-block:: python 107 | 108 | # Same imports as before 109 | from spotiflow.model import SpotiflowModelConfig 110 | 111 | # Create the model config 112 | model_config = SpotiflowModelConfig( 113 | is_3d=True, 114 | grid=2, # subsampling factor for prediction 115 | # you can pass other arguments here 116 | ) 117 | 118 | model = Spotiflow(model_config) 119 | # Train and save the model as before 120 | 121 | 122 | Customizing the training 123 | ^^^^^^^^^^^^^^^^^^^^^^^^ 124 | 125 | You can also pass other parameters relevant for training to the `fit` method. For example, you can change the number of epochs, the batch size, the learning rate, etc. You can do that using the `train_config` parameter. For more information on the arguments allowed, see the documentation of :py:func:`spotiflow.model.spotiflow.Spotiflow.fit` method as well as :py:mod:`spotiflow.model.config.SpotiflowTrainingConfig`. As an example, let's change the number of epochs and the learning rate: 126 | 127 | .. code-block:: python 128 | 129 | train_config = { 130 | "num_epochs": 100, 131 | "learning_rate": 0.001, 132 | "smart_crop": True, 133 | # other parameters 134 | } 135 | 136 | model.fit( 137 | train_imgs, 138 | train_spots, 139 | val_imgs, 140 | val_spots, 141 | save_dir="/path/to/my_trained_model", 142 | train_config=train_config, 143 | # other parameters 144 | ) 145 | 146 | 147 | In order to change the model architecture (`e.g.` number of input channels, number of layers, variance for the heatmap generation, etc.), you can create a :py:mod:`spotiflow.model.config.SpotiflowModelConfig` object and populate it accordingly. Then you can pass it to the `Spotiflow` constructor (note that this is necessary for 3D). For example, if our image is RGB and we need the network to use 3 input channels, we can do the following: 148 | 149 | .. code-block:: python 150 | 151 | from spotiflow.model import SpotiflowModelConfig 152 | 153 | # Create the model config 154 | model_config = SpotiflowModelConfig( 155 | in_channels=3, 156 | # you can pass other arguments here 157 | ) 158 | model = Spotiflow(model_config) 159 | -------------------------------------------------------------------------------- /extra/analyze_spot_clusters.py: -------------------------------------------------------------------------------- 1 | """ 2 | Retrieve basic statistics of spot clusters in an image by running Spotiflow to detect individual spots and then aggreagating them according to a radius search. 3 | 4 | Usage: 5 | python analyze_spot_clusters.py --input /PATH/TO/IMG --model SPOITFLOW_MODEL --output ./out 6 | """ 7 | import argparse 8 | from pathlib import Path 9 | 10 | import networkx as nx 11 | import numpy as np 12 | import pandas as pd 13 | from skimage import io 14 | from sklearn.neighbors import radius_neighbors_graph 15 | from spotiflow.model import Spotiflow 16 | from spotiflow.utils import write_coords_csv 17 | 18 | 19 | def analyze_clusters(spots: np.ndarray, max_distance: float = 11.0): 20 | """ 21 | Get information of clusters by building an r-radius graph. 22 | """ 23 | adj_matrix = radius_neighbors_graph( 24 | spots, radius=max_distance, mode="distance", metric="euclidean" 25 | ) 26 | graph = nx.from_scipy_sparse_array(adj_matrix) 27 | conn_components = nx.connected_components(graph) 28 | columns = ["cluster_id", "mean_y", "mean_x", "num_spots"] 29 | if spots.shape[1] == 3: 30 | columns.insert(1, "mean_z") 31 | df = pd.DataFrame(columns=columns) 32 | for i, component in enumerate(conn_components): 33 | curr_spots = spots[list(component)] 34 | center = np.mean(curr_spots, axis=0) 35 | if center.shape[0] == 3: 36 | mean_z, mean_y, mean_x = center 37 | else: 38 | mean_y, mean_x = center 39 | 40 | component_data = { 41 | "cluster_id": i, 42 | "num_spots": len(component), 43 | "mean_y": mean_y, 44 | "mean_x": mean_x, 45 | } 46 | if center.shape[0] == 3: 47 | component_data["mean_z"] = mean_z 48 | curr_df = pd.DataFrame(component_data, index=[0]) 49 | df = pd.concat([df, curr_df], ignore_index=True) 50 | return df 51 | 52 | 53 | if __name__ == "__main__": 54 | parser = argparse.ArgumentParser() 55 | parser.add_argument("--input", type=Path, help="Path to the image") 56 | parser.add_argument("--model", type=Path, help="Path to the model") 57 | parser.add_argument("--output", type=Path, help="Path to the output folder") 58 | parser.add_argument("--max-distance", type=float, default=11.0, help="Max distance to consider two spots as part of the same cluster") 59 | args = parser.parse_args() 60 | 61 | args.output.mkdir(exist_ok=True, parents=True) 62 | 63 | img = io.imread(args.input) # load the image 64 | model = Spotiflow.from_folder(args.model) # load the clusters model 65 | spots, _ = model.predict(img, normalizer="auto") 66 | print("Analyzing clusters...") 67 | clusters_df = analyze_clusters(spots, max_distance=args.max_distance) 68 | clusters_df.to_csv(args.output / f"{args.input.stem}_clusters.csv", index=False) 69 | print("Done!") 70 | -------------------------------------------------------------------------------- /extra/annotation_ui.py: -------------------------------------------------------------------------------- 1 | """ 2 | Starts a napari widget for annotating points on an image starting from LoG candidates or existing results stored in a CSV. 3 | 4 | The widget allows to: 5 | - toggle points visibility 6 | - save points to a csv file 7 | - detect points using LoG detector 8 | - increase/decrease LoG threshold with w/e keys and a slider 9 | 10 | Usage: 11 | python annotation_ui.py [options] 12 | """ 13 | import numpy as np 14 | import napari 15 | from pathlib import Path 16 | from tifffile import imread 17 | from qtpy.QtWidgets import QMessageBox 18 | from csbdeep.utils import normalize 19 | from skimage.feature import blob_log 20 | from magicgui import magicgui 21 | import pandas as pd 22 | import argparse 23 | from spotiflow.utils import read_coords_csv 24 | from spotiflow.utils.fitting import estimate_params 25 | 26 | 27 | KEY_SHORTCUTS = { 28 | 'toggle': ('q', 'toggle points'), 29 | 'save': ('s' , 'save csv'), 30 | 'thr_dec': ('w', 'decrease thr'), 31 | 'thr_inc': ('e', 'increase thr'), 32 | 'detect': ('d', 'detect') 33 | } 34 | 35 | def load_points(path): 36 | path = Path(path) 37 | if path.suffix==".npy": 38 | return np.load(args.points) 39 | elif path.suffix==".csv": 40 | return read_coords_csv(path) 41 | else: 42 | raise ValueError(f'not supported extension {path.suffix}') 43 | 44 | 45 | def save_points(path, arr): 46 | path = Path(path) 47 | print(path) 48 | if path.suffix==".npy": 49 | np.save(path, arr) 50 | elif path.suffix==".csv": 51 | pd.DataFrame(arr, columns=['y','x']).to_csv(path, index=False) 52 | else: 53 | raise ValueError(f'not supported extension {path.suffix}') 54 | 55 | def filter_points_bbox(points, bbox): 56 | """ bbox = ((y1,x1), (y2, x2)) """ 57 | inds = np.bitwise_and( 58 | np.bitwise_and(points[:,0]>=bbox[0,0],points[:,0]=bbox[0,1],points[:,1]=6.2", 4 | "wheel", 5 | "numpy"] 6 | build-backend = "setuptools.build_meta" 7 | 8 | [tool.setuptools_scm] 9 | write_to = "spotiflow/_version.py" 10 | -------------------------------------------------------------------------------- /scripts/predict_3d.py: -------------------------------------------------------------------------------- 1 | """Sample script to detect spots on a 3D model with Spotiflow. 2 | """ 3 | 4 | import argparse 5 | import logging 6 | import sys 7 | from pathlib import Path 8 | 9 | import tifffile 10 | from skimage import io 11 | from spotiflow import utils 12 | from spotiflow.model import Spotiflow 13 | from spotiflow.model.pretrained import list_registered 14 | 15 | logging.basicConfig(level=logging.INFO, stream=sys.stdout) 16 | 17 | IMAGE_EXTENSIONS = ("tif", "tiff", "png", "jpg", "jpeg") 18 | 19 | 20 | if __name__ == "__main__": 21 | parser = argparse.ArgumentParser() 22 | parser.add_argument("--input", type=Path) 23 | parser.add_argument("--output", type=Path) 24 | parser.add_argument("--model", type=str, default="smfish_3d") 25 | parser.add_argument("--channel", type=int, required=False, default=None) 26 | parser.add_argument("--debug", action="store_true", default=False) 27 | parser.add_argument("--min-distance", type=int, default=1, required=True) 28 | parser.add_argument("--max-yx-tile-size", type=int, default=256) 29 | parser.add_argument("--max-z-tile-size", type=int, default=32) 30 | args = parser.parse_args() 31 | 32 | if args.model not in list_registered(): 33 | args.model = Path(args.model) 34 | assert args.model.exists(), f"Model not found: {args.model}" 35 | 36 | assert args.input.exists(), f"Input file not found: {args.input}" 37 | assert args.output.suffix == ".csv", f"Output file must have a .csv extension" 38 | 39 | 40 | print("Reading input data...") 41 | img = io.imread(str(args.input)) 42 | if img.ndim == 4: 43 | assert args.channel is not None, "Channel argument required if input is 4D (C,Z,Y,X)" 44 | img = img[args.channel] 45 | elif args.channel is not None: 46 | print("Ignoring channel argument as input is (Z,Y,X)") 47 | 48 | if args.debug: 49 | print("Debug mode. Will predict on crop") 50 | img = img[128:192, 1000:1256, 1000:1256] 51 | print("Loading model...") 52 | if args.model not in list_registered(): 53 | model = Spotiflow.from_folder(args.model, map_location="auto") 54 | else: 55 | model = Spotiflow.from_pretrained(args.model, map_location="auto") 56 | 57 | 58 | print(f"Image shape is: {img.shape}") 59 | 60 | args.output.parent.mkdir(parents=True, exist_ok=True) 61 | 62 | n_tiles = tuple(max(s//g, 1) for s, g in zip(img.shape, (args.max_z_tile_size, args.max_yx_tile_size, args.max_yx_tile_size))) 63 | print(n_tiles) 64 | print("Predicting volume...") 65 | spots, details = model.predict( 66 | img, 67 | subpix=True, 68 | n_tiles=n_tiles, # change if you run out of memory 69 | device="auto", 70 | min_distance=args.min_distance, 71 | ) 72 | 73 | if not args.debug: 74 | utils.write_coords_csv(spots, args.output) 75 | 76 | if args.debug: 77 | tifffile.imwrite(args.output.parent/"img_debug.tif", img) 78 | -------------------------------------------------------------------------------- /scripts/predict_zarr_multigpu.py: -------------------------------------------------------------------------------- 1 | """ 2 | Script to predict spots on a large Zarr volumes using multiple GPUs. 3 | In a nutshell, overlapping tiles are dispatched to different GPUs, which process them. 4 | The results of each GPU are finally gathered and written to a CSV file. 5 | 6 | The script is designed to be run with torchrun, and highly suitable for cluster environments. 7 | The distributed environment is automatically setup (but the number of GPUs, which should be specified). 8 | 9 | For more info on the arguments, please run `python predict_zarr_multigpu.py -h`. 10 | 11 | Example usage: 12 | torchrun --nproc_per_node=${n_gpus} predict_zarr_multigpu.py --input PATH/TO/ZARR --output PATH/TO/OUTPUT.csv --precomputed-percentiles X Y --dataloader-num-workers 2 13 | """ 14 | import argparse 15 | import logging 16 | import os 17 | import sys 18 | import time 19 | from pathlib import Path 20 | 21 | import dask 22 | import dask.array as da 23 | import numpy as np 24 | import torch 25 | import torch.distributed as dist 26 | from spotiflow.model import Spotiflow 27 | from spotiflow.model.pretrained import list_registered 28 | from spotiflow.utils import write_coords_csv 29 | 30 | logging.basicConfig(level=logging.INFO, stream=sys.stdout) 31 | log = logging.getLogger(__name__) 32 | 33 | APPROX_TILE_SIZE = (256, 256, 256) 34 | 35 | 36 | def get_percentiles( 37 | x: da.Array, 38 | pmin: float = 1.0, 39 | pmax: float = 99.8, 40 | max_samples: int = 1e5, 41 | ): 42 | n_skip = int(max(1, x.size // max_samples)) 43 | with dask.config.set(**{"array.slicing.split_large_chunks": False}): 44 | mi, ma = da.percentile( 45 | x.ravel()[::n_skip], (pmin, pmax), internal_method="tdigest" 46 | ).compute() 47 | return mi, ma 48 | 49 | 50 | def normalize(x, mi, ma, eps: float = 1e-20): 51 | return (x - mi) / (ma - mi + eps) 52 | 53 | 54 | if __name__ == "__main__": 55 | parser = argparse.ArgumentParser() 56 | parser.add_argument( 57 | "--model", type=str, default="smfish_3d", help="Pre-trained model name or path." 58 | ) 59 | parser.add_argument( 60 | "--input", 61 | type=Path, 62 | help="Path pointing to the input volume. Should be in Zarr format", 63 | ) 64 | parser.add_argument( 65 | "--zarr-component", 66 | type=str, 67 | default=None, 68 | help="Zarr component to read from the Zarr file, if necessary.", 69 | ) 70 | parser.add_argument( 71 | "--output", 72 | type=Path, 73 | help="Path pointing to the output CSV file where the detected spots will be written to.", 74 | ) 75 | parser.add_argument( 76 | "--dataloader-num-workers", 77 | type=int, 78 | default=0, 79 | help="Number of workers to use in the dataloader.", 80 | ) 81 | parser.add_argument( 82 | "--prob-thresh", 83 | default=0.4, 84 | type=float, 85 | help="Probability threshold for spot detection.", 86 | ) 87 | parser.add_argument( 88 | "--min-distance", 89 | default=1, 90 | type=int, 91 | help="Minimum distance between detections.", 92 | ) 93 | parser.add_argument( 94 | "--precomputed-percentiles", 95 | nargs=2, 96 | default=None, 97 | help="If given, will use the precomputed percentiles instead of recomputing them.", 98 | ) 99 | args = parser.parse_args() 100 | 101 | img = da.from_zarr(str(args.input), component=args.zarr_component) 102 | 103 | print(f"Array is {img.shape} (~{img.nbytes/1e9:.2f} GB)") 104 | print(f"Number of voxels: ~2^{int(np.log2(img.size))}") 105 | 106 | n_tiles = tuple(max(1, s // t) for s, t in zip(img.shape, APPROX_TILE_SIZE)) 107 | 108 | distributed = torch.cuda.device_count() > 1 109 | torch.distributed.init_process_group(backend="nccl") 110 | gpu_id = int(os.environ["LOCAL_RANK"]) 111 | if gpu_id == 0: 112 | print("Distributed session successfully initialized") 113 | t0 = time.monotonic() 114 | 115 | device = torch.device(f"cuda:{gpu_id}" if gpu_id >= 0 else "cpu") 116 | 117 | mi_ma = torch.zeros(2, device=device) 118 | 119 | if gpu_id == 0: 120 | args.output.parent.mkdir(parents=True, exist_ok=True) 121 | if args.precomputed_percentiles is None: 122 | print("Computing percentiles...") 123 | log.warning( 124 | "It is highly recommended to precompute percentiles and pass them as an argument to avoid Dask hijacking threads.\nIf the execution seems to halt, please re-run with the --precomputed-percentiles argument set to the percentiles computed here." 125 | ) 126 | 1/0 127 | t0_p = time.monotonic() 128 | mi, ma = get_percentiles(img.astype(np.float32), 1, 99.8) 129 | te_p = time.monotonic() 130 | print(f"Percentiles ({mi:.2f}, {ma:.2f}) computed in {te_p-t0_p:.2f} s") 131 | else: 132 | mi, ma = tuple(float(x) for x in args.precomputed_percentiles) 133 | print(f"Using precomputed percentiles ({mi:.2f}, {ma:.2f})...") 134 | mi_ma = torch.tensor([mi, ma], device=device, dtype=torch.float32) 135 | 136 | if args.model not in list_registered(): 137 | args.model = Path(args.model) 138 | model = Spotiflow.from_folder(args.model) 139 | else: 140 | model = Spotiflow.from_pretrained(args.model) 141 | model.to(device) 142 | model.eval() 143 | model = torch.compile(model) 144 | 145 | dist.barrier() 146 | dist.broadcast(mi_ma, src=0) 147 | dist.barrier() 148 | 149 | p1, p998 = mi_ma[0].item(), mi_ma[1].item() 150 | del mi_ma 151 | 152 | spots, _ = model.predict( 153 | img, 154 | subpix=True, 155 | n_tiles=n_tiles, 156 | min_distance=1, 157 | prob_thresh=args.prob_thresh, 158 | device=None, 159 | normalizer=lambda x: normalize(x, p1, p998), 160 | distributed_params={ 161 | "gpu_id": gpu_id, 162 | "num_workers": args.dataloader_num_workers, 163 | "num_replicas": int(os.environ["WORLD_SIZE"]), 164 | }, 165 | ) 166 | 167 | spots = torch.from_numpy(spots).to(device) 168 | # Collect shapes 169 | if gpu_id == 0: 170 | all_shapes = [spots.shape] # Start with the root process's own tensor shape 171 | for src_gpu_id in range(1, dist.get_world_size()): 172 | shape_tensor = torch.zeros(2, dtype=torch.long, device=device) 173 | dist.recv(tensor=shape_tensor, src=src_gpu_id) 174 | all_shapes.append(tuple(shape_tensor.tolist())) 175 | else: 176 | # Non-root processes: Send tensor shape to root process 177 | shape_tensor = torch.tensor(spots.shape, dtype=torch.long, device=device) 178 | dist.send(tensor=shape_tensor, dst=0) 179 | 180 | # Send based on shape 181 | if gpu_id == 0: 182 | all_spots = [spots] # Start with the root process's own tensor 183 | for idx, src_gpu_id in enumerate(range(1, dist.get_world_size())): 184 | recv_tensor = torch.zeros( 185 | all_shapes[idx + 1], device=device, dtype=spots.dtype 186 | ) # Use collected shapes 187 | dist.recv(tensor=recv_tensor, src=src_gpu_id) 188 | all_spots.append(recv_tensor) 189 | else: 190 | # Non-root processes: Send tensor to root process 191 | dist.send(tensor=spots, dst=0) 192 | 193 | # Concat at root and write 194 | if gpu_id == 0: 195 | all_spots = torch.cat(all_spots, dim=0).cpu().numpy() 196 | print("All spots shape is", all_spots.shape) 197 | print("Writing...") 198 | 199 | write_coords_csv( 200 | all_spots, 201 | str(args.output), 202 | ) 203 | print("Written!") 204 | te = time.monotonic() 205 | print(f"Total ellapsed time: {te-t0:.2f} s") 206 | dist.barrier() 207 | sys.exit(0) 208 | -------------------------------------------------------------------------------- /scripts/train_simple.py: -------------------------------------------------------------------------------- 1 | """Sample script to train a Spotiflow model. 2 | """ 3 | 4 | import argparse 5 | import numpy as np 6 | from pathlib import Path 7 | from skimage import io 8 | from itertools import chain 9 | 10 | from spotiflow.model import Spotiflow, SpotiflowModelConfig 11 | from spotiflow import utils 12 | import lightning.pytorch as pl 13 | 14 | IMAGE_EXTENSIONS = ("tif", "tiff", "png", "jpg", "jpeg") 15 | 16 | 17 | def get_data(data_dir): 18 | """Load data from data_dir.""" 19 | img_files = sorted(tuple(chain(*tuple(data_dir.glob(f"*.{ext}") for ext in IMAGE_EXTENSIONS)))) 20 | spots_files = sorted(data_dir.glob("*.csv")) 21 | 22 | images = tuple(io.imread(str(f)) for f in img_files) 23 | spots = tuple(utils.read_coords_csv(str(f)).astype(np.float32) for f in spots_files) 24 | return images, spots 25 | 26 | 27 | if __name__ == "__main__": 28 | parser = argparse.ArgumentParser() 29 | parser.add_argument("--data-dir", type=Path, default="/data/spots/datasets/hybiss_spots_v4") 30 | parser.add_argument("--save-dir", type=Path, default="/data/tmp/spotiflow_simple_train_debug") 31 | parser.add_argument("--sigma", type=float, default=1.0) 32 | parser.add_argument("--seed", type=int, default=42) 33 | args = parser.parse_args() 34 | 35 | pl.seed_everything(args.seed, workers=True) 36 | 37 | print("Loading training data...") 38 | train_images, train_spots = get_data(args.data_dir / "train") 39 | print(f"Training data loaded (N={len(train_images)}).") 40 | 41 | print("Loading validation data...") 42 | val_images, val_spots = get_data(args.data_dir / "val") 43 | print(f"Validation data loaded (N={len(val_images)}).") 44 | 45 | print("Instantiating model...") 46 | model = Spotiflow(SpotiflowModelConfig(sigma=args.sigma)) 47 | 48 | print("Launching training...") 49 | model.fit( 50 | train_images, 51 | train_spots, 52 | val_images, 53 | val_spots, 54 | save_dir=args.save_dir, 55 | device="auto", 56 | ) 57 | print("Done!") 58 | -------------------------------------------------------------------------------- /scripts/train_simple_3d.py: -------------------------------------------------------------------------------- 1 | """Sample script to train a Spotiflow model. 2 | """ 3 | 4 | import argparse 5 | import logging 6 | import sys 7 | from itertools import chain 8 | from pathlib import Path 9 | 10 | import lightning.pytorch as pl 11 | import numpy as np 12 | from skimage import io 13 | from spotiflow import utils 14 | from spotiflow.model import Spotiflow, SpotiflowModelConfig 15 | 16 | logging.basicConfig(level=logging.INFO, stream=sys.stdout) 17 | 18 | IMAGE_EXTENSIONS = ("tif", "tiff", "png", "jpg", "jpeg") 19 | 20 | 21 | 22 | def get_data(data_dir, debug=False): 23 | """Load data from data_dir.""" 24 | img_files = sorted(tuple(chain(*tuple(data_dir.glob(f"*.{ext}") for ext in IMAGE_EXTENSIONS)))) 25 | spots_files = sorted(data_dir.glob("*.csv")) 26 | if debug: 27 | img_files = img_files[:32] 28 | spots_files = spots_files[:32] 29 | images = tuple(io.imread(str(f)) for f in img_files) 30 | spots = tuple(utils.read_coords_csv3d(str(f)).astype(np.float32) for f in spots_files) 31 | return images, spots 32 | 33 | 34 | if __name__ == "__main__": 35 | parser = argparse.ArgumentParser() 36 | parser.add_argument("--data-dir", type=Path, default="/data/spots/datasets_3d/synth3d") 37 | parser.add_argument("--save-dir", type=Path, default="/data/tmp/spotiflow_3d_debug/synth3d") 38 | parser.add_argument("--sigma", type=float, default=1.0) 39 | parser.add_argument("--seed", type=int, default=42) 40 | parser.add_argument("--levels", type=int, default=4) 41 | parser.add_argument("--pretrained-path", type=Path, default=None) 42 | parser.add_argument("--crop-size", type=int, default=128) 43 | parser.add_argument("--crop-size-depth", type=int, default=32) 44 | parser.add_argument("--num-epochs", type=int, default=200) 45 | parser.add_argument("--debug", action="store_true") 46 | parser.add_argument("--batch-size", type=int, default=4) 47 | parser.add_argument("--pos-weight", type=float, default=10.) 48 | args = parser.parse_args() 49 | 50 | pl.seed_everything(args.seed, workers=True) 51 | 52 | print("Loading training data...") 53 | train_images, train_spots = get_data(args.data_dir / "train", debug=args.debug) 54 | print(f"Training data loaded (N={len(train_images)}).") 55 | 56 | print("Loading validation data...") 57 | val_images, val_spots = get_data(args.data_dir / "val", debug=args.debug) 58 | print(f"Validation data loaded (N={len(val_images)}).") 59 | 60 | if args.pretrained_path is not None: 61 | print("Loading pretrained model...") 62 | model = Spotiflow.from_folder(args.pretrained_path) 63 | print("Launching fine-tuning...") 64 | else: 65 | print("Instantiating new model...") 66 | model = Spotiflow(SpotiflowModelConfig(in_channels=1, sigma=args.sigma, is_3d=True, levels=args.levels, grid=(1,1,1))) 67 | print("Launching training...") 68 | 69 | model.fit( 70 | train_images, 71 | train_spots, 72 | val_images, 73 | val_spots, 74 | save_dir=args.save_dir if not args.debug else args.save_dir/"debug", 75 | augment_train=True, 76 | device="auto", 77 | deterministic=False, 78 | logger="tensorboard" if not args.debug else "none", 79 | train_config={ 80 | "num_epochs": args.num_epochs if not args.debug else 5, 81 | "crop_size": args.crop_size, 82 | "crop_size_depth": args.crop_size_depth, 83 | "smart_crop": True, 84 | "batch_size": args.batch_size, 85 | "pos_weight": args.pos_weight, 86 | } 87 | ) 88 | print("Done!") 89 | -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [metadata] 2 | name = spotiflow 3 | author = Albert Dominguez Mantes, Martin Weigert 4 | author_email = albert.dominguezmantes@epfl.ch, martin.weigert@epfl.ch 5 | dynamic = ["version"] 6 | license = BSD 3-Clause License 7 | description = Accurate and efficient spot detection for microscopy data 8 | long_description = file: README.md 9 | long_description_content_type = text/markdown 10 | classifiers = 11 | Development Status :: 4 - Beta 12 | Intended Audience :: Science/Research 13 | Topic :: Scientific/Engineering 14 | License :: OSI Approved :: BSD License 15 | Programming Language :: Python :: 3.9 16 | Programming Language :: Python :: 3.10 17 | Programming Language :: Python :: 3.11 18 | Programming Language :: Python :: 3.12 19 | Programming Language :: Python :: 3.13 20 | 21 | [options] 22 | packages = find: 23 | install_requires = 24 | configargparse 25 | crick 26 | csbdeep 27 | dask 28 | lightning 29 | networkx 30 | numpy 31 | pandas 32 | Pillow 33 | pydash 34 | scikit_image 35 | scipy 36 | setuptools 37 | scikit-image 38 | tensorboard 39 | tifffile 40 | torchvision 41 | tqdm 42 | typing-extensions 43 | wandb 44 | zarr 45 | python_requires = >=3.9, <3.14 46 | 47 | [options.entry_points] 48 | console_scripts = 49 | spotiflow-predict = spotiflow.cli.predict:main 50 | spotiflow-train = spotiflow.cli.train:main 51 | 52 | [options.extras_require] 53 | testing = 54 | pytest 55 | pytest-cov 56 | pytest-mock 57 | tox 58 | docs = 59 | sphinx 60 | sphinx-immaterial 61 | napari = 62 | napari-spotiflow 63 | starfish = 64 | starfish 65 | 66 | [flake8] 67 | ignore = E116, E501, E203 68 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | """Adapted from https://github.com/stardist/stardist/blob/master/setup.py""" 2 | from __future__ import absolute_import, print_function 3 | from setuptools import setup, Extension 4 | from setuptools.command.build_ext import build_ext 5 | from numpy import get_include 6 | from os import path 7 | 8 | # ------------------------------------------------------------------------------------ 9 | _dir = path.dirname(__file__) 10 | 11 | 12 | class build_ext_openmp(build_ext): 13 | # https://www.openmp.org/resources/openmp-compilers-tools/ 14 | # python setup.py build_ext --help-compiler 15 | openmp_compile_args = { 16 | "msvc": [["/openmp"]], 17 | "intel": [["-qopenmp"]], 18 | "*": [["-fopenmp"], ["-Xpreprocessor", "-fopenmp"]], 19 | } 20 | openmp_link_args = openmp_compile_args # ? 21 | 22 | def build_extension(self, ext): 23 | compiler = self.compiler.compiler_type.lower() 24 | if compiler.startswith("intel"): 25 | compiler = "intel" 26 | if compiler not in self.openmp_compile_args: 27 | compiler = "*" 28 | 29 | # thanks to @jaimergp (https://github.com/conda-forge/staged-recipes/pull/17766) 30 | # issue: qhull has a mix of c and c++ source files 31 | # gcc warns about passing -std=c++11 for c files, but clang errors out 32 | compile_original = self.compiler._compile 33 | 34 | def compile_patched(obj, src, ext, cc_args, extra_postargs, pp_opts): 35 | # remove c++ specific (extra) options for c files 36 | if src.lower().endswith(".c"): 37 | extra_postargs = [ 38 | arg for arg in extra_postargs if not arg.lower().startswith("-std") 39 | ] 40 | return compile_original(obj, src, ext, cc_args, extra_postargs, pp_opts) 41 | 42 | # monkey patch the _compile method 43 | self.compiler._compile = compile_patched 44 | 45 | # store original args 46 | _extra_compile_args = list(ext.extra_compile_args) 47 | _extra_link_args = list(ext.extra_link_args) 48 | 49 | # try compiler-specific flag(s) to enable openmp 50 | for compile_args, link_args in zip( 51 | self.openmp_compile_args[compiler], self.openmp_link_args[compiler] 52 | ): 53 | 54 | try: 55 | ext.extra_compile_args = _extra_compile_args + compile_args 56 | ext.extra_link_args = _extra_link_args + link_args 57 | print(">>> try building with OpenMP support: ", compile_args, link_args) 58 | return super(build_ext_openmp, self).build_extension(ext) 59 | except Exception as _: 60 | print(f">>> compiling with '{' '.join(compile_args)}' failed") 61 | 62 | print(">>> compiling with OpenMP support failed, re-trying without") 63 | 64 | ext.extra_compile_args = _extra_compile_args 65 | ext.extra_link_args = _extra_link_args 66 | return super(build_ext_openmp, self).build_extension(ext) 67 | 68 | 69 | external_root = path.join(_dir, "spotiflow", "lib", "external") 70 | nanoflann_root = path.join(external_root, "nanoflann") 71 | 72 | setup( 73 | cmdclass={"build_ext": build_ext_openmp}, 74 | ext_modules=[ 75 | Extension( 76 | "spotiflow.lib.spotflow2d", 77 | sources=["spotiflow/lib/spotflow2d.cpp"], 78 | extra_compile_args=["-std=c++11"], 79 | include_dirs=[get_include()] + [nanoflann_root], 80 | ), 81 | Extension( 82 | "spotiflow.lib.spotflow3d", 83 | sources=["spotiflow/lib/spotflow3d.cpp"], 84 | extra_compile_args=["-std=c++11"], 85 | include_dirs=[get_include()] + [nanoflann_root], 86 | ), 87 | Extension( 88 | "spotiflow.lib.point_nms", 89 | sources=["spotiflow/lib/point_nms.cpp"], 90 | extra_compile_args=["-std=c++11"], 91 | include_dirs=[get_include()] + [nanoflann_root], 92 | ), 93 | Extension( 94 | "spotiflow.lib.point_nms3d", 95 | sources=["spotiflow/lib/point_nms3d.cpp"], 96 | extra_compile_args=["-std=c++11"], 97 | include_dirs=[get_include()] + [nanoflann_root], 98 | ), 99 | Extension( 100 | "spotiflow.lib.filters", 101 | sources=["spotiflow/lib/filters.cpp"], 102 | extra_compile_args=["-std=c++11"], 103 | include_dirs=[get_include()] + [nanoflann_root], 104 | ), 105 | Extension( 106 | "spotiflow.lib.filters3d", 107 | sources=["spotiflow/lib/filters3d.cpp"], 108 | extra_compile_args=["-std=c++11"], 109 | include_dirs=[get_include()] + [nanoflann_root], 110 | ), 111 | ], 112 | include_package_data=True, 113 | ) 114 | -------------------------------------------------------------------------------- /spotiflow/__init__.py: -------------------------------------------------------------------------------- 1 | from ._version import __version__, __version_tuple__ 2 | -------------------------------------------------------------------------------- /spotiflow/augmentations/__init__.py: -------------------------------------------------------------------------------- 1 | from .pipeline import Pipeline 2 | from .transforms import * -------------------------------------------------------------------------------- /spotiflow/augmentations/pipeline/__init__.py: -------------------------------------------------------------------------------- 1 | from .pipeline import Pipeline -------------------------------------------------------------------------------- /spotiflow/augmentations/pipeline/pipeline.py: -------------------------------------------------------------------------------- 1 | from typing import List, Tuple 2 | 3 | import torch 4 | 5 | from ..transforms.base import BaseAugmentation 6 | 7 | class Pipeline(object): 8 | def __init__(self, *augs) -> None: 9 | super().__init__() 10 | self._augmentations = [] 11 | for aug in augs: 12 | self.add(aug) 13 | 14 | 15 | @property 16 | def augmentations(self) -> List[BaseAugmentation]: 17 | """Ordered augmentations in the pipeline. 18 | 19 | Returns: 20 | List[BaseAugmentation]: Ordered augmentations in the pipeline 21 | """ 22 | return self._augmentations 23 | 24 | def add(self, augmentation: BaseAugmentation): 25 | """Add a new augmentation to the pipeline. 26 | 27 | Args: 28 | augmentation (BaseAugmentation): augmentation object to be added 29 | """ 30 | if not isinstance(augmentation, BaseAugmentation): 31 | raise TypeError("Only BaseAugmentation instances can be added") 32 | self._augmentations.append(augmentation) 33 | 34 | def __call__(self, img: torch.Tensor, pts: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: 35 | """Apply augmentations sequentially to an image and the corresponding points. 36 | 37 | Args: 38 | img (torch.Tensor): (N+1)D tensor of shape (C, ..., H, W) 39 | pts (torch.Tensor): 2D tensor of shape (N, D) 40 | 41 | Returns: 42 | Tuple[torch.Tensor, torch.Tensor]: augmented batch of images and points 43 | """ 44 | for aug in self.augmentations: 45 | img, pts = aug(img, pts) 46 | return img, pts 47 | 48 | 49 | def __repr__(self) -> str: 50 | aug_list = '\n- '.join(str(aug) for aug in self.augmentations) 51 | return f"Pipeline\n- {aug_list}" 52 | 53 | 54 | def __add__(self, other): 55 | if isinstance(other, Pipeline): 56 | return Pipeline(*self.augmentations, *other.augmentations) 57 | else: 58 | raise TypeError("Can only add Pipeline") -------------------------------------------------------------------------------- /spotiflow/augmentations/test/test_pipeline.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from spotiflow.augmentations import Pipeline 3 | from spotiflow.augmentations.transforms import FlipRot90, Rotation, Translation 4 | from spotiflow.augmentations.transforms.utils import _generate_img_from_points 5 | 6 | import matplotlib.pyplot as plt 7 | 8 | if __name__ == "__main__": 9 | 10 | 11 | 12 | pts = torch.randint(5, 95, (10, 2)) 13 | img = torch.tensor(_generate_img_from_points(pts.numpy(), (100, 100))).unsqueeze(0) 14 | 15 | pipeline = Pipeline( 16 | FlipRot90(), 17 | Rotation(order=1), 18 | Translation(shift=10) 19 | ) 20 | 21 | 22 | plt.ion() 23 | fig, axs = plt.subplots(1, 4, figsize=(16, 8)) 24 | 25 | def _to_rgb(img, pts): 26 | img2 = torch.tensor(_generate_img_from_points(pts.numpy(), (100, 100))).unsqueeze(0) 27 | return torch.stack((img2[0], img[0], img2[0]), -1).numpy() 28 | 29 | axs[0].imshow(_to_rgb(img, pts)) 30 | axs[0].set_title('original') 31 | 32 | for ax in axs[1:]: 33 | img2, pts2 = pipeline(img, pts) 34 | ax.imshow(_to_rgb(img2, pts2)) 35 | ax.set_title('augmented') 36 | 37 | for ax in axs.flatten(): 38 | ax.axis('off') 39 | 40 | plt.show() -------------------------------------------------------------------------------- /spotiflow/augmentations/test/transforms/test_crop.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import pytest 3 | import torch 4 | 5 | from spotiflow.augmentations.transforms import Crop 6 | from spotiflow.augmentations.transforms.utils import _generate_img_from_points 7 | from typing import Tuple, Union 8 | 9 | 10 | ABS_TOLERANCE = 1e-8 11 | 12 | @pytest.mark.parametrize("img_size", [(4, 1, 224, 224), (5, 2, 512, 512), (1, 100, 100), (3, 327, 312)]) 13 | @pytest.mark.parametrize("crop_size", [(200, 201), (64, 64), (128, 100), (100, 31)]) 14 | @pytest.mark.parametrize("n_pts", [10, 100]) 15 | def test_crop_augmentation(img_size: Tuple[int, ...], 16 | crop_size: Tuple[int, ...], 17 | n_pts: int, 18 | caplog): 19 | if caplog is not None: 20 | caplog.set_level(logging.CRITICAL) 21 | 22 | torch.manual_seed(img_size[-1]*n_pts*crop_size[-1]) 23 | 24 | img = torch.zeros(img_size) 25 | msize = min(img_size[-2:]) 26 | pts = torch.randint(0, msize, (n_pts, 2)).repeat(img_size[0], 1, 1) 27 | for b in range(img_size[0]): 28 | img[b] = torch.from_numpy(_generate_img_from_points(pts[b].numpy(), img_size[-2:], sigma=0.0)) # Use deltas to avoid cropping at the border of the Gaussianized spot introducing non-existing errors 29 | aug = Crop(size=crop_size) 30 | if crop_size[0] > img_size[-2] or crop_size[1] > img_size[-1]: 31 | with pytest.raises(AssertionError): 32 | img_aug, pts_aug = aug(img, pts) 33 | else: 34 | img_aug, pts_aug = aug(img, pts) 35 | 36 | img_from_aug_pts = torch.zeros(*img_size[:-2], *crop_size) 37 | for b in range(img_size[0]): 38 | img_from_aug_pts[b] = torch.from_numpy(_generate_img_from_points(pts_aug[b].numpy(), crop_size, sigma=0.0)) 39 | 40 | mse = ((img_aug - img_from_aug_pts)**2).mean() 41 | if __name__ == "__main__": 42 | import matplotlib.pyplot as plt 43 | fig, ax = plt.subplots(1, 4, figsize=(16, 4)) 44 | ax[0].imshow(img[0], cmap="magma") 45 | ax[0].title.set_text("Original") 46 | ax[1].imshow(img_aug[0], cmap="magma") 47 | ax[1].title.set_text("Augmented") 48 | ax[2].imshow(img_from_aug_pts[0], cmap="magma") 49 | ax[2].title.set_text("From Augmented Points") 50 | ax[3].imshow((img_aug[0]-img_from_aug_pts[0])**2, cmap="magma") 51 | ax[3].title.set_text(f"Squared Difference (MSE: {mse:.5f})") 52 | fig.show() 53 | assert torch.allclose(img_aug, img_from_aug_pts, atol=ABS_TOLERANCE), "Image augmentation is not correct." 54 | 55 | if __name__ == "__main__": 56 | test_crop_augmentation((3, 327, 312), (128, 100), 100, None) 57 | 58 | -------------------------------------------------------------------------------- /spotiflow/augmentations/test/transforms/test_fliprot.py: -------------------------------------------------------------------------------- 1 | from numbers import Number 2 | 3 | import logging 4 | import pytest 5 | import torch 6 | 7 | from spotiflow.augmentations.transforms import FlipRot90 8 | from spotiflow.augmentations.transforms.utils import _generate_img_from_points 9 | from typing import Tuple 10 | 11 | MSE_TOLERANCE = 1e-8 12 | 13 | 14 | @pytest.mark.parametrize("img_size", [(4, 1, 224, 224), (5, 2, 512, 512), (1, 100, 100), 15 | (3, 327, 312), (8, 100, 101), (10, 242, 256)]) 16 | @pytest.mark.parametrize("n_pts", [10, 100]) 17 | def test_fliprot90_augmentation(img_size: Tuple[int, ...], 18 | n_pts: int, 19 | caplog): 20 | if caplog is not None: 21 | caplog.set_level(logging.CRITICAL) 22 | 23 | torch.manual_seed(img_size[-1]*n_pts) 24 | 25 | img = torch.zeros(img_size) 26 | msize = min(img_size[-2:]) 27 | 28 | 29 | 30 | pts = torch.randint(msize//3, msize-msize//3, (n_pts, 2)).repeat(img_size[0], 1, 1) 31 | for b in range(img_size[0]): 32 | img[b] = torch.from_numpy(_generate_img_from_points(pts[b].numpy(), img_size[-2:], sigma=1)) 33 | 34 | aug = FlipRot90() 35 | img_aug, pts_aug = aug(img, pts) 36 | 37 | img_from_aug_pts = torch.zeros(img_size) 38 | for b in range(img_size[0]): 39 | img_from_aug_pts[b] = torch.from_numpy(_generate_img_from_points(pts_aug[b].round().numpy(), img_size[-2:], sigma=1)) 40 | 41 | mse = ((img_aug - img_from_aug_pts)**2).mean() 42 | 43 | if __name__ == "__main__": 44 | import matplotlib.pyplot as plt 45 | fig, ax = plt.subplots(1, 4, figsize=(16, 4)) 46 | ax[0].imshow(img[0], cmap="magma") 47 | ax[0].title.set_text("Original") 48 | ax[1].imshow(img_aug[0], cmap="magma") 49 | ax[1].title.set_text("Augmented") 50 | ax[2].imshow(img_from_aug_pts[0], cmap="magma") 51 | ax[2].title.set_text("From Augmented Points") 52 | ax[3].imshow((img_aug[0]-img_from_aug_pts[0])**2, cmap="magma") 53 | ax[3].title.set_text(f"Squared Difference (MSE: {mse:.5f})") 54 | fig.show() 55 | assert mse < MSE_TOLERANCE, "FlipRot90 augmentation is not correct." 56 | 57 | if __name__ == "__main__": 58 | test_fliprot90_augmentation((10, 242, 256), 100, None) -------------------------------------------------------------------------------- /spotiflow/augmentations/test/transforms/test_intensity_shift.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import pytest 3 | import torch 4 | 5 | from spotiflow.augmentations.transforms.intensity_shift import IntensityScaleShift 6 | from spotiflow.augmentations.transforms.utils import _generate_img_from_points 7 | from typing import Tuple, Union 8 | 9 | 10 | ABS_TOLERANCE = 1e-8 11 | 12 | @pytest.mark.parametrize("img_size", [(4, 1, 224, 224), (5, 2, 512, 512), (1, 100, 100), (3, 327, 312)]) 13 | @pytest.mark.parametrize("scale", [(.8, 1.2), (.4, 1.6), (.5, 2.)]) 14 | @pytest.mark.parametrize("shift", [(0., 0.05), (-.1, .1), (-.3, .3)]) 15 | @pytest.mark.parametrize("n_pts", [10, 100]) 16 | def test_intensity_shift_augmentation(img_size: Tuple[int, ...], 17 | scale: Tuple[float, float], 18 | shift: Tuple[float, float], 19 | n_pts: int, 20 | caplog): 21 | if caplog is not None: 22 | caplog.set_level(logging.CRITICAL) 23 | 24 | torch.manual_seed(img_size[-1]*n_pts) 25 | 26 | img = torch.zeros(img_size) 27 | msize = min(img_size[-2:]) 28 | pts = torch.randint(0, msize, (n_pts, 2)).repeat(img_size[0], 1, 1) 29 | for b in range(img_size[0]): 30 | img[b] = torch.from_numpy(_generate_img_from_points(pts[b].numpy(), img_size[-2:], sigma=1.)) # Use deltas to avoid cropping at the border of the Gaussianized spot introducing non-existing errors 31 | aug = IntensityScaleShift(scale=scale, shift=shift) 32 | img_aug, pts_aug = aug(img, pts) 33 | 34 | img_from_aug_pts = torch.zeros(*img_size) 35 | 36 | assert torch.allclose(pts, pts_aug, atol=ABS_TOLERANCE), "Points changed after Gaussian noise addition, which should not be!" 37 | 38 | for b in range(img_size[0]): 39 | img_from_aug_pts[b] = torch.from_numpy(_generate_img_from_points(pts_aug[b].numpy(), img_size[-2:], sigma=1.)) 40 | 41 | if __name__ == "__main__": 42 | import matplotlib.pyplot as plt 43 | fig, ax = plt.subplots(1, 2, figsize=(8, 4)) 44 | ax[0].imshow(img[0], cmap="magma") 45 | ax[0].title.set_text("Original") 46 | ax[1].imshow(img_aug[0], cmap="magma") 47 | ax[1].title.set_text("Augmented") 48 | fig.show() 49 | 50 | 51 | if __name__ == "__main__": 52 | test_intensity_shift_augmentation((3, 327, 312), (.8, 1.2), (-.1, .1), 100, None) 53 | 54 | -------------------------------------------------------------------------------- /spotiflow/augmentations/test/transforms/test_noise.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import pytest 3 | import torch 4 | 5 | from spotiflow.augmentations.transforms.noise import GaussianNoise, SaltAndPepperNoise 6 | from spotiflow.augmentations.transforms.utils import _generate_img_from_points 7 | from typing import Tuple, Union 8 | 9 | 10 | ABS_TOLERANCE = 1e-8 11 | 12 | @pytest.mark.parametrize("img_size", [(4, 1, 224, 224), (5, 2, 512, 512), (1, 100, 100), (3, 327, 312)]) 13 | @pytest.mark.parametrize("sigma", [(0., 0.05), (0.4, 1.)]) 14 | @pytest.mark.parametrize("n_pts", [10, 100]) 15 | def test_gaussian_noise_augmentation(img_size: Tuple[int, ...], 16 | sigma: Tuple[float, ...], 17 | n_pts: int, 18 | caplog): 19 | if caplog is not None: 20 | caplog.set_level(logging.CRITICAL) 21 | 22 | torch.manual_seed(img_size[-1]*n_pts) 23 | 24 | img = torch.zeros(img_size) 25 | msize = min(img_size[-2:]) 26 | pts = torch.randint(0, msize, (n_pts, 2)).repeat(img_size[0], 1, 1) 27 | for b in range(img_size[0]): 28 | img[b] = torch.from_numpy(_generate_img_from_points(pts[b].numpy(), img_size[-2:], sigma=1.)) # Use deltas to avoid cropping at the border of the Gaussianized spot introducing non-existing errors 29 | aug = GaussianNoise(sigma=sigma) 30 | img_aug, pts_aug = aug(img, pts) 31 | 32 | img_from_aug_pts = torch.zeros(*img_size) 33 | assert torch.allclose(pts, pts_aug, atol=ABS_TOLERANCE), "Points changed after Gaussian noise addition, which should not be!" 34 | 35 | for b in range(img_size[0]): 36 | img_from_aug_pts[b] = torch.from_numpy(_generate_img_from_points(pts_aug[b].numpy(), img_size[-2:], sigma=1.)) 37 | 38 | if __name__ == "__main__": 39 | import matplotlib.pyplot as plt 40 | fig, ax = plt.subplots(1, 2, figsize=(8, 4)) 41 | ax[0].imshow(img[0], cmap="magma") 42 | ax[0].title.set_text("Original") 43 | ax[1].imshow(img_aug[0], cmap="magma") 44 | ax[1].title.set_text("Augmented") 45 | fig.show() 46 | 47 | @pytest.mark.parametrize("img_size", [(4, 1, 224, 224), (5, 2, 512, 512), (1, 100, 100), (3, 327, 312)]) 48 | @pytest.mark.parametrize("prob_pepper", [(0., 0.0005), (0., 0.1), (0., 0.)]) 49 | @pytest.mark.parametrize("prob_salt", [(0., 0.0005), (0., 0.01), (0., 0.)]) 50 | @pytest.mark.parametrize("n_pts", [10, 100]) 51 | def test_saltpepper_noise_augmentation(img_size: Tuple[int, ...], 52 | prob_pepper: Tuple[float, ...], 53 | prob_salt: Tuple[float, ...], 54 | n_pts: int, 55 | caplog): 56 | if caplog is not None: 57 | caplog.set_level(logging.CRITICAL) 58 | 59 | torch.manual_seed(img_size[-1]*n_pts) 60 | 61 | img = torch.zeros(img_size) 62 | msize = min(img_size[-2:]) 63 | pts = torch.randint(0, msize, (n_pts, 2)).repeat(img_size[0], 1, 1) 64 | for b in range(img_size[0]): 65 | img[b] = torch.from_numpy(_generate_img_from_points(pts[b].numpy(), img_size[-2:], sigma=1.)) # Use deltas to avoid cropping at the border of the Gaussianized spot introducing non-existing errors 66 | aug = SaltAndPepperNoise(prob_pepper=prob_pepper, prob_salt=prob_salt) 67 | img_aug, pts_aug = aug(img, pts) 68 | 69 | img_from_aug_pts = torch.zeros(*img_size) 70 | if all(p == 0 for p in prob_pepper) and all(p == 0 for p in prob_salt): 71 | assert torch.allclose(img, img_aug, atol=ABS_TOLERANCE), "Image changed, but should not!" 72 | assert torch.allclose(pts, pts_aug, atol=ABS_TOLERANCE), "Points changed after Gaussian noise addition, which should not be!" 73 | 74 | for b in range(img_size[0]): 75 | img_from_aug_pts[b] = torch.from_numpy(_generate_img_from_points(pts_aug[b].numpy(), img_size[-2:], sigma=1.)) 76 | 77 | if __name__ == "__main__": 78 | import matplotlib.pyplot as plt 79 | fig, ax = plt.subplots(1, 2, figsize=(8, 4)) 80 | ax[0].imshow(img[0], cmap="magma") 81 | ax[0].title.set_text("Original") 82 | ax[1].imshow(img_aug[0], cmap="magma") 83 | ax[1].title.set_text("Augmented") 84 | fig.show() 85 | 86 | if __name__ == "__main__": 87 | # test_gaussian_noise_augmentation((3, 327, 312), (0, 0.05), 100, None) 88 | test_saltpepper_noise_augmentation((3, 327, 312), (0., 0.), (0, 0.005), 100, None) 89 | 90 | -------------------------------------------------------------------------------- /spotiflow/augmentations/test/transforms/test_rotation.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import pytest 3 | import torch 4 | 5 | from spotiflow.augmentations.transforms import Rotation 6 | from spotiflow.augmentations.transforms.utils import _generate_img_from_points 7 | from typing import Tuple 8 | 9 | 10 | 11 | MSE_TOLERANCE = 1e-2 12 | 13 | 14 | @pytest.mark.parametrize("img_size", [(4, 1, 224, 224), (5, 2, 512, 512), (1, 100, 100), (3, 327, 312)]) 15 | @pytest.mark.parametrize("angle", [(-180, 180), (3, 10), None, (0, 0)]) 16 | @pytest.mark.parametrize("n_pts", [10, 100]) 17 | def test_rotation_augmentation(img_size: Tuple[int, ...], 18 | angle: Tuple[int, ...], 19 | n_pts: int, 20 | caplog): 21 | if caplog is not None: 22 | caplog.set_level(logging.CRITICAL) 23 | 24 | torch.manual_seed(img_size[-1]*n_pts) 25 | 26 | img = torch.zeros(img_size) 27 | msize = min(img_size[-2:]) 28 | pts = torch.randint(msize//3, msize-msize//3, (img_size[0], n_pts, 2)) 29 | for b in range(img_size[0]): 30 | img[b] = torch.from_numpy(_generate_img_from_points(pts[b].numpy(), img_size[-2:])) 31 | if angle is not None and (angle == (0, 0) or any(phi is None for phi in angle)): 32 | with pytest.raises(ValueError): 33 | aug = Rotation(order=1, angle=angle) 34 | else: 35 | aug = Rotation(order=1, angle=angle) 36 | img_aug, pts_aug = aug(img, pts) 37 | img_from_aug_pts = torch.zeros(img_size) 38 | for b in range(img_size[0]): 39 | img_from_aug_pts[b] = torch.from_numpy(_generate_img_from_points(pts_aug[b].numpy(), img_size[-2:])) 40 | mse = ((img_aug - img_from_aug_pts)**2).mean() 41 | if __name__ == "__main__": 42 | import matplotlib.pyplot as plt 43 | fig, ax = plt.subplots(1, 4, figsize=(16, 4)) 44 | ax[0].imshow(img[0], cmap="magma") 45 | ax[0].title.set_text("Original") 46 | ax[1].imshow(img_aug[0], cmap="magma") 47 | ax[1].title.set_text("Augmented") 48 | ax[2].imshow(img_from_aug_pts[0], cmap="magma") 49 | ax[2].title.set_text("From Augmented Points") 50 | ax[3].imshow((img_aug[0]-img_from_aug_pts[0])**2, cmap="magma") 51 | ax[3].title.set_text(f"Squared Difference (MSE: {mse:.5f})") 52 | fig.show() 53 | assert mse < MSE_TOLERANCE, "Image augmentation is not correct." 54 | 55 | if __name__ == "__main__": 56 | test_rotation_augmentation((1, 224, 224), (-180, 180), 10, None) -------------------------------------------------------------------------------- /spotiflow/augmentations/test/transforms/test_scale.py: -------------------------------------------------------------------------------- 1 | from numbers import Number 2 | 3 | import logging 4 | import pytest 5 | import torch 6 | 7 | from spotiflow.augmentations.transforms import IsotropicScale 8 | from spotiflow.augmentations.transforms.utils import _generate_img_from_points 9 | from typing import Tuple 10 | 11 | MSE_TOLERANCE = 1e-1 12 | 13 | 14 | @pytest.mark.parametrize("img_size", [(4, 1, 224, 224), (5, 2, 512, 512), (1, 100, 100), (3, 327, 312)]) 15 | @pytest.mark.parametrize("scaling_factor", [(.5, 2), (2, 3), (.2, 5), (-1, 2), (5, .2), (":)", 2)]) 16 | @pytest.mark.parametrize("n_pts", [10, 100]) 17 | def test_scale_augmentation(img_size: Tuple[int, ...], 18 | scaling_factor: Tuple[int, ...], 19 | n_pts: int, 20 | caplog): 21 | if caplog is not None: 22 | caplog.set_level(logging.CRITICAL) 23 | 24 | torch.manual_seed(img_size[-1]*n_pts) 25 | 26 | img = torch.zeros(img_size) 27 | msize = min(img_size[-2:]) 28 | 29 | 30 | 31 | pts = torch.randint(msize//3, msize-msize//3, (n_pts, 2)).repeat(img_size[0], 1, 1) 32 | for b in range(img_size[0]): 33 | img[b] = torch.from_numpy(_generate_img_from_points(pts[b].numpy(), img_size[-2:], sigma=1)) 34 | if any(not isinstance(sf, Number) or sf <= 0 for sf in scaling_factor) or scaling_factor[0] > scaling_factor[1]: 35 | with pytest.raises(ValueError): 36 | aug = IsotropicScale(order=1, scaling_factor=scaling_factor) 37 | else: 38 | aug = IsotropicScale(order=1, scaling_factor=scaling_factor) 39 | img_aug, pts_aug = aug(img, pts) 40 | img_from_aug_pts = torch.zeros(img_size) 41 | for b in range(img_size[0]): 42 | img_from_aug_pts[b] = torch.from_numpy(_generate_img_from_points(pts_aug[b].round().numpy(), img_size[-2:], sigma=1)) 43 | mse = ((img_aug - img_from_aug_pts)**2).mean() 44 | if __name__ == "__main__": 45 | import matplotlib.pyplot as plt 46 | fig, ax = plt.subplots(1, 4, figsize=(16, 4)) 47 | ax[0].imshow(img[0], cmap="magma") 48 | ax[0].title.set_text("Original") 49 | ax[1].imshow(img_aug[0], cmap="magma") 50 | ax[1].title.set_text("Augmented") 51 | ax[2].imshow(img_from_aug_pts[0], cmap="magma") 52 | ax[2].title.set_text("From Augmented Points") 53 | ax[3].imshow((img_aug[0]-img_from_aug_pts[0])**2, cmap="magma") 54 | ax[3].title.set_text(f"Squared Difference (MSE: {mse:.5f})") 55 | fig.show() 56 | assert mse < MSE_TOLERANCE, "Image augmentation is not correct." 57 | 58 | if __name__ == "__main__": 59 | test_scale_augmentation((1, 100, 100), (2, 3), 100, None) -------------------------------------------------------------------------------- /spotiflow/augmentations/test/transforms/test_translation.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import pytest 3 | import torch 4 | 5 | from spotiflow.augmentations.transforms import Translation 6 | from spotiflow.augmentations.transforms.utils import _generate_img_from_points 7 | from typing import Tuple, Union 8 | 9 | 10 | ABS_TOLERANCE = 1e-5 11 | 12 | @pytest.mark.parametrize("img_size", [(4, 1, 224, 224), (5, 2, 512, 512), (1, 100, 100), (3, 327, 312)]) 13 | @pytest.mark.parametrize("shift", [(-5, 5), (-1, 3), (0, 0), (1, 4), (-5, -2)]) 14 | @pytest.mark.parametrize("n_pts", [10, 100]) 15 | def test_translation_augmentation(img_size: Tuple[int, ...], 16 | shift: Tuple[int, ...], 17 | n_pts: int, 18 | caplog): 19 | if caplog is not None: 20 | caplog.set_level(logging.CRITICAL) 21 | 22 | torch.manual_seed(img_size[-1]*n_pts*shift[-1]) 23 | 24 | img = torch.zeros(img_size) 25 | msize = min(img_size[-2:]) 26 | pts = torch.randint(msize//3, msize-msize//3, (img_size[0], n_pts, 3)) # 3 bcs of class label 27 | for b in range(img_size[0]): 28 | img[b] = torch.from_numpy(_generate_img_from_points(pts[b,...,:2].numpy(), img_size[-2:])) 29 | 30 | if shift == (0, 0): 31 | with pytest.raises(ValueError): 32 | aug = Translation(order=0, shift=shift) 33 | else: 34 | aug = Translation(order=0, shift=shift) 35 | img_aug, pts_aug = aug(img, pts) 36 | img_from_aug_pts = torch.zeros(img_size) 37 | for b in range(img_size[0]): 38 | img_from_aug_pts[b] = torch.from_numpy(_generate_img_from_points(pts_aug[b,...,:2].numpy(), img_size[-2:])) 39 | mse = ((img_aug - img_from_aug_pts)**2).mean() 40 | if __name__ == "__main__": 41 | import matplotlib.pyplot as plt 42 | fig, ax = plt.subplots(1, 4, figsize=(16, 4)) 43 | ax[0].imshow(img[0], cmap="magma") 44 | ax[0].title.set_text("Original") 45 | ax[1].imshow(img_aug[0], cmap="magma") 46 | ax[1].title.set_text("Augmented") 47 | ax[2].imshow(img_from_aug_pts[0], cmap="magma") 48 | ax[2].title.set_text("From Augmented Points") 49 | ax[3].imshow((img_aug[0]-img_from_aug_pts[0])**2, cmap="magma") 50 | ax[3].title.set_text(f"Squared Difference (MSE: {mse:.5f})") 51 | fig.show() 52 | assert torch.allclose(img_aug, img_from_aug_pts, atol=ABS_TOLERANCE), "Image augmentation is not correct." 53 | 54 | if __name__ == "__main__": 55 | test_translation_augmentation((1, 224, 224), (-5, 5), 10, None) 56 | 57 | -------------------------------------------------------------------------------- /spotiflow/augmentations/transforms/__init__.py: -------------------------------------------------------------------------------- 1 | from .crop import Crop 2 | from .fliprot import FlipRot90 3 | from .intensity_shift import IntensityScaleShift 4 | from .noise import GaussianNoise, SaltAndPepperNoise 5 | from .rotation import Rotation 6 | from .scale import IsotropicScale 7 | from .translation import Translation 8 | -------------------------------------------------------------------------------- /spotiflow/augmentations/transforms/base.py: -------------------------------------------------------------------------------- 1 | import abc 2 | from typing import Tuple 3 | import torch 4 | 5 | class BaseAugmentation(abc.ABC): 6 | def __init__(self, probability: float, **kwargs): 7 | assert 0 <= probability <= 1 8 | self._probability = probability 9 | 10 | @property 11 | def probability(self) -> float: 12 | """Probability of applying every augmentation at every call 13 | 14 | Returns: 15 | float: probability of applying every augmentation at every call 16 | """ 17 | return self._probability 18 | 19 | @abc.abstractmethod 20 | def apply(self, img: torch.Tensor, pts: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: 21 | """ implementation of the augmentation""" 22 | pass 23 | 24 | def __call__(self, img: torch.Tensor, pts: torch.Tensor): 25 | if self._should_apply(): 26 | return self.apply(img, pts) 27 | else: 28 | return img, pts 29 | 30 | def _should_apply(self) -> bool: 31 | """Sample from a [0,1) uniform distribution and compare to the probability 32 | of applying the augmentation in order to decide whether to apply 33 | the augmentation or not. 34 | 35 | Returns: 36 | bool: return whether the augmentation should be applied 37 | """ 38 | return torch.rand(()) < self.probability 39 | -------------------------------------------------------------------------------- /spotiflow/augmentations/transforms/crop.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple 2 | 3 | import torch 4 | import torchvision.transforms.functional as tvf 5 | 6 | from .base import BaseAugmentation 7 | from .utils import _filter_points_idx 8 | 9 | 10 | class Crop(BaseAugmentation): 11 | def __init__(self, size: Tuple[int, int], probability: float=1.0, point_priority: float=0): 12 | """Augmentation class for random crops 13 | 14 | Args: 15 | probability (float): probability of applying the augmentation. Must be in [0, 1]. 16 | size (Tuple[int, int]): size of the crop in (y, x) format 17 | priority (float): prioritizes crops centered around keypoints. Must be in [0, 1]. 18 | 19 | """ 20 | super().__init__(probability) 21 | assert len(size) == 2 and all([isinstance(s, int) for s in size]), "Size must be a 2-length tuple of integers" 22 | self._size = size 23 | self._point_priority = point_priority 24 | 25 | @property 26 | def size(self) -> Tuple[int, int]: 27 | return self._size 28 | 29 | @property 30 | def point_priority(self) -> float: 31 | return self._point_priority 32 | 33 | def apply(self, img: torch.Tensor, pts: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: 34 | # Generate random top-left anchor 35 | y, x = img.shape[-2:] 36 | assert y >= self.size[0] and x >= self.size[1], "Image is smaller than crop size" 37 | cy, cx = self._generate_tl_anchor(y, x, pts) 38 | # Crop image 39 | img_c = tvf.crop(img, top=cy, left=cx, height=self.size[0], width=self.size[1]) 40 | 41 | # Crop points 42 | if pts.shape[-1] == 2: 43 | pts_c = pts - torch.FloatTensor([cy, cx]) 44 | else: 45 | pts_c = pts - torch.FloatTensor([cy, cx, 0]) 46 | idxs_in = _filter_points_idx(pts_c, self.size) 47 | return img_c, pts_c[idxs_in].view(*pts.shape[:-2], -1, pts.shape[-1]) 48 | 49 | def _generate_tl_anchor(self, y: int, x: int, pts: torch.Tensor) -> Tuple[int, int]: 50 | prob = torch.FloatTensor(1).uniform_(0, 1).item() 51 | 52 | if prob>self.point_priority: 53 | # Randomly generate top-left anchor 54 | cy, cx = torch.FloatTensor(1).uniform_(0, y-self.size[0]).item(), torch.FloatTensor(1).uniform_(0, x-self._size[1]).item() 55 | return int(cy), int(cx) 56 | else: 57 | width = self.size[0]//4, self.size[1]//4 58 | # Remove points that are not anchor candidates 59 | valid_pt_coords = pts[(pts[..., 0] >= self.size[0]//2-width[0]) & (pts[..., 0] < y-self.size[0]//2+width[0]) & (pts[..., 1] >= self.size[1]//2-width[1]) & (pts[..., 1] < x-self.size[1]//2+width[1])][:, :2] 60 | if valid_pt_coords.shape[0] == 0: 61 | # sample randomly if no points are valid 62 | cy, cx = torch.FloatTensor(1).uniform_(0, y-self.size[0]).item(), torch.FloatTensor(1).uniform_(0, x-self._size[1]).item() 63 | else: 64 | # select a point 65 | center_idx = torch.randint(0, valid_pt_coords.shape[0], (1,)).item() 66 | cy, cx = valid_pt_coords[center_idx] 67 | cy = cy + torch.randint(-width[0], width[0]+1, (1,)) 68 | cx = cx + torch.randint(-width[1], width[1]+1, (1,)) 69 | cy -= self.size[0]//2 70 | cx -= self.size[1]//2 71 | cy = torch.clip(cy, 0, y-self.size[0]).item() 72 | cx = torch.clip(cx, 0, x-self.size[1]).item() 73 | return int(cy), int(cx) 74 | 75 | def __repr__(self) -> str: 76 | return f"Crop(size={self.size}, probability={self.probability})" -------------------------------------------------------------------------------- /spotiflow/augmentations/transforms/fliprot.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Tuple 2 | import itertools 3 | import torch 4 | 5 | from .base import BaseAugmentation 6 | from .utils import _flatten_axis 7 | 8 | def _subgroup_flips(ndim: int, axis: Optional[Tuple[int, ...]]=None) -> Tuple[Tuple[bool, ...], ...]: 9 | """Adapted from https://github.com/stardist/augmend/blob/main/augmend/transforms/affine.py not to depend on numpy 10 | iterate over the product subgroup (False,True) of given axis 11 | """ 12 | axis = _flatten_axis(ndim, axis) 13 | res = [False for _ in range(ndim)] 14 | for prod in itertools.product((False, True), repeat=len(axis)): 15 | for a, p in zip(axis, prod): 16 | res[a] = p 17 | yield tuple(res) 18 | 19 | def _fliprot_pts(pts: torch.Tensor, dims_to_flip: Tuple[int, ...], shape: Tuple[int, int], ndims: int) -> torch.Tensor: 20 | """Flip and rotate points accordingly to the flipping dimensions. 21 | 22 | Args: 23 | pts (torch.Tensor): points to be flipped and rotated. 24 | dims_to_flip (Tuple[int]): indices of the dimensions to be flipped. 25 | shape (Tuple[int, int]): shape of the image. 26 | 27 | Returns: 28 | torch.Tensor: flipped and rotated points. 29 | """ 30 | y, x = shape 31 | pts_fr = pts.clone() 32 | for dim in dims_to_flip: 33 | if dim == ndims-2: 34 | pts_fr[..., 0] = y - 1 - pts_fr[..., 0] 35 | elif dim == ndims-1: 36 | pts_fr[..., 1] = x - 1 - pts_fr[..., 1] 37 | return pts_fr 38 | 39 | class FlipRot90(BaseAugmentation): 40 | def __init__(self, probability: float=1.0) -> None: 41 | """Augmentation class for FlipRot90 augmentation. 42 | 43 | Args: 44 | probability (float): probability of applying the augmentation. Must be in [0, 1]. 45 | 46 | """ 47 | super().__init__(probability) 48 | 49 | def apply(self, img: torch.Tensor, pts: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: 50 | """Applies FlipRot90 augmentation to the given image and points. 51 | """ 52 | # Randomly choose the spatial axis/axes to flip 53 | combs = tuple(_subgroup_flips(img.ndim, axis=(-2, -1))) 54 | idx = torch.randint(len(combs), (1,)).item() 55 | dims_to_flip = tuple(i for i, c in enumerate(combs[idx]) if c) 56 | 57 | # Return original image and points if no axis is flipped 58 | if len(dims_to_flip) == 0: 59 | return img, pts 60 | # Flip image and points 61 | 62 | img_fr = torch.flip(img, dims_to_flip) 63 | pts_fr = _fliprot_pts(pts, dims_to_flip, img.shape[-2:], ndims=img.ndim) 64 | return img_fr, pts_fr 65 | 66 | 67 | def __repr__(self) -> str: 68 | return f"FlipRot90(probability={self.probability})" -------------------------------------------------------------------------------- /spotiflow/augmentations/transforms/intensity_shift.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple 2 | import torch 3 | 4 | from .base import BaseAugmentation 5 | 6 | 7 | class IntensityScaleShift(BaseAugmentation): 8 | def __init__(self, scale: Tuple[float, float]=(.8, 1.2), 9 | shift: Tuple[float, float]=(-.1, .1), probability: float=1.0) -> None: 10 | """Augmentation class for shifting and scaling the intensity of the image. 11 | 12 | I = I * scale + shift 13 | 14 | Args: 15 | probability (float): probability of applying the augmentation. Must be in [0, 1]. 16 | scale (Tuple[int, int]): range of the scaling factor to apply to the image. 17 | shift (Tuple[int, int]): range of the shift to apply to the image. 18 | """ 19 | super().__init__(probability) 20 | assert len(scale) == 2 and all([isinstance(s, float) for s in scale]), "Scale must be a 2-length tuple of floating point numbers." 21 | assert len(shift) == 2 and all([isinstance(s, float) for s in shift]), "Shift must be a 2-length tuple of floating point numbers." 22 | assert scale[0] <= scale[1], "First element of scale must be smaller or equal to the second element." 23 | assert shift[0] <= shift[1], "First element of shift must be smaller or equal to the second element." 24 | self._scale = scale 25 | self._shift = shift 26 | 27 | 28 | def apply(self, img: torch.Tensor, pts: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: 29 | """Applies IntensityScaleShift augmentation to the given image and points. 30 | Args: 31 | img (torch.Tensor): image to be augmented. 32 | pts (torch.Tensor): points to be augmented. 33 | Returns: 34 | Tuple[torch.Tensor, torch.Tensor]: augmented image and points. 35 | """ 36 | # Randomly choose the scaling factor and shift 37 | sampled_scale = self._sample_scale(img.device) 38 | sampled_shift = self._sample_shift(img.device) 39 | return sampled_scale*img + sampled_shift, pts 40 | 41 | def _sample_scale(self, device): 42 | return torch.empty(1, dtype=torch.float32, device=device).uniform_(*self._scale) 43 | 44 | def _sample_shift(self, device): 45 | return torch.empty(1, dtype=torch.float32, device=device).uniform_(*self._shift) 46 | 47 | def __repr__(self) -> str: 48 | return f"IntensityScaleShift(scale={self._scale}, shift={self._shift}, probability={self.probability})" -------------------------------------------------------------------------------- /spotiflow/augmentations/transforms/noise.py: -------------------------------------------------------------------------------- 1 | from numbers import Number 2 | from typing import Tuple 3 | import torch 4 | 5 | from .base import BaseAugmentation 6 | 7 | 8 | class GaussianNoise(BaseAugmentation): 9 | def __init__(self, sigma: Tuple[float, float]=(0., 0.05), probability: float=1.0) -> None: 10 | """Augmentation class for shifting and scaling the intensity of the image. 11 | 12 | I = I * scale + shift 13 | 14 | Args: 15 | probability (float): probability of applying the augmentation. Must be in [0, 1]. 16 | scale (Tuple[int, int]): range of the scaling factor to apply to the image. 17 | shift (Tuple[int, int]): range of the shift to apply to the image. 18 | """ 19 | super().__init__(probability) 20 | assert len(sigma) == 2 and all([isinstance(s, Number) for s in sigma]), "Sigma must be a 2-length tuple of floating point numbers." 21 | self._sigma = sigma 22 | 23 | def apply(self, img: torch.Tensor, pts: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: 24 | """Applies IntensityScaleShift augmentation to the given image and points. 25 | Args: 26 | img (torch.Tensor): image to be augmented. 27 | pts (torch.Tensor): points to be augmented. 28 | Returns: 29 | Tuple[torch.Tensor, torch.Tensor]: augmented image and points. 30 | """ 31 | # Randomly choose the scaling factor and shift 32 | sampled_sigma = self._sample_sigma() 33 | noise = torch.randn_like(img, dtype=torch.float32, device=img.device) * sampled_sigma 34 | return img + noise, pts 35 | 36 | def _sample_sigma(self): 37 | return torch.empty(1, dtype=torch.float32).uniform_(*self._sigma) 38 | 39 | def __repr__(self) -> str: 40 | return f"GaussianNoise(sigma={self._sigma}, probability={self.probability})" 41 | 42 | class SaltAndPepperNoise(BaseAugmentation): 43 | def __init__(self, prob_pepper: Tuple[float, float]=(0., 0.001), prob_salt: Tuple[float, float]=(0., 0.001), probability: float=1.0): 44 | """Augmentation class for addition of salt and pepper noise to the image. 45 | 46 | Each pixel is randomly set to the minimum (pepper) or maximum (salt) value of the image with a given probability. 47 | 48 | Args: 49 | probability (float): probability of applying the augmentation. Must be in [0, 1]. 50 | prob_pepper (Tuple[float, float]): range of the potential probability of adding pepper noise to the image. 51 | prob_salt (Tuple[float, float]): range of the potential probability of adding salt noise to the image. 52 | """ 53 | super().__init__(probability) 54 | 55 | assert len(prob_pepper) == 2 and \ 56 | all([isinstance(s, Number) for s in prob_pepper]) and \ 57 | all([s >= 0 and s<= 1 for s in prob_pepper]), "prob_pepper must be a 2-length tuple of floating point numbers between zero and one." 58 | 59 | assert len(prob_salt) == 2 and \ 60 | all([isinstance(s, Number) for s in prob_salt]) and \ 61 | all([s >= 0 and s<= 1 for s in prob_salt]), "prob_salt must be a 2-length tuple of floating point numbers between zero and one." 62 | 63 | assert prob_pepper[0] <= prob_pepper[1], "First element of prob_pepper must be smaller or equal to the second element." 64 | assert prob_salt[0] <= prob_salt[1], "First element of prob_salt must be smaller or equal to the second element." 65 | 66 | self._prob_salt = prob_salt 67 | self._prob_pepper = prob_pepper 68 | 69 | def apply(self, img: torch.Tensor, pts: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: 70 | sampled_prob_pepper = self._sample_prob_pepper() 71 | sampled_prob_salt = self._sample_prob_salt() 72 | 73 | pepper_mask = torch.empty_like(img, dtype=torch.float32).uniform_(0, 1) 74 | salt_mask = torch.empty_like(img, dtype=torch.float32).uniform_(0, 1) 75 | 76 | pepper_value = img.min() 77 | salt_value = img.max() 78 | 79 | img_sp = img.clone() 80 | img_sp[pepper_mask < sampled_prob_pepper] = pepper_value 81 | img_sp[salt_mask < sampled_prob_salt] = salt_value 82 | return img_sp, pts 83 | 84 | def _sample_prob_pepper(self): 85 | return torch.empty(1, dtype=torch.float32).uniform_(*self._prob_pepper) 86 | 87 | def _sample_prob_salt(self): 88 | return torch.empty(1, dtype=torch.float32).uniform_(*self._prob_salt) 89 | 90 | 91 | def __repr__(self) -> str: 92 | return f"SaltAndPepperNoise(prob_pepper={self._prob_pepper}, prob_salt={self._prob_salt}, probability={self.probability})" -------------------------------------------------------------------------------- /spotiflow/augmentations/transforms/rotation.py: -------------------------------------------------------------------------------- 1 | from numbers import Number 2 | from torchvision.transforms import InterpolationMode 3 | from typing import Literal, Optional, Tuple 4 | import math 5 | import torch 6 | import torchvision.transforms.functional as tvf 7 | 8 | from .base import BaseAugmentation 9 | from .utils import _filter_points_idx 10 | 11 | def _affine_rotation_matrix(phi:float, center:tuple[float]) -> torch.Tensor: 12 | cy, cx = center 13 | translation_mat = torch.FloatTensor([ 14 | [1., 0., -cy], 15 | [0., 1., -cx], 16 | [0., 0., 1.] 17 | ]) 18 | translation_mat_inv = torch.FloatTensor([ 19 | [1., 0., cy], 20 | [0., 1., cx], 21 | [0., 0., 1.] 22 | ]) 23 | 24 | si, co = math.sin(phi), math.cos(phi) 25 | 26 | rot_matrix = torch.FloatTensor([ 27 | [co, -si, 0.], 28 | [si, co, 0.], 29 | [0., 0., 1.] 30 | ]) 31 | 32 | affine_mat = translation_mat_inv@rot_matrix@translation_mat 33 | return affine_mat 34 | 35 | 36 | class Rotation(BaseAugmentation): 37 | def __init__(self, order: Literal[0, 1]=1, angle: Optional[Number] = (-180,180), probability: float=1.0) -> None: 38 | """Augmentation class for random rotations 39 | 40 | Args: 41 | probability (float): probability of applying the augmentation. Must be in [0, 1]. 42 | order (Literal[0, 1]): order of interpolation. Use 0 for nearest neighbor, 1 for bilinear. 43 | angle (Optional[float]): +/- rotation angle (in degrees). If None, angle is randomly sampled from [-180,180]. Defaults to None. 44 | 45 | """ 46 | super().__init__(probability) 47 | self._order = int(order) 48 | if self._order == 0: 49 | self._interp_mode = InterpolationMode.NEAREST 50 | elif self._order == 1: 51 | self._interp_mode = InterpolationMode.BILINEAR 52 | else: 53 | raise ValueError("Order must be 0 or 1.") 54 | if angle is None: 55 | angle = (-180, 180) 56 | elif isinstance(angle, Number): 57 | angle = (-angle, angle) 58 | elif isinstance(angle, tuple) and len(angle) == 2 and all(isinstance(x, Number) for x in angle): 59 | angle = angle 60 | else: 61 | raise ValueError("angle must be either a number or a tuple of two numbers") 62 | self._angle = angle 63 | if all(phi == 0 for phi in angle): 64 | raise ValueError("Angle range cannot be (0, 0).") 65 | 66 | def apply(self, img: torch.Tensor, pts: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: 67 | # Generate random rotation angle 68 | phi_deg = self._sample_angle() 69 | phi_rad = phi_deg * math.pi / 180. 70 | 71 | # Rotate image 72 | img_r = tvf.rotate(img, phi_deg, interpolation=self._interp_mode) 73 | 74 | # Generate affine transformation matrix 75 | y, x = img.shape[-2:] 76 | center_y, center_x = (y-1)/2, (x-1)/2 77 | 78 | affine_mat = _affine_rotation_matrix(phi_rad, (center_y, center_x)).to(img.device) 79 | # Rotate points 80 | if pts.shape[2] == 2: # no class labels 81 | should_add_cls_label = False 82 | affine_coords = torch.cat([pts.float(), torch.ones((*pts.shape[:-1],1), device=img.device)], axis=-1) # Euclidean -> homogeneous coordinates 83 | else: 84 | should_add_cls_label = True 85 | affine_coords = torch.cat([pts[..., :-1].float(), torch.ones((*pts.shape[:-1],1), device=img.device)], axis=-1) # Euclidean -> homogeneous coordinates 86 | 87 | pts_r = (affine_coords@affine_mat.T) 88 | pts_r = pts_r[..., :-1] # Homogeneous -> Euclidean coordinates 89 | if should_add_cls_label: 90 | pts_r = torch.cat([pts_r, pts[..., -1:]], axis=-1) 91 | 92 | idxs_in = _filter_points_idx(pts_r, img_r.shape[-2:]) 93 | return img_r, pts_r[idxs_in].view(*pts.shape[:-2], -1, pts.shape[-1]) 94 | 95 | def _sample_angle(self): 96 | return torch.FloatTensor(1).uniform_(*self._angle).item() 97 | 98 | 99 | def __repr__(self) -> str: 100 | return f"Rotation(angle={self._angle}, probability={self.probability})" -------------------------------------------------------------------------------- /spotiflow/augmentations/transforms/scale.py: -------------------------------------------------------------------------------- 1 | from numbers import Number 2 | from torchvision.transforms import InterpolationMode 3 | from typing import Literal, Optional, Tuple 4 | import torch 5 | import torchvision.transforms.functional as tvf 6 | 7 | from .base import BaseAugmentation 8 | from .utils import _filter_points_idx 9 | 10 | def _affine_scaling_matrix(scaling_factor: float, center:tuple[float]) -> torch.Tensor: 11 | """Generate a centered scale matrix transformation acting on 2D euclidean coordinates 12 | 13 | Args: 14 | scaling_factor (float): scaling factor 15 | 16 | Returns: 17 | torch.Tensor: scale transformation matrix 18 | """ 19 | cy, cx = center 20 | translation_mat = torch.FloatTensor([ 21 | [1., 0., -cy], 22 | [0., 1., -cx], 23 | [0., 0., 1.] 24 | ]) 25 | translation_mat_inv = torch.FloatTensor([ 26 | [1., 0., cy], 27 | [0., 1., cx], 28 | [0., 0., 1.] 29 | ]) 30 | scaling_mat = torch.FloatTensor([ 31 | [scaling_factor, 0., 0.], 32 | [0., scaling_factor, 0.], 33 | [0., 0., 1.] 34 | ]) 35 | affine_mat = translation_mat_inv@scaling_mat@translation_mat 36 | return affine_mat 37 | 38 | 39 | class IsotropicScale(BaseAugmentation): 40 | def __init__(self, order: Literal[0, 1]=1, 41 | scaling_factor: Optional[Tuple[Number, Number]] = (0.5,2.0), probability: float=1.0) -> None: 42 | """Augmentation class for random isotropic scaling 43 | 44 | Args: 45 | probability (float): probability of applying the augmentation. Must be in [0, 1]. 46 | order (Literal[0, 1]): order of interpolation. Use 0 for nearest neighbor, 1 for bilinear. 47 | scaling_factor (Optional[Tuple[Number, Number]], optional): scaling factor range to be sampled. If None, scaling factor is randomly sampled from [0.5, 2.0]. Defaults to None. 48 | """ 49 | super().__init__(probability) 50 | self._order = int(order) 51 | if self._order == 0: 52 | self._interp_mode = InterpolationMode.NEAREST 53 | elif self._order == 1: 54 | self._interp_mode = InterpolationMode.BILINEAR 55 | else: 56 | raise ValueError("Order must be 0 or 1.") 57 | 58 | if all(isinstance(x, Number) for x in scaling_factor) and len(scaling_factor) == 2: 59 | scaling_factor = scaling_factor 60 | else: 61 | raise ValueError("Scaling factor must be a 2-length tuple of numbers.") 62 | self._scaling_factor = scaling_factor 63 | if any(sf <= 0 for sf in self._scaling_factor) or self._scaling_factor[0] > self._scaling_factor[1]: 64 | raise ValueError("Scaling factor must be in (0, +inf) and the first element must be smaller than the second one.") 65 | 66 | 67 | def apply(self, img: torch.Tensor, pts: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: 68 | """Apply the augmentation to the input image and points 69 | 70 | Args: 71 | img (torch.Tensor): input image tensor of shape (C, H, W) 72 | pts (torch.Tensor): input points tensor of shape (N, 2) or (N, 3) 73 | 74 | Returns: 75 | Tuple[torch.Tensor, torch.Tensor]: transformed image tensor of shape (C, H, W) and transformed points tensor of shape (N, 2) or (N, 3) 76 | """ 77 | # Sample scaling factor 78 | sampled_scaling_factor = self._sample_scaling_factor() 79 | img_scaled = tvf.affine(img, 80 | angle=0, 81 | translate=(0,0), 82 | scale=sampled_scaling_factor, 83 | shear=0, 84 | interpolation=self._interp_mode) 85 | 86 | # Generate affine transformation matrix 87 | y, x = img.shape[-2:] 88 | center_y, center_x = (y-1)/2, (x-1)/2 89 | affine_mat = _affine_scaling_matrix(sampled_scaling_factor, (center_y, center_x)).to(img.device) 90 | 91 | # Scale points 92 | if pts.shape[2] == 2: # no class labels 93 | should_add_cls_label = False 94 | affine_coords = torch.cat([pts.float(), torch.ones((*pts.shape[:-1],1), device=img.device)], axis=-1) # Euclidean -> homogeneous coordinates 95 | else: 96 | should_add_cls_label = True 97 | affine_coords = torch.cat([pts[..., :-1].float(), torch.ones((*pts.shape[:-1],1), device=img.device)], axis=-1) # Euclidean -> homogeneous coordinates 98 | pts_scaled = (affine_coords@affine_mat.T) 99 | pts_scaled = pts_scaled[..., :-1] # Homogeneous -> Euclidean coordinates 100 | if should_add_cls_label: 101 | pts_scaled = torch.cat([pts_scaled, pts[..., -1:]], axis=-1) 102 | 103 | 104 | idxs_in = _filter_points_idx(pts_scaled, img_scaled.shape[-2:]) 105 | return img_scaled, pts_scaled[idxs_in].view(*pts.shape[:-2], -1, pts.shape[-1]) 106 | 107 | def _sample_scaling_factor(self) -> float: 108 | return torch.FloatTensor(1).uniform_(*self._scaling_factor).item() 109 | 110 | def __repr__(self) -> str: 111 | return f"IsotropicScale(scaling_factor={self._scaling_factor}, probability={self.probability})" 112 | 113 | class AnisotropicScale(BaseAugmentation): 114 | # TODO 115 | def __init__(self, probability: float): 116 | raise NotImplementedError("AnisotropicScaleAugmentation is not implemented yet") 117 | -------------------------------------------------------------------------------- /spotiflow/augmentations/transforms/translation.py: -------------------------------------------------------------------------------- 1 | from torchvision.transforms import InterpolationMode 2 | from typing import Literal, Tuple 3 | import torch 4 | import torchvision.transforms.functional as tvf 5 | from numbers import Number 6 | from .base import BaseAugmentation 7 | from .utils import _filter_points_idx 8 | 9 | class Translation(BaseAugmentation): 10 | def __init__(self, order: Literal[0, 1]=1, shift: Tuple[int, int]=(-5,5), probability: float=1.0) -> None: 11 | """Augmentation class for random isotropic scaling 12 | 13 | Args: 14 | probability (float): probability of applying the augmentation. Must be in [0, 1]. 15 | order (Literal[0, 1]): order of interpolation. Use 0 for nearest neighbor, 1 for bilinear. 16 | scaling_factor (Optional[Tuple[Number, Number]], optional): scaling factor range to be sampled. If None, scaling factor is randomly sampled from [0.5, 2.0]. Defaults to None. 17 | 18 | """ 19 | super().__init__(probability) 20 | self._order = int(order) 21 | if self._order == 0: 22 | self._interp_mode = InterpolationMode.NEAREST 23 | elif self._order == 1: 24 | self._interp_mode = InterpolationMode.BILINEAR 25 | else: 26 | raise ValueError("Order must be 0 or 1.") 27 | 28 | if isinstance(shift, Number): 29 | shift = (-shift, shift) 30 | elif all(isinstance(x, int) for x in shift) and len(shift) == 2: 31 | shift = shift 32 | else: 33 | raise ValueError("Shift must be either a single integer or a 2-length tuple of integers.") 34 | self._shift = shift 35 | if all(x == 0 for x in shift): 36 | raise ValueError("Shift range cannot be (0, 0).") 37 | 38 | 39 | def apply(self, img: torch.Tensor, pts: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: 40 | """Apply the augmentation to the input image and points 41 | 42 | Args: 43 | img (torch.Tensor): input image tensor of shape (C, H, W) 44 | pts (torch.Tensor): input points tensor of shape (N, 2) or (N, 3) 45 | 46 | Returns: 47 | Tuple[torch.Tensor, torch.Tensor]: transformed image tensor of shape (C, H, W) and transformed points tensor of shape (N, 2) or (N, 3) 48 | """ 49 | # Sample shift 50 | sampled_translation = self._sample_shift(img.device) 51 | tr_x, tr_y = sampled_translation 52 | # Skip if no translation 53 | if tr_y == 0 and tr_x == 0: 54 | return img, pts 55 | 56 | # Apply translation to image and points 57 | img_translated = tvf.affine(img, 58 | angle=0, 59 | translate=(tr_y, tr_x), # First axis arg is horizontal shift, second is vertical shift 60 | scale=1, 61 | shear=0, 62 | interpolation=self._interp_mode) 63 | 64 | pts_translated = pts[...,:-1] + sampled_translation 65 | pts_translated = torch.cat([pts_translated, pts[...,-1:]], dim=-1) 66 | idxs_in = _filter_points_idx(pts_translated, img_translated.shape[-2:]) 67 | pts_translated = pts_translated[idxs_in].view(*pts.shape[:-2], -1, pts.shape[-1]) 68 | return img_translated, pts_translated 69 | 70 | 71 | def _sample_shift(self, device: torch.device) -> torch.Tensor: 72 | """Sample a random shift from the shift range 73 | 74 | Returns: 75 | torch.Tensor: random shift 76 | """ 77 | return torch.randint(*self._shift, size=(2,), device=device) 78 | 79 | 80 | def __repr__(self) -> str: 81 | return f"Translation(shift={self._shift}, probability={self.probability})" -------------------------------------------------------------------------------- /spotiflow/augmentations/transforms/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | def _filter_points_idx(pts: torch.Tensor, shape: tuple[int]) -> torch.Tensor: 4 | """returns indices of points that are within the image boundaries""" 5 | if pts.shape[-1] == len(shape): 6 | return torch.all(torch.logical_and(pts >= 0, pts < torch.tensor(shape, device=pts.device)), dim=-1) 7 | elif pts.shape[-1] == len(shape)+1: # Last dimension is class label, ignore it 8 | # Ignore last dimension, then add it back 9 | return torch.all(torch.logical_and(pts[..., :-1] >= 0, pts[..., :-1] < torch.tensor(shape, device=pts.device)), dim=-1) 10 | else: 11 | raise ValueError(f"Points should have shape (N, {len(shape)}) or (N, {len(shape)+1}) if class labels are given") 12 | 13 | 14 | def _flatten_axis(ndim, axis=None): 15 | """Adapted from https://github.com/stardist/augmend/blob/main/augmend/transforms/affine.py not to depend on numpy 16 | converts axis to a flatten tuple 17 | e.g. 18 | flatten_axis(3, axis = None) = (0,1,2) 19 | flatten_axis(4, axis = (-2,-1)) = (2,3) 20 | """ 21 | 22 | # allow for e.g. axis = -1, axis = None, ... 23 | all_axis = tuple(range(ndim)) 24 | 25 | if axis is None: 26 | axis = tuple(all_axis) 27 | else: 28 | if isinstance(axis, int): 29 | axis = [axis, ] 30 | elif isinstance(axis, tuple): 31 | axis = list(axis) 32 | if max(axis) > max(all_axis): 33 | raise ValueError("axis = %s too large" % max(axis)) 34 | axis = tuple([all_axis[i] for i in axis]) 35 | return axis 36 | 37 | 38 | def _generate_img_from_points(points, shape, sigma=1.): 39 | """Adapted from https://github.com/weigertlab/spotipy-torch/blob/main/spotipy_torch/utils/utils.py""" 40 | import numpy as np 41 | from scipy.spatial.distance import cdist 42 | import networkx as nx 43 | import scipy.ndimage as ndi 44 | def _filter_shape(points, shape, idxr_array=None): 45 | """ returns all values in "points" that are inside the shape as given by the indexer array 46 | if the indexer array is None, then the array to be filtered itself is used 47 | """ 48 | if idxr_array is None: 49 | idxr_array = points.copy() 50 | assert idxr_array.ndim==2 and idxr_array.shape[1]==2 51 | idx = np.all(np.logical_and(idxr_array >= 0, idxr_array < np.array(shape)), axis=1) 52 | return points[idx] 53 | 54 | x = np.zeros(shape, np.float32) 55 | points = np.asarray(points).astype(np.int32) 56 | assert points.ndim==2 and points.shape[1]==2 57 | 58 | points = _filter_shape(points, shape) 59 | 60 | if len(points)==0: 61 | return x 62 | D = cdist(points, points) 63 | A = D < 8*sigma+1 64 | np.fill_diagonal(A, False) 65 | G = nx.from_numpy_array(A) 66 | x = np.zeros(shape, np.float32) 67 | while len(G)>0: 68 | inds = nx.maximal_independent_set(G) 69 | gauss = np.zeros(shape, np.float32) 70 | gauss[tuple(points[inds].T)] = 1 71 | g = ndi.gaussian_filter(gauss, sigma, mode= "constant") 72 | g /= np.max(g) 73 | x = np.maximum(x,g) 74 | G.remove_nodes_from(inds) 75 | return x -------------------------------------------------------------------------------- /spotiflow/augmentations/transforms3d/__init__.py: -------------------------------------------------------------------------------- 1 | from .crop import Crop3D 2 | from .fliprot import FlipRot903D 3 | from .intensity_shift import IntensityScaleShift3D 4 | from .noise import GaussianNoise3D, SaltAndPepperNoise3D 5 | from .rotation import RotationYX3D 6 | from .translation import TranslationYX3D 7 | -------------------------------------------------------------------------------- /spotiflow/augmentations/transforms3d/crop.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple 2 | 3 | import torch 4 | import torchvision.transforms.functional as tvf 5 | 6 | from ..transforms.base import BaseAugmentation 7 | from ..transforms.utils import _filter_points_idx 8 | 9 | 10 | class Crop3D(BaseAugmentation): 11 | def __init__(self, size: Tuple[int, int, int], probability: float=1.0, point_priority: float=0): 12 | """Augmentation class for random crops 13 | 14 | Args: 15 | probability (float): probability of applying the augmentation. Must be in [0, 1]. 16 | size (Tuple[int, int]): size of the crop in (z, y, x) format 17 | priority (float): prioritizes crops centered around keypoints. Must be in [0, 1]. 18 | """ 19 | super().__init__(probability) 20 | assert len(size) == 3 and all([isinstance(s, int) for s in size]), "Size must be a 3-length tuple of integers" 21 | self._size = size 22 | self._point_priority = point_priority 23 | 24 | @property 25 | def size(self) -> Tuple[int, int, int]: 26 | return self._size 27 | 28 | @property 29 | def point_priority(self) -> float: 30 | return self._point_priority 31 | 32 | def apply(self, img: torch.Tensor, pts: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: 33 | # Generate random front-top-left anchor 34 | z, y, x = img.shape[-3:] 35 | assert z >= self.size[0] and y >= self.size[1] and x >= self.size[2], "Image is smaller than crop size" 36 | cz, cy, cx = self._generate_tl_anchor(z, y, x, pts) 37 | 38 | # Crop volume 39 | img_c = img[..., cz:cz+self.size[0], cy:cy+self.size[1], cx:cx+self.size[2]] 40 | 41 | # # Crop image 42 | # img_c = tvf.crop(img, top=cy, left=cx, height=self.size[0], width=self.size[1]) 43 | 44 | # Crop points 45 | pts_c = pts - torch.FloatTensor([cz, cy, cx]) 46 | idxs_in = _filter_points_idx(pts_c, self.size) 47 | return img_c, pts_c[idxs_in].view(*pts.shape[:-2], -1, pts.shape[-1]) 48 | 49 | def _generate_tl_anchor(self, z: int, y: int, x: int, pts: torch.Tensor) -> Tuple[int, int, int]: 50 | prob = torch.FloatTensor(1).uniform_(0, 1).item() 51 | 52 | if prob>self.point_priority: 53 | # Randomly generate top-left anchor 54 | cz, cy, cx = torch.FloatTensor(1).uniform_(0, z-self.size[0]).item(), torch.FloatTensor(1).uniform_(0, y-self._size[1]).item(), torch.FloatTensor(1).uniform_(0, x-self._size[2]).item() 55 | return int(cz), int(cy), int(cx) 56 | else: 57 | width = self.size[0]//4, self.size[1]//4, self.size[2]//4 58 | # Remove points that are not anchor candidates 59 | 60 | valid_pt_coords = pts[(pts[..., 0] >= self.size[0]//2-width[0]) & (pts[..., 0] < z-self.size[0]//2+width[0]) & (pts[..., 1] >= self.size[1]//2-width[1]) & (pts[..., 1] < y-self.size[1]//2+width[1]) & (pts[..., 2] >= self.size[2]//2-width[2]) & (pts[..., 2] < x-self.size[2]//2+width[2])] 61 | if valid_pt_coords.shape[0] == 0: 62 | # sample randomly if no points are valid 63 | cz, cy, cx = torch.FloatTensor(1).uniform_(0, z-self.size[0]).item(), torch.FloatTensor(1).uniform_(0, y-self._size[1]).item(), torch.FloatTensor(1).uniform_(0, x-self._size[2]).item() 64 | else: 65 | # select a point 66 | center_idx = torch.randint(0, valid_pt_coords.shape[0], (1,)).item() 67 | cz, cy, cx = valid_pt_coords[center_idx] 68 | cz = cz + torch.randint(-width[0], width[0]+1, (1,)) 69 | cy = cy + torch.randint(-width[1], width[1]+1, (1,)) 70 | cx = cx + torch.randint(-width[2], width[2]+1, (1,)) 71 | cz -= self.size[0]//2 72 | cy -= self.size[1]//2 73 | cx -= self.size[2]//2 74 | cz = torch.clip(cz, 0, z-self.size[0]).item() 75 | cy = torch.clip(cy, 0, y-self.size[1]).item() 76 | cx = torch.clip(cx, 0, x-self.size[2]).item() 77 | return int(cz), int(cy), int(cx) 78 | 79 | def __repr__(self) -> str: 80 | return f"Crop(size={self.size}, probability={self.probability})" -------------------------------------------------------------------------------- /spotiflow/augmentations/transforms3d/fliprot.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Tuple 2 | import itertools 3 | import torch 4 | 5 | from ..transforms.base import BaseAugmentation 6 | from ..transforms.utils import _flatten_axis 7 | 8 | def _subgroup_flips(ndim: int, axis: Optional[Tuple[int, ...]]=None) -> Tuple[Tuple[bool, ...], ...]: 9 | """Adapted from https://github.com/stardist/augmend/blob/main/augmend/transforms/affine.py not to depend on numpy 10 | iterate over the product subgroup (False,True) of given axis 11 | """ 12 | axis = _flatten_axis(ndim, axis) 13 | res = [False for _ in range(ndim)] 14 | for prod in itertools.product((False, True), repeat=len(axis)): 15 | for a, p in zip(axis, prod): 16 | res[a] = p 17 | yield tuple(res) 18 | 19 | def _fliprot_pts(pts: torch.Tensor, dims_to_flip: Tuple[int, ...], shape: Tuple[int, int, int], ndims: int) -> torch.Tensor: 20 | """Flip and rotate points accordingly to the flipping dimensions. 21 | 22 | Args: 23 | pts (torch.Tensor): points to be flipped and rotated. 24 | dims_to_flip (Tuple[int]): indices of the dimensions to be flipped. 25 | shape (Tuple[int, int]): shape of the image. 26 | 27 | Returns: 28 | torch.Tensor: flipped and rotated points. 29 | """ 30 | z, y, x = shape 31 | pts_fr = pts.clone() 32 | for dim in dims_to_flip: 33 | if dim == ndims-3: 34 | pts_fr[..., 0] = z - 1 - pts_fr[..., 0] 35 | if dim == ndims-2: 36 | pts_fr[..., 1] = y - 1 - pts_fr[..., 1] 37 | elif dim == ndims-1: 38 | pts_fr[..., 2] = x - 1 - pts_fr[..., 2] 39 | return pts_fr 40 | 41 | class FlipRot903D(BaseAugmentation): 42 | def __init__(self, probability: float=1.0) -> None: 43 | """Augmentation class for FlipRot90 augmentation. 44 | 45 | Args: 46 | probability (float): probability of applying the augmentation. Must be in [0, 1]. 47 | """ 48 | super().__init__(probability) 49 | 50 | def apply(self, img: torch.Tensor, pts: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: 51 | """Applies FlipRot90 augmentation to the given image and points. 52 | """ 53 | # Randomly choose the spatial axis/axes to flip 54 | combs = tuple(_subgroup_flips(img.ndim, axis=(-3, -2, -1))) 55 | idx = torch.randint(len(combs), (1,)).item() 56 | dims_to_flip = tuple(i for i, c in enumerate(combs[idx]) if c) 57 | 58 | # Return original image and points if no axis is flipped 59 | if len(dims_to_flip) == 0: 60 | return img, pts 61 | # Flip image and points 62 | 63 | img_fr = torch.flip(img, dims_to_flip) 64 | pts_fr = _fliprot_pts(pts, dims_to_flip, img.shape[-3:], ndims=img.ndim) 65 | return img_fr, pts_fr 66 | 67 | 68 | def __repr__(self) -> str: 69 | return f"FlipRot90(probability={self.probability})" -------------------------------------------------------------------------------- /spotiflow/augmentations/transforms3d/intensity_shift.py: -------------------------------------------------------------------------------- 1 | 2 | """Create a dummy IntensityScaleShift3D class simply for API coherence""" 3 | from ..transforms import IntensityScaleShift 4 | 5 | IntensityScaleShift3D = IntensityScaleShift 6 | -------------------------------------------------------------------------------- /spotiflow/augmentations/transforms3d/noise.py: -------------------------------------------------------------------------------- 1 | 2 | """Create dummy GaussianNoise3D, SaltAndPepperNoise3D classes simply for API coherence""" 3 | from ..transforms import GaussianNoise, SaltAndPepperNoise 4 | 5 | GaussianNoise3D = GaussianNoise 6 | SaltAndPepperNoise3D = SaltAndPepperNoise 7 | -------------------------------------------------------------------------------- /spotiflow/augmentations/transforms3d/rotation.py: -------------------------------------------------------------------------------- 1 | from numbers import Number 2 | from torchvision.transforms import InterpolationMode 3 | from typing import Literal, Optional, Tuple 4 | import math 5 | import torch 6 | import torchvision.transforms.functional as tvf 7 | 8 | from ..transforms.base import BaseAugmentation 9 | from ..transforms.utils import _filter_points_idx 10 | 11 | def _affine_rotation_matrix(phi:float, center:tuple[float]) -> torch.Tensor: 12 | cz, cy, cx = center 13 | translation_mat = torch.FloatTensor([ 14 | [1., 0., 0., -cz], 15 | [0., 1., 0., -cy], 16 | [0., 0., 1., -cx], 17 | [0., 0., 0., 1.] 18 | ]) 19 | translation_mat_inv = torch.FloatTensor([ 20 | [1., 0., 0., cz], 21 | [0., 1., 0., cy], 22 | [0., 0., 1., cx], 23 | [0., 0., 0., 1.] 24 | ]) 25 | 26 | si, co = math.sin(phi), math.cos(phi) 27 | 28 | rot_matrix_yx = torch.FloatTensor([ 29 | [1., 0., 0., 0.], 30 | [0., co, si, 0.], 31 | [0., -si, co, 0.], 32 | [0., 0., 0., 1.] 33 | ]) 34 | affine_mat = translation_mat_inv@rot_matrix_yx@translation_mat 35 | return affine_mat 36 | 37 | 38 | class RotationYX3D(BaseAugmentation): 39 | def __init__(self, order: Literal[0, 1]=1, angle: Optional[Number] = (-180,180), probability: float=1.0) -> None: 40 | """Augmentation class for random rotations (YX plane). 41 | 42 | Args: 43 | probability (float): probability of applying the augmentation. Must be in [0, 1]. 44 | order (Literal[0, 1]): order of interpolation. Use 0 for nearest neighbor, 1 for bilinear. 45 | angle (Optional[float]): +/- rotation angle (in degrees). If None, angle is randomly sampled from [-180,180]. Defaults to None. 46 | 47 | """ 48 | super().__init__(probability) 49 | self._order = int(order) 50 | if self._order == 0: 51 | self._interp_mode = InterpolationMode.NEAREST 52 | elif self._order == 1: 53 | self._interp_mode = InterpolationMode.BILINEAR 54 | else: 55 | raise ValueError("Order must be 0 or 1.") 56 | if angle is None: 57 | angle = (-180, 180) 58 | elif isinstance(angle, Number): 59 | angle = (-angle, angle) 60 | elif isinstance(angle, tuple) and len(angle) == 2 and all(isinstance(x, Number) for x in angle): 61 | angle = angle 62 | else: 63 | raise ValueError("angle must be either a number or a tuple of two numbers") 64 | self._angle = angle 65 | if all(phi == 0 for phi in angle): 66 | raise ValueError("Angle range cannot be (0, 0).") 67 | 68 | def apply(self, img: torch.Tensor, pts: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: 69 | # Generate random rotation angle 70 | phi_deg = self._sample_angle() 71 | phi_rad = phi_deg * math.pi / 180. 72 | 73 | # Rotate image 74 | assert img.ndim < 5 or img.shape[0] == 1, "Leading axis dimensionality should be 1 for volumetric data" 75 | 76 | if should_unsqueeze := (img.ndim == 5): 77 | img = img.squeeze(0) 78 | img_r = tvf.rotate(img, phi_deg, interpolation=self._interp_mode) 79 | 80 | if should_unsqueeze: 81 | img_r = img_r.unsqueeze(0) 82 | 83 | # Generate affine transformation matrix 84 | y, x = img.shape[-2:] 85 | center_z, center_y, center_x = 0, (y-1)/2, (x-1)/2 86 | 87 | affine_mat = _affine_rotation_matrix(-phi_rad, (center_z, center_y, center_x)).to(img.device) 88 | # Rotate points 89 | affine_coords = torch.cat([pts.float(), torch.ones((*pts.shape[:-1],1), device=img.device)], axis=-1) # Euclidean -> homogeneous coordinates 90 | pts_r = (affine_coords@affine_mat.T) 91 | pts_r = pts_r[..., :-1] # Homogeneous -> Euclidean coordinates 92 | 93 | idxs_in = _filter_points_idx(pts_r, img_r.shape[-3:]) 94 | 95 | return img_r, pts_r[idxs_in].view(*pts.shape[:-2], -1, pts.shape[-1]) 96 | 97 | def _sample_angle(self): 98 | return torch.FloatTensor(1).uniform_(*self._angle).item() 99 | 100 | 101 | def __repr__(self) -> str: 102 | return f"Rotation(angle={self._angle}, probability={self.probability})" -------------------------------------------------------------------------------- /spotiflow/augmentations/transforms3d/translation.py: -------------------------------------------------------------------------------- 1 | from torchvision.transforms import InterpolationMode 2 | from typing import Literal, Tuple 3 | import torch 4 | import torchvision.transforms.functional as tvf 5 | from numbers import Number 6 | from ..transforms.base import BaseAugmentation 7 | from ..transforms.utils import _filter_points_idx 8 | 9 | class TranslationYX3D(BaseAugmentation): 10 | def __init__(self, order: Literal[0, 1]=1, shift: Tuple[int, int]=(-5,5), probability: float=1.0) -> None: 11 | """Augmentation class for random translation in the YX plane. 12 | 13 | Args: 14 | probability (float): probability of applying the augmentation. Must be in [0, 1]. 15 | order (Literal[0, 1]): order of interpolation. Use 0 for nearest neighbor, 1 for bilinear. 16 | scaling_factor (Optional[Tuple[Number, Number]], optional): scaling factor range to be sampled. If None, scaling factor is randomly sampled from [0.5, 2.0]. Defaults to None. 17 | 18 | """ 19 | super().__init__(probability) 20 | self._order = int(order) 21 | if self._order == 0: 22 | self._interp_mode = InterpolationMode.NEAREST 23 | elif self._order == 1: 24 | self._interp_mode = InterpolationMode.BILINEAR 25 | else: 26 | raise ValueError("Order must be 0 or 1.") 27 | 28 | if isinstance(shift, Number): 29 | shift = (-shift, shift) 30 | elif all(isinstance(x, int) for x in shift) and len(shift) == 2: 31 | shift = shift 32 | else: 33 | raise ValueError("Shift must be either a single integer or a 2-length tuple of integers.") 34 | self._shift = shift 35 | if all(x == 0 for x in shift): 36 | raise ValueError("Shift range cannot be (0, 0).") 37 | 38 | 39 | def apply(self, img: torch.Tensor, pts: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: 40 | """Apply the augmentation to the input image and points 41 | 42 | Args: 43 | img (torch.Tensor): input image tensor of shape ((B), C, D, H, W) 44 | pts (torch.Tensor): input points tensor of shape (N, 2) or (N, 3) 45 | 46 | Returns: 47 | Tuple[torch.Tensor, torch.Tensor]: transformed image tensor of shape (C, H, W) and transformed points tensor of shape (N, 2) or (N, 3) 48 | """ 49 | # Sample shift 50 | sampled_translation = self._sample_shift(img.device) 51 | tr_x, tr_y = sampled_translation 52 | # Skip if no translation 53 | if tr_y == 0 and tr_x == 0: 54 | return img, pts 55 | 56 | if img.ndim == 5: 57 | img = img.squeeze(0) 58 | 59 | # Apply translation to image and points 60 | img_translated = tvf.affine(img, 61 | angle=0, 62 | translate=(tr_y, tr_x), # First axis arg is horizontal shift, second is vertical shift 63 | scale=1, 64 | shear=0, 65 | interpolation=self._interp_mode) 66 | 67 | 68 | pts_translated = pts + torch.concat((torch.zeros(1, device=img.device), sampled_translation), axis=0) 69 | idxs_in = _filter_points_idx(pts_translated, img_translated.shape[-3:]) 70 | pts_translated = pts_translated[idxs_in].view(*pts.shape[:-2], -1, pts.shape[-1]) 71 | return img_translated, pts_translated 72 | 73 | 74 | def _sample_shift(self, device: torch.device) -> torch.Tensor: 75 | """Sample a random YX shift from the shift range 76 | 77 | Returns: 78 | torch.Tensor: random shift 79 | """ 80 | return torch.randint(*self._shift, size=(2,), device=device) 81 | 82 | 83 | def __repr__(self) -> str: 84 | return f"Translation(shift={self._shift}, probability={self.probability})" -------------------------------------------------------------------------------- /spotiflow/data/__init__.py: -------------------------------------------------------------------------------- 1 | # ruff: noqa: F401 2 | from .spots import SpotsDataset, collate_spots 3 | from .spots3d import Spots3DDataset 4 | 5 | from ..sample_data import * -------------------------------------------------------------------------------- /spotiflow/data/spots3d.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | from typing import Callable, Dict, Literal, Optional, Sequence, Union 3 | from typing_extensions import Self 4 | from tqdm.auto import tqdm 5 | 6 | from skimage import io 7 | import logging 8 | import numpy as np 9 | import sys 10 | import torch 11 | import tifffile 12 | from itertools import chain 13 | import pandas as pd 14 | 15 | from .spots import SpotsDataset 16 | from .. import utils 17 | 18 | log = logging.getLogger(__name__) 19 | log.setLevel(logging.INFO) 20 | 21 | console_handler = logging.StreamHandler(sys.stdout) 22 | console_handler.setLevel(logging.INFO) 23 | formatter = logging.Formatter("%(levelname)s:%(name)s:%(message)s") 24 | console_handler.setFormatter(formatter) 25 | log.addHandler(console_handler) 26 | 27 | 28 | class Spots3DDataset(SpotsDataset): 29 | """Base spot dataset class instantiated with loaded images and centers. 30 | 31 | Example: 32 | 33 | from spotiflow.data import Spots3DDataset 34 | from spotiflow.augmentations import Pipeline, transforms3d 35 | 36 | augmenter = transforms3d.Crop3D(probability=1, size=(128, 128, 128)) 37 | data = Spots3DDataset(imgs, centers, augmenter=augmenter) 38 | 39 | """ 40 | def __getitem__(self, idx: int) -> Dict: 41 | img, centers = self.images[idx], self._centers[idx] 42 | 43 | if self._defer_normalization: 44 | img = self._normalizer(img) 45 | 46 | img = torch.from_numpy(img.copy()).unsqueeze(0) # Add B dimension 47 | centers = torch.from_numpy(centers.copy()).unsqueeze(0) # Add B dimension 48 | 49 | assert img.ndim in (4, 5) # Images should be in BCDWH or BDHW format 50 | if img.ndim == 4: 51 | img = img.unsqueeze(1) # Add C dimension 52 | 53 | img, centers = self.augmenter(img, centers) 54 | img, centers = img.squeeze(0), centers.squeeze(0) # Remove B dimension 55 | 56 | if self._compute_flow: 57 | flow = utils.points_to_flow3d( 58 | centers.numpy(), img.shape[-3:], sigma=self._sigma, grid=self._grid, 59 | ).transpose((3, 0, 1, 2)) 60 | flow = torch.from_numpy(flow).float() 61 | 62 | 63 | heatmap_lv0 = utils.points_to_prob3d( 64 | centers.numpy(), img.shape[-3:], mode=self._mode, sigma=self._sigma, grid=self._grid, 65 | ) 66 | 67 | # Build target at different resolution levels 68 | heatmaps = [ 69 | utils.multiscale_decimate(heatmap_lv0, ds, is_3d=True) 70 | for ds in self._downsample_factors 71 | ] 72 | 73 | # Cast to tensor and add channel dimension 74 | ret_obj = {"img": img.float(), "pts": centers.float()} 75 | 76 | if self._compute_flow: 77 | ret_obj.update({"flow": flow}) 78 | 79 | ret_obj.update( 80 | { 81 | f"heatmap_lv{lv}": torch.from_numpy(heatmap.copy()).unsqueeze(0) 82 | for lv, heatmap in enumerate(heatmaps) 83 | } 84 | ) 85 | return ret_obj 86 | 87 | #FIXME: duplicated code, should be gone when class labels are allowed in 3D 88 | @classmethod 89 | def from_folder( 90 | cls, 91 | path: Union[Path, str], 92 | augmenter: Optional[Callable] = None, 93 | downsample_factors: Sequence[int] = (1,), 94 | sigma: float = 1.0, 95 | image_extensions: Sequence[str] = ("tif", "tiff", "png", "jpg", "jpeg"), 96 | mode: str = "max", 97 | max_files: Optional[int] = None, 98 | compute_flow: bool = False, 99 | normalizer: Optional[Union[Callable, Literal["auto"]]] = "auto", 100 | random_state: Optional[int] = None, 101 | add_class_label: bool = False, 102 | grid: Optional[Sequence[int]] = None, 103 | ) -> Self: 104 | """Build dataset from folder. Images and centers are loaded from disk and normalized. 105 | 106 | Args: 107 | path (Union[Path, str]): Path to folder containing images (with given extensions) and centers. 108 | augmenter (Callable): Augmenter function. 109 | downsample_factors (Sequence[int], optional): Downsample factors. Defaults to (1,). 110 | sigma (float, optional): Sigma of Gaussian kernel to generate heatmap. Defaults to 1. 111 | image_extensions (Sequence[str], optional): Image extensions to look for in images. Defaults to ("tif", "tiff", "png", "jpg", "jpeg"). 112 | mode (str, optional): Mode of heatmap generation. Defaults to "max". 113 | max_files (Optional[int], optional): Maximum number of files to load. Defaults to None (all of them). 114 | compute_flow (bool, optional): Whether to compute flow from centers. Defaults to False. 115 | normalizer (Optional[Union[Callable, Literal["auto"]]], optional): Normalizer function. Defaults to "auto" (percentile-based normalization with p_min=1 and p_max=99.8). 116 | random_state (Optional[int], optional): Random state used when shuffling file names when "max_files" is not None. Defaults to None. 117 | 118 | Returns: 119 | Self: Dataset instance. 120 | """ 121 | assert not add_class_label, "add_class_label not supported for 3D datasets yet." 122 | if isinstance(path, str): 123 | path = Path(path) 124 | image_files = sorted(path.glob("*.tif")) 125 | center_files = sorted(path.glob("*.csv")) 126 | 127 | image_files = sorted( 128 | tuple(chain(*tuple(path.glob(f"*.{ext}") for ext in image_extensions))) 129 | ) 130 | 131 | if max_files is not None: 132 | rng = np.random.default_rng( 133 | random_state if random_state is not None else 42 134 | ) 135 | idx = np.arange(len(image_files)) 136 | rng.shuffle(idx) 137 | image_files = [image_files[i] for i in idx[:max_files]] 138 | center_files = [center_files[i] for i in idx[:max_files]] 139 | 140 | if not len(image_files) == len(center_files): 141 | raise ValueError( 142 | f"Different number of images and centers found! {len(image_files)} images, {len(center_files)} centers." 143 | ) 144 | 145 | 146 | images = [io.imread(img) for img in tqdm(image_files, desc="Loading images")] 147 | 148 | centers = [ 149 | utils.read_coords_csv3d(center).astype(np.float32) 150 | for center in tqdm(center_files, desc="Loading centers") 151 | ] 152 | 153 | return cls( 154 | images=images, 155 | centers=centers, 156 | augmenter=augmenter, 157 | downsample_factors=downsample_factors, 158 | sigma=sigma, 159 | mode=mode, 160 | compute_flow=compute_flow, 161 | image_files=image_files, 162 | normalizer=normalizer, 163 | add_class_label=add_class_label, 164 | grid=grid, 165 | ) 166 | 167 | 168 | def save(self, path, prefix="img_"): 169 | path = Path(path) 170 | path.mkdir(exist_ok=True, parents=True) 171 | for i, (x, y) in tqdm( 172 | enumerate(zip(self.images, self._centers)), desc="Saving", total=len(self) 173 | ): 174 | tifffile.imwrite(path / f"{prefix}{i:05d}.tif", x) 175 | pd.DataFrame(y, columns=("Z", "Y", "X")).to_csv(path / f"{prefix}{i:05d}.csv") 176 | -------------------------------------------------------------------------------- /spotiflow/lib/external/nanoflann/LICENSE.txt: -------------------------------------------------------------------------------- 1 | Software License Agreement (BSD License) 2 | 3 | Copyright 2008-2009 Marius Muja (mariusm@cs.ubc.ca). All rights reserved. 4 | Copyright 2008-2009 David G. Lowe (lowe@cs.ubc.ca). All rights reserved. 5 | Copyright 2011-2016 Jose Luis Blanco (joseluisblancoc@gmail.com). All rights reserved. 6 | 7 | THE BSD LICENSE 8 | 9 | Redistribution and use in source and binary forms, with or without 10 | modification, are permitted provided that the following conditions 11 | are met: 12 | 13 | 1. Redistributions of source code must retain the above copyright 14 | notice, this list of conditions and the following disclaimer. 15 | 2. Redistributions in binary form must reproduce the above copyright 16 | notice, this list of conditions and the following disclaimer in the 17 | documentation and/or other materials provided with the distribution. 18 | 19 | THIS SOFTWARE IS PROVIDED BY THE AUTHOR ``AS IS'' AND ANY EXPRESS OR 20 | IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES 21 | OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. 22 | IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY DIRECT, INDIRECT, 23 | INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT 24 | NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, 25 | DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY 26 | THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 27 | (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF 28 | THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 29 | -------------------------------------------------------------------------------- /spotiflow/lib/filters.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | #include 6 | #include 7 | #include 8 | 9 | #include "numpy/arrayobject.h" 10 | 11 | #ifdef _OPENMP 12 | #include 13 | #endif 14 | 15 | 16 | inline int clip(int n, int lower, int upper) 17 | { 18 | return std::max(lower, std::min(n, upper)); 19 | } 20 | 21 | 22 | void _max_filter_horiz(float *src, float * dst, const int kernel_size, const int Nx, const int Ny){ 23 | #pragma omp parallel for 24 | for (int i = 0; i < Ny; i++){ 25 | for (int j = 0; j < Nx; j++){ 26 | float max = -1e10; 27 | for (int k = -kernel_size; k <= kernel_size; k++){ 28 | const int j2 = clip(j + k, 0, Nx - 1); 29 | const float val = src[i * Nx + j2]; 30 | if (val > max) 31 | max = val; 32 | } 33 | dst[i * Nx + j] = max; 34 | } 35 | } 36 | } 37 | 38 | void _max_filter_vert(float *src, float * dst, const int kernel_size, const int Nx, const int Ny){ 39 | #pragma omp parallel for 40 | for (int j = 0; j < Nx; j++){ 41 | for (int i = 0; i < Ny; i++){ 42 | float max = -1e10; 43 | for (int k = -kernel_size; k <= kernel_size; k++){ 44 | const int i2 = clip(i + k, 0, Ny - 1); 45 | const float val = src[i2 * Nx + j]; 46 | if (val > max) 47 | max = val; 48 | } 49 | dst[i * Nx + j] = max; 50 | } 51 | } 52 | } 53 | 54 | 55 | void _transpose(float *src, float *dst, const int Nx, const int Ny){ 56 | #pragma omp parallel for 57 | for (int j = 0; j < Nx; j++){ 58 | for (int i = 0; i < Ny; i++){ 59 | dst[j * Ny + i] = src[i * Nx + j]; 60 | } 61 | } 62 | } 63 | static PyObject *c_maximum_filter_2d_float(PyObject *self, PyObject *args) 64 | { 65 | 66 | // #ifdef _OPENMP 67 | // const int nthreads = omp_get_max_threads(); 68 | // std::cout << "Using " << nthreads << " thread(s)" << std::endl; 69 | // #endif 70 | 71 | // #ifdef __APPLE__ 72 | // #pragma omp parallel for 73 | // #else 74 | // #pragma omp parallel for schedule(dynamic) 75 | // #endif 76 | // for (int i = 0; i < 32; i++){ 77 | // std::cout << "Hello from thread " << omp_get_thread_num() << std::endl; 78 | // }; 79 | 80 | PyArrayObject *src = NULL; 81 | int kernel_size; 82 | int max_threads; 83 | 84 | if (!PyArg_ParseTuple(args, "O!ii", &PyArray_Type, &src, &kernel_size, &max_threads)) 85 | return NULL; 86 | 87 | #ifdef _OPENMP 88 | omp_set_num_threads(max_threads); 89 | #endif 90 | 91 | npy_intp *dims = PyArray_DIMS(src); 92 | const long Ny = dims[0]; 93 | const long Nx = dims[1]; 94 | 95 | PyArrayObject *dst = (PyArrayObject *)PyArray_SimpleNew(2, dims, NPY_FLOAT32); 96 | 97 | float *src_data = (float *)PyArray_DATA(src); 98 | float *dst_data = (float *)PyArray_DATA(dst); 99 | float *tmp = new float[Nx * Ny]; 100 | 101 | _max_filter_horiz(src_data, tmp, kernel_size, Nx, Ny); 102 | _max_filter_vert(tmp, dst_data, kernel_size, Nx, Ny); 103 | 104 | // _max_filter_horiz(src_data, tmp, kernel_size, Nx, Ny); 105 | // _transpose(tmp, tmp2, Nx, Ny); 106 | // _max_filter_horiz(tmp2, tmp, kernel_size, Nx, Ny); 107 | // _transpose(tmp, dst_data, Nx, Ny); 108 | 109 | delete[] tmp; 110 | 111 | 112 | return PyArray_Return(dst); 113 | } 114 | 115 | //------------------------------------------------------------------------ 116 | 117 | static struct PyMethodDef methods[] = { 118 | {"c_maximum_filter_2d_float", c_maximum_filter_2d_float, METH_VARARGS, "point max filter"}, 119 | {NULL, NULL, 0, NULL} 120 | 121 | }; 122 | 123 | static struct PyModuleDef moduledef = { 124 | PyModuleDef_HEAD_INIT, 125 | "filters", 126 | NULL, 127 | -1, 128 | methods, 129 | NULL, NULL, NULL, NULL}; 130 | 131 | PyMODINIT_FUNC PyInit_filters(void) 132 | { 133 | import_array(); 134 | return PyModule_Create(&moduledef); 135 | } 136 | -------------------------------------------------------------------------------- /spotiflow/lib/filters3d.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | #include 6 | #include 7 | #include 8 | 9 | #include "numpy/arrayobject.h" 10 | 11 | #ifdef _OPENMP 12 | #include 13 | #endif 14 | 15 | 16 | inline int clip(int n, int lower, int upper) 17 | { 18 | return std::max(lower, std::min(n, upper)); 19 | } 20 | 21 | 22 | void _max_filter_horiz(float *src, float * dst, const int kernel_size, const int Nx, const int Ny, const int Nz){ 23 | #pragma omp parallel for 24 | for (int i = 0; i < Ny; i++){ 25 | for (int j = 0; j < Nx; j++){ 26 | for (int k = 0; k < Nz; k++) { 27 | float max = -1e10; 28 | for (int l = -kernel_size; l <= kernel_size; l++){ 29 | const int j2 = clip(j + l, 0, Nx - 1); 30 | const float val = src[k * Ny * Nx + i * Nx + j2]; 31 | if (val > max) 32 | max = val; 33 | } 34 | dst[k * Ny * Nx + i * Nx + j] = max; 35 | } 36 | } 37 | } 38 | } 39 | 40 | void _max_filter_vert(float *src, float * dst, const int kernel_size, const int Nx, const int Ny, const int Nz) { 41 | #pragma omp parallel for 42 | for (int j = 0; j < Nx; j++){ 43 | for (int i = 0; i < Ny; i++){ 44 | for (int k = 0; k < Nz; k++) { 45 | float max = -1e10; 46 | for (int l = -kernel_size; l <= kernel_size; l++){ 47 | const int i2 = clip(i + l, 0, Ny - 1); 48 | const float val = src[k * Ny * Nx + i2 * Nx + j]; 49 | if (val > max) 50 | max = val; 51 | } 52 | dst[k * Ny * Nx + i * Nx + j] = max; 53 | } 54 | } 55 | } 56 | } 57 | 58 | void _max_filter_depth(float *src, float * dst, const int kernel_size, const int Nx, const int Ny, const int Nz) { 59 | #pragma omp parallel for 60 | for (int k = 0; k < Nz; k++){ 61 | for (int i = 0; i < Ny; i++){ 62 | for (int j = 0; j < Nx; j++){ 63 | float max = -1e10; 64 | for (int l = -kernel_size; l <= kernel_size; l++){ 65 | const int k2 = clip(k + l, 0, Nz - 1); 66 | const float val = src[k2 * Ny * Nx + i * Nx + j]; 67 | if (val > max) 68 | max = val; 69 | } 70 | dst[k * Ny * Nx + i * Nx + j] = max; 71 | } 72 | } 73 | } 74 | } 75 | 76 | 77 | 78 | static PyObject *c_maximum_filter_3d_float(PyObject *self, PyObject *args) 79 | { 80 | PyArrayObject *src = NULL; 81 | int kernel_size; 82 | int max_threads; 83 | 84 | if (!PyArg_ParseTuple(args, "O!ii", &PyArray_Type, &src, &kernel_size, &max_threads)) 85 | return NULL; 86 | 87 | #ifdef _OPENMP 88 | omp_set_num_threads(max_threads); 89 | #endif 90 | 91 | npy_intp *dims = PyArray_DIMS(src); 92 | const long Nz = dims[0]; 93 | const long Ny = dims[1]; 94 | const long Nx = dims[2]; 95 | 96 | PyArrayObject *dst = (PyArrayObject *)PyArray_SimpleNew(3, dims, NPY_FLOAT32); 97 | 98 | float *src_data = (float *)PyArray_DATA(src); 99 | float *dst_data = (float *)PyArray_DATA(dst); 100 | 101 | float *tmp1 = new float[Nx * Ny * Nz]; 102 | float *tmp2 = new float[Nx * Ny * Nz]; 103 | 104 | _max_filter_horiz(src_data, tmp1, kernel_size, Nx, Ny, Nz); 105 | _max_filter_vert(tmp1, tmp2, kernel_size, Nx, Ny, Nz); 106 | _max_filter_depth(tmp2, dst_data, kernel_size, Nx, Ny, Nz); 107 | 108 | 109 | delete[] tmp1; 110 | delete[] tmp2; 111 | 112 | 113 | return PyArray_Return(dst); 114 | } 115 | 116 | //------------------------------------------------------------------------ 117 | 118 | static struct PyMethodDef methods[] = { 119 | {"c_maximum_filter_3d_float", c_maximum_filter_3d_float, METH_VARARGS, "point max filter 3d"}, 120 | {NULL, NULL, 0, NULL} 121 | 122 | }; 123 | 124 | static struct PyModuleDef moduledef = { 125 | PyModuleDef_HEAD_INIT, 126 | "filters3d", 127 | NULL, 128 | -1, 129 | methods, 130 | NULL, NULL, NULL, NULL}; 131 | 132 | PyMODINIT_FUNC PyInit_filters3d(void) 133 | { 134 | import_array(); 135 | return PyModule_Create(&moduledef); 136 | } 137 | -------------------------------------------------------------------------------- /spotiflow/lib/point_nms.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | #include 6 | #include 7 | #include 8 | 9 | #include "numpy/arrayobject.h" 10 | 11 | #ifdef _OPENMP 12 | #include 13 | #endif 14 | 15 | #include 16 | 17 | inline int clip(int n, int lower, int upper) 18 | { 19 | return std::max(lower, std::min(n, upper)); 20 | } 21 | 22 | template struct Point2D 23 | { 24 | T x, y; 25 | }; 26 | 27 | 28 | 29 | template 30 | struct PointCloud2D 31 | { 32 | 33 | std::vector> pts; 34 | // Must return the number of data points 35 | inline size_t kdtree_get_point_count() const { return pts.size(); } 36 | // Returns the dim'th component of the idx'th point in the class: 37 | // Since this is inlined and the "dim" argument is typically an immediate value, the 38 | // "if/else's" are actually solved at compile time. 39 | inline T kdtree_get_pt(const size_t idx, const size_t dim) const 40 | { 41 | if (dim == 0) 42 | return pts[idx].x; 43 | else 44 | return pts[idx].y; 45 | } 46 | // Optional bounding-box computation: return false to default to a standard bbox computation loop. 47 | // Return true if the BBOX was already computed by the class and returned in "bb" so it can be avoided to redo it again. 48 | // Look at bb.size() to find out the expected dimensionality (e.g. 2 or 3 for point clouds) 49 | template 50 | bool kdtree_get_bbox(BBOX & /* bb */) const { return false; } 51 | }; 52 | 53 | inline int round_to_int(float r) 54 | { 55 | return (int)lrint(r); 56 | } 57 | 58 | 59 | static PyObject *c_point_nms_2d(PyObject *self, PyObject *args) 60 | { 61 | 62 | PyArrayObject *points = NULL; 63 | float min_distance; 64 | 65 | if (!PyArg_ParseTuple(args, "O!f", &PyArray_Type, &points, &min_distance)) 66 | return NULL; 67 | 68 | npy_intp *dims = PyArray_DIMS(points); 69 | const float min_distance_squared = min_distance * min_distance; 70 | const long n_points = dims[0]; 71 | 72 | npy_intp dims_dst[1]; 73 | dims_dst[0] = n_points; 74 | PyArrayObject *dst = (PyArrayObject *)PyArray_SimpleNew(1, dims_dst, NPY_BOOL); 75 | 76 | 77 | // std::cout << "dims[0]: " << dims[0] << std::endl; 78 | // std::cout << "dims[1]: " << dims[1] << std::endl; 79 | // std::cout << "min_distance: " << min_distance << std::endl; 80 | 81 | 82 | // build kdtree 83 | 84 | PointCloud2D cloud; 85 | 86 | cloud.pts.resize(dims[0]); 87 | for (long i = 0; i < n_points; i++) 88 | { 89 | cloud.pts[i].y = *(float *)PyArray_GETPTR2(points, i, 0); 90 | cloud.pts[i].x = *(float *)PyArray_GETPTR2(points, i, 1); 91 | } 92 | 93 | // construct a kd-tree: 94 | typedef nanoflann::KDTreeSingleIndexAdaptor< 95 | nanoflann::L2_Simple_Adaptor>, 96 | PointCloud2D, 2> 97 | my_kd_tree_t; 98 | 99 | // build the index from points 100 | my_kd_tree_t index(2, cloud, nanoflann::KDTreeSingleIndexAdaptorParams(10 /* max leaf */)); 101 | 102 | index.buildIndex(); 103 | 104 | 105 | // #ifdef __APPLE__ 106 | // #pragma omp parallel for 107 | // #else 108 | // #pragma omp parallel for schedule(dynamic) 109 | // #endif 110 | 111 | // for (long k = 0; k < n_points; k++) 112 | // { 113 | // const float x = index.dataset.kdtree_get_pt(k,0); 114 | // const float y = index.dataset.kdtree_get_pt(k,1); 115 | 116 | // // if (k != i){ 117 | // std::cout << "Index: " << k << " (y,x) = "<< y << " " << x << std::endl; 118 | // } 119 | 120 | 121 | bool * suppressed = new bool[n_points]; 122 | for (long i = 0; i < n_points; i++) 123 | { 124 | suppressed[i] = false; 125 | } 126 | 127 | std::vector> results; 128 | float query_point[2]; 129 | nanoflann::SearchParams params; 130 | 131 | for (long i = 0; i < n_points; i++) 132 | { 133 | if (suppressed[i]){ 134 | continue; 135 | } 136 | query_point[0] = *(float *)PyArray_GETPTR2(points, i, 1); 137 | query_point[1] = *(float *)PyArray_GETPTR2(points, i, 0); 138 | std::vector> ret_matches; 139 | const size_t n_matches = index.radiusSearch(&query_point[0], min_distance_squared, ret_matches, params); 140 | 141 | // std::cout << "----- " << i << " (y,x) = " << query_point[0] << ", " << query_point[1] << " n_matches: " << n_matches << std::endl; 142 | 143 | for (long j = 0; j < n_matches; j++) 144 | { 145 | const long k = ret_matches[j].first; 146 | const float dist = ret_matches[j].second; 147 | if ((k != i) && (dist < min_distance_squared)) { 148 | // std::cout << "suppressed: " << k << " "<< *(float *)PyArray_GETPTR2(points, k, 0) << " " << *(float *)PyArray_GETPTR2(points, k, 1) << " distance " << dist << std::endl; 149 | suppressed[k] = true; 150 | } 151 | } 152 | 153 | 154 | } 155 | 156 | for (long i = 0; i < n_points; i++) 157 | { 158 | *(bool *)PyArray_GETPTR1(dst, i) = !suppressed[i]; 159 | } 160 | 161 | delete [] suppressed; 162 | return PyArray_Return(dst); 163 | } 164 | 165 | //------------------------------------------------------------------------ 166 | 167 | static struct PyMethodDef methods[] = { 168 | {"c_point_nms_2d", c_point_nms_2d, METH_VARARGS, "point nms"}, 169 | {NULL, NULL, 0, NULL} 170 | 171 | }; 172 | 173 | static struct PyModuleDef moduledef = { 174 | PyModuleDef_HEAD_INIT, 175 | "point_nms", 176 | NULL, 177 | -1, 178 | methods, 179 | NULL, NULL, NULL, NULL}; 180 | 181 | PyMODINIT_FUNC PyInit_point_nms(void) 182 | { 183 | import_array(); 184 | return PyModule_Create(&moduledef); 185 | } 186 | -------------------------------------------------------------------------------- /spotiflow/lib/point_nms3d.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | #include 6 | #include 7 | #include 8 | 9 | #include "numpy/arrayobject.h" 10 | 11 | #ifdef _OPENMP 12 | #include 13 | #endif 14 | 15 | #include 16 | 17 | inline int clip(int n, int lower, int upper) 18 | { 19 | return std::max(lower, std::min(n, upper)); 20 | } 21 | 22 | template struct Point3D 23 | { 24 | T x, y, z; 25 | }; 26 | 27 | 28 | 29 | template 30 | struct PointCloud3D 31 | { 32 | 33 | std::vector> pts; 34 | // Must return the number of data points 35 | inline size_t kdtree_get_point_count() const { return pts.size(); } 36 | // Returns the dim'th component of the idx'th point in the class: 37 | // Since this is inlined and the "dim" argument is typically an immediate value, the 38 | // "if/else's" are actually solved at compile time. 39 | inline T kdtree_get_pt(const size_t idx, const size_t dim) const 40 | { 41 | if (dim == 0) 42 | return pts[idx].x; 43 | else if (dim == 1) 44 | return pts[idx].y; 45 | else 46 | return pts[idx].z; 47 | } 48 | // Optional bounding-box computation: return false to default to a standard bbox computation loop. 49 | // Return true if the BBOX was already computed by the class and returned in "bb" so it can be avoided to redo it again. 50 | // Look at bb.size() to find out the expected dimensionality (e.g. 2 or 3 for point clouds) 51 | template 52 | bool kdtree_get_bbox(BBOX & /* bb */) const { return false; } 53 | }; 54 | 55 | inline int round_to_int(float r) 56 | { 57 | return (int)lrint(r); 58 | } 59 | 60 | 61 | static PyObject *c_point_nms_3d(PyObject *self, PyObject *args) 62 | { 63 | 64 | PyArrayObject *points = NULL; 65 | float min_distance; 66 | 67 | if (!PyArg_ParseTuple(args, "O!f", &PyArray_Type, &points, &min_distance)) 68 | return NULL; 69 | 70 | npy_intp *dims = PyArray_DIMS(points); 71 | const float min_distance_squared = min_distance * min_distance; 72 | const long n_points = dims[0]; 73 | 74 | npy_intp dims_dst[1]; 75 | dims_dst[0] = n_points; 76 | PyArrayObject *dst = (PyArrayObject *)PyArray_SimpleNew(1, dims_dst, NPY_BOOL); 77 | 78 | 79 | // std::cout << "dims[0]: " << dims[0] << std::endl; 80 | // std::cout << "dims[1]: " << dims[1] << std::endl; 81 | // std::cout << "min_distance: " << min_distance << std::endl; 82 | 83 | 84 | // build kdtree 85 | 86 | PointCloud3D cloud; 87 | 88 | cloud.pts.resize(dims[0]); 89 | for (long i = 0; i < n_points; i++) 90 | { 91 | cloud.pts[i].z = *(float *)PyArray_GETPTR2(points, i, 0); 92 | cloud.pts[i].y = *(float *)PyArray_GETPTR2(points, i, 1); 93 | cloud.pts[i].x = *(float *)PyArray_GETPTR2(points, i, 2); 94 | } 95 | 96 | // construct a kd-tree: 97 | typedef nanoflann::KDTreeSingleIndexAdaptor< 98 | nanoflann::L2_Simple_Adaptor>, 99 | PointCloud3D, 3> 100 | my_kd_tree_t; 101 | 102 | // build the index from points 103 | my_kd_tree_t index(3, cloud, nanoflann::KDTreeSingleIndexAdaptorParams(10 /* max leaf */)); 104 | 105 | index.buildIndex(); 106 | 107 | 108 | // #ifdef __APPLE__ 109 | // #pragma omp parallel for 110 | // #else 111 | // #pragma omp parallel for schedule(dynamic) 112 | // #endif 113 | 114 | // for (long k = 0; k < n_points; k++) 115 | // { 116 | // const float x = index.dataset.kdtree_get_pt(k,0); 117 | // const float y = index.dataset.kdtree_get_pt(k,1); 118 | 119 | // // if (k != i){ 120 | // std::cout << "Index: " << k << " (y,x) = "<< y << " " << x << std::endl; 121 | // } 122 | 123 | 124 | bool * suppressed = new bool[n_points]; 125 | for (long i = 0; i < n_points; i++) 126 | { 127 | suppressed[i] = false; 128 | } 129 | 130 | std::vector> results; 131 | float query_point[3]; 132 | nanoflann::SearchParams params; 133 | 134 | for (long i = 0; i < n_points; i++) 135 | { 136 | if (suppressed[i]){ 137 | continue; 138 | } 139 | query_point[0] = *(float *)PyArray_GETPTR2(points, i, 2); 140 | query_point[1] = *(float *)PyArray_GETPTR2(points, i, 1); 141 | query_point[2] = *(float *)PyArray_GETPTR2(points, i, 0); 142 | std::vector> ret_matches; 143 | const size_t n_matches = index.radiusSearch(&query_point[0], min_distance_squared, ret_matches, params); 144 | 145 | // std::cout << "----- " << i << " (y,x) = " << query_point[0] << ", " << query_point[1] << " n_matches: " << n_matches << std::endl; 146 | 147 | for (long j = 0; j < n_matches; j++) 148 | { 149 | const long k = ret_matches[j].first; 150 | const float dist = ret_matches[j].second; 151 | if ((k != i) && (dist < min_distance_squared)) { 152 | // std::cout << "suppressed: " << k << " "<< *(float *)PyArray_GETPTR2(points, k, 0) << " (y,x) = " << *(float *)PyArray_GETPTR2(points, k, 0) << " distance " << dist << std::endl; 153 | suppressed[k] = true; 154 | } 155 | } 156 | 157 | 158 | } 159 | 160 | for (long i = 0; i < n_points; i++) 161 | { 162 | *(bool *)PyArray_GETPTR1(dst, i) = !suppressed[i]; 163 | } 164 | 165 | delete [] suppressed; 166 | return PyArray_Return(dst); 167 | } 168 | 169 | //------------------------------------------------------------------------ 170 | 171 | static struct PyMethodDef methods[] = { 172 | {"c_point_nms_3d", c_point_nms_3d, METH_VARARGS, "point nms 3D"}, 173 | {NULL, NULL, 0, NULL} 174 | 175 | }; 176 | 177 | static struct PyModuleDef moduledef = { 178 | PyModuleDef_HEAD_INIT, 179 | "point_nms3d", 180 | NULL, 181 | -1, 182 | methods, 183 | NULL, NULL, NULL, NULL}; 184 | 185 | PyMODINIT_FUNC PyInit_point_nms3d(void) 186 | { 187 | import_array(); 188 | return PyModule_Create(&moduledef); 189 | } 190 | -------------------------------------------------------------------------------- /spotiflow/model/__init__.py: -------------------------------------------------------------------------------- 1 | from .backbones import ResNetBackbone, UNetBackbone 2 | from .config import SpotiflowModelConfig, SpotiflowTrainingConfig 3 | from .post import FeaturePyramidNetwork, MultiHeadProcessor 4 | from .spotiflow import Spotiflow 5 | from .trainer import CustomEarlyStopping, SpotiflowModelCheckpoint 6 | -------------------------------------------------------------------------------- /spotiflow/model/backbones/__init__.py: -------------------------------------------------------------------------------- 1 | from .resnet import ResNetBackbone 2 | from .unet import UNetBackbone -------------------------------------------------------------------------------- /spotiflow/model/bg_remover.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | class BackgroundRemover(nn.Module): 6 | """Remove background of an input image I_in by substracting a low-pass filtered 7 | of the image (I_low) from it. That is, I_out = I_in-I_low. The convolving filter is 8 | a large radius Gaussian kernel. Note that this is disabled by default. 9 | 10 | 11 | Args: 12 | n_channels (int, optional): number of channels in the input image. Defaults to 1. 13 | radius (int, optional): Gaussian filter radius (in px.). Defaults to 51. 14 | """ 15 | def __init__(self, n_channels: int=1, radius: int=51) -> None: 16 | super().__init__() 17 | assert n_channels == 1, "Only 1-channel images are currently supported" 18 | assert radius % 2 == 1, "Radius must be odd" 19 | self._half_radius = int(radius)//2 20 | h = torch.exp(-torch.linspace(-2,2,2*self._half_radius+1)**2).float() 21 | h /= torch.sum(h) 22 | 23 | self.register_buffer("wy", h.reshape((1,1,len(h),1))) 24 | self.register_buffer("wx", h.reshape((1,1,1,len(h)))) 25 | 26 | self.wy.requires_grad = False 27 | self.wx.requires_grad = False 28 | 29 | def forward(self, x: torch.Tensor) -> torch.Tensor: 30 | y = F.pad(x, pad=tuple(4*[self._half_radius]), mode="reflect") 31 | y = F.conv2d(y, weight=self.wy, stride=1, padding="valid") 32 | y = F.conv2d(y, weight=self.wx, stride=1, padding="valid") 33 | return x - y -------------------------------------------------------------------------------- /spotiflow/model/losses/__init__.py: -------------------------------------------------------------------------------- 1 | from .adaptive_wing import AdaptiveWingLoss -------------------------------------------------------------------------------- /spotiflow/model/losses/adaptive_wing.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import scipy.ndimage 4 | import numpy as np 5 | 6 | class AdaptiveWingLoss(nn.Module): 7 | """ 8 | Adaptive Wing loss (Wang et al., ICCV 2019) 9 | 10 | Args: 11 | theta (float): Threshold between linear and non linear loss. 12 | alpha (float): Used to adapt loss shape to input shape and make loss smooth at 0 (background). 13 | It needs to be slightly above 2 to maintain ideal properties. 14 | omega (float): Multiplicating factor for non linear part of the loss. 15 | epsilon (float): factor to avoid gradient explosion. It must not be too small 16 | reduction (str): function reduction applied to loss 17 | """ 18 | 19 | def __init__(self, theta=0.5, alpha=2.1, omega=14, epsilon=1, reduction="none"): 20 | super().__init__() 21 | self.theta = theta 22 | self.alpha = alpha 23 | self.omega = omega 24 | self.epsilon = epsilon 25 | self._reduction = reduction 26 | 27 | def forward(self, input, target): 28 | A = self.omega * (1 / (1 + torch.pow(self.theta / self.epsilon, 29 | self.alpha - target))) * \ 30 | (self.alpha - target) * torch.pow(self.theta / self.epsilon, 31 | self.alpha - target - 1) * (1 / self.epsilon) 32 | C = (self.theta * A - self.omega * torch.log(1 + torch.pow(self.theta / self.epsilon, self.alpha - target))) 33 | 34 | abs_diff = torch.abs(input - target) 35 | idx_small = abs_diff < self.theta 36 | idx_large = abs_diff >= self.theta 37 | 38 | loss = torch.zeros_like(input) 39 | 40 | loss[idx_small] = self.omega*torch.log(1+torch.pow(abs_diff[idx_small]/self.epsilon, self.alpha-target[idx_small])) 41 | loss[idx_large] = A[idx_large]*abs_diff[idx_large] - C[idx_large] 42 | if self._reduction == "none": 43 | return loss 44 | elif self._reduction == "sum": 45 | return loss.sum() 46 | elif self._reduction == "mean": 47 | return loss.mean() 48 | else: 49 | raise ValueError(f"Invalid reduction {self._reduction}") 50 | 51 | 52 | if __name__ == "__main__": 53 | inp = torch.clip(torch.randn(3, 1, 512, 512), 0, 1) 54 | target = torch.clip(torch.randn(3, 1, 512, 512), 0, 1) 55 | loss = AdaptiveWingLoss(reduction="sum") 56 | print(loss(inp, target)) 57 | -------------------------------------------------------------------------------- /spotiflow/model/pretrained.py: -------------------------------------------------------------------------------- 1 | import os 2 | from dataclasses import dataclass 3 | from pathlib import Path 4 | from typing import Optional 5 | 6 | from ..utils import NotRegisteredError 7 | from ..utils.get_file import get_file 8 | 9 | 10 | @dataclass 11 | class RegisteredModel: 12 | """ 13 | Dataclass to store information about a registered model. 14 | 15 | url: the url of the zipped model folder 16 | md5_hash: the md5 hash of the zipped model folder 17 | """ 18 | 19 | url: str 20 | md5_hash: str 21 | is_3d: bool 22 | 23 | def list_registered(): 24 | return list(_REGISTERED.keys()) 25 | 26 | 27 | def _default_cache_dir(): 28 | default_cache_dir = os.getenv("SPOTIFLOW_CACHE_DIR", None) 29 | if default_cache_dir is None: 30 | return Path("~").expanduser() / ".spotiflow" / "models" 31 | default_cache_dir = Path(default_cache_dir) 32 | if default_cache_dir.stem != "models": 33 | default_cache_dir = default_cache_dir / "models" 34 | return default_cache_dir 35 | 36 | def get_pretrained_model_path(name: str, cache_dir: Optional[Path] = None) -> Path: 37 | """ 38 | Downloads and extracts the pretrained model with the given name. 39 | The model is downloaded and extracted in the given cache_dir. If it is not given, it will be 40 | downloaded to ~/.spotiflow and extracted to ~/.spotiflow/models/name. 41 | """ 42 | if name not in _REGISTERED: 43 | raise NotRegisteredError(f"No pretrained model named {name} found. Available models: {','.join(sorted(list_registered()))}") 44 | model = _REGISTERED[name] 45 | path = Path( 46 | get_file( 47 | fname=f"{name}.zip", 48 | origin=model.url, 49 | file_hash=model.md5_hash, 50 | cache_dir=_default_cache_dir() if cache_dir is None else cache_dir, 51 | cache_subdir="", 52 | extract=True, 53 | ) 54 | ) 55 | return path.parent / name 56 | 57 | 58 | _REGISTERED = { 59 | "hybiss": RegisteredModel( 60 | url="https://drive.switch.ch/index.php/s/O4hqFSSGX6veLwa/download", 61 | md5_hash="254afa97c137d0bd74fd9c1827f0e323", 62 | is_3d=False, 63 | ), 64 | "general": RegisteredModel( 65 | url="https://drive.switch.ch/index.php/s/6AoTEgpIAeQMRvX/download", 66 | md5_hash="9dd31a36b737204e91b040515e3d899e", 67 | is_3d=False, 68 | ), 69 | "synth_complex": RegisteredModel( 70 | url="https://drive.switch.ch/index.php/s/CiCjNJaJzpVVD2M/download", 71 | md5_hash="d692fa21da47e4a50b4c52f49442508b", 72 | is_3d=False, 73 | ), 74 | "synth_3d": RegisteredModel( 75 | url="https://drive.switch.ch/index.php/s/VhDqgDoHc11yP6v/download", 76 | md5_hash="a031f1284590886fbae37dc583c0270d", 77 | is_3d=True, 78 | ), 79 | "smfish_3d": RegisteredModel( 80 | url="https://drive.switch.ch/index.php/s/Vym7tqiORZOP5Zt/download", 81 | md5_hash="c5ab30ba3b9ccb07b4c34442d1b5b615", 82 | is_3d=True, 83 | ) 84 | } 85 | -------------------------------------------------------------------------------- /spotiflow/sample_data/__init__.py: -------------------------------------------------------------------------------- 1 | from .datasets import load_dataset 2 | 3 | def __abspath(path): 4 | import os 5 | base_path = os.path.abspath(os.path.dirname(__file__)) 6 | return os.path.join(base_path, path) 7 | 8 | def test_image_hybiss_2d(): 9 | # TODO: proper docstring after paper is published :) 10 | """ Single test HybISS image from the Spotiflow paper (doi.org/10.1101/2024.02.01.578426) 11 | """ 12 | from tifffile import imread 13 | img = imread(__abspath("images/img_hybiss_2d.tif")) 14 | return img 15 | 16 | def test_image_terra_2d(): 17 | # TODO: proper docstring after paper is published :) 18 | """ Single test Terra frame from the Spotiflow paper (doi.org/10.1101/2024.02.01.578426) 19 | """ 20 | from tifffile import imread 21 | img = imread(__abspath("images/img_terra_2d.tif")) 22 | return img 23 | 24 | def test_timelapse_telomeres_2d(): 25 | # TODO: proper docstring after paper is published :) 26 | """Timelapse of telomeres from the Spotiflow paper (doi.org/10.1101/2024.02.01.578426) 27 | """ 28 | from tifffile import imread 29 | img = imread(__abspath("images/timelapse_telomeres_2d.tif")) 30 | return img 31 | 32 | def test_image_synth_3d(): 33 | # TODO: proper docstring after paper is published :) 34 | """ Single synthetic volumetric stack from the Spotiflow paper (doi.org/10.1101/2024.02.01.578426) 35 | """ 36 | from tifffile import imread 37 | img = imread(__abspath("images/img_synth_3d.tif")) 38 | return img 39 | -------------------------------------------------------------------------------- /spotiflow/sample_data/datasets.py: -------------------------------------------------------------------------------- 1 | import os 2 | from dataclasses import dataclass 3 | from pathlib import Path 4 | from typing import Optional, Union 5 | 6 | from ..utils import NotRegisteredError, get_data 7 | from ..utils.get_file import get_file 8 | 9 | 10 | @dataclass 11 | class RegisteredDataset: 12 | """ 13 | Dataclass to store information about a registered dataset. 14 | 15 | url: the url of the zipped dataset folder 16 | md5_hash: the md5 hash of the zipped dataset folder 17 | """ 18 | 19 | url: str 20 | md5_hash: str 21 | is_3d: bool 22 | 23 | 24 | def list_registered(): 25 | return list(_REGISTERED.keys()) 26 | 27 | def _default_cache_dir(): 28 | default_cache_dir = os.getenv("SPOTIFLOW_CACHE_DIR", None) 29 | if default_cache_dir is None: 30 | return Path("~").expanduser() / ".spotiflow" / "datasets" 31 | default_cache_dir = Path(default_cache_dir) 32 | if default_cache_dir.stem != "datasets": 33 | default_cache_dir = default_cache_dir / "datasets" 34 | return default_cache_dir 35 | 36 | 37 | def get_training_datasets_path(name: str, cache_dir: Optional[Path] = None) -> Path: 38 | """ 39 | Downloads and extracts the training dataset with the given name. 40 | The dataset is downloaded to the given cache_dir. If not given, it 41 | will be downloaded to ~/.spotiflow/datasets and extracted to ~/.spotiflow/datasets/name. 42 | """ 43 | if name not in _REGISTERED: 44 | raise NotRegisteredError(f"No training dataset named {name} found. Available datasets: {','.join(sorted(list_registered()))}") 45 | dataset = _REGISTERED[name] 46 | path = Path( 47 | get_file( 48 | fname=f"{name}.zip", 49 | origin=dataset.url, 50 | file_hash=dataset.md5_hash, 51 | cache_dir=_default_cache_dir() if cache_dir is None else cache_dir, 52 | cache_subdir="", 53 | extract=True, 54 | ) 55 | ) 56 | return path.parent / name 57 | 58 | def load_dataset(name: str, include_test: bool=False, cache_dir: Optional[Union[Path, str]] = None): 59 | """ 60 | Downloads and extracts the training dataset with the given name. 61 | The dataset is downloaded to ~/.spotiflow/datasets and extracted to ~/.spotiflow/datasets/name. 62 | 63 | Args: 64 | name (str): the name of the dataset to load. 65 | include_test (bool, optional): whether to include the test set in the returned data. Defaults to False. 66 | cache_dir (Optional[Union[Path, str]], optional): directory to cache the model. Defaults to None. If None, will use the default cache directory (given by the env var SPOTIFLOW_CACHE_DIR if set, otherwise ~/.spotiflow). 67 | """ 68 | if name not in _REGISTERED: 69 | raise NotRegisteredError(f"No training dataset named {name} found. Available datasets: {','.join(sorted(list_registered()))}") 70 | if cache_dir is not None and isinstance(cache_dir, str): 71 | cache_dir = Path(cache_dir) 72 | dataset = _REGISTERED[name] 73 | path = get_training_datasets_path(name, cache_dir=cache_dir) 74 | return get_data(path, include_test=include_test, is_3d=dataset.is_3d) 75 | 76 | 77 | _REGISTERED = { 78 | "synth_complex": RegisteredDataset( 79 | url="https://drive.switch.ch/index.php/s/aWdxUHULLkLLtqS/download", 80 | md5_hash="5f44b03603fe1733ac0f2340a69ae238", 81 | is_3d=False, 82 | ), 83 | "merfish": RegisteredDataset( 84 | url="https://drive.switch.ch/index.php/s/fsjOypn4ICpSF2w/download", 85 | md5_hash="17fcdbd12cc71630e4f49652ded837c7", 86 | is_3d=False, 87 | ), 88 | "synth_3d": RegisteredDataset( 89 | url="https://drive.switch.ch/index.php/s/EemgJK1Bno8c3n4/download", 90 | md5_hash="f1715515763288362ee3351caca02825", 91 | is_3d=True, 92 | ), 93 | } 94 | -------------------------------------------------------------------------------- /spotiflow/sample_data/images/img_hybiss_2d.tif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/weigertlab/spotiflow/8545981b0189c4faac0bfa6278376493fe6462b7/spotiflow/sample_data/images/img_hybiss_2d.tif -------------------------------------------------------------------------------- /spotiflow/sample_data/images/img_synth_3d.tif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/weigertlab/spotiflow/8545981b0189c4faac0bfa6278376493fe6462b7/spotiflow/sample_data/images/img_synth_3d.tif -------------------------------------------------------------------------------- /spotiflow/sample_data/images/img_terra_2d.tif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/weigertlab/spotiflow/8545981b0189c4faac0bfa6278376493fe6462b7/spotiflow/sample_data/images/img_terra_2d.tif -------------------------------------------------------------------------------- /spotiflow/sample_data/images/timelapse_telomeres_2d.tif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/weigertlab/spotiflow/8545981b0189c4faac0bfa6278376493fe6462b7/spotiflow/sample_data/images/timelapse_telomeres_2d.tif -------------------------------------------------------------------------------- /spotiflow/starfish/__init__.py: -------------------------------------------------------------------------------- 1 | from .spotiflow_wrapper import SpotiflowDetector -------------------------------------------------------------------------------- /spotiflow/test/test_model_saveload.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pytest 3 | import torch 4 | 5 | from spotiflow.model import Spotiflow, SpotiflowModelConfig 6 | from tempfile import TemporaryDirectory 7 | 8 | 9 | def equal_states_dict(state_dict_1, state_dict_2): 10 | if len(state_dict_1) != len(state_dict_2): 11 | return False 12 | 13 | for ((k_1, v_1), (k_2, v_2)) in zip( 14 | state_dict_1.items(), state_dict_2.items() 15 | ): 16 | if k_1 != k_2: 17 | return False 18 | 19 | if not torch.allclose(v_1, v_2): 20 | return False 21 | return True 22 | 23 | 24 | DEVICE_STR = "cpu" 25 | 26 | np.random.seed(42) 27 | torch.random.manual_seed(42) 28 | torch.backends.cudnn.benchmark = False 29 | torch.use_deterministic_algorithms(True) 30 | 31 | 32 | @pytest.mark.parametrize("in_channels", (1, 3)) 33 | def test_save_load( 34 | in_channels: int, 35 | ): 36 | model_config = SpotiflowModelConfig( 37 | levels=2, 38 | in_channels=in_channels, 39 | out_channels=1, 40 | background_remover=in_channels==1, 41 | ) 42 | model = Spotiflow(model_config) 43 | 44 | with TemporaryDirectory() as tmpdir: 45 | model.save(tmpdir, which="best", update_thresholds=True) 46 | model_same = Spotiflow.from_folder(tmpdir, map_location=DEVICE_STR) 47 | with pytest.raises(AssertionError): 48 | _ = Spotiflow.from_folder(tmpdir, map_location=DEVICE_STR, which="notexist") 49 | 50 | assert model_same.config == model.config, "Model configs are not equal" 51 | assert equal_states_dict(model.state_dict(), model_same.state_dict()), "Model states are not equal" 52 | 53 | -------------------------------------------------------------------------------- /spotiflow/test/test_prediction.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pytest 3 | import torch 4 | 5 | from spotiflow.model import Spotiflow, SpotiflowModelConfig 6 | from typing import Tuple 7 | 8 | 9 | AVAILABLE_DEVICES = [None, "auto", "cpu"] 10 | if torch.cuda.is_available(): 11 | AVAILABLE_DEVICES += ["cuda"] 12 | if torch.backends.mps.is_available(): 13 | AVAILABLE_DEVICES += ["mps"] 14 | 15 | AVAILABLE_DEVICES = tuple(AVAILABLE_DEVICES) 16 | 17 | np.random.seed(42) 18 | torch.random.manual_seed(42) 19 | torch.backends.cudnn.benchmark = False 20 | torch.use_deterministic_algorithms(True) 21 | 22 | 23 | @pytest.mark.parametrize("img_size", ((64, 64), (101, 241))) 24 | @pytest.mark.parametrize("n_tiles", ((1,1), (2,2), (2,2,1))) 25 | @pytest.mark.parametrize("scale", (1/3, 1., 2.)) 26 | @pytest.mark.parametrize("in_channels", (1, 3, 5)) 27 | @pytest.mark.parametrize("device", AVAILABLE_DEVICES) 28 | @pytest.mark.parametrize("subpix", (True, False)) 29 | def test_predict(img_size: Tuple[int, int], 30 | n_tiles: Tuple[int, ...], 31 | scale: float, 32 | in_channels: int, 33 | device: str, 34 | subpix: bool, 35 | ): 36 | img = np.random.randn(*img_size, in_channels).astype(np.float32) 37 | model_config = SpotiflowModelConfig( 38 | levels=2, 39 | in_channels=in_channels, 40 | out_channels=1, 41 | ) 42 | model = Spotiflow(model_config) 43 | orig_device = str(next(model.parameters()).device) 44 | 45 | wrong_scale, not_implemented = False, False 46 | if scale < 1: 47 | inv_scale = int(1/scale) 48 | wrong_scale = any(s % inv_scale != 0 for s in img_size) 49 | 50 | if scale != 1 and subpix: 51 | not_implemented = True 52 | 53 | 54 | if not wrong_scale and not not_implemented: 55 | pred, details = model.predict( 56 | img, 57 | n_tiles=n_tiles, 58 | scale=scale, 59 | verbose=False, 60 | device=device, 61 | subpix=subpix, 62 | ) 63 | if "cuda" in AVAILABLE_DEVICES and device in ("auto", "cuda"): 64 | assert str(next(model.parameters()).device).startswith("cuda") 65 | elif "mps" in AVAILABLE_DEVICES and device in ("auto", "mps"): 66 | assert str(next(model.parameters()).device).startswith("mps") 67 | elif device is None: 68 | assert str(next(model.parameters()).device).startswith(orig_device) 69 | else: 70 | assert str(next(model.parameters()).device).startswith("cpu") 71 | assert all(p==s for p, s in zip(details.heatmap.shape, img_size)), f"Wrong heatmap shape: expected {img_size}, got {details.heatmap.shape}" 72 | if pred.shape[0] > 0: 73 | assert pred.min() >= 0, "Point detection coordinates should be non-negative" 74 | assert pred.max() < img_size[0] or pred.max() < img_size[1], "Point detection coordinates should be within the image dimensions" 75 | elif wrong_scale and not not_implemented: 76 | with pytest.raises(AssertionError): 77 | pred, details = model.predict( 78 | img, 79 | n_tiles=n_tiles, 80 | scale=scale, 81 | verbose=False, 82 | device=device, 83 | subpix=subpix, 84 | ) 85 | elif not_implemented: 86 | with pytest.raises(NotImplementedError): 87 | pred, details = model.predict( 88 | img, 89 | n_tiles=n_tiles, 90 | scale=scale, 91 | verbose=False, 92 | device=device, 93 | subpix=subpix, 94 | ) 95 | 96 | 97 | if __name__ == "__main__": 98 | test_predict( 99 | img_size=(101, 241), 100 | n_tiles=(1,1,1), 101 | # n_tiles=(2, 2, 1), 102 | scale=1., 103 | # scale=1/3, 104 | in_channels=3, 105 | device=None, 106 | subpix=True, 107 | ) 108 | 109 | -------------------------------------------------------------------------------- /spotiflow/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import, print_function 2 | 3 | from .fitting import estimate_params 4 | from .matching import * 5 | from .parallel import tile_iterator 6 | from .peaks import * 7 | from .utils import * 8 | 9 | 10 | class NotRegisteredError(Exception): 11 | """Custom exception to be raised when a model or dataset is not registered. 12 | """ 13 | pass 14 | -------------------------------------------------------------------------------- /spotiflow/utils/matching.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import sys 3 | from types import SimpleNamespace 4 | from typing import Optional 5 | 6 | import numpy as np 7 | from scipy.optimize import linear_sum_assignment 8 | from scipy.spatial.distance import cdist 9 | 10 | log = logging.getLogger(__name__) 11 | log.setLevel(logging.INFO) 12 | 13 | console_handler = logging.StreamHandler(sys.stdout) 14 | console_handler.setLevel(logging.INFO) 15 | formatter = logging.Formatter("%(levelname)s:%(name)s:%(message)s") 16 | console_handler.setFormatter(formatter) 17 | log.addHandler(console_handler) 18 | 19 | 20 | def points_matching( 21 | p1, 22 | p2, 23 | cutoff_distance=3, 24 | eps=1e-8, 25 | class_label_p1: Optional[int] = None, 26 | class_label_p2: Optional[int] = None, 27 | ): 28 | """finds matching that minimizes sum of mean squared distances""" 29 | if class_label_p1 is not None: 30 | p1 = p1[p1[:, -1] == class_label_p1][:, :-1] 31 | elif class_label_p2 is not None: 32 | p2 = p2[p2[:, -1] == class_label_p2][:, :-1] 33 | 34 | assert p1.shape[1] == p2.shape[1], "Last dimensions of p1, p2 must match" 35 | 36 | if len(p1) == 0 or len(p2) == 0: 37 | D = np.zeros((0, 0)) 38 | else: 39 | D = cdist(p1, p2, metric="sqeuclidean") 40 | 41 | if D.size > 0: 42 | D[D > cutoff_distance**2] = 1e10 * (1 + D.max()) 43 | 44 | i, j = linear_sum_assignment(D) 45 | valid = D[i, j] <= cutoff_distance**2 46 | 47 | i, j = i[valid], j[valid] 48 | 49 | res = SimpleNamespace() 50 | 51 | tp = len(i) 52 | fp = len(p2) - tp 53 | fn = len(p1) - tp 54 | res.tp = tp 55 | res.fp = fp 56 | res.fn = fn 57 | 58 | # when there is no tp and we dont predict anything the accuracy should be 1 not 0 59 | tp_eps = tp + eps 60 | 61 | res.accuracy = tp_eps / (tp_eps + fp + fn) if tp_eps > 0 else 0 62 | res.precision = tp_eps / (tp_eps + fp) if tp_eps > 0 else 0 63 | res.recall = tp_eps / (tp_eps + fn) if tp_eps > 0 else 0 64 | res.f1 = (2 * tp_eps) / (2 * tp_eps + fp + fn) if tp_eps > 0 else 0 65 | res.dist = np.sqrt(D[i, j]) 66 | res.mean_dist = np.mean(res.dist) if len(res.dist) > 0 else 0 67 | 68 | pq_num = np.sum(cutoff_distance - res.dist) / cutoff_distance 69 | pq_den = tp_eps + fp / 2 + fn / 2 70 | res.panoptic_quality = pq_num / pq_den if tp_eps > 0 else 0 71 | 72 | res.false_negatives = tuple(set(range(len(p1))).difference(set(i))) 73 | res.false_positives = tuple(set(range(len(p2))).difference(set(j))) 74 | res.matched_pairs = tuple(zip(i, j)) 75 | return res 76 | 77 | 78 | def points_matching_dataset( 79 | p1s, 80 | p2s, 81 | cutoff_distance=3, 82 | by_image=True, 83 | eps=1e-8, 84 | class_label_p1: Optional[int] = None, 85 | class_label_p2: Optional[int] = None, 86 | ): 87 | """ 88 | by_image is True -> metrics are computed by image and then averaged 89 | by_image is False -> TP/FP/FN are aggregated and only then are metrics computed 90 | """ 91 | stats = tuple( 92 | points_matching( 93 | p1, 94 | p2, 95 | cutoff_distance=cutoff_distance, 96 | eps=eps, 97 | class_label_p1=class_label_p1, 98 | class_label_p2=class_label_p2, 99 | ) 100 | for p1, p2 in zip(p1s, p2s) 101 | ) 102 | 103 | if by_image: 104 | res = dict() 105 | for k, v in vars(stats[0]).items(): 106 | if np.isscalar(v): 107 | res[k] = np.mean([vars(s)[k] for s in stats]) 108 | return SimpleNamespace(**res) 109 | else: 110 | res = SimpleNamespace() 111 | res.tp = 0 112 | res.fp = 0 113 | res.fn = 0 114 | 115 | for s in stats: 116 | for k in ("tp", "fp", "fn"): 117 | setattr(res, k, getattr(res, k) + getattr(s, k)) 118 | 119 | dists = np.concatenate([s.dist for s in stats]) 120 | 121 | tp_eps = res.tp + eps 122 | res.accuracy = tp_eps / (tp_eps + res.fp + res.fn) if tp_eps > 0 else 0 123 | res.precision = tp_eps / (tp_eps + res.fp) if tp_eps > 0 else 0 124 | res.recall = tp_eps / (tp_eps + res.fn) if tp_eps > 0 else 0 125 | res.f1 = (2 * tp_eps) / (2 * tp_eps + res.fp + res.fn) if tp_eps > 0 else 0 126 | 127 | pq_num = np.sum(cutoff_distance - dists) / cutoff_distance 128 | pq_den = tp_eps + res.fp / 2 + res.fn / 2 129 | 130 | res.panoptic_quality = pq_num / pq_den if tp_eps > 0 else 0 131 | res.mean_dist = np.mean(dists) if len(dists) > 0 else 0 132 | return res 133 | -------------------------------------------------------------------------------- /tests/test_data.py: -------------------------------------------------------------------------------- 1 | from spotiflow.data import SpotsDataset 2 | from utils import example_data 3 | 4 | 5 | if __name__ == "__main__": 6 | 7 | imgs, points = example_data() 8 | 9 | data = SpotsDataset(imgs, points, downsample_factors=(1, 2)) 10 | 11 | out = data[0] 12 | -------------------------------------------------------------------------------- /tests/test_fit.py: -------------------------------------------------------------------------------- 1 | 2 | import numpy as np 3 | from spotiflow.utils import points_to_prob, estimate_params 4 | from spotiflow.model import Spotiflow 5 | 6 | def test_fit2d(return_results: bool=False): 7 | np.random.seed(42) 8 | 9 | n_points=64 10 | points = np.random.randint(20,245-20, (n_points,2)) 11 | sigmas = np.random.uniform(1, 5, n_points) 12 | 13 | x = points_to_prob(points, (256,256), sigma=sigmas, mode='sum') 14 | 15 | x += .2+0.05*np.random.normal(0, 1, x.shape) 16 | 17 | params = estimate_params(x, points) 18 | assert params is not None 19 | if return_results: 20 | return x, sigmas, params 21 | 22 | def test_fit3d(return_results: bool=False): 23 | 24 | np.random.seed(42) 25 | ndim=3 26 | 27 | n_points=64 28 | points = np.random.randint(20,128-20, (n_points,ndim)) 29 | sigmas = np.random.uniform(1, 5, n_points) 30 | 31 | x = points_to_prob(points, (128,)*ndim, sigma=sigmas, mode='sum') 32 | 33 | x += .2+0.05*np.random.normal(0, 1, x.shape) 34 | 35 | params = estimate_params(x, points) 36 | assert params is not None 37 | if return_results: 38 | return x, sigmas, params 39 | 40 | if __name__ == "__main__": 41 | 42 | 43 | x, sigmas, params = test_fit3d(return_results=True) 44 | 45 | 46 | model = Spotiflow.from_pretrained("synth_3d") 47 | 48 | img = np.clip(200*x, 0,255).astype(np.uint8) 49 | 50 | points, details = model.predict(img, fit_params=True) -------------------------------------------------------------------------------- /tests/test_model.py: -------------------------------------------------------------------------------- 1 | from spotiflow.model import Spotiflow 2 | from spotiflow.utils import points_matching, points_to_flow, flow_to_vector 3 | from spotiflow.sample_data import test_image_hybiss_2d as _sample_image 4 | from utils import example_data 5 | 6 | 7 | def test_model(): 8 | model = Spotiflow.from_pretrained("hybiss", map_location="cpu") 9 | x = _sample_image() 10 | points, _ = model.predict(x, device="cpu") 11 | 12 | 13 | def test_flow(): 14 | X, P = example_data(64, sigma=3, noise=0.01) 15 | 16 | model = Spotiflow.from_pretrained("hybiss") 17 | p1, details1 = model.predict(X[0], subpix=False) 18 | p2, details2 = model.predict(X[0], subpix=True) 19 | p3, details3 = model.predict(X[0], subpix=1) 20 | 21 | s1 = points_matching(P[0], p1) 22 | s2 = points_matching(P[0], p2) 23 | s3 = points_matching(P[0], p3) 24 | 25 | f0 = points_to_flow(P[0], sigma=model.config.sigma, shape=X[0].shape) 26 | flow_to_vector(f0, sigma=model.config.sigma) 27 | 28 | print(f"mean error w/o subpix {s1.mean_dist:.4f}") 29 | print(f"mean error with subpix {s2.mean_dist:.4f}") 30 | print(f"mean error with subpix 1 {s3.mean_dist:.4f}") 31 | 32 | 33 | if __name__ == "__main__": 34 | # test_model() 35 | test_flow() 36 | -------------------------------------------------------------------------------- /tests/test_peaks.py: -------------------------------------------------------------------------------- 1 | from types import SimpleNamespace 2 | from typing import Tuple 3 | 4 | import numpy as np 5 | import pytest 6 | from spotiflow.utils import points_from_heatmap_flow, points_to_flow, points_to_prob 7 | from spotiflow.utils.matching import points_matching 8 | 9 | 10 | def round_trip(points:np.ndarray, grid: Tuple[int], sigma: float=1.5): 11 | """ Test round trip of points through the flow field and back""" 12 | points = np.asarray(points) 13 | ndim = points.shape[1] 14 | shape = (2*int(points.max()),)*ndim 15 | # get heatmap and flow 16 | heatmap = points_to_prob(points, shape=shape, sigma=sigma, mode="max", grid=grid) 17 | flow = points_to_flow(points, shape, sigma=sigma, grid=grid) 18 | 19 | points_new = points_from_heatmap_flow(heatmap, flow, sigma=sigma, grid=grid) 20 | 21 | return SimpleNamespace(points=points, points_new=points_new, heatmap=heatmap, flow=flow, sigma=sigma) 22 | 23 | @pytest.mark.parametrize("ndim", (2, 3)) 24 | @pytest.mark.parametrize("grid", (None, (2, 2))) 25 | def test_prob_flow_roundtrip(ndim, grid, debug:bool=False): 26 | points = np.stack(np.meshgrid(*tuple(np.linspace(10,48,4) for _ in range(ndim)), indexing="ij"), axis=-1).reshape(-1, ndim) 27 | if ndim == 3 and grid is not None: 28 | grid = (*grid, 2) 29 | 30 | points = points + np.random.uniform(-1, 1, points.shape) 31 | if ndim == 2 and (grid is not None or (isinstance(grid, tuple) and any(g > 1 for g in grid))): 32 | with pytest.raises(NotImplementedError): 33 | _ = round_trip(points, grid=grid) 34 | else: 35 | out = round_trip(points, grid=grid) 36 | 37 | diff = points_matching(out.points, out.points_new).mean_dist 38 | 39 | print(f"Max diff: {diff:4f}") 40 | if debug: 41 | import napari 42 | v = napari.Viewer() 43 | v.add_points(out.points, name="points", size=5, face_color="green") 44 | v.add_points(out.points_new, name="points_new", size=5, face_color="red") 45 | else: 46 | assert diff < 1e-3, f"Max diff: {diff:4f}" 47 | 48 | 49 | 50 | if __name__ == "__main__": 51 | 52 | # works 53 | test_prob_flow_roundtrip(ndim=3, grid=None) 54 | 55 | # works 56 | test_prob_flow_roundtrip(ndim=3, grid=(2,2,2)) 57 | -------------------------------------------------------------------------------- /tests/test_training.py: -------------------------------------------------------------------------------- 1 | from spotiflow.data import SpotsDataset 2 | from spotiflow.model import SpotiflowModelConfig, SpotiflowTrainingConfig, Spotiflow 3 | import lightning.pytorch as pl 4 | import torch 5 | from utils import example_data 6 | 7 | 8 | if __name__ == "__main__": 9 | 10 | device = "cuda" if torch.cuda.is_available() else "mps" 11 | 12 | X, P = example_data(64, sigma=3, noise=0.01) 13 | Xv, Pv = example_data(4, sigma=3, noise=0.01) 14 | 15 | backbone = "unet" 16 | batch_norm = False 17 | batch_norm = True 18 | 19 | compute_flow = True 20 | # compute_flow = False 21 | 22 | n_levels = 3 23 | 24 | sigma = 1.5 25 | 26 | config = SpotiflowModelConfig( 27 | backbone=backbone, 28 | levels=n_levels, 29 | compute_flow=compute_flow, 30 | mode="slim", 31 | sigma=sigma, 32 | fmap_inc_factor=2, 33 | background_remover=False, 34 | batch_norm=batch_norm, 35 | ) 36 | 37 | train_config = SpotiflowTrainingConfig(num_epochs=100, pos_weight=10, batch_size=4) 38 | 39 | data = SpotsDataset( 40 | X, 41 | P, 42 | compute_flow=compute_flow, 43 | sigma=sigma, 44 | downsample_factors=(1, 2, 4, 8)[:n_levels], 45 | ) 46 | data_v = SpotsDataset( 47 | Xv, 48 | Pv, 49 | compute_flow=compute_flow, 50 | sigma=sigma, 51 | downsample_factors=(1, 2, 4, 8)[:n_levels], 52 | ) 53 | 54 | model = Spotiflow(config) 55 | 56 | print(f"Total params: {sum(p.numel() for p in model.parameters())}") 57 | 58 | logger = pl.loggers.TensorBoardLogger( 59 | save_dir="foo", 60 | name=f"{backbone}_batch_norm_{batch_norm}_flow_{compute_flow}", 61 | ) 62 | 63 | model.fit(data, data_v, train_config, device, logger=logger, deterministic=False) 64 | -------------------------------------------------------------------------------- /tests/test_training_simple.py: -------------------------------------------------------------------------------- 1 | from spotiflow.model import SpotiflowModelConfig, Spotiflow 2 | import torch 3 | from utils import example_data 4 | 5 | 6 | if __name__ == "__main__": 7 | 8 | device = "cuda" if torch.cuda.is_available() else "mps" 9 | 10 | X, P = example_data(64, sigma=3, noise=0.01) 11 | Xv, Pv = example_data(4, sigma=3, noise=0.01) 12 | 13 | config = SpotiflowModelConfig() 14 | 15 | model = Spotiflow(config) 16 | 17 | print(f"Total params: {sum(p.numel() for p in model.parameters())}") 18 | 19 | model.fit(X, P, Xv, Pv, save_dir="tmp", train_config={"num_epochs": 10}) 20 | -------------------------------------------------------------------------------- /tests/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from spotiflow.utils import points_to_prob 3 | 4 | def example_data(n_samples: int = 10, size: int = 256, noise: float = 0.2, sigma=2): 5 | def _single(): 6 | p = np.random.uniform(0, 200, (np.random.randint(18, 24), 2)) 7 | 8 | p = p.astype(int) + 0.5 9 | x = points_to_prob(p, (256, 256), sigma=sigma) 10 | x = x + noise * np.random.normal(0, 1, x.shape) 11 | return x, p 12 | 13 | X, P = tuple(zip(*tuple(_single() for _ in range(n_samples)))) 14 | return X, P 15 | -------------------------------------------------------------------------------- /tox.ini: -------------------------------------------------------------------------------- 1 | # For more information about tox, see https://tox.readthedocs.io/en/latest/ 2 | [tox] 3 | envlist = py{39,310}-{linux,macos} 4 | isolated_build=true 5 | 6 | [gh-actions] 7 | python = 8 | 3.9: py39 9 | 3.10: py310 10 | 11 | [gh-actions:env] 12 | PLATFORM = 13 | ubuntu-latest: linux 14 | macos-latest: macos 15 | 16 | [testenv] 17 | platform = 18 | macos: darwin 19 | linux: linux 20 | passenv = 21 | CI 22 | GITHUB_ACTIONS 23 | extras = 24 | testing 25 | commands = pytest -v --color=yes --cov=spotiflow --cov-report=xml --------------------------------------------------------------------------------