├── .gitattributes
├── .github
└── workflows
│ ├── build.yaml
│ └── lint.yaml
├── .gitignore
├── .pre-commit-config.yaml
├── .readthedocs.yaml
├── CONTRIBUTING.md
├── LICENSE
├── README.md
├── assets
├── Figure.png
├── WSI_intro.png
├── logo.ai
├── logo.svg
├── logo.webp
└── logo@3x.png
├── docs
├── clean_up.py
└── source
│ ├── _static
│ ├── cemm-logo.svg
│ ├── custom.css
│ ├── logo.svg
│ └── logo@3x.png
│ ├── _templates
│ ├── autosummary
│ │ └── class.rst
│ └── models.rst
│ ├── api
│ ├── .gitignore
│ ├── cv.rst
│ ├── index.rst
│ ├── io.rst
│ ├── models.rst
│ ├── plotting.rst
│ ├── preprocess.rst
│ ├── segmentation.rst
│ └── tools.rst
│ ├── conf.py
│ ├── contributing.rst
│ ├── contributors.rst
│ ├── index.rst
│ ├── installation.rst
│ └── tutorials
│ ├── .gitignore
│ ├── 00_intro_wsi.ipynb
│ ├── 01_preprocessing.ipynb
│ ├── 02_feature_extraction.ipynb
│ ├── 03_multiple_slides.ipynb
│ ├── 04_genomics_integration.ipynb
│ ├── 05_cell-segmentation.ipynb
│ ├── 05_training_models.ipynb
│ ├── 06_visualization.ipynb
│ ├── 07_zero-shot-learning.ipynb
│ ├── index.rst
│ └── matplotlibrc
├── pyproject.toml
├── src
└── lazyslide
│ ├── __init__.py
│ ├── __main__.py
│ ├── _const.py
│ ├── _utils.py
│ ├── cv
│ ├── __init__.py
│ ├── mask.py
│ ├── scorer
│ │ ├── __init__.py
│ │ ├── base.py
│ │ ├── focuslitenn
│ │ │ ├── __init__.py
│ │ │ ├── focuslitenn-2kernel-mse.pt
│ │ │ └── model.py
│ │ ├── module.py
│ │ └── utils.py
│ ├── tiles_merger.py
│ └── transform
│ │ ├── __init__.py
│ │ ├── compose.py
│ │ └── mods.py
│ ├── datasets
│ ├── __init__.py
│ └── _sample.py
│ ├── io
│ ├── __init__.py
│ └── _annotaiton.py
│ ├── metrics.py
│ ├── models
│ ├── __init__.py
│ ├── _model_registry.py
│ ├── _utils.py
│ ├── base.py
│ ├── model_registry.csv
│ ├── multimodal
│ │ ├── __init__.py
│ │ ├── conch.py
│ │ ├── plip.py
│ │ ├── prism.py
│ │ └── titan.py
│ ├── segmentation
│ │ ├── __init__.py
│ │ ├── cellpose.py
│ │ ├── grandqc.py
│ │ ├── instanseg.py
│ │ ├── nulite
│ │ │ ├── __init__.py
│ │ │ ├── api.py
│ │ │ └── model.py
│ │ ├── postprocess.py
│ │ ├── sam.py
│ │ └── smp.py
│ └── vision
│ │ ├── __init__.py
│ │ ├── base.py
│ │ ├── conch.py
│ │ ├── gigapath.py
│ │ ├── h_optimus.py
│ │ ├── hibou.py
│ │ ├── midnight.py
│ │ ├── phikon.py
│ │ ├── plip.py
│ │ ├── uni.py
│ │ └── virchow.py
│ ├── plotting
│ ├── __init__.py
│ ├── _api.py
│ └── _wsi_viewer.py
│ ├── preprocess
│ ├── __init__.py
│ ├── _graph.py
│ ├── _tiles.py
│ ├── _tissue.py
│ └── _utils.py
│ ├── py.typed
│ ├── segmentation
│ ├── __init__.py
│ ├── _artifact.py
│ ├── _cell.py
│ ├── _seg_runner.py
│ ├── _tissue.py
│ └── _zero_shot.py
│ └── tools
│ ├── __init__.py
│ ├── _domain.py
│ ├── _features.py
│ ├── _signatures.py
│ ├── _spatial_features.py
│ ├── _text_annotate.py
│ ├── _tissue_props.py
│ └── _zero_shot.py
├── tests
├── conftest.py
├── data
│ └── CMU-1-Small-Region.svs
├── test_cv.py
├── test_datasets.py
├── test_pp.py
└── test_tl.py
├── uv.lock
└── workflow
├── main.nf
└── modules
└── qc
└── main.nf
/.gitattributes:
--------------------------------------------------------------------------------
1 | *.ipynb linguist-detectable=false
--------------------------------------------------------------------------------
/.github/workflows/build.yaml:
--------------------------------------------------------------------------------
1 | name: CI
2 |
3 | on:
4 | push:
5 | paths:
6 | - '.github/**'
7 | - 'src/lazyslide/**'
8 | - 'tests/**'
9 | - 'pyproject.toml'
10 | pull_request:
11 | paths:
12 | - '.github/**'
13 | - 'src/lazyslide/**'
14 | - 'tests/**'
15 | - 'pyproject.toml'
16 |
17 | jobs:
18 | Test:
19 | strategy:
20 | fail-fast: false
21 | matrix:
22 | os: [ubuntu-latest, windows-latest, macos-latest]
23 | python-version: ['3.11', '3.12', '3.13']
24 |
25 | runs-on: ${{ matrix.os }}
26 | steps:
27 | - uses: actions/checkout@v4
28 | - name: Set up uv
29 | uses: astral-sh/setup-uv@v5
30 | with:
31 | python-version: ${{ matrix.python-version }}
32 | enable-cache: true
33 | cache-dependency-glob: "uv.lock"
34 | - name: Install project
35 | run: uv sync --dev
36 | - name: Tests
37 | run: |
38 | uv run task test-ci
39 |
40 | Upload_to_pypi:
41 | runs-on: ubuntu-latest
42 | permissions:
43 | id-token: write
44 | steps:
45 | - uses: actions/checkout@v4
46 | - name: Setup uv
47 | uses: astral-sh/setup-uv@v5
48 | with:
49 | python-version: '3.12'
50 | enable-cache: true
51 | cache-dependency-glob: "uv.lock"
52 |
53 | - name: Publish to test pypi
54 | run: |
55 | uv build
56 | uv publish --publish-url https://test.pypi.org/legacy/ || exit 0
57 |
58 | - name: Publish to pypi
59 | if: github.event_name == 'push' && startsWith(github.event.ref, 'refs/tags/v')
60 | run: |
61 | uv build
62 | uv publish
--------------------------------------------------------------------------------
/.github/workflows/lint.yaml:
--------------------------------------------------------------------------------
1 | name: Lint with Ruff
2 | on: [push, pull_request]
3 | jobs:
4 | ruff:
5 | runs-on: ubuntu-latest
6 | steps:
7 | - uses: actions/checkout@v4
8 | - uses: chartboost/ruff-action@v1
9 | with:
10 | src: "src/lazyslide"
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | share/python-wheels/
24 | *.egg-info/
25 | .installed.cfg
26 | *.egg
27 | MANIFEST
28 |
29 | # PyInstaller
30 | # Usually these files are written by a python script from a template
31 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
32 | *.manifest
33 | *.spec
34 |
35 | # Installer logs
36 | pip-log.txt
37 | pip-delete-this-directory.txt
38 |
39 | # Unit test / coverage reports
40 | htmlcov/
41 | .tox/
42 | .nox/
43 | .coverage
44 | .coverage.*
45 | .cache
46 | nosetests.xml
47 | coverage.xml
48 | *.cover
49 | *.py,cover
50 | .hypothesis/
51 | .pytest_cache/
52 | cover/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/build/
73 | docs/.jupyter_cache/
74 | docs/jupyter_execute
75 |
76 |
77 | # PyBuilder
78 | .pybuilder/
79 | target/
80 |
81 | # Jupyter Notebook
82 | .ipynb_checkpoints
83 |
84 | # IPython
85 | profile_default/
86 | ipython_config.py
87 |
88 | # pyenv
89 | # For a library or package, you might want to ignore these files since the code is
90 | # intended to run in multiple environments; otherwise, check them in:
91 | # .python-version
92 |
93 | # pipenv
94 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
95 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
96 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
97 | # install all needed dependencies.
98 | #Pipfile.lock
99 |
100 | # poetry
101 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
102 | # This is especially recommended for binary packages to ensure reproducibility, and is more
103 | # commonly ignored for libraries.
104 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
105 | #poetry.lock
106 |
107 | # pdm
108 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
109 | #pdm.lock
110 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
111 | # in version control.
112 | # https://pdm.fming.dev/#use-with-ide
113 | .pdm.toml
114 |
115 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
116 | __pypackages__/
117 |
118 | # Celery stuff
119 | celerybeat-schedule
120 | celerybeat.pid
121 |
122 | # SageMath parsed files
123 | *.sage.py
124 |
125 | # Environments
126 | .env
127 | .venv
128 | env/
129 | venv/
130 | ENV/
131 | env.bak/
132 | venv.bak/
133 |
134 | # Spyder project settings
135 | .spyderproject
136 | .spyproject
137 |
138 | # Rope project settings
139 | .ropeproject
140 |
141 | # mkdocs documentation
142 | /site
143 |
144 | # mypy
145 | .mypy_cache/
146 | .dmypy.json
147 | dmypy.json
148 |
149 | # Pyre type checker
150 | .pyre/
151 |
152 | # pytype static type analyzer
153 | .pytype/
154 |
155 | # Cython debug symbols
156 | cython_debug/
157 |
158 | # Ruff cache
159 | .ruff_cache/
160 |
161 | # PyCharm
162 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
163 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
164 | # and can be added to the global gitignore or merged into this file. For a more nuclear
165 | # option (not recommended) you can uncomment the following to ignore the entire idea folder.
166 | .idea/
167 | .DS_Store
168 |
169 | work/
170 | .nextflow.log*
171 | data/
172 | checkpoints/
173 | annotations/
174 | # JetBrains AI Agent
175 | .junie/
176 | pretrained_models/
177 | figures/
178 | sample_data/
--------------------------------------------------------------------------------
/.pre-commit-config.yaml:
--------------------------------------------------------------------------------
1 | repos:
2 | - repo: https://github.com/astral-sh/uv-pre-commit
3 | # uv version.
4 | rev: 0.5.26
5 | hooks:
6 | - id: uv-lock
7 | - repo: https://github.com/astral-sh/ruff-pre-commit
8 | # Ruff version.
9 | rev: v0.6.5
10 | hooks:
11 | # Run the linter.
12 | - id: ruff
13 | args: [ --fix ]
14 | # Run the formatter.
15 | - id: ruff-format
16 |
--------------------------------------------------------------------------------
/.readthedocs.yaml:
--------------------------------------------------------------------------------
1 | # See https://docs.readthedocs.io/en/stable/config-file/v2.html for details
2 | version: 2
3 |
4 | build:
5 | os: ubuntu-22.04
6 | tools:
7 | python: "3.12"
8 | jobs:
9 | post_install:
10 | - pip install uv
11 | - UV_PROJECT_ENVIRONMENT=$READTHEDOCS_VIRTUALENV_PATH uv sync --all-extras --group docs --link-mode=copy
12 |
13 | sphinx:
14 | configuration: docs/source/conf.py
15 |
--------------------------------------------------------------------------------
/CONTRIBUTING.md:
--------------------------------------------------------------------------------
1 | # Contributing
2 |
3 | We welcome contributions to this project.
4 |
5 |
6 | ## For core contributors
7 |
8 | Please do not commit directly to the `main` branch.
9 | Instead, create a new branch for your changes and submit a pull request.
10 |
11 | ### How to set up your development environment
12 |
13 | 1. Clone the repository
14 |
15 | ```bash
16 | git clone https://github.com/rendeirolab/LazySlide.git
17 | # or
18 | gh repo clone rendeirolab/LazySlide
19 | ```
20 |
21 | 2. Checkout a new branch
22 |
23 | ```bash
24 | git checkout -b my-new-branch
25 | ```
26 |
27 | 3. We use [uv](https://docs.astral.sh/uv/) to manage our development environment.
28 |
29 | ```bash
30 | uv lock
31 | uv run pre-commit install
32 | ```
33 |
34 | We use [pre-commit](https://pre-commit.com/) to run code formatting and linting checks before each commit.
35 |
36 | 4. Start a IPython/Jupyter session
37 |
38 | ```bash
39 | uv run --with ipython ipython
40 | # or
41 | uv run --with jupyter jupyter lab
42 | ```
43 |
44 | 5. Make your changes
45 |
46 | 6. (If needed) Add a test case and then run the tests
47 |
48 | ```bash
49 | uv run task test
50 | ```
51 |
52 | 7. (If needed) Update the documentation
53 |
54 | To build the documentation, use:
55 |
56 | ```bash
57 | # Build doc with cache
58 | uv run task doc-build
59 | # Fresh build
60 | uv run task doc-clean-build
61 | ```
62 |
63 | To serve the documentation, use:
64 |
65 | ```bash
66 | uv run task doc-serve
67 | ```
68 |
69 | This will start a local server at [http://localhost:8000](http://localhost:8000).
70 |
71 | 8. Commit your changes and push them to your fork
72 |
73 | 9. Submit a pull request
74 |
75 |
76 | ## How to report bugs
77 |
78 |
79 | ## How to suggest enhancements
80 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2025 Rendeiro Lab
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # LazySlide
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 | Accessible and interoperable whole slide image analysis
10 |
11 |
12 | [](https://lazyslide.readthedocs.io/en/stable)
13 | 
14 | 
15 | 
16 |
17 | [Installation](https://lazyslide.readthedocs.io/en/stable/installation.html) |
18 | [Tutorials](https://lazyslide.readthedocs.io/en/stable/tutorials/index.html)
19 |
20 | LazySlide is a Python framework for whole slide image (WSI) analysis, designed to integrate seamlessly with the scverse
21 | ecosystem.
22 |
23 | By adopting standardized data structures and APIs familiar to the single-cell and genomics community, LazySlide enables
24 | intuitive, interoperable, and reproducible workflows for histological analysis. It supports a range of tasks from basic
25 | preprocessing to advanced deep learning applications, facilitating the integration of histopathology into modern
26 | computational biology.
27 |
28 | ## Key features
29 |
30 | - **Interoperability**: Built on top of SpatialData, ensuring compatibility with scverse tools like scanpy, anndata, and
31 | squidpy.
32 | - **Accessibility**: User-friendly APIs that cater to both beginners and experts in digital pathology.
33 | - **Scalability**: Efficient handling of large WSIs, enabling high-throughput analyses.
34 | - **Multimodal integration**: Combine histological data with transcriptomics, genomics, and textual annotations.
35 | - **Foundation model support**: Native integration with state-of-the-art models (e.g., UNI, CONCH, Gigapath, Virchow)
36 | for tasks like zero-shot classification and captioning.
37 | - **Deep learning ready**: Provides PyTorch dataloaders for seamless integration into machine learning pipelines.
38 |
39 | 
40 |
41 | ## Documentation
42 |
43 | Comprehensive documentation is available at [https://lazyslide.readthedocs.io](https://lazyslide.readthedocs.io). It
44 | includes tutorials, API references, and guides to help you get started.
45 |
46 | ## Installation
47 |
48 | Lazyslide is available at the [PyPI index](https://pypi.org/project/lazyslide). This means that you can get it with your
49 | favourite package manager:
50 |
51 | - `pip install lazyslide` or
52 | - `uv add lazyslide`
53 |
54 | For full instructions, please refer to
55 | the [Installation page in the documentation](https://lazyslide.readthedocs.io/en/stable/installation.html).
56 |
57 | ## Quick start
58 |
59 | With a few lines of code, you can quickly run process a whole slide image (tissue segmentation, tesselation, feature
60 | extraction):
61 |
62 | ```python
63 | import lazyslide as zs
64 |
65 | wsi = zs.datasets.sample()
66 |
67 | # Pipeline
68 | zs.pp.find_tissues(wsi)
69 | zs.pp.tile_tissues(wsi, tile_px=256, mpp=0.5)
70 | zs.tl.feature_extraction(wsi, model='resnet50')
71 |
72 | # Access the features
73 | features = wsi['resnet50_tiles']
74 | ```
75 |
76 | ## Contributing
77 |
78 | We welcome contributions from the community. Please refer to our [contributing guide](CONTRIBUTING.md) for guidelines on
79 | how to contribute.
80 |
81 | ## Licence
82 |
83 | LazySlide is released under the [MIT License](LICENCE).
84 |
--------------------------------------------------------------------------------
/assets/Figure.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rendeirolab/LazySlide/f39634cc994b3098b0933075b9d25ecd99b9014e/assets/Figure.png
--------------------------------------------------------------------------------
/assets/WSI_intro.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rendeirolab/LazySlide/f39634cc994b3098b0933075b9d25ecd99b9014e/assets/WSI_intro.png
--------------------------------------------------------------------------------
/assets/logo.ai:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rendeirolab/LazySlide/f39634cc994b3098b0933075b9d25ecd99b9014e/assets/logo.ai
--------------------------------------------------------------------------------
/assets/logo.webp:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rendeirolab/LazySlide/f39634cc994b3098b0933075b9d25ecd99b9014e/assets/logo.webp
--------------------------------------------------------------------------------
/assets/logo@3x.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rendeirolab/LazySlide/f39634cc994b3098b0933075b9d25ecd99b9014e/assets/logo@3x.png
--------------------------------------------------------------------------------
/docs/clean_up.py:
--------------------------------------------------------------------------------
1 | from pathlib import Path
2 | import shutil
3 |
4 | root = Path(__file__).parent # ./docs
5 |
6 | target_folders = [
7 | root / "build",
8 | root / "source" / "api" / "_autogen",
9 | root / "jupyter_execute",
10 | ]
11 |
12 |
13 | if __name__ == "__main__":
14 | for folder in target_folders:
15 | if folder.exists():
16 | shutil.rmtree(folder)
17 |
--------------------------------------------------------------------------------
/docs/source/_static/custom.css:
--------------------------------------------------------------------------------
1 | html[data-theme="light"] {
2 | --pst-color-primary: #C68FE6;
3 | --pst-color-secondary: #FFCD05;
4 | --pst-color-link: #C68FE6;
5 | --pst-color-inline-code: rgb(96, 141, 130);
6 | }
7 |
8 | html[data-theme="dark"] {
9 | --pst-color-primary: #C68FE6;
10 | --pst-color-secondary: #FFCD05;
11 | }
12 |
13 | /* Change the highlight color, increase contrast*/
14 | html[data-theme="light"] .highlight .hll {
15 | background-color: #fcf427;
16 | }
17 |
18 | .cell_output img {
19 | height: auto !important;
20 | }
21 |
22 | .navbar-brand .logo__image {
23 | height: 150px;
24 | }
--------------------------------------------------------------------------------
/docs/source/_static/logo@3x.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rendeirolab/LazySlide/f39634cc994b3098b0933075b9d25ecd99b9014e/docs/source/_static/logo@3x.png
--------------------------------------------------------------------------------
/docs/source/_templates/autosummary/class.rst:
--------------------------------------------------------------------------------
1 | {{ fullname | escape | underline}}
2 |
3 | .. currentmodule:: {{ module }}
4 |
5 | .. autoclass:: {{ objname }}
6 | :special-members: __call__
--------------------------------------------------------------------------------
/docs/source/_templates/models.rst:
--------------------------------------------------------------------------------
1 | {{ fullname | escape | underline }}
2 |
3 |
4 | .. currentmodule:: {{ module }}
5 |
6 |
7 | {% if objtype in ['class'] %}
8 | .. auto{{ objtype }}:: {{ objname }}
9 | :show-inheritance:
10 | :special-members: __call__
11 |
12 | {% else %}
13 | .. auto{{ objtype }}:: {{ objname }}
14 |
15 | {% endif %}
--------------------------------------------------------------------------------
/docs/source/api/.gitignore:
--------------------------------------------------------------------------------
1 | _autogen/
--------------------------------------------------------------------------------
/docs/source/api/cv.rst:
--------------------------------------------------------------------------------
1 | Computer vision utilities
2 | -------------------------
3 |
4 | Scorers
5 | ~~~~~~~
6 |
7 | .. currentmodule:: lazyslide.cv.scorer
8 |
9 | .. autosummary::
10 | :toctree: _autogen
11 | :nosignatures:
12 |
13 | FocusLite
14 | Contrast
15 | SplitRGB
16 | Redness
17 | Brightness
18 | ScorerBase
19 |
20 |
21 | Mask
22 | ~~~~
23 |
24 | .. currentmodule:: lazyslide.cv
25 |
26 | .. autosummary::
27 | :toctree: _autogen
28 | :nosignatures:
29 |
30 | Mask
31 | BinaryMask
32 | MultiLabelMask
33 | MultiClassMask
34 |
35 |
36 | Polygon merging
37 | ~~~~~~~~~~~~~~~
38 |
39 | .. currentmodule:: lazyslide.cv
40 |
41 | .. autosummary::
42 | :toctree: _autogen
43 | :nosignatures:
44 |
45 | merge_polygons
46 |
--------------------------------------------------------------------------------
/docs/source/api/index.rst:
--------------------------------------------------------------------------------
1 | API Reference
2 | =============
3 |
4 | .. toctree::
5 | :maxdepth: 1
6 | :hidden:
7 |
8 | preprocess
9 | tools
10 | plotting
11 | segmentation
12 | io
13 | models
14 | cv
15 |
16 |
17 | .. grid:: 1 2 2 2
18 | :gutter: 2
19 |
20 | .. grid-item-card:: Preprocessing
21 | :link: preprocess
22 | :link-type: doc
23 |
24 | Preprocessing functions for WSI
25 |
26 | .. grid-item-card:: Tools
27 | :link: tools
28 | :link-type: doc
29 |
30 | Tools for WSI analysis
31 |
32 | .. grid-item-card:: Plotting
33 | :link: plotting
34 | :link-type: doc
35 |
36 | Plotting functions for WSI
37 |
38 | .. grid-item-card:: Segmentation
39 | :link: segmentation
40 | :link-type: doc
41 |
42 | Segmentation tasks on WSI
43 |
44 | .. grid-item-card:: Models
45 | :link: models
46 | :link-type: doc
47 |
48 | Models for WSI analysis
49 |
50 | .. grid-item-card:: Computer Vision
51 | :link: cv
52 | :link-type: doc
53 |
54 | Computer Vision utilities for WSI
55 |
56 | .. grid-item-card:: IO
57 | :link: io
58 | :link-type: doc
59 |
60 | IO for annotations
61 |
--------------------------------------------------------------------------------
/docs/source/api/io.rst:
--------------------------------------------------------------------------------
1 | IO :code:`io`
2 | -------------
3 |
4 | .. currentmodule:: lazyslide
5 |
6 | .. autosummary::
7 | :toctree: _autogen
8 | :nosignatures:
9 |
10 | io.load_annotations
11 | io.export_annotations
--------------------------------------------------------------------------------
/docs/source/api/models.rst:
--------------------------------------------------------------------------------
1 | Models
2 | ------
3 |
4 | .. currentmodule:: lazyslide.models
5 |
6 | .. autosummary::
7 | :toctree: _autogen
8 | :nosignatures:
9 |
10 | list_models
11 |
12 |
13 | Vision Models
14 | ~~~~~~~~~~~~~
15 |
16 | .. currentmodule:: lazyslide.models.vision
17 |
18 | .. autosummary::
19 | :toctree: _autogen
20 | :nosignatures:
21 |
22 | UNI
23 | UNI2
24 | GigaPath
25 | PLIPVision
26 | CONCHVision
27 | Virchow
28 | Virchow2
29 | Phikon
30 | PhikonV2
31 | HOptimus0
32 | HOptimus1
33 | H0Mini
34 |
35 |
36 | Image-Text Models
37 | ~~~~~~~~~~~~~~~~~~~
38 |
39 | .. currentmodule:: lazyslide.models.multimodal
40 |
41 | .. autosummary::
42 | :toctree: _autogen
43 | :nosignatures:
44 |
45 | PLIP
46 | CONCH
47 | Titan
48 | Prism
49 |
50 |
51 | Segmentation Models
52 | ~~~~~~~~~~~~~~~~~~~
53 |
54 | .. currentmodule:: lazyslide.models.segmentation
55 |
56 | .. autosummary::
57 | :toctree: _autogen
58 | :nosignatures:
59 |
60 | Instanseg
61 | NuLite
62 | GrandQCTissue
63 | GrandQCArtifact
64 | SMPBase
65 |
66 | Base Models
67 | ~~~~~~~~~~~
68 |
69 | .. currentmodule:: lazyslide.models.base
70 |
71 | .. autosummary::
72 | :toctree: _autogen
73 | :nosignatures:
74 |
75 | ModelBase
76 | ImageModel
77 | ImageTextModel
78 | SegmentationModel
79 | SlideEncoderModel
80 | TimmModel
81 |
--------------------------------------------------------------------------------
/docs/source/api/plotting.rst:
--------------------------------------------------------------------------------
1 | Plotting: :code:`pl`
2 | --------------------
3 |
4 | .. currentmodule:: lazyslide
5 |
6 | .. autosummary::
7 | :toctree: _autogen
8 | :nosignatures:
9 |
10 | pl.tissue
11 | pl.tiles
12 | pl.annotations
13 | pl.WSIViewer
14 |
--------------------------------------------------------------------------------
/docs/source/api/preprocess.rst:
--------------------------------------------------------------------------------
1 | Preprocessing: :code:`pp`
2 | -------------------------
3 |
4 | .. currentmodule:: lazyslide
5 |
6 | .. autosummary::
7 | :toctree: _autogen
8 | :nosignatures:
9 |
10 | pp.find_tissues
11 | pp.score_tissues
12 | pp.tile_tissues
13 | pp.score_tiles
14 | pp.tile_graph
15 |
--------------------------------------------------------------------------------
/docs/source/api/segmentation.rst:
--------------------------------------------------------------------------------
1 | Segmentation :code:`seg`
2 | -------------------------
3 |
4 | .. currentmodule:: lazyslide
5 |
6 | .. autosummary::
7 | :toctree: _autogen
8 | :nosignatures:
9 |
10 | seg.cells
11 | seg.nulite
12 | seg.semantic
13 | seg.tissue
14 | seg.artifact
15 |
--------------------------------------------------------------------------------
/docs/source/api/tools.rst:
--------------------------------------------------------------------------------
1 | Tools: :code:`tl`
2 | -----------------
3 |
4 |
5 | Image Embedding
6 | ~~~~~~~~~~~~~~~
7 |
8 | .. currentmodule:: lazyslide
9 |
10 | .. autosummary::
11 | :toctree: _autogen
12 | :nosignatures:
13 |
14 | tl.feature_extraction
15 | tl.feature_aggregation
16 | tl.spatial_features
17 | tl.feature_utag
18 |
19 | Tissue Geometry
20 | ~~~~~~~~~~~~~~~
21 |
22 | .. currentmodule:: lazyslide
23 |
24 | .. autosummary::
25 | :toctree: _autogen
26 | :nosignatures:
27 |
28 | tl.tissue_props
29 |
30 |
31 | Tissue Spatial Domain
32 | ~~~~~~~~~~~~~~~~~~~~~
33 |
34 | .. currentmodule:: lazyslide
35 |
36 | .. autosummary::
37 | :toctree: _autogen
38 | :nosignatures:
39 |
40 | tl.spatial_domain
41 | tl.tile_shaper
42 |
43 |
44 | Multi-Modal Analysis
45 | ~~~~~~~~~~~~~~~~~~~~
46 |
47 | .. currentmodule:: lazyslide
48 |
49 | .. autosummary::
50 | :toctree: _autogen
51 | :nosignatures:
52 |
53 | tl.text_embedding
54 | tl.text_image_similarity
55 |
56 |
57 | Zero-shot Learning
58 | ~~~~~~~~~~~~~~~~~~
59 |
60 | .. currentmodule:: lazyslide
61 |
62 | .. autosummary::
63 | :toctree: _autogen
64 | :nosignatures:
65 |
66 | tl.zero_shot_score
67 | tl.slide_caption
68 |
--------------------------------------------------------------------------------
/docs/source/conf.py:
--------------------------------------------------------------------------------
1 | from datetime import datetime
2 |
3 | import lazyslide
4 |
5 | project = "LazySlide"
6 | copyright = f"{datetime.now().year}, Rendeiro Lab"
7 | author = "LazySlide Contributors"
8 | release = lazyslide.__version__
9 |
10 | # -- General configuration ---------------------------------------------------
11 | # https://www.sphinx-doc.org/en/master/usage/configuration.html#general-configuration
12 |
13 | extensions = [
14 | "numpydoc",
15 | "sphinx.ext.autodoc",
16 | "sphinx.ext.autosummary",
17 | "sphinx.ext.autosectionlabel",
18 | "matplotlib.sphinxext.plot_directive",
19 | "sphinx.ext.intersphinx",
20 | "sphinx_design",
21 | "sphinx_copybutton",
22 | "myst_nb",
23 | "sphinx_contributors",
24 | ]
25 | autoclass_content = "class"
26 | autodoc_docstring_signature = True
27 | autodoc_default_options = {
28 | "members": True,
29 | "show-inheritance": True,
30 | "no-undoc-members": True,
31 | "special-members": "__call__",
32 | "exclude-members": "__init__, __weakref__",
33 | "class-doc-from": "class",
34 | }
35 | autodoc_typehints = "none"
36 | # setting autosummary
37 | autosummary_generate = True
38 | numpydoc_show_class_members = False
39 | add_module_names = False
40 |
41 | templates_path = ["_templates"]
42 | exclude_patterns = []
43 |
44 |
45 | html_theme = "sphinx_book_theme"
46 | html_static_path = ["_static"]
47 | html_logo = "_static/logo@3x.png"
48 | html_css_files = ["custom.css"]
49 | html_theme_options = {
50 | "repository_url": "https://github.com/rendeirolab/LazySlide",
51 | "navigation_with_keys": True,
52 | "show_prev_next": False,
53 | }
54 | # html_sidebars = {"installation": [], "cli": []}
55 |
56 | nb_output_stderr = "remove"
57 | nb_execution_mode = "off"
58 | nb_merge_streams = True
59 | myst_enable_extensions = [
60 | "colon_fence",
61 | "html_image",
62 | ]
63 |
64 | copybutton_prompt_text = r">>> |\.\.\. |\$ |In \[\d*\]: | {2,5}\.\.\.: | {5," r"8}: "
65 | copybutton_prompt_is_regexp = True
66 |
67 | # Plot directive
68 | plot_include_source = True
69 | plot_html_show_source_link = False
70 | plot_html_show_formats = False
71 | plot_formats = [("png", 200)]
72 |
73 | intersphinx_mapping = {
74 | "wsidata": ("https://wsidata.readthedocs.io/en/latest", None),
75 | "torch": ("https://pytorch.org/docs/stable/", None),
76 | }
77 |
--------------------------------------------------------------------------------
/docs/source/contributing.rst:
--------------------------------------------------------------------------------
1 | Contributing
2 | ============
3 |
4 | We welcome contributions to the LazySlide project. This document provides guidelines for contributing to the project.
5 |
6 | Project overview
7 | ----------------
8 |
9 | LazySlide is a modularized and scalable whole slide image analysis toolkit. The project is structured as follows:
10 |
11 | - ``src/lazyslide``: Main package code
12 | - ``tests``: Test files
13 | - ``docs``: Documentation
14 |
15 |
16 | For core contributors
17 | ---------------------
18 |
19 | Please do not commit directly to the ``main`` branch.
20 | Instead, create a new branch for your changes and submit a pull request.
21 |
22 | Set up development environment
23 | ------------------------------
24 |
25 | We use `uv `_ to manage our development environment.
26 | Please make sure you have it installed before proceeding.
27 |
28 | 1. Clone the repository::
29 |
30 | git clone https://github.com/rendeirolab/lazyslide.git
31 | # or
32 | gh repo clone rendeirolab/lazyslide
33 |
34 | 2. Checkout a new branch::
35 |
36 | git checkout -b my-new-branch
37 |
38 | 3. We use `uv `_ to manage our development environment::
39 |
40 | uv lock
41 | uv run pre-commit install
42 |
43 | We use `pre-commit `_ to run code formatting and linting checks before each commit.
44 |
45 | 4. Start an IPython/Jupyter session::
46 |
47 | uv run --with ipython ipython
48 | # or
49 | uv run --with jupyter jupyter lab
50 |
51 | 5. Make your changes.
52 |
53 | Testing
54 | -------
55 |
56 | LazySlide uses pytest for testing. Tests are located in the ``tests`` directory.
57 |
58 | To run all tests::
59 |
60 | uv run task test
61 |
62 | To run a specific test file::
63 |
64 | uv run python -m pytest tests/test_example.py
65 |
66 | When adding new tests:
67 |
68 | 1. Create a new file in the ``tests`` directory with a name starting with ``test_``.
69 | 2. Import pytest and the module you want to test.
70 | 3. Write test functions with names starting with ``test_``.
71 | 4. Use assertions to verify expected behavior.
72 |
73 | Code style and development guidelines
74 | -------------------------------------
75 |
76 | LazySlide uses `ruff `_ for both linting and formatting.
77 | The configuration is defined in ``pyproject.toml`` and enforced through pre-commit hooks.
78 |
79 | To format code::
80 |
81 | uv run task fmt
82 | # or
83 | ruff format docs/source src/lazyslide tests
84 |
85 | Documentation
86 | -------------
87 |
88 | Documentation is built using Sphinx and is located in the ``docs`` directory.
89 |
90 | To build the documentation::
91 |
92 | # Build doc with cache
93 | uv run task doc-build
94 | # Fresh build
95 | uv run task doc-clean-build
96 |
97 | To serve the documentation locally::
98 |
99 | uv run task doc-serve
100 |
101 | This will start a local server at http://localhost:8000.
102 |
103 | Documentation is written in reStructuredText (.rst) and Jupyter notebooks (.ipynb) using the myst-nb extension.
104 |
105 | Submitting changes
106 | ------------------
107 |
108 | 1. Commit your changes and push them to your branch.
109 | 2. Create a pull request on GitHub.
110 | 3. Ensure all CI checks pass.
111 | 4. Wait for a review from a maintainer.
112 |
113 | Reporting issues
114 | ----------------
115 |
116 | If you encounter a bug or have a feature request, please open an issue on the
117 | `GitHub repository `_.
118 |
119 | When reporting a bug, please include:
120 |
121 | - A clear description of the issue
122 | - Steps to reproduce the problem
123 | - Expected behavior
124 | - Actual behavior
125 | - Any relevant logs or error messages
126 | - Your environment (OS, Python version, package versions)
127 |
--------------------------------------------------------------------------------
/docs/source/contributors.rst:
--------------------------------------------------------------------------------
1 | Contributors
2 | ============
3 |
4 |
5 | .. card:: Rendeiro Lab
6 |
7 | LazySlide is developed by `Rendeiro Lab `_
8 | at the `CeMM Research Center for Molecular Medicine `_.
9 |
10 | .. image:: _static/cemm-logo.svg
11 | :width: 200px
12 | :align: center
13 |
14 |
15 | Developers
16 | ----------
17 |
18 | - `Yimin Zheng `_, lead developer.
19 | - `Ernesto Abila `_, developer.
20 | - `Andre Rendeiro `_, lab leader, guidance and support.
21 |
22 | .. contributors:: rendeirolab/LazySlide
23 | :avatars:
24 |
25 |
--------------------------------------------------------------------------------
/docs/source/index.rst:
--------------------------------------------------------------------------------
1 | LazySlide: Accessible and interoperable whole slide image analysis
2 | ==================================================================
3 |
4 | .. grid:: 1 2 2 2
5 |
6 | .. grid-item::
7 | :columns: 12 4 4 4
8 |
9 | .. image:: _static/logo@3x.png
10 | :align: center
11 | :width: 150px
12 |
13 | .. grid-item::
14 | :columns: 12 8 8 8
15 | :child-align: center
16 |
17 | **LasySlide** LazySlide is a Python framework for whole slide image (WSI) analysis,
18 | designed to integrate seamlessly with the `scverse`_ ecosystem.
19 |
20 | By adopting standardized data structures and APIs familiar to the single-cell and genomics community,
21 | LazySlide enables intuitive, interoperable, and reproducible workflows for histological analysis.
22 | It supports a range of tasks from basic preprocessing to advanced deep learning applications,
23 | facilitating the integration of histopathology into modern computational biology.
24 |
25 | Key features
26 | ------------
27 |
28 | * **Interoperability**: Built on top of `SpatialData`_, ensuring compatibility with scverse tools like `Scanpy`_, `Anndata`_, and `Squidpy`_. Check out `WSIData`_ for more details.
29 | * **Accessibility**: User-friendly APIs that cater to both beginners and experts in digital pathology.
30 | * **Scalability**: Efficient handling of large WSIs, enabling high-throughput analyses.
31 | * **Multimodal integration**: Combine histological data with transcriptomics, genomics, and textual annotations.
32 | * **Foundation model support**: Native integration with state-of-the-art models (e.g., UNI, CONCH, Gigapath, Virchow) for tasks like zero-shot classification and captioning.
33 | * **Deep learning ready**: Provides PyTorch dataloaders for seamless integration into machine learning pipelines.
34 |
35 | Whether you're a novice in digital pathology or an expert computational biologist, LazySlide provides a scalable and modular foundation to accelerate AI-driven discovery in tissue biology and pathology.
36 |
37 | .. image:: https://github.com/rendeirolab/LazySlide/raw/main/assets/Figure.png
38 |
39 | |
40 |
41 | .. toctree::
42 | :maxdepth: 1
43 | :hidden:
44 |
45 | installation
46 | tutorials/index
47 | api/index
48 | contributing
49 | contributors
50 |
51 |
52 | .. grid:: 1 2 2 2
53 | :gutter: 2
54 |
55 | .. grid-item-card:: Installation
56 | :link: installation
57 | :link-type: doc
58 |
59 | How to install LazySlide
60 |
61 | .. grid-item-card:: Tutorials
62 | :link: tutorials/index
63 | :link-type: doc
64 |
65 | Get started with LazySlide
66 |
67 | .. grid-item-card:: Contributing
68 | :link: contributing
69 | :link-type: doc
70 |
71 | Contribute to Lazyslide
72 |
73 | .. grid-item-card:: Contributors
74 | :link: contributors
75 | :link-type: doc
76 |
77 | The team behind LazySlide
78 |
79 | .. _scverse: https://scverse.org/
80 | .. _WSIData: https://wsidata.readthedocs.io/
81 | .. _SpatialData: https://spatialdata.scverse.org/
82 | .. _Scanpy: https://scanpy.readthedocs.io/
83 | .. _Anndata: https://anndata.readthedocs.io/
84 | .. _Squidpy: https://squidpy.readthedocs.io/
85 |
86 |
--------------------------------------------------------------------------------
/docs/source/installation.rst:
--------------------------------------------------------------------------------
1 | Installation
2 | ============
3 |
4 | You can install :code:`lazyslide` with different package manager you prefer.
5 |
6 | .. tab-set::
7 |
8 | .. tab-item:: PyPI
9 |
10 | The default installation.
11 |
12 | .. code-block:: bash
13 |
14 | pip install lazyslide
15 |
16 | .. tab-item:: uv
17 |
18 | .. code-block:: bash
19 |
20 | uv add lazyslide
21 |
22 | .. tab-item:: Conda
23 |
24 | .. warning::
25 |
26 | Not available yet.
27 |
28 | .. code-block:: bash
29 |
30 | conda install -c conda-forge lazyslide
31 |
32 | .. tab-item:: Mamba
33 |
34 | .. warning::
35 |
36 | Not available yet.
37 |
38 | .. code-block:: bash
39 |
40 | mamba install lazyslide
41 |
42 | .. tab-item:: Development
43 |
44 | If you want to install the latest version from the GitHub repository, you can use the following command:
45 |
46 | .. code-block:: bash
47 |
48 | pip install git+https://github.com/rendeirolab/lazyslide.git
49 |
50 |
51 | Installation of slide readers
52 | -----------------------------
53 |
54 | LazySlide uses :code:`wsidata` to handle the IO with the slide files.
55 | To support different file formats, you need to install corresponding slide readers.
56 | The reader will be automatically detected by :code:`wsidata` when you open the slide file.
57 |
58 |
59 | .. tab-set::
60 |
61 | .. tab-item:: TiffSlide
62 |
63 | `TiffSlide `_ is a cloud native openslide-python replacement
64 | based on tifffile.
65 |
66 | TiffSlide is installed by default. You don't need to install it manually.
67 |
68 | .. code-block:: bash
69 |
70 | pip install tiffslide
71 |
72 | .. tab-item:: OpenSlide
73 |
74 | `OpenSlide `_ is a C library that provides a simple interface to read whole-slide images.
75 |
76 | OpenSlide is installed by default, you don't need to install it manually.
77 |
78 | But you can always install from PyPI
79 |
80 | .. code-block:: bash
81 |
82 | pip install openslide-python openslide-bin
83 |
84 | In case your OpenSlide installation is not working, you can install it manually.
85 |
86 | For Linux and OSX users, it's suggested that you install :code:`openslide` with conda or mamba:
87 |
88 | .. code-block:: bash
89 |
90 | conda install -c conda-forge openslide-python
91 | # or
92 | mamba install -c conda-forge openslide-python
93 |
94 |
95 | For Windows users, you need to download compiled :code:`openslide` from
96 | `GitHub Release `_.
97 | If you open the folder, you should find a :code:`bin` folder.
98 |
99 | Make sure you point the :code:`bin` folder for python to locate the :code:`openslide` binary.
100 | You need to run following code to import the :code:`openslide`,
101 | it's suggested to run this code before everything:
102 |
103 | .. code-block:: python
104 |
105 | import os
106 | with os.add_dll_directory("path/to/openslide/bin")):
107 | import openslide
108 |
109 | .. tab-item:: BioFormats
110 |
111 | `BioFormats `_ is a standalone Java library
112 | for reading and writing life sciences image file formats.
113 |
114 | `scyjava `_ is used to interact with the BioFormats library.
115 |
116 | .. code-block:: bash
117 |
118 | pip install scyjava
119 |
120 | .. tab-item:: CuCIM
121 |
122 | `CuCIM `_ is a GPU-accelerated image I/O library.
123 |
124 | .. warning::
125 |
126 | CuCIM support is not available yet.
127 |
128 | Please refer to the `CuCIM GitHub `_.
--------------------------------------------------------------------------------
/docs/source/tutorials/.gitignore:
--------------------------------------------------------------------------------
1 | /tmp
2 | GTEx*
3 | *.sha256
--------------------------------------------------------------------------------
/docs/source/tutorials/05_training_models.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "id": "1df73331906bcf35",
6 | "metadata": {},
7 | "source": [
8 | "# Training deep learning models with LazySlide"
9 | ]
10 | },
11 | {
12 | "cell_type": "markdown",
13 | "id": "6a559156-6909-46c5-bd0c-597fc02f2fe5",
14 | "metadata": {},
15 | "source": [
16 | "## Classification task"
17 | ]
18 | },
19 | {
20 | "cell_type": "markdown",
21 | "id": "3b777679-2963-476c-8136-1b5d17ac33ee",
22 | "metadata": {},
23 | "source": [
24 | "## Segmentation task"
25 | ]
26 | },
27 | {
28 | "cell_type": "markdown",
29 | "id": "a774bc01-fc85-41ac-9d02-0f0914e6e804",
30 | "metadata": {},
31 | "source": [
32 | "## Tissue generative model"
33 | ]
34 | },
35 | {
36 | "cell_type": "code",
37 | "execution_count": null,
38 | "id": "4ad9b7e0-f76f-4ad3-b922-5e1c4815f316",
39 | "metadata": {},
40 | "outputs": [],
41 | "source": []
42 | }
43 | ],
44 | "metadata": {
45 | "kernelspec": {
46 | "display_name": "Python 3 (ipykernel)",
47 | "language": "python",
48 | "name": "python3"
49 | },
50 | "language_info": {
51 | "codemirror_mode": {
52 | "name": "ipython",
53 | "version": 3
54 | },
55 | "file_extension": ".py",
56 | "mimetype": "text/x-python",
57 | "name": "python",
58 | "nbconvert_exporter": "python",
59 | "pygments_lexer": "ipython3",
60 | "version": "3.12.8"
61 | },
62 | "widgets": {
63 | "application/vnd.jupyter.widget-state+json": {
64 | "state": {},
65 | "version_major": 2,
66 | "version_minor": 0
67 | }
68 | }
69 | },
70 | "nbformat": 4,
71 | "nbformat_minor": 5
72 | }
73 |
--------------------------------------------------------------------------------
/docs/source/tutorials/index.rst:
--------------------------------------------------------------------------------
1 | Tutorials
2 | =========
3 |
4 | Here is a list of tutorials that will help you get started with the LazySlide.
5 |
6 | .. toctree::
7 | :hidden:
8 | :maxdepth: 1
9 |
10 | 00_intro_wsi
11 | 01_preprocessing
12 | 02_feature_extraction
13 | 03_multiple_slides
14 | 04_genomics_integration
15 | 05_cell-segmentation
16 | 06_visualization
17 | 07_zero-shot-learning
18 |
19 | .. card:: Introduction to WSI
20 |
21 | :doc:`00_intro_wsi`
22 |
23 | .. card:: Preprocessing
24 |
25 | :doc:`01_preprocessing`
26 |
27 | .. card:: Feature extraction and spatial analysis
28 |
29 | :doc:`02_feature_extraction`
30 |
31 | .. card:: Working with multiple slides
32 |
33 | :doc:`03_multiple_slides`
34 |
35 | .. card:: Integration with RNA-seq
36 |
37 | :doc:`04_genomics_integration`
38 |
39 | .. card:: Cell segmentation
40 |
41 | :doc:`05_cell-segmentation`
42 |
43 | .. card:: WSI visualization in LazySlide
44 |
45 | :doc:`06_visualization`
46 |
47 | .. card:: Zero-shot learning LazySlide
48 |
49 | :doc:`07_zero-shot-learning`
50 |
51 |
--------------------------------------------------------------------------------
/docs/source/tutorials/matplotlibrc:
--------------------------------------------------------------------------------
1 | pdf.fonttype: 42
2 | svg.fonttype: none
3 | font.family: sans-serif
4 | font.sans-serif: Arial
5 | font.size: 10.0
6 | figure.figsize: 4.0, 4.0
7 | savefig.dpi: 300 # figure dots per inch or 'figure'
8 | savefig.facecolor: none # figure face color when saving
9 | savefig.edgecolor: none # figure edge color when saving
10 | savefig.bbox: tight
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [build-system]
2 | requires = ["hatchling"]
3 | build-backend = "hatchling.build"
4 |
5 | [project]
6 | name = "lazyslide"
7 | description = "Modularized and scalable whole slide image analysis"
8 | readme = "README.md"
9 | requires-python = ">=3.10"
10 | license = "MIT"
11 | authors = [
12 | {name = "Yimin Zheng", email = "yzheng@cemm.at"},
13 | {name = "Ernesto Abila", email = "eabila@cemm.at"},
14 | {name = "André F. Rendeiro", email = "arendeiro@cemm.at"},
15 | ]
16 | keywords = ["histopathology", "whole slide image", "image analysis", "segmentation", "deep learning"]
17 | classifiers = [
18 | "Development Status :: 3 - Alpha",
19 | "Intended Audience :: Science/Research",
20 | "License :: OSI Approved :: MIT License",
21 | "Natural Language :: English",
22 | "Operating System :: OS Independent",
23 | "Programming Language :: Python :: 3",
24 | "Topic :: File Formats",
25 | "Topic :: Scientific/Engineering :: Bio-Informatics",
26 | ]
27 | Documentation = "https://lazyslide.readthedocs.io"
28 | repository = "https://github.com/rendeirolab/lazyslide"
29 | dynamic = ["version"]
30 | dependencies = [
31 | "wsidata>=0.6.0",
32 | "scikit-learn>=1.0",
33 | "matplotlib>=3.9.0",
34 | "matplotlib-scalebar>=0.9.0",
35 | "legendkit>=0.3.4",
36 | "rich>=13.0.0",
37 | "cyclopts>=3.0.0",
38 | "timm>=1.0.3",
39 | "torch>=2.0.0",
40 | "seaborn>=0.12.2",
41 | "psutil>=5.9.0",
42 | ]
43 |
44 | [project.optional-dependencies]
45 | all = [
46 | "scipy>=1.15.1",
47 | "scanpy>=1.10.4",
48 | "torchvision>=0.15", # >0.15
49 | "torchstain>=1.4.1",
50 | "transformers>=4.49.0",
51 | ]
52 |
53 | # Define entry points
54 | [project.scripts]
55 | lazyslide = "lazyslide.__main__:app"
56 | zs = "lazyslide.__main__:app"
57 |
58 | [tool.hatch.version]
59 | path = "src/lazyslide/__init__.py"
60 |
61 | [tool.hatch.build.targets.sdist]
62 | exclude = [
63 | "docs",
64 | "data",
65 | "assets",
66 | "tests",
67 | "scripts",
68 | ".readthedocs.yaml",
69 | ".github",
70 | ".gitignore",
71 | ]
72 | include = [
73 | "README.md",
74 | "LICENSE",
75 | "pyproject.toml",
76 | "src/lazyslide",
77 | ]
78 |
79 | [tool.hatch.build.targets.wheel]
80 | packages = ["src/lazyslide", "README.md", "LICENSE", "pyproject.toml"]
81 |
82 | [tool.hatch.metadata]
83 | allow-direct-references = true
84 |
85 | [tool.ruff]
86 | lint.ignore = ["F401"]
87 | line-length = 88
88 |
89 | [tool.ruff.lint.per-file-ignores]
90 | "tests/test_example.py" = ["E402"]
91 | "tests/test_loader.py" = ["E402"]
92 |
93 | [tool.mypy]
94 | ignore_missing_imports = true
95 |
96 | [tool.taskipy.tasks]
97 | hello = "echo Hello, World!"
98 | test = "pytest tests --disable-warnings"
99 | test-ci = "python -X faulthandler -m pytest tests -v --tb=short --disable-warnings"
100 | doc-build = "sphinx-build -b html docs/source docs/build"
101 | doc-clean-build = "python docs/clean_up.py && sphinx-build -b html docs/source docs/build"
102 | doc-serve = "python -m http.server -d docs/build"
103 | fmt = "ruff format docs/source src/lazyslide tests"
104 |
105 | [tool.uv]
106 | default-groups = ["dev", "docs", "tutorials", "model"]
107 |
108 | [dependency-groups]
109 | dev = [
110 | "jupyterlab>=4.3.5",
111 | "pytest>=8.3.4",
112 | "pre-commit>=4.1.0",
113 | "ruff>=0.9.4",
114 | "taskipy>=1.14.1",
115 | "torchvision>=0.21.0",
116 | "torchstain>=1.4.1",
117 | "matplotlib>=3.10.0",
118 | "matplotlib-scalebar>=0.9.0",
119 | "scikit-learn>=1.6.1",
120 | "scanpy>=1.10.4",
121 | "scipy>=1.15.1",
122 | "segmentation-models-pytorch>=0.4.0",
123 | "albumentations>=2.0.3",
124 | "spatialdata-plot>=0.2.9",
125 | "scyjava>=1.12.0",
126 | ]
127 | docs = [
128 | "sphinx>=8.1.3",
129 | "sphinx-copybutton>=0.5.2",
130 | "sphinx-design>=0.6.1",
131 | "myst-nb>=1.1.2",
132 | "numpydoc>=1.8.0",
133 | "pydata-sphinx-theme>=0.16.1",
134 | "sphinx>=8.1.3",
135 | "sphinx-copybutton>=0.5.2",
136 | "sphinx-design>=0.6.1",
137 | "sphinx-book-theme>=1.1.3",
138 | "sphinx-contributors>=0.2.7",
139 | ]
140 | tutorials = [
141 | "igraph>=0.11.8",
142 | "ipywidgets>=8.1.5",
143 | "marsilea>=0.5.1",
144 | "parse>=1.20.2",
145 | "gseapy>=1.1.7",
146 | "mpl-fontkit>=0.5.1",
147 | "matplotlib-venn>=1.1.2",
148 | "muon>=0.1.7",
149 | "mofapy2>=0.7.2",
150 | "pypalettes>=0.1.5",
151 | "bokeh>=3.7.2",
152 | "dask-jobqueue>=0.9.0",
153 | ]
154 | napari = [
155 | "napari[all]>=0.5.6",
156 | "napari-spatialdata>=0.5.5",
157 | "spatialdata-plot>=0.2.9",
158 | ]
159 | model = [
160 | "einops>=0.8.1",
161 | "einops-exts>=0.0.4",
162 | "environs>=14.1.1",
163 | "sacremoses>=0.1.1",
164 | "conch",
165 | "transformers>=4.49.0",
166 | ]
167 |
168 |
169 | [tool.uv.sources]
170 | # wsidata = { git = "https://github.com/rendeirolab/wsidata", branch = "main" }
171 | # wsidata = { path = "../wsidata", editable = true }
172 | conch = { git = "https://github.com/mahmoodlab/CONCH.git" }
173 |
174 | [tool.uv.workspace]
175 | members = ["scripts/grandqc/artifacts_detection"]
176 |
177 | [tool.pytest.ini_options]
178 | filterwarnings = [
179 | "ignore::UserWarning"
180 | ]
181 |
--------------------------------------------------------------------------------
/src/lazyslide/__init__.py:
--------------------------------------------------------------------------------
1 | """Efficient and Scalable Whole Slide Image (WSI) processing library."""
2 |
3 | __version__ = "0.6.0"
4 |
5 |
6 | import sys
7 |
8 | # Re-export the public API
9 | from wsidata import open_wsi, agg_wsi
10 |
11 | from . import cv
12 | from . import io
13 | from . import models
14 | from . import plotting as pl
15 | from . import preprocess as pp
16 | from . import segmentation as seg
17 | from . import tools as tl
18 | from . import datasets
19 | from . import metrics
20 |
21 | # Inject the aliases into the current module
22 | sys.modules.update({f"{__name__}.{m}": globals()[m] for m in ["tl", "pp", "pl", "seg"]})
23 | del sys
24 |
25 |
26 | __all__ = [
27 | "open_wsi",
28 | "agg_wsi",
29 | "pp",
30 | "tl",
31 | "pl",
32 | "seg",
33 | "cv",
34 | "models",
35 | "io",
36 | ]
37 |
--------------------------------------------------------------------------------
/src/lazyslide/_const.py:
--------------------------------------------------------------------------------
1 | class Key:
2 | tissue_qc = "qc"
3 | tile_qc = "qc"
4 | tissue: str = "tissues"
5 | tissue_id: str = "tissue_id"
6 | tiles = "tiles"
7 | tile_spec: str = "tile_spec"
8 | annotations: str = "annotations"
9 |
10 | @classmethod
11 | def tile_graph(cls, name):
12 | return f"{name}_graph"
13 |
14 | @classmethod
15 | def feature(cls, name, tile_key=None):
16 | tile_key = tile_key or cls.tiles
17 | return f"{name}_{tile_key}"
18 |
19 | @classmethod
20 | def feature_slide(cls, name, tile_key=None):
21 | tile_key = tile_key or cls.tiles
22 | return f"{name}_{tile_key}_slide"
23 |
--------------------------------------------------------------------------------
/src/lazyslide/_utils.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | import inspect
4 | import os
5 | from functools import wraps
6 | from types import FrameType
7 |
8 | from rich.console import Console
9 |
10 | console = Console()
11 |
12 |
13 | def get_torch_device():
14 | """Automatically get the torch device"""
15 | import torch
16 |
17 | if torch.cuda.is_available():
18 | device = torch.device("cuda")
19 | elif torch.backends.mps.is_available():
20 | device = torch.device("mps")
21 | else:
22 | device = torch.device("cpu")
23 | return device
24 |
25 |
26 | def default_pbar(disable=False):
27 | """Get the default progress bar"""
28 | from rich.progress import Progress
29 | from rich.progress import (
30 | TextColumn,
31 | BarColumn,
32 | TaskProgressColumn,
33 | TimeRemainingColumn,
34 | )
35 |
36 | return Progress(
37 | TextColumn("[progress.description]{task.description}"),
38 | BarColumn(bar_width=30),
39 | TaskProgressColumn(),
40 | TimeRemainingColumn(compact=True, elapsed_when_finished=True),
41 | disable=disable,
42 | console=console,
43 | transient=True,
44 | )
45 |
46 |
47 | def chunker(seq, num_workers):
48 | avg = len(seq) / num_workers
49 | out = []
50 | last = 0.0
51 |
52 | while last < len(seq):
53 | out.append(seq[int(last) : int(last + avg)])
54 | last += avg
55 |
56 | return out
57 |
58 |
59 | def find_stack_level() -> int:
60 | """
61 | Find the first place in the stack that is not inside pandas
62 | (tests notwithstanding).
63 | """
64 |
65 | import pandas as pd
66 |
67 | pkg_dir = os.path.dirname(pd.__file__)
68 | test_dir = os.path.join(pkg_dir, "tests")
69 |
70 | # https://stackoverflow.com/questions/17407119/python-inspect-stack-is-slow
71 | frame: FrameType | None = inspect.currentframe()
72 | try:
73 | n = 0
74 | while frame:
75 | filename = inspect.getfile(frame)
76 | if filename.startswith(pkg_dir) and not filename.startswith(test_dir):
77 | frame = frame.f_back
78 | n += 1
79 | else:
80 | break
81 | finally:
82 | # See note in
83 | # https://docs.python.org/3/library/inspect.html#inspect.Traceback
84 | del frame
85 | return n
86 |
87 |
88 | def _param_doc(param_type, param_text):
89 | return f"""{param_type}\n\t{param_text}"""
90 |
91 |
92 | PARAMS_DOCSTRING = {
93 | "wsi": _param_doc(
94 | param_type=":class:`WSIData `",
95 | param_text="The WSIData object to work on.",
96 | ),
97 | "key_added": _param_doc(
98 | param_type="str, default: '{key_added}'",
99 | param_text="The key to save the result in the WSIData object.",
100 | ),
101 | }
102 |
103 |
104 | def _doc(obj=None, *, key_added: str = None):
105 | """
106 | A decorator to inject docstring to an object by replacing the placeholder in docstring by looking up a dict.
107 | """
108 |
109 | def decorator(obj):
110 | if obj.__doc__ is not None:
111 | if key_added is not None:
112 | PARAMS_DOCSTRING["key_added"] = PARAMS_DOCSTRING["key_added"].format(
113 | key_added=key_added
114 | )
115 | obj.__doc__ = obj.__doc__.format(**PARAMS_DOCSTRING)
116 |
117 | @wraps(obj)
118 | def wrapper(*args, **kwargs):
119 | return obj(*args, **kwargs)
120 |
121 | return wrapper
122 |
123 | if obj is None:
124 | return decorator
125 | else:
126 | return decorator(obj)
127 |
--------------------------------------------------------------------------------
/src/lazyslide/cv/__init__.py:
--------------------------------------------------------------------------------
1 | from .mask import Mask, BinaryMask, MultiLabelMask, MultiClassMask
2 | from .tiles_merger import merge_polygons
3 |
--------------------------------------------------------------------------------
/src/lazyslide/cv/scorer/__init__.py:
--------------------------------------------------------------------------------
1 | from .base import ScorerBase, ComposeScorer
2 | from .focuslitenn import FocusLite
3 | from .module import Contrast, SplitRGB, Redness, Brightness
4 |
--------------------------------------------------------------------------------
/src/lazyslide/cv/scorer/base.py:
--------------------------------------------------------------------------------
1 | from collections import namedtuple
2 |
3 | ScoreResult = namedtuple("ScoreResult", ["scores", "qc"])
4 |
5 |
6 | class ScorerBase:
7 | """
8 | Base class for all scorers.
9 |
10 | All scores are operated on a patch.
11 |
12 | Image -> float
13 | """
14 |
15 | def __call__(self, patch, mask=None):
16 | return self.apply(patch, mask=None)
17 |
18 | def __repr__(self):
19 | return f"{self.__class__.__name__}()"
20 |
21 | def apply(self, patch, mask=None) -> ScoreResult:
22 | """The scorer will return the scores and the bool value indicating of QC"""
23 | raise NotImplementedError
24 |
25 |
26 | class ComposeScorer(ScorerBase):
27 | """
28 | Compose multiple scorers into one.
29 |
30 | Parameters
31 | ----------
32 | scorers : List[ScorerBase]
33 | List of scorers to be composed.
34 | """
35 |
36 | def __init__(self, scorers):
37 | self.scorers = scorers
38 |
39 | def apply(self, patch, mask=None) -> ScoreResult:
40 | scores = {}
41 | qc = True
42 | for scorer in self.scorers:
43 | score, _qc = scorer.apply(patch, mask)
44 | scores.update(score)
45 | qc &= _qc
46 | return ScoreResult(scores=scores, qc=qc)
47 |
--------------------------------------------------------------------------------
/src/lazyslide/cv/scorer/focuslitenn/__init__.py:
--------------------------------------------------------------------------------
1 | from .model import FocusLite
2 |
--------------------------------------------------------------------------------
/src/lazyslide/cv/scorer/focuslitenn/focuslitenn-2kernel-mse.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rendeirolab/LazySlide/f39634cc994b3098b0933075b9d25ecd99b9014e/src/lazyslide/cv/scorer/focuslitenn/focuslitenn-2kernel-mse.pt
--------------------------------------------------------------------------------
/src/lazyslide/cv/scorer/focuslitenn/model.py:
--------------------------------------------------------------------------------
1 | import math
2 | from pathlib import Path
3 |
4 | import numpy as np
5 |
6 | from lazyslide.cv.scorer.base import ScorerBase, ScoreResult
7 |
8 | try:
9 | import torch
10 |
11 | class FocusLiteNN(torch.nn.Module):
12 | """
13 | A FocusLiteNN model for filtering out-of-focus regions in whole slide images.
14 | """
15 |
16 | def __init__(self, num_channel=2):
17 | super().__init__()
18 | self.num_channel = num_channel
19 | self.conv = torch.nn.Conv2d(
20 | 3, self.num_channel, 7, stride=5, padding=1
21 | ) # 47x47
22 | self.maxpool = torch.nn.MaxPool2d(kernel_size=47)
23 | if self.num_channel > 1:
24 | self.fc = torch.nn.Conv2d(self.num_channel, 1, 1, stride=1, padding=0)
25 |
26 | for m in self.modules():
27 | if isinstance(m, torch.nn.Conv2d):
28 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
29 | m.weight.data.normal_(0, math.sqrt(2.0 / n))
30 |
31 | def forward(self, x):
32 | batch_size = x.size()[0]
33 |
34 | x = self.conv(x)
35 | x = -self.maxpool(-x) # minpooling
36 | if self.num_channel > 1:
37 | x = self.fc(x)
38 | x = x.view(batch_size, -1)
39 |
40 | return x
41 | except ImportError:
42 |
43 | class FocusLiteNN:
44 | def __init__(self, *args, **kwargs):
45 | raise ImportError(
46 | "FocusLiteNN requires torch. You can install it using `pip install torch`."
47 | "Please restart the kernel after installation."
48 | )
49 |
50 |
51 | def load_focuslite_model(device="cpu"):
52 | model = FocusLiteNN()
53 | if not hasattr(model, "forward"):
54 | raise ModuleNotFoundError("To use Focuslite, you need to install pytorch")
55 | ckpt = torch.load(
56 | Path(__file__).parent / "focuslitenn-2kernel-mse.pt",
57 | map_location=device,
58 | weights_only=True,
59 | )
60 | model.load_state_dict(ckpt["state_dict"])
61 | model.eval()
62 | # model = torch.compile(model)
63 | return model
64 |
65 |
66 | class FocusLite(ScorerBase):
67 | # The device must be CPU, otherwise this module cannot be serialized
68 | def __init__(self, threshold=3, device="cpu"):
69 | from torchvision.transforms import ToTensor, Resize
70 |
71 | # threshold should be between 1 and 12
72 | if not (1 <= threshold <= 12):
73 | raise ValueError("threshold should be between 1 and 12")
74 | self.threshold = threshold
75 | self.model = load_focuslite_model(device)
76 | self.to_tensor = ToTensor()
77 | self.resize = Resize((256, 256), antialias=False)
78 |
79 | def apply(self, patch, mask=None):
80 | """Higher score means the patch is more clean, range from 0 to 1"""
81 | arr = self.to_tensor(patch)
82 | # If the image is not big enough, resize it
83 | if arr.shape[1] < 256 or arr.shape[2] < 256:
84 | arr = self.resize(arr)
85 | arr = torch.stack([arr], dim=0)
86 | score = self.model(arr)
87 | score = max(0, np.mean(torch.squeeze(score.cpu().data, dim=1).numpy()))
88 | return ScoreResult(scores={"focus": score}, qc=score < self.threshold)
89 |
--------------------------------------------------------------------------------
/src/lazyslide/cv/scorer/module.py:
--------------------------------------------------------------------------------
1 | import cv2
2 | import numpy as np
3 |
4 | from .base import ScorerBase, ScoreResult
5 | from .utils import dtype_limits
6 |
7 |
8 | class SplitRGB(ScorerBase):
9 | """
10 | Calculate the RGB value of a patch.
11 |
12 | Brightness is calculated as the mean of the pixel values.
13 |
14 | The patch need to be in shape (H, W, 3).
15 |
16 | Parameters
17 | ----------
18 | red_threshold : float
19 | Threshold to determine if a patch is red enough.
20 |
21 | """
22 |
23 | def __init__(
24 | self,
25 | threshold: (int, int, int) = (
26 | 0,
27 | 0,
28 | 0,
29 | ),
30 | method="mean",
31 | dim="xyc",
32 | ):
33 | self.threshold = np.array(threshold)
34 | self.method = method
35 | self.dim = dim
36 | if dim == "xyc":
37 | self.func = self._score_xyc
38 | elif dim == "cyx":
39 | self.func = self._score_cyx
40 | else:
41 | raise ValueError(f"Unknown dim {dim}, should be 'xyc' or 'cyx'")
42 |
43 | def _score_xyc(self, patch, mask=None):
44 | if mask is not None:
45 | img = patch[mask]
46 | else:
47 | img = patch
48 | c_int = getattr(img, self.method)(axis=(0, 1))
49 | return {"red": c_int[0], "green": c_int[1], "blue": c_int[2]}
50 |
51 | def _score_cyx(self, patch, mask=None):
52 | if mask is not None:
53 | c_int = [patch[c][mask].mean() for c in range(3)]
54 | else:
55 | c_int = [patch[c].mean() for c in range(3)]
56 | return {"red": c_int[0], "green": c_int[1], "blue": c_int[2]}
57 |
58 | def apply(self, patch, mask=None):
59 | scores = self.func(patch, mask)
60 | return ScoreResult(scores=scores, qc=scores > self.threshold)
61 |
62 |
63 | class Redness(SplitRGB):
64 | def __init__(self, red_threshold=0.5, **kwargs):
65 | self.red_threshold = red_threshold
66 | super().__init__(**kwargs)
67 |
68 | def apply(self, patch, mask=None):
69 | scores = self.func(patch, mask)
70 | return ScoreResult(
71 | scores={"redness": scores["red"]}, qc=scores["red"] > self.red_threshold
72 | )
73 |
74 |
75 | class Brightness(ScorerBase):
76 | def __init__(self, threshold=235):
77 | self.threshold = threshold
78 |
79 | def apply(self, patch, mask=None) -> ScoreResult:
80 | if mask is not None:
81 | bright = patch[mask].mean()
82 | else:
83 | bright = patch.mean()
84 | return ScoreResult(scores={"brightness": bright}, qc=bright < self.threshold)
85 |
86 |
87 | class Contrast(ScorerBase):
88 | """
89 | Calculate the contrast of a patch.
90 |
91 | Contrast is calculated as the standard deviation of the pixel values.
92 |
93 | Parameters
94 | ----------
95 | threshold : float
96 | Threshold to determine if a patch is contrasted or not.
97 | """
98 |
99 | def __init__(
100 | self,
101 | fraction_threshold=0.05,
102 | lower_percentile=1,
103 | upper_percentile=99,
104 | ):
105 | self.fraction_threshold = fraction_threshold
106 | self.lower_percentile = lower_percentile
107 | self.upper_percentile = upper_percentile
108 |
109 | def apply(self, patch, mask=None):
110 | patch = np.asarray(patch)
111 | if patch.dtype == bool:
112 | ratio = int((patch.max() == 1) and (patch.min() == 0))
113 | elif patch.ndim == 3:
114 | if patch.shape[2] == 4:
115 | patch = cv2.cvtColor(patch, cv2.COLOR_RGBA2RGB)
116 | if patch.shape[2] == 3:
117 | patch = cv2.cvtColor(patch, cv2.COLOR_RGB2GRAY)
118 |
119 | dlimits = dtype_limits(patch, clip_negative=False)
120 | limits = np.percentile(
121 | patch, [self.lower_percentile, self.upper_percentile]
122 | )
123 | ratio = (limits[1] - limits[0]) / (dlimits[1] - dlimits[0])
124 | else:
125 | raise NotImplementedError("Only support 3D image or 2D image")
126 |
127 | return ScoreResult(
128 | scores={"contrast": ratio}, qc=ratio > self.fraction_threshold
129 | )
130 |
131 |
132 | class Sharpness(ScorerBase):
133 | """
134 | Calculate the sharpness of a patch.
135 |
136 | Sharpness is calculated as the variance of the Laplacian of the pixel values.
137 |
138 | Parameters
139 | ----------
140 | threshold : float
141 | Threshold to determine if a patch is sharp or not.
142 | """
143 |
144 | def __init__(self, threshold: float = 0.5):
145 | self.threshold = threshold
146 |
147 | def apply(self, patch, mask=None):
148 | score = cv2.Laplacian(patch, cv2.CV_64F).var()
149 | return ScoreResult(scores={"sharpness": score}, qc=score > self.threshold)
150 |
151 |
152 | class Sobel(ScorerBase):
153 | """
154 | Calculate the sobel of a patch.
155 |
156 | Sobel is calculated as the variance of the Sobel of the pixel values.
157 |
158 | Parameters
159 | ----------
160 | threshold : float
161 | Threshold to determine if a patch is sharp or not.
162 | """
163 |
164 | name = "sobel"
165 |
166 | def __init__(self, threshold: float = 0.5):
167 | self.threshold = threshold
168 |
169 | def apply(self, patch, mask=None):
170 | score = cv2.Sobel(patch, 3, 3, 3).var()
171 | return ScoreResult(scores={"sobel": score}, qc=score > self.threshold)
172 |
173 |
174 | class Canny(ScorerBase):
175 | """
176 | Calculate the canny of a patch.
177 |
178 | Canny is calculated as the variance of the Canny of the pixel values.
179 |
180 | Parameters
181 | ----------
182 | threshold : float
183 | Threshold to determine if a patch is sharp or not.
184 | """
185 |
186 | name = "canny"
187 |
188 | def __init__(self, threshold: float = 0.5):
189 | self.threshold = threshold
190 |
191 | def apply(self, patch, mask=None):
192 | score = cv2.Canny(patch, cv2.CV_64F).var()
193 | return ScoreResult(scores={"canny": score}, qc=score > self.threshold)
194 |
--------------------------------------------------------------------------------
/src/lazyslide/cv/scorer/utils.py:
--------------------------------------------------------------------------------
1 | # This is copied from https://github.com/scikit-image/scikit-image/blob/v0.24.0/skimage/util/dtype.py
2 | import warnings
3 |
4 | import numpy as np
5 |
6 | _integer_types = (
7 | np.int8,
8 | np.byte,
9 | np.int16,
10 | np.short,
11 | np.int32,
12 | np.int64,
13 | np.longlong,
14 | np.int_,
15 | np.intp,
16 | np.intc,
17 | int,
18 | np.uint8,
19 | np.ubyte,
20 | np.uint16,
21 | np.ushort,
22 | np.uint32,
23 | np.uint64,
24 | np.ulonglong,
25 | np.uint,
26 | np.uintp,
27 | np.uintc,
28 | )
29 | _integer_ranges = {t: (np.iinfo(t).min, np.iinfo(t).max) for t in _integer_types}
30 | dtype_range = {
31 | bool: (False, True),
32 | np.bool_: (False, True),
33 | float: (-1, 1),
34 | np.float16: (-1, 1),
35 | np.float32: (-1, 1),
36 | np.float64: (-1, 1),
37 | }
38 |
39 | with warnings.catch_warnings():
40 | warnings.filterwarnings("ignore", category=DeprecationWarning)
41 |
42 | # np.bool8 is a deprecated alias of np.bool_
43 | if hasattr(np, "bool8"):
44 | dtype_range[np.bool8] = (False, True)
45 |
46 | dtype_range.update(_integer_ranges)
47 |
48 | _supported_types = list(dtype_range.keys())
49 |
50 |
51 | def dtype_limits(image, clip_negative=False):
52 | """Return intensity limits, i.e. (min, max) tuple, of the image's dtype.
53 |
54 | Parameters
55 | ----------
56 | image : ndarray
57 | Input image.
58 | clip_negative : bool, optional
59 | If True, clip the negative range (i.e. return 0 for min intensity)
60 | even if the image dtype allows negative values.
61 |
62 | Returns
63 | -------
64 | imin, imax : tuple
65 | Lower and upper intensity limits.
66 | """
67 | imin, imax = dtype_range[image.dtype.type]
68 | if clip_negative:
69 | imin = 0
70 | return imin, imax
71 |
--------------------------------------------------------------------------------
/src/lazyslide/cv/tiles_merger.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | from itertools import combinations
4 |
5 | import geopandas as gpd
6 | import numpy as np
7 | from shapely.ops import unary_union
8 | from shapely.strtree import STRtree
9 |
10 |
11 | class PolygonMerger:
12 | """
13 | Merge polygons from different tiles.
14 |
15 | If the polygons are overlapping/touching, the overlapping regions are merged.
16 |
17 | If probabilities exist, the probabilities are averaged weighted by the area of the polygons.
18 |
19 | Parameters
20 | ----------
21 | gdf : `GeoDataFrame `
22 | The GeoDataFrame containing the polygons.
23 | class_col : str, default: None
24 | The column that specify the names of the polygons.
25 | prob_col : str, default: None
26 | The column that specify the probabilities of the polygons.
27 | buffer_px : float, default: 0
28 | The buffer size for the polygons to test the intersection.
29 | drop_overlap : float, default: 0.9
30 | The ratio to drop the overlapping polygons.
31 |
32 | """
33 |
34 | def __init__(
35 | self,
36 | gdf: gpd.GeoDataFrame,
37 | class_col: str = None,
38 | prob_col: str = None,
39 | buffer_px: float = 0,
40 | drop_overlap: float = 0.9,
41 | ):
42 | self.gdf = gdf
43 | self.class_col = class_col
44 | self.prob_col = prob_col
45 | self.buffer_px = buffer_px
46 | self.drop_overlap = drop_overlap
47 |
48 | self._has_class = class_col in gdf.columns if class_col else False
49 | self._has_prob = prob_col in gdf.columns if prob_col else False
50 | self._preprocessed_polygons = self._preprocess_polys()
51 | self._merged_polygons = None
52 |
53 | def _preprocess_polys(self):
54 | """Preprocess the polygons."""
55 | new_gdf = self.gdf.copy()
56 | if self.buffer_px > 0:
57 | new_gdf["geometry"] = self.gdf["geometry"].buffer(self.buffer_px)
58 | # Filter out invalid and empty geometries efficiently
59 | return new_gdf[new_gdf["geometry"].is_valid & ~new_gdf["geometry"].is_empty]
60 |
61 | def _merge_overlap(self, gdf: gpd.GeoDataFrame):
62 | """
63 | Merge the overlapping polygons recursively.
64 |
65 | This function has no assumptions about the class or probability
66 | """
67 | pass
68 |
69 | def _tree_merge(self, gdf: gpd.GeoDataFrame):
70 | polygons = gdf["geometry"].tolist()
71 | tree = STRtree(polygons)
72 | visited = set()
73 | merged = []
74 |
75 | for geom in polygons:
76 | if geom in visited:
77 | continue
78 |
79 | groups_ix = tree.query(geom, predicate="intersects")
80 | groups_ix = set([g for g in groups_ix if g not in visited])
81 | if len(groups_ix) == 0:
82 | continue
83 | else:
84 | # continue finding other polygons that intersect with the group
85 | # until the group size is stable
86 | current_group_size = len(groups_ix)
87 | while True:
88 | new_groups_ix = set()
89 | for ix in groups_ix:
90 | c_groups_ix = tree.query(polygons[ix], predicate="intersects")
91 | c_groups_ix = [g for g in c_groups_ix if g not in visited]
92 | new_groups_ix.update(c_groups_ix)
93 | groups_ix.update(new_groups_ix)
94 | if len(groups_ix) == current_group_size:
95 | break
96 | current_group_size = len(groups_ix)
97 |
98 | # Sort the group index
99 | groups_ix = np.sort(list(groups_ix))
100 |
101 | # Merge the group
102 | merged_geoms = [] # (polygon, row_ix, groups_ix)
103 |
104 | if len(groups_ix) == 1:
105 | ix = groups_ix[0]
106 | m_geoms = polygons[ix]
107 | merged_geoms.append((m_geoms, ix, groups_ix))
108 | else:
109 | m_geoms = [polygons[g] for g in groups_ix]
110 | if self._has_class:
111 | ref_df = gpd.GeoDataFrame(
112 | {
113 | "names": [gdf[self.class_col].values[g] for g in groups_ix],
114 | "index": groups_ix,
115 | "geometry": m_geoms,
116 | }
117 | )
118 |
119 | # {class_name: polygon}
120 | named_polys = (
121 | ref_df[["names", "geometry"]]
122 | .groupby("names")
123 | .apply(unary_union)
124 | .to_dict()
125 | )
126 |
127 | if self.drop_overlap > 0:
128 | # If the two classes instances are more than 90% overlapping
129 | # The smaller one is removed
130 | while len(named_polys) > 1:
131 | names = list(named_polys.keys())
132 | combs = combinations(names, 2)
133 | for n1, n2 in combs:
134 | if n1 in named_polys and n2 in named_polys:
135 | p1, p2 = named_polys[n1], named_polys[n2]
136 | if p1.intersection(p2).is_empty:
137 | continue
138 | area, drop = (
139 | (p1.area, n1)
140 | if p1.area < p2.area
141 | else (p2.area, n2)
142 | )
143 | union = p1.union(p2).area
144 | overlap_ratio = union / area
145 | if overlap_ratio > self.drop_overlap:
146 | del named_polys[drop]
147 | break
148 | for n, p in named_polys.items():
149 | gs = ref_df[ref_df["names"] == n]["index"].tolist()
150 | merged_geoms.append((p, gs[0], gs))
151 | else:
152 | m_geoms = unary_union(m_geoms)
153 | merged_geoms.append((m_geoms, groups_ix[0], groups_ix))
154 | # Postprocess the merged polygon
155 | for m_geom, ix, gs_ix in merged_geoms:
156 | if self.buffer_px > 0:
157 | m_geom = m_geom.buffer(-self.buffer_px).buffer(0)
158 | if m_geom.is_valid & (m_geom.is_empty is False):
159 | m_data = gdf.iloc[ix].copy()
160 | m_data["geometry"] = m_geom
161 | if self._has_prob:
162 | gs_gdf = gdf.iloc[gs_ix]
163 | m_data[self.prob_col] = np.average(
164 | gs_gdf[self.prob_col], weights=gs_gdf["geometry"].area
165 | )
166 | merged.append(m_data)
167 | for g in groups_ix:
168 | visited.add(g)
169 | return gpd.GeoDataFrame(merged)
170 |
171 | def merge(self):
172 | """Launch the merging process."""
173 | self._merged_polygons = self._tree_merge(self._preprocessed_polygons)
174 |
175 | @property
176 | def merged_polygons(self):
177 | return self._merged_polygons
178 |
179 |
180 | def merge_polygons(
181 | gdf: gpd.GeoDataFrame,
182 | class_col: str = None,
183 | prob_col: str = None,
184 | buffer_px: float = 0,
185 | drop_overlap: float = 0.9,
186 | ):
187 | merger = PolygonMerger(gdf, class_col, prob_col, buffer_px, drop_overlap)
188 | merger.merge()
189 | return merger.merged_polygons
190 |
191 |
192 | merge_polygons.__doc__ = PolygonMerger.__doc__
193 |
--------------------------------------------------------------------------------
/src/lazyslide/cv/transform/__init__.py:
--------------------------------------------------------------------------------
1 | """This module is highly inspired by both torchvison and pathml"""
2 |
3 | from .compose import TissueDetectionHE
4 |
5 | from .mods import (
6 | MedianBlur,
7 | GaussianBlur,
8 | BoxBlur,
9 | MorphOpen,
10 | MorphClose,
11 | BinaryThreshold,
12 | ArtifactFilterThreshold,
13 | Compose,
14 | )
15 |
--------------------------------------------------------------------------------
/src/lazyslide/cv/transform/compose.py:
--------------------------------------------------------------------------------
1 | import cv2
2 | from shapely import Polygon
3 |
4 | from .mods import (
5 | Transform,
6 | MedianBlur,
7 | MorphClose,
8 | ArtifactFilterThreshold,
9 | BinaryThreshold,
10 | ForegroundDetection,
11 | )
12 |
13 |
14 | class TissueDetectionHE(Transform):
15 | """
16 | Detect tissue regions from H&E stained slide.
17 | First applies a median blur, then binary thresholding, then morphological opening and closing, and finally
18 | foreground detection.
19 |
20 | Parameters
21 | ----------
22 | use_saturation : bool
23 | Whether to convert to HSV and use saturation channel for tissue detection.
24 | If False, convert from RGB to greyscale and use greyscale image_ref for tissue detection. Defaults to True.
25 | blur_ksize : int
26 | kernel size used to apply median blurring. Defaults to 15.
27 | threshold : int
28 | threshold for binary thresholding. If None, uses Otsu's method. Defaults to None.
29 | morph_n_iter : int
30 | number of iterations of morphological opening and closing to apply. Defaults to 3.
31 | morph_k_size : int
32 | kernel size for morphological opening and closing. Defaults to 7.
33 | min_region_size : int
34 | """
35 |
36 | def __init__(
37 | self,
38 | use_saturation=False,
39 | blur_ksize=17,
40 | threshold=7,
41 | morph_n_iter=3,
42 | morph_k_size=7,
43 | min_tissue_area=0.01,
44 | min_hole_area=0.0001,
45 | detect_holes=True,
46 | filter_artifacts=True,
47 | ):
48 | self.set_params(
49 | use_saturation=use_saturation,
50 | blur_ksize=blur_ksize,
51 | threshold=threshold,
52 | morph_n_iter=morph_n_iter,
53 | morph_k_size=morph_k_size,
54 | min_tissue_area=min_tissue_area,
55 | min_hole_area=min_hole_area,
56 | detect_holes=detect_holes,
57 | filter_artifacts=filter_artifacts,
58 | )
59 |
60 | if filter_artifacts:
61 | thresholder = ArtifactFilterThreshold(threshold=threshold)
62 | else:
63 | if threshold is None:
64 | thresholder = BinaryThreshold(use_otsu=True)
65 | else:
66 | thresholder = BinaryThreshold(use_otsu=False, threshold=threshold)
67 |
68 | foreground = ForegroundDetection(
69 | min_foreground_area=min_tissue_area,
70 | min_hole_area=min_hole_area,
71 | detect_holes=detect_holes,
72 | )
73 |
74 | self.pipeline = [
75 | MedianBlur(kernel_size=blur_ksize),
76 | thresholder,
77 | # MorphOpen(kernel_size=morph_k_size, n_iterations=morph_n_iter),
78 | MorphClose(kernel_size=morph_k_size, n_iterations=morph_n_iter),
79 | foreground,
80 | ]
81 |
82 | def apply(self, image):
83 | filter_artifacts = self.params["filter_artifacts"]
84 | use_saturation = self.params["use_saturation"]
85 |
86 | if not filter_artifacts:
87 | if use_saturation:
88 | image = cv2.cvtColor(image, cv2.COLOR_RGB2HSV)[:, :, 1]
89 | else:
90 | image = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
91 |
92 | for p in self.pipeline:
93 | image = p.apply(image)
94 | return image
95 |
96 |
97 | class Mask2Polygon(Transform):
98 | """
99 | Convert binary mask to polygon.
100 |
101 | Parameters
102 | ----------
103 | min_area : int
104 | Minimum area of detected regions to be included in the polygon.
105 | """
106 |
107 | def __init__(
108 | self,
109 | min_area=0,
110 | morph_k_size=7,
111 | morph_n_iter=3,
112 | min_tissue_area=0.01,
113 | min_hole_area=0.0001,
114 | detect_holes=True,
115 | ):
116 | self.set_params(min_area=min_area)
117 |
118 | self.pipeline = [
119 | # MorphOpen(kernel_size=morph_k_size, n_iterations=morph_n_iter),
120 | MorphClose(kernel_size=morph_k_size, n_iterations=morph_n_iter),
121 | ForegroundDetection(
122 | min_foreground_area=min_tissue_area,
123 | min_hole_area=min_hole_area,
124 | detect_holes=detect_holes,
125 | ),
126 | ]
127 |
128 | def apply(self, mask):
129 | min_area = self.params["min_area"]
130 |
131 | for p in self.pipeline:
132 | try:
133 | mask = p.apply(mask)
134 | except Exception as e:
135 | print(self.__class__.__name__, e)
136 |
137 | tissue_instances = mask
138 | polygons = []
139 | if len(tissue_instances) == 0:
140 | return []
141 | for tissue in tissue_instances:
142 | shell = tissue.contour
143 | if len(tissue.holes) == 0:
144 | tissue_poly = Polygon(shell)
145 | else:
146 | holes = [hole for hole in tissue.holes]
147 | tissue_poly = Polygon(shell, holes=holes)
148 | if tissue_poly.area < min_area:
149 | continue
150 | polygons.append(tissue_poly)
151 | return polygons
152 |
--------------------------------------------------------------------------------
/src/lazyslide/datasets/__init__.py:
--------------------------------------------------------------------------------
1 | from ._sample import (
2 | sample,
3 | gtex_artery,
4 | lung_carcinoma,
5 | )
6 |
--------------------------------------------------------------------------------
/src/lazyslide/datasets/_sample.py:
--------------------------------------------------------------------------------
1 | import pooch
2 | from wsidata import open_wsi
3 |
4 | ENTRY = pooch.create(
5 | path=pooch.os_cache("lazyslide"),
6 | base_url="https://lazyslide.blob.core.windows.net/lazyslide-data",
7 | registry={
8 | "sample.svs": "sha256:ed92d5a9f2e86df67640d6f92ce3e231419ce127131697fbbce42ad5e002c8a7",
9 | "sample.zarr.zip": "sha256:075a3ab61e6958673d79612cc29796a92cf875ad049fc1fe5780587968635378",
10 | "GTEX-1117F-0526.svs": "sha256:222ab7f2bb42dcd0bcfaccd910cb13be452b453499e6117ab553aa6cd60a135e",
11 | "GTEX-1117F-0526.zarr.zip": "sha256:2323b656322d2dcc7e9d18aaf586b39a88bf8f2a3959f642f109eb54268f3732",
12 | "lung_carcinoma.ndpi": "sha256:3297b0a564f22940208c61caaca56d97ba81c9b6b7816ebc4042a087e557f85e",
13 | "lung_carcinoma.zarr.zip": "sha256:0a8ccfc608f55624b473c6711b55739c3279d3b6fc5b654395dfc23b010bf866",
14 | },
15 | )
16 |
17 | logger = pooch.get_logger()
18 | logger.setLevel("WARNING")
19 |
20 |
21 | def _load_dataset(slide_file, zarr_file, with_data=True, pbar=False):
22 | slide = ENTRY.fetch(slide_file)
23 | _ = ENTRY.fetch(
24 | zarr_file,
25 | progressbar=pbar,
26 | processor=pooch.Unzip(extract_dir=zarr_file.rstrip(".zip")),
27 | )
28 | store = "auto" if with_data else None
29 | return open_wsi(slide, store=store)
30 |
31 |
32 | def sample(with_data: bool = True, pbar: bool = False):
33 | """
34 | Load a small sample slide (~1.9 MB).
35 |
36 | Source: https://openslide.cs.cmu.edu/download/openslide-testdata/Aperio/CMU-1-Small-Region.svs
37 |
38 | Parameters
39 | ----------
40 | with_data : bool, default: True
41 | Whether to load the associated zarr storage data.
42 | pbar : bool, default: False
43 | Whether to show the progress bar.
44 |
45 | """
46 | return _load_dataset(
47 | "sample.svs", "sample.zarr.zip", with_data=with_data, pbar=pbar
48 | )
49 |
50 |
51 | def gtex_artery(with_data: bool = True, pbar: bool = False):
52 | """
53 | A GTEX artery slide.
54 |
55 | Source: https://gtexportal.org/home/histologyPage, GTEX-1117F-0526
56 |
57 | Parameters
58 | ----------
59 | with_data : bool, default: True
60 | Whether to load the associated zarr storage data.
61 | pbar : bool, default: False
62 | Whether to show the progress bar.
63 |
64 | """
65 | return _load_dataset(
66 | "GTEX-1117F-0526.svs",
67 | "GTEX-1117F-0526.zarr.zip",
68 | with_data=with_data,
69 | pbar=pbar,
70 | )
71 |
72 |
73 | def lung_carcinoma(with_data: bool = True, pbar: bool = False):
74 | """
75 | A lung carcinoma slide.
76 |
77 | Source: https://idr.openmicroscopy.org/webclient/img_detail/9846318/?dataset=10801
78 |
79 | Parameters
80 | ----------
81 | with_data : bool, default: True
82 | Whether to load the associated zarr storage data.
83 | pbar : bool, default: False
84 | Whether to show the progress bar.
85 |
86 | """
87 |
88 | return _load_dataset(
89 | "lung_carcinoma.ndpi", "lung_carcinoma.zarr.zip", with_data=with_data, pbar=pbar
90 | )
91 |
--------------------------------------------------------------------------------
/src/lazyslide/io/__init__.py:
--------------------------------------------------------------------------------
1 | from ._annotaiton import load_annotations, export_annotations
2 |
--------------------------------------------------------------------------------
/src/lazyslide/io/_annotaiton.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | import json
4 | from itertools import cycle
5 | from pathlib import Path
6 | from typing import List, Literal, Mapping, Iterable
7 |
8 | import pandas as pd
9 | from geopandas import GeoDataFrame
10 | from wsidata import WSIData
11 | from wsidata.io import update_shapes_data, add_shapes
12 |
13 | from lazyslide._const import Key
14 |
15 |
16 | def _in_bounds_transform(wsi: WSIData, annos: GeoDataFrame, reverse: bool = False):
17 | from functools import partial
18 | from shapely.affinity import translate
19 |
20 | xoff, yoff, _, _ = wsi.properties.bounds
21 | if reverse:
22 | xoff, yoff = -xoff, -yoff
23 | trans = partial(translate, xoff=xoff, yoff=yoff)
24 | annos["geometry"] = annos["geometry"].apply(lambda x: trans(x))
25 | return annos
26 |
27 |
28 | def load_annotations(
29 | wsi: WSIData,
30 | annotations: str | Path | GeoDataFrame = None,
31 | *,
32 | explode: bool = True,
33 | in_bounds: bool = False,
34 | join_with: str | List[str] = Key.tissue,
35 | join_to: str = None,
36 | json_flatten: str | List[str] = "classification",
37 | min_area: float = 1e2,
38 | key_added: str = "annotations",
39 | ):
40 | """Load the annotation file and add it to the WSIData
41 |
42 | Parameters
43 | ----------
44 | wsi : :class:`WSIData `
45 | The WSIData object to work on.
46 | annotations : str, Path, GeoDataFrame
47 | The path to the annotation file or the GeoDataFrame.
48 | explode : bool, default: True
49 | Whether to explode the annotations.
50 | in_bounds : bool, default: False
51 | Whether to move the annotations to the slide bounds.
52 | join_with : str, List[str], default: 'tissues'
53 | The key to join the annotations with.
54 | join_to : str, default: None
55 | The key to join the annotations to.
56 | json_flatten : str, default: "classification"
57 | The column(s) to flatten the json data, if not exist, it will be ignored.
58 | "classification" is the default column for the QuPath annotations.
59 | min_area : float, default: 1e2
60 | The minimum area of the annotation.
61 | key_added : str, default: 'annotations'
62 | The key to store the annotations.
63 |
64 | """
65 | import geopandas as gpd
66 |
67 | if isinstance(annotations, (str, Path)):
68 | geo_path = Path(annotations)
69 | anno_df = gpd.read_file(geo_path)
70 | elif isinstance(annotations, GeoDataFrame):
71 | anno_df = annotations
72 | else:
73 | raise ValueError(f"Invalid annotations: {annotations}")
74 |
75 | # remove crs
76 | anno_df.crs = None
77 |
78 | if explode:
79 | anno_df = (
80 | anno_df.explode()
81 | .assign(**{"__area__": lambda x: x.geometry.area})
82 | .query(f"__area__ > {min_area}")
83 | .drop(columns=["__area__"], errors="ignore")
84 | .reset_index(drop=True)
85 | )
86 |
87 | if json_flatten is not None:
88 |
89 | def flatten_json(x):
90 | if isinstance(x, dict):
91 | return x
92 | elif isinstance(x, str):
93 | try:
94 | return json.loads(x)
95 | except json.JSONDecodeError:
96 | return {}
97 |
98 | if isinstance(json_flatten, str):
99 | json_flatten = [json_flatten]
100 | for col in json_flatten:
101 | if col in anno_df.columns:
102 | anno_df[col] = anno_df[col].apply(flatten_json)
103 | anno_df = anno_df.join(
104 | anno_df[col].apply(pd.Series).add_prefix(f"{col}_")
105 | )
106 | anno_df.drop(columns=[col], inplace=True)
107 |
108 | if in_bounds:
109 | anno_df = _in_bounds_transform(wsi, anno_df)
110 |
111 | # get tiles
112 | if isinstance(join_with, str):
113 | join_with = [join_with]
114 |
115 | join_anno_df = anno_df.copy()
116 | for key in join_with:
117 | if key in wsi:
118 | shapes_df = wsi[key]
119 | # join the annotations with the tiles
120 | join_anno_df = (
121 | gpd.sjoin(shapes_df, join_anno_df, how="right", predicate="intersects")
122 | .reset_index(drop=True)
123 | .drop(columns=["index_left"])
124 | )
125 | add_shapes(wsi, key_added, join_anno_df)
126 |
127 | # TODO: still Buggy
128 | if join_to is not None:
129 | if join_to in wsi:
130 | shapes_df = wsi[join_to]
131 | # join the annotations with the tiles
132 | shapes_df = (
133 | gpd.sjoin(
134 | shapes_df[["geometry"]], anno_df, how="left", predicate="intersects"
135 | )
136 | .reset_index(drop=True)
137 | .drop(columns=["index_right"], errors="ignore")
138 | )
139 | update_shapes_data(wsi, join_to, shapes_df)
140 |
141 |
142 | def export_annotations(
143 | wsi: WSIData,
144 | key: str,
145 | *,
146 | in_bounds: bool = False,
147 | classes: str = None,
148 | colors: str | Mapping = None,
149 | format: Literal["qupath"] = "qupath",
150 | file: str | Path = None,
151 | ):
152 | """
153 | Export the annotations
154 |
155 | Parameters
156 | ----------
157 | wsi : :class:`WSIData `
158 | The WSIData object to work on.
159 | key : str
160 | The key to export.
161 | in_bounds : bool, default: False
162 | Whether to move the annotations to the slide bounds.
163 | classes : str, default: None
164 | The column to use for the classification.
165 | If None, the classification will be ignored.
166 | colors : str, Mapping, default: None
167 | The column to use for the color.
168 | If None, the color will be ignored.
169 | format : str, default: 'qupath'
170 | The format to export.
171 | Currently only 'qupath' is supported.
172 | file : str, Path, default: None
173 | The file to save the annotations.
174 | If None, the annotations will not be saved.
175 |
176 |
177 | """
178 | gdf = wsi.shapes[key].copy()
179 | if in_bounds:
180 | gdf = _in_bounds_transform(wsi, gdf, reverse=True)
181 |
182 | if format == "qupath":
183 | # Prepare classification column
184 | import json
185 |
186 | if classes is not None:
187 | class_values = gdf[classes]
188 |
189 | if colors is None:
190 | # Assign default colors
191 | colors = cycle(
192 | [
193 | "#1B9E77", # Teal Green
194 | "#D95F02", # Burnt Orange
195 | "#7570B3", # Deep Lavender
196 | "#E7298A", # Magenta
197 | "#66A61E", # Olive Green
198 | "#E6AB02", # Goldenrod
199 | "#A6761D", # Earthy Brown
200 | "#666666", # Charcoal Gray
201 | "#1F78B4", # Cool Blue
202 | ]
203 | )
204 |
205 | if colors is not None:
206 | color_values = cycle([])
207 | if isinstance(colors, str):
208 | color_values = gdf[colors]
209 | elif isinstance(colors, Iterable):
210 | # if sequence of colors, map to class values
211 | colors = dict(zip(pd.unique(class_values), colors))
212 | else:
213 | raise ValueError(f"Invalid colors: {colors}")
214 |
215 | if isinstance(colors, Mapping):
216 | color_values = map(lambda x: colors.get(x, None), gdf[classes])
217 |
218 | # covert color to rgb array
219 | from matplotlib.colors import to_rgb
220 |
221 | color_values = map(
222 | lambda x: tuple(int(255 * c) for c in to_rgb(x))
223 | if x is not None
224 | else None,
225 | color_values,
226 | )
227 |
228 | classifications = []
229 | for class_value, color_value in zip(class_values, color_values):
230 | json_string = json.dumps({"name": class_value, "color": color_value})
231 | classifications.append(json_string)
232 | gdf["classification"] = classifications
233 |
234 | if file is not None:
235 | gdf.to_file(file)
236 |
237 | return gdf
238 |
--------------------------------------------------------------------------------
/src/lazyslide/metrics.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from anndata import AnnData
3 |
4 |
5 | def topk_score(
6 | matrix: np.ndarray | AnnData,
7 | k: int = 5,
8 | agg_method: str = "max",
9 | ) -> np.ndarray:
10 | """
11 | Get the top k score from a feature x class matrix.
12 |
13 | Parameters
14 | ----------
15 | matrix : np.ndarray | AnnData
16 | The input matrix. Feature x class.
17 | k : int, default: 5
18 | The number of top scores to return.
19 | agg_method : str, default: "max"
20 | The method to use for aggregation.
21 | Can be "max", "mean", "median" or "sum".
22 |
23 | Returns
24 | -------
25 | np.ndarray
26 | The top k scores.
27 |
28 | """
29 | if isinstance(matrix, AnnData):
30 | matrix = matrix.X
31 |
32 | top_k_score = np.sort(matrix, axis=0)[-k:]
33 | score = getattr(np, agg_method)(top_k_score, axis=0)
34 | return score
35 |
--------------------------------------------------------------------------------
/src/lazyslide/models/__init__.py:
--------------------------------------------------------------------------------
1 | from typing import Dict, Type
2 |
3 | from . import multimodal
4 | from . import segmentation
5 | from . import vision
6 | from .base import (
7 | ModelBase,
8 | ImageModel,
9 | ImageTextModel,
10 | SegmentationModel,
11 | SlideEncoderModel,
12 | TimmModel,
13 | )
14 |
15 | from ._model_registry import MODEL_REGISTRY, list_models
16 |
--------------------------------------------------------------------------------
/src/lazyslide/models/_model_registry.py:
--------------------------------------------------------------------------------
1 | from dataclasses import dataclass
2 | from enum import Enum
3 | from pathlib import Path
4 | from typing import Type, List
5 |
6 | import pandas as pd
7 |
8 | from . import ModelBase
9 | from . import multimodal
10 | from . import segmentation
11 | from . import vision
12 |
13 |
14 | class ModelTask(Enum):
15 | vision = "vision"
16 | segmentation = "segmentation"
17 | multimodal = "multimodal"
18 |
19 |
20 | @dataclass
21 | class ModelCard:
22 | name: str
23 | model_type: ModelTask
24 | module: Type[ModelBase]
25 | github_url: str = None
26 | hf_url: str = None
27 | paper_url: str = None
28 | description: str = None
29 | keys: List[str] = None
30 |
31 | def __post_init__(self):
32 | try:
33 | inject_doc = str(self)
34 | origin_doc = self.module.__doc__
35 | if origin_doc is None:
36 | origin_doc = ""
37 | else:
38 | origin_doc = f"\n\n{origin_doc}"
39 | self.module.__doc__ = f"{inject_doc}{origin_doc}"
40 | except AttributeError:
41 | # If the module does not have a __doc__ attribute, skip the injection
42 | pass
43 |
44 | if self.keys is None:
45 | self.keys = [self.name.lower()]
46 |
47 | def __str__(self):
48 | skeleton = ""
49 | if self.github_url is not None:
50 | skeleton += f":octicon:`mark-github;1em;` `GitHub <{self.github_url}>`__ \\"
51 | if self.hf_url is not None:
52 | skeleton += f"🤗 `Hugging Face <{self.hf_url}>`__ \\"
53 | if self.paper_url is not None:
54 | skeleton += f" :octicon:`book;1em;` `Paper <{self.paper_url}>`__"
55 | if self.description is not None:
56 | skeleton += f"\n| {self.description}"
57 |
58 | return skeleton
59 |
60 |
61 | MODEL_REGISTRY = {}
62 |
63 | MODEL_DB = pd.read_csv(f"{Path(__file__).parent}/model_registry.csv")
64 | _modules = {
65 | ModelTask.vision: vision,
66 | ModelTask.segmentation: segmentation,
67 | ModelTask.multimodal: multimodal,
68 | }
69 |
70 | for _, row in MODEL_DB.iterrows():
71 | model_type = ModelTask(row["model_type"])
72 | card = ModelCard(
73 | name=row["name"],
74 | model_type=model_type,
75 | module=getattr(_modules[model_type], row["module"]),
76 | github_url=None if pd.isna(row["github_url"]) else row["github_url"],
77 | hf_url=None if pd.isna(row["hf_url"]) else row["hf_url"],
78 | paper_url=None if pd.isna(row["paper_url"]) else row["paper_url"],
79 | description=None if pd.isna(row["description"]) else row["description"],
80 | )
81 | keys = [i.strip() for i in row["keys"].split(",")] if row["keys"] else []
82 | for key in keys:
83 | MODEL_REGISTRY[key] = card
84 |
85 |
86 | def list_models(task: ModelTask = None):
87 | """List all available models.
88 |
89 | If you want to get models for feature extraction,
90 | you can use task='vision' or task='multimodal'.
91 |
92 | Parameters
93 | ----------
94 | task : {'vision', 'segmentation', 'multimodal'}, default: None
95 | The task to filter the models. If None, return all models.
96 |
97 | Returns
98 | -------
99 | list
100 | A list of model names.
101 |
102 | """
103 | if task is None:
104 | return list(MODEL_REGISTRY.keys())
105 | if task is not None:
106 | task = ModelTask(task)
107 | if task in ModelTask:
108 | return [
109 | name
110 | for name, model in MODEL_REGISTRY.items()
111 | if model.model_type == task
112 | ]
113 | else:
114 | raise ValueError(
115 | f"Unknown task: {task}. "
116 | "Available tasks are: vision, segmentation, multimodal."
117 | )
118 |
--------------------------------------------------------------------------------
/src/lazyslide/models/_utils.py:
--------------------------------------------------------------------------------
1 | from contextlib import contextmanager
2 |
3 | import torch
4 |
5 |
6 | def _fake_class(name, deps, inject=""):
7 | def init(self, *args, **kwargs):
8 | raise ImportError(
9 | f"To use {name}, you need to install {', '.join(deps)}."
10 | f"{inject}"
11 | "Please restart the kernel after installation."
12 | )
13 |
14 | # Dynamically create the class
15 | new_class = type(name, (object,), {"__init__": init})
16 |
17 | return new_class
18 |
19 |
20 | @contextmanager
21 | def hf_access(name):
22 | """
23 | Context manager for Hugging Face access.
24 | """
25 | from huggingface_hub.errors import GatedRepoError
26 |
27 | try:
28 | yield
29 | except GatedRepoError as e:
30 | raise GatedRepoError(
31 | f"You don't have access to {name}. Please request access to the model on HuggingFace. "
32 | "After access granted, please login to HuggingFace with huggingface-cli on this machine "
33 | "with a token that has access to this model. "
34 | "You may also pass token as an argument in LazySlide, however, this is not recommended."
35 | ) from e
36 |
37 |
38 | def get_default_transform():
39 | """The default transform for the model."""
40 | from torchvision.transforms import InterpolationMode
41 | from torchvision.transforms.v2 import (
42 | Compose,
43 | Normalize,
44 | CenterCrop,
45 | ToImage,
46 | ToDtype,
47 | Resize,
48 | )
49 |
50 | transforms = [
51 | ToImage(),
52 | Resize(
53 | size=(224, 224),
54 | interpolation=InterpolationMode.BICUBIC,
55 | max_size=None,
56 | antialias=True,
57 | ),
58 | CenterCrop(224),
59 | ToDtype(dtype=torch.float32, scale=True),
60 | Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
61 | ]
62 | return Compose(transforms)
63 |
--------------------------------------------------------------------------------
/src/lazyslide/models/base.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | from pathlib import Path
4 | from typing import Callable
5 |
6 | import numpy as np
7 | import torch
8 |
9 | from lazyslide.models._utils import hf_access, get_default_transform
10 |
11 |
12 | class ModelBase:
13 | model: torch.nn.Module
14 | name: str = "ModelBase"
15 | is_restricted: bool = False
16 |
17 | def get_transform(self):
18 | return None
19 |
20 | def to(self, device):
21 | self.model.to(device)
22 | return self
23 |
24 | @staticmethod
25 | def load_weights(url, progress=True):
26 | from timm.models.hub import download_cached_file
27 |
28 | return Path(download_cached_file(url, progress=progress))
29 |
30 |
31 | class ImageModel(ModelBase):
32 | # TODO: Add a config that specify the recommended input tile size and mpp
33 |
34 | def get_transform(self):
35 | import torch
36 | from torchvision.transforms.v2 import (
37 | Compose,
38 | ToImage,
39 | ToDtype,
40 | Resize,
41 | Normalize,
42 | )
43 |
44 | return Compose(
45 | [
46 | ToImage(),
47 | ToDtype(dtype=torch.float32, scale=True),
48 | Resize(size=(224, 224), antialias=False),
49 | Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
50 | ]
51 | )
52 |
53 | def encode_image(self, image) -> np.ndarray[np.float32]:
54 | raise NotImplementedError
55 |
56 | def __call__(self, image):
57 | return self.encode_image(image)
58 |
59 |
60 | class TimmModel(ImageModel):
61 | def __init__(self, name, token=None, compile=False, compile_kws=None, **kwargs):
62 | import timm
63 | from huggingface_hub import login
64 |
65 | if token is not None:
66 | login(token)
67 |
68 | default_kws = {"pretrained": True, "num_classes": 0}
69 | default_kws.update(kwargs)
70 |
71 | with hf_access(name):
72 | self.model = timm.create_model(name, **default_kws)
73 |
74 | if compile:
75 | if compile_kws is None:
76 | compile_kws = {}
77 | self.compiled_model = torch.compile(self.model, **compile_kws)
78 |
79 | def get_transform(self):
80 | return get_default_transform()
81 |
82 | @torch.inference_mode()
83 | def encode_image(self, image):
84 | with torch.inference_mode():
85 | return self.model(image).cpu().detach().numpy()
86 |
87 |
88 | class SlideEncoderModel(ModelBase):
89 | def encode_slide(self, embeddings, coords=None):
90 | raise NotImplementedError
91 |
92 |
93 | class ImageTextModel(ImageModel):
94 | def encode_image(self, image):
95 | """This should return the image feature before normalize."""
96 | raise NotImplementedError
97 |
98 | def encode_text(self, text):
99 | raise NotImplementedError
100 |
101 | def tokenize(self, text):
102 | raise NotImplementedError
103 |
104 |
105 | class SegmentationModel(ModelBase):
106 | CLASS_MAPPING = None
107 |
108 | def get_transform(self):
109 | import torch
110 | from torchvision.transforms.v2 import Compose, ToImage, ToDtype, Normalize
111 |
112 | return Compose(
113 | [
114 | ToImage(),
115 | ToDtype(dtype=torch.float32, scale=True),
116 | Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
117 | ]
118 | )
119 |
120 | def segment(self, image):
121 | raise NotImplementedError
122 |
123 | def get_postprocess(self) -> Callable | None:
124 | return None
125 |
--------------------------------------------------------------------------------
/src/lazyslide/models/model_registry.csv:
--------------------------------------------------------------------------------
1 | name,keys,model_type,module,github_url,hf_url,paper_url,description
2 | CONCH,conch,multimodal,CONCH,https://github.com/mahmoodlab/CONCH,https://huggingface.co/MahmoodLab/conch,https://doi.org/10.1038/s41591-024-02856-4,Multimodal foundation model
3 | PLIP,plip,multimodal,PLIP,https://github.com/PathologyFoundation/plip,https://huggingface.co/vinid/plip,https://doi.org/10.1038/s41591-023-02504-3,Multimodal foundation model
4 | Prism,prism,multimodal,Prism,https://github.com/mahmoodlab/PRISM,https://huggingface.co/paige-ai/Prism,https://doi.org/10.48550/arXiv.2405.10254,Slide-Level multimodal generative model
5 | Titan,"titan, conch_v1.5",multimodal,Titan,https://github.com/mahmoodlab/TITAN,https://huggingface.co/MahmoodLab/TITAN,https://doi.org/10.48550/arXiv.2411.19666,Multimodal foundation model
6 | Uni,uni,vision,UNI,https://github.com/mahmoodlab/UNI,https://huggingface.co/MahmoodLab/UNI,https://doi.org/10.1038/s41591-024-02857-3,Vision foundation model
7 | Uni2,uni2,vision,UNI2,https://github.com/mahmoodlab/UNI,https://huggingface.co/MahmoodLab/UNI2-h,https://doi.org/10.1038/s41591-024-02857-3,Vision foundation model
8 | GigaPath,gigapath,vision,GigaPath,https://github.com/prov-gigapath/prov-gigapath,https://huggingface.co/prov-gigapath/prov-gigapath,https://doi.org/10.1038/s41586-024-07441-w,Vision foundation model
9 | Virchow,virchow,vision,Virchow,,https://huggingface.co/paige-ai/Virchow,https://doi.org/10.1038/s41591-024-03141-0,Vision foundation model
10 | Virchow2,virchow2,vision,Virchow2,,https://huggingface.co/paige-ai/Virchow2,https://doi.org/10.48550/arXiv.2408.00738,Vision foundation model
11 | Phikon,phikon,vision,Phikon,https://github.com/owkin/HistoSSLscaling/,https://huggingface.co/owkin/phikon,https://doi.org/10.1101/2023.07.21.23292757,Vision foundation model
12 | PhikonV2,phikonv2,vision,PhikonV2,https://github.com/owkin,https://huggingface.co/owkin/phikon-v2,https://doi.org/10.48550/arXiv.2409.09173,Vision foundation model
13 | H-optimus-0,h-optimus-0,vision,HOptimus0,https://github.com/bioptimus,https://huggingface.co/bioptimus/H-optimus-0,,Vision foundation model
14 | H-optimus-1,h-optimus-1,vision,HOptimus1,https://github.com/bioptimus,https://huggingface.co/bioptimus/H-optimus-1,,Vision foundation model
15 | H0-mini,h0-mini,vision,H0Mini,https://github.com/bioptimus,https://huggingface.co/bioptimus/H0-mini,https://doi.org/10.48550/arXiv.2501.16239,Vision foundation model
16 | CONCHVision,conch_vision,vision,CONCHVision,https://github.com/mahmoodlab/CONCH,https://huggingface.co/MahmoodLab/conch,https://doi.org/10.1038/s41591-024-02856-4,Multimodal foundation model
17 | PLIPVision,plip_vision,vision,PLIPVision,https://github.com/PathologyFoundation/plip,https://huggingface.co/vinid/plip,https://doi.org/10.1038/s41591-023-02504-3,Multimodal foundation model
18 | NuLite,nulite,segmentation,NuLite,https://github.com/CosmoIknosLab/NuLite,,https://doi.org/10.48550/arXiv.2408.01797,Cell segmentation and classification
19 | InstanSeg,instanseg,segmentation,Instanseg,https://github.com/instanseg/instanseg,,https://doi.org/10.48550/arXiv.2408.15954,Cell segmentation
20 | GrandQC-Tissue,grandqc-tissue,segmentation,GrandQCTissue,https://github.com/cpath-ukk/grandqc,,https://doi.org/10.1038/s41467-024-54769-y,Tissue segmentation
21 | GrandQC-Artifact,grandqc-artifact,segmentation,GrandQCArtifact,https://github.com/cpath-ukk/grandqc,,https://doi.org/10.1038/s41467-024-54769-y,Artifact segmentation
22 | Midnight,midnight,vision,Midnight,https://github.com/kaiko-ai/midnight,https://huggingface.co/kaiko-ai/midnight,https://doi.org/10.48550/arXiv.2504.05186,Vision foundation model
23 | HibouB,hibou-b,vision,HibouB,https://github.com/HistAI/hibou/tree/main,https://huggingface.co/histai/hibou-b,https://doi.org/10.48550/arXiv.2406.05074,Foundation Vision Transformer
24 | HibouL,hibou-l,vision,HibouL,https://github.com/HistAI/hibou/tree/main,https://huggingface.co/histai/hibou-l,https://doi.org/10.48550/arXiv.2406.05074,Foundation Vision Transformer
--------------------------------------------------------------------------------
/src/lazyslide/models/multimodal/__init__.py:
--------------------------------------------------------------------------------
1 | from .conch import CONCH
2 | from .plip import PLIP
3 | from .titan import Titan
4 | from .prism import Prism
5 |
--------------------------------------------------------------------------------
/src/lazyslide/models/multimodal/conch.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | from .._utils import hf_access
4 | from ..base import ImageTextModel
5 |
6 |
7 | class CONCH(ImageTextModel):
8 | def __init__(self, model_path=None, token=None):
9 | try:
10 | from conch.open_clip_custom import create_model_from_pretrained
11 | from conch.open_clip_custom import get_tokenizer
12 | except ImportError:
13 | raise ImportError(
14 | "Conch is not installed. You can install it using "
15 | "`pip install git+https://github.com/mahmoodlab/CONCH.git`."
16 | )
17 |
18 | if model_path is None:
19 | model_path = "hf_hub:MahmoodLab/conch"
20 |
21 | with hf_access(model_path):
22 | self.model, self.processor = create_model_from_pretrained(
23 | "conch_ViT-B-16", model_path, hf_auth_token=token
24 | )
25 | self.tokenizer = get_tokenizer()
26 |
27 | @torch.inference_mode()
28 | def encode_image(self, image):
29 | if not isinstance(image, torch.Tensor):
30 | image = self.processor(image)
31 | if image.dim() == 3:
32 | image = image.unsqueeze(0)
33 |
34 | image_feature = self.model.encode_image(
35 | image, normalize=True, proj_contrast=True
36 | )
37 | return image_feature
38 |
39 | def tokenize(self, text):
40 | from conch.open_clip_custom import tokenize
41 |
42 | return tokenize(self.tokenizer, text)
43 |
44 | @torch.inference_mode()
45 | def encode_text(self, text):
46 | encode_texts = self.tokenize(text)
47 | text_feature = self.model.encode_text(encode_texts)
48 | return text_feature
49 |
--------------------------------------------------------------------------------
/src/lazyslide/models/multimodal/plip.py:
--------------------------------------------------------------------------------
1 | # Modified from https://github.com/PathologyFoundation/plip/blob/main/plip.py
2 |
3 | import torch
4 |
5 | from .._utils import hf_access
6 | from ..base import ImageTextModel
7 |
8 |
9 | class PLIP(ImageTextModel):
10 | def __init__(self, model_path=None, token=None):
11 | try:
12 | from transformers import CLIPModel, CLIPProcessor
13 | except ImportError:
14 | raise ImportError(
15 | "Please install the 'transformers' package to use the PLIP model"
16 | )
17 |
18 | if model_path is None:
19 | model_path = "vinid/plip"
20 |
21 | with hf_access(model_path):
22 | self.model = CLIPModel.from_pretrained(model_path, use_auth_token=token)
23 | self.processor = CLIPProcessor.from_pretrained(
24 | model_path, use_auth_token=token
25 | )
26 |
27 | def get_transform(self):
28 | return None
29 |
30 | @torch.inference_mode()
31 | def encode_image(self, image):
32 | inputs = self.processor(images=image, return_tensors="pt")
33 | inputs = {k: v.to(self.model.device) for k, v in inputs.items()}
34 | image_features = self.model.get_image_features(**inputs)
35 | image_features = torch.nn.functional.normalize(image_features, p=2, dim=-1)
36 | return image_features
37 |
38 | @torch.inference_mode()
39 | def encode_text(self, text):
40 | inputs = self.processor(
41 | text=text,
42 | return_tensors="pt",
43 | max_length=77,
44 | padding="max_length",
45 | truncation=True,
46 | )
47 | inputs = {k: v.to(self.model.device) for k, v in inputs.items()}
48 | text_features = self.model.get_text_features(**inputs)
49 | return text_features
50 |
--------------------------------------------------------------------------------
/src/lazyslide/models/multimodal/prism.py:
--------------------------------------------------------------------------------
1 | import warnings
2 |
3 | import torch
4 |
5 | from .._utils import hf_access
6 | from ..base import ModelBase
7 |
8 |
9 | class Prism(ModelBase):
10 | def __init__(self, model_path=None, token=None):
11 | from transformers import AutoModel
12 |
13 | # Suppress warnings from transformers
14 | with warnings.catch_warnings(), hf_access(model_path):
15 | warnings.simplefilter("ignore")
16 |
17 | self.model = AutoModel.from_pretrained(
18 | "paige-ai/Prism",
19 | trust_remote_code=True,
20 | token=token,
21 | )
22 |
23 | @torch.inference_mode()
24 | def encode_slide(self, embeddings, coords=None) -> dict:
25 | # Make sure the embeddings has a batch dimension
26 | if len(embeddings.shape) == 2:
27 | embeddings = embeddings.unsqueeze(0)
28 | return self.model.slide_representations(embeddings)
29 |
30 | @torch.inference_mode()
31 | def score(
32 | self,
33 | slide_embedding,
34 | prompts: list[list[str]],
35 | ):
36 | if len(prompts):
37 | pass
38 |
39 | device = self.model.device
40 |
41 | # Flatten all prompts and track indices for class reconstruction
42 | flat_prompts = []
43 | group_lengths = []
44 | for group in prompts:
45 | flat_prompts.extend(group)
46 | group_lengths.append(len(group))
47 |
48 | token_ids = self.model.tokenize(flat_prompts)[:, :-1].to(device)
49 |
50 | dummy_image_latents = torch.empty(
51 | (len(flat_prompts), 1, self.model.text_decoder.context_dim), device=device
52 | )
53 | decoder_out = self.model.text_decoder(token_ids, dummy_image_latents)
54 |
55 | text_proj = self.model.text_to_latents(decoder_out["text_embedding"])
56 | image_proj = self.model.img_to_latents(slide_embedding)
57 |
58 | sim = torch.einsum("i d, j d -> i j", image_proj, text_proj) # (image, prompt)
59 | sim = sim * self.model.temperature.exp()
60 | zero_shot_probs = torch.softmax(
61 | sim.to(torch.float), dim=-1
62 | ) # (Bi, total_prompts)
63 |
64 | # Sum probabilities per group (class)
65 | class_probs = []
66 | start = 0
67 | for length in group_lengths:
68 | end = start + length
69 | class_probs.append(zero_shot_probs[:, start:end].sum(dim=-1, keepdim=True))
70 | start = end
71 |
72 | probs = torch.cat(class_probs, dim=-1)
73 | return probs.detach().cpu().numpy()
74 |
75 | @torch.inference_mode()
76 | def caption(
77 | self,
78 | img_latents,
79 | prompt: list[str],
80 | max_length: int = 100,
81 | ):
82 | genned_ids = self.model.generate(
83 | self.model.tokenize(prompt).to(self.model.device),
84 | key_value_states=img_latents,
85 | do_sample=False,
86 | num_beams=5,
87 | num_beam_groups=1,
88 | max_length=max_length,
89 | )
90 | genned_caption = self.model.untokenize(genned_ids)
91 |
92 | return genned_caption
93 |
--------------------------------------------------------------------------------
/src/lazyslide/models/multimodal/titan.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | from .._utils import hf_access
4 | from ..base import ImageModel
5 |
6 |
7 | class Titan(ImageModel):
8 | name = "titan"
9 |
10 | TEMPLATES = [
11 | "CLASSNAME.",
12 | "an image of CLASSNAME.",
13 | "the image shows CLASSNAME.",
14 | "the image displays CLASSNAME.",
15 | "the image exhibits CLASSNAME.",
16 | "an example of CLASSNAME.",
17 | "CLASSNAME is shown.",
18 | "this is CLASSNAME.",
19 | "I observe CLASSNAME.",
20 | "the pathology image shows CLASSNAME.",
21 | "a pathology image shows CLASSNAME.",
22 | "the pathology slide shows CLASSNAME.",
23 | "shows CLASSNAME.",
24 | "contains CLASSNAME.",
25 | "presence of CLASSNAME.",
26 | "CLASSNAME is present.",
27 | "CLASSNAME is observed.",
28 | "the pathology image reveals CLASSNAME.",
29 | "a microscopic image of showing CLASSNAME.",
30 | "histology shows CLASSNAME.",
31 | "CLASSNAME can be seen.",
32 | "the tissue shows CLASSNAME.",
33 | "CLASSNAME is identified.",
34 | ]
35 |
36 | def __init__(self, model_path=None, token=None):
37 | from transformers import AutoModel
38 |
39 | with hf_access(model_path):
40 | self.model = AutoModel.from_pretrained(
41 | "MahmoodLab/TITAN",
42 | add_pooling_layer=False,
43 | use_auth_token=token,
44 | trust_remote_code=True,
45 | )
46 | self.conch, self.conch_transform = self.model.return_conch()
47 |
48 | def to(self, device):
49 | super().to(device)
50 | self.conch.to(device)
51 |
52 | def get_transform(self):
53 | from torchvision.transforms import InterpolationMode
54 | from torchvision.transforms.v2 import (
55 | Resize,
56 | CenterCrop,
57 | ToImage,
58 | ToDtype,
59 | Normalize,
60 | Compose,
61 | )
62 |
63 | return Compose(
64 | [
65 | ToImage(),
66 | Resize(448, interpolation=InterpolationMode.BICUBIC, antialias=True),
67 | CenterCrop(448),
68 | ToDtype(dtype=torch.float32, scale=True),
69 | Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
70 | ]
71 | )
72 |
73 | @torch.inference_mode()
74 | def encode_image(self, image):
75 | image_feature = self.conch(image)
76 | return image_feature.detach().cpu().numpy()
77 |
78 | @torch.inference_mode()
79 | def encode_slide(self, embeddings, coords=None, base_tile_size=None, **kwargs):
80 | slide_embeddings = self.model.encode_slide_from_patch_features(
81 | embeddings, coords, base_tile_size
82 | )
83 | return slide_embeddings.detach().cpu().numpy()
84 |
85 | @torch.inference_mode()
86 | def score(
87 | self, slide_embeddings, prompts: list[str], template: str = None, **kwargs
88 | ):
89 | if template is None:
90 | template = self.TEMPLATES
91 |
92 | classifier = self.model.zero_shot_classifier(prompts, template)
93 | scores = self.model.zero_shot(slide_embeddings, classifier)
94 | return scores.squeeze(0).detach().cpu().numpy()
95 |
--------------------------------------------------------------------------------
/src/lazyslide/models/segmentation/__init__.py:
--------------------------------------------------------------------------------
1 | from .instanseg import Instanseg
2 | from .nulite import NuLite
3 | from .grandqc import GrandQCTissue, GrandQCArtifact
4 | from .postprocess import (
5 | instanseg_postprocess,
6 | semanticseg_postprocess,
7 | )
8 | from .smp import SMPBase
9 | from .sam import SAM
10 |
--------------------------------------------------------------------------------
/src/lazyslide/models/segmentation/cellpose.py:
--------------------------------------------------------------------------------
1 | from lazyslide.models.base import SegmentationModel
2 |
3 |
4 | class Cellpose(SegmentationModel):
5 | def __init__(self, model_type="nuclei"):
6 | from cellpose import models
7 |
8 | self.cellpose_model = models.Cellpose(model_type=model_type, gpu=False)
9 |
10 | def to(self, device):
11 | self.cellpose_model.device = device
12 |
13 | def get_transform(self):
14 | return None
15 |
16 | def segment(self, image):
17 | masks, flows, styles = self.cellpose_model.eval(
18 | image, diameter=30, channels=[0, 0]
19 | )
20 |
--------------------------------------------------------------------------------
/src/lazyslide/models/segmentation/grandqc.py:
--------------------------------------------------------------------------------
1 | from typing import Literal
2 |
3 | import torch
4 | from lazyslide.models.base import SegmentationModel
5 | from lazyslide.models.segmentation.postprocess import semanticseg_postprocess
6 | from lazyslide.models.segmentation.smp import SMPBase
7 |
8 |
9 | class GrandQCArtifact(SegmentationModel):
10 | CLASS_MAPPING = {
11 | 0: "Background",
12 | 1: "Normal Tissue",
13 | 2: "Fold",
14 | 3: "Darkspot & Foreign Object",
15 | 4: "PenMarking",
16 | 5: "Edge & Air Bubble",
17 | 6: "Out of Focus",
18 | 7: "Background",
19 | }
20 |
21 | def __init__(self, model: Literal["5x", "7x", "10x"] = "7x"):
22 | from huggingface_hub import hf_hub_download
23 |
24 | weights_map = {
25 | "5x": "GrandQC_MPP2_traced.pt",
26 | "7x": "GrandQC_MPP15_traced.pt",
27 | "10x": "GrandQC_MPP1_traced.pt",
28 | }
29 | weights = hf_hub_download(
30 | "RendeiroLab/LazySlide-models", f"grandqc/{weights_map[model]}"
31 | )
32 |
33 | self.model = torch.jit.load(weights)
34 |
35 | def get_transform(self):
36 | import torch
37 | from torchvision.transforms.v2 import (
38 | Compose,
39 | ToImage,
40 | ToDtype,
41 | Normalize,
42 | )
43 |
44 | return Compose(
45 | [
46 | ToImage(),
47 | ToDtype(dtype=torch.float32, scale=True),
48 | Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
49 | ]
50 | )
51 |
52 | @torch.inference_mode()
53 | def segment(self, image):
54 | out = self.model(image)
55 | return out.detach().cpu().numpy()
56 |
57 | def get_postprocess(self):
58 | return semanticseg_postprocess
59 |
60 |
61 | class GrandQCTissue(SMPBase):
62 | CLASS_MAPPING = {
63 | 0: "Background",
64 | 1: "Tissue",
65 | }
66 |
67 | def __init__(self):
68 | from huggingface_hub import hf_hub_download
69 |
70 | weights = hf_hub_download(
71 | "RendeiroLab/LazySlide-models", "grandqc/Tissue_Detection_MPP10.pth"
72 | )
73 |
74 | super().__init__(
75 | arch="unetplusplus",
76 | encoder_name="timm-efficientnet-b0",
77 | encoder_weights="imagenet",
78 | in_channels=3,
79 | classes=2,
80 | activation=None,
81 | )
82 | self.model.load_state_dict(
83 | torch.load(weights, map_location=torch.device("cpu"), weights_only=True)
84 | )
85 | self.model.eval()
86 |
87 | @torch.inference_mode()
88 | def segment(self, image):
89 | return self.model.predict(image)
90 |
--------------------------------------------------------------------------------
/src/lazyslide/models/segmentation/instanseg.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | from typing import Callable
4 |
5 | import numpy as np
6 | import torch
7 |
8 | from lazyslide.models.base import SegmentationModel
9 | from .postprocess import instanseg_postprocess
10 |
11 |
12 | class PercentileNormalize:
13 | def __call__(self, image: torch.Tensor) -> torch.Tensor:
14 | # image shape should be [C, H, W]
15 | for c in range(image.shape[0]):
16 | channel = image[c]
17 | min_i = torch.quantile(channel.flatten(), 0.001)
18 | max_i = torch.quantile(channel.flatten(), 0.999)
19 | image[c] = (channel - min_i) / max(1e-3, max_i - min_i)
20 | return image
21 |
22 | def __repr__(self):
23 | return self.__class__.__name__ + "()"
24 |
25 |
26 | class Instanseg(SegmentationModel):
27 | """Apply the InstaSeg model to the input image."""
28 |
29 | _base_mpp = 0.5
30 |
31 | def __init__(self, model_file=None):
32 | from huggingface_hub import hf_hub_download
33 |
34 | model_file = hf_hub_download(
35 | "RendeiroLab/LazySlide-models", "instanseg/instanseg_v0_1_0.pt"
36 | )
37 |
38 | self.model = torch.jit.load(model_file, map_location="cpu")
39 |
40 | def get_transform(self):
41 | from torchvision.transforms.v2 import ToImage, ToDtype, Compose
42 |
43 | return Compose(
44 | [
45 | ToImage(), # Converts numpy or PIL to torch.Tensor in [C, H, W] format
46 | ToDtype(dtype=torch.float32, scale=False),
47 | PercentileNormalize(),
48 | ]
49 | )
50 |
51 | @torch.inference_mode()
52 | def segment(self, image):
53 | # with torch.inference_mode():
54 | out = self.model(image)
55 | return out.squeeze().cpu().numpy().astype(np.uint16)
56 |
57 | def get_postprocess(self) -> Callable | None:
58 | return instanseg_postprocess
59 |
--------------------------------------------------------------------------------
/src/lazyslide/models/segmentation/nulite/__init__.py:
--------------------------------------------------------------------------------
1 | from .api import NuLite
2 |
--------------------------------------------------------------------------------
/src/lazyslide/models/segmentation/nulite/api.py:
--------------------------------------------------------------------------------
1 | from typing import Literal
2 |
3 | import cv2
4 | import numpy as np
5 | import torch
6 | import geopandas as gpd
7 |
8 | from lazyslide.cv import Mask
9 | from lazyslide.models.base import SegmentationModel
10 |
11 | from .model import NuLite as NuLiteModel
12 |
13 |
14 | class NuLite(SegmentationModel):
15 | def __init__(
16 | self,
17 | variant: Literal["H", "M", "T"] = "H",
18 | ):
19 | from huggingface_hub import hf_hub_download
20 |
21 | model_file = hf_hub_download(
22 | "RendeiroLab/LazySlide-models", f"nulite/NuLite-{variant}-Weights.pth"
23 | )
24 |
25 | weights = torch.load(model_file, map_location="cpu")
26 |
27 | config = weights["config"]
28 | self.model = NuLiteModel(
29 | config["data.num_nuclei_classes"],
30 | config["data.num_tissue_classes"],
31 | config["model.backbone"],
32 | )
33 | self.model.load_state_dict(weights["model_state_dict"])
34 |
35 | def get_transform(self):
36 | from torchvision.transforms.v2 import ToImage, ToDtype, Normalize, Compose
37 |
38 | return Compose(
39 | [
40 | ToImage(),
41 | ToDtype(dtype=torch.float32, scale=True),
42 | Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)),
43 | ]
44 | )
45 |
46 | @torch.inference_mode()
47 | def segment(self, image):
48 | return self.model.forward(image, retrieve_tokens=True)
49 |
50 | def get_postprocess(self):
51 | return nulite_preprocess
52 |
53 |
54 | CLASS_MAPPING = {
55 | 0: "Background",
56 | 1: "Neoplastic",
57 | 2: "Inflammatory",
58 | 3: "Connective",
59 | 4: "Dead",
60 | 5: "Epithelial",
61 | }
62 |
63 |
64 | def nulite_preprocess(
65 | output,
66 | ksize: int = 11,
67 | min_object_size: int = 3,
68 | nucleus_size: (int, int) = (20, 5000),
69 | ) -> gpd.GeoDataFrame:
70 | """Preprocess the image for NuLite model."""
71 |
72 | binary_mask = output["nuclei_binary_map"].softmax(0).detach().cpu().numpy()[1]
73 | hv_map = output["hv_map"].detach().cpu().numpy()
74 | type_prob_map = (
75 | output["nuclei_type_map"].softmax(0).detach().cpu().numpy()[1::]
76 | ) # to skip background
77 |
78 | _, blb = cv2.threshold(binary_mask.astype(np.float32), 0.5, 1, cv2.THRESH_BINARY)
79 | blb = blb.astype(np.uint8)
80 |
81 | # Remove small objects based on connected components.
82 | # Use cv2.connectedComponentsWithStats to label regions and filter by area.
83 | num_labels, labels, stats, _ = cv2.connectedComponentsWithStats(blb, connectivity=8)
84 | min_size = 3 # Minimum pixel area to keep an object
85 | blb_clean = np.zeros_like(blb)
86 | for label in range(1, num_labels): # label 0 is the background.
87 | if stats[label, cv2.CC_STAT_AREA] >= min_size:
88 | blb_clean[labels == label] = 1
89 |
90 | h_map, v_map = hv_map
91 | # STEP 2: Normalize directional maps
92 | h_dir_norm = cv2.normalize(
93 | h_map, None, alpha=0, beta=1, norm_type=cv2.NORM_MINMAX
94 | ).astype(np.float32)
95 | v_dir_norm = cv2.normalize(
96 | v_map, None, alpha=0, beta=1, norm_type=cv2.NORM_MINMAX
97 | ).astype(np.float32)
98 |
99 | # STEP 3: Compute edges using Sobel operators
100 | # ksize = 11 # Kernel size for Sobel operators; adjust for edge sensitivity.
101 | sobelh = cv2.Sobel(h_dir_norm, cv2.CV_64F, dx=1, dy=0, ksize=ksize)
102 | sobelv = cv2.Sobel(v_dir_norm, cv2.CV_64F, dx=0, dy=1, ksize=ksize)
103 |
104 | # Normalize the edge responses and invert them to prepare for the "distance" map.
105 | sobelh_norm = 1 - cv2.normalize(
106 | sobelh, None, alpha=0, beta=1, norm_type=cv2.NORM_MINMAX
107 | )
108 | sobelv_norm = 1 - cv2.normalize(
109 | sobelv, None, alpha=0, beta=1, norm_type=cv2.NORM_MINMAX
110 | )
111 |
112 | # Combine edge images by taking the maximum value at each pixel.
113 | overall = np.maximum(sobelh_norm, sobelv_norm)
114 |
115 | # Remove non-nuclei regions from the edge map.
116 | overall = overall - (1 - blb_clean.astype(np.float32))
117 | overall[overall < 0] = 0 # Set negative values to zero
118 |
119 | # STEP 4: Create an inverse “distance” map for watershed
120 | # The idea is to make the centers of nuclei correspond to local minima.
121 | # dist = (1.0 - overall) * blb_clean.astype(np.float32)
122 | # dist = -cv2.GaussianBlur(dist, (3, 3), 0)
123 |
124 | # STEP 5: Create markers for watershed (seed regions)
125 | # Identify the nucleus interior by thresholding the overall edge image.
126 | _, overall_bin = cv2.threshold(overall, 0.4, 1, cv2.THRESH_BINARY)
127 | overall_bin = overall_bin.astype(np.uint8)
128 |
129 | # Subtract the boundaries from the clean binary mask
130 | marker = blb_clean - overall_bin
131 | marker[marker < 0] = 0
132 |
133 | # Fill holes and do a morphological closing to smooth marker regions.
134 | kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (5, 5))
135 | marker_closed = cv2.morphologyEx(marker, cv2.MORPH_CLOSE, kernel)
136 |
137 | # Again, remove tiny markers using connected component analysis.
138 | num_labels, markers, stats, _ = cv2.connectedComponentsWithStats(
139 | marker_closed, connectivity=8
140 | )
141 | object_size = 10 # Minimum size (in pixels) for a marker region
142 | markers_clean = np.zeros_like(markers, dtype=np.int32)
143 | for label in range(1, num_labels):
144 | if stats[label, cv2.CC_STAT_AREA] >= object_size:
145 | markers_clean[markers == label] = label
146 |
147 | # STEP 6: Apply the Watershed algorithm using only OpenCV
148 | # The watershed function in OpenCV requires a 3-channel image.
149 | # Here, we build a dummy 3-channel (RGB) image from our binary mask (for visualization/masking purposes).
150 | dummy_img = cv2.cvtColor((blb_clean * 255).astype(np.uint8), cv2.COLOR_GRAY2BGR)
151 |
152 | # Watershed modifies the marker image in place.
153 | # The boundaries between segmented regions will be marked with -1.
154 | cv2.watershed(dummy_img, markers_clean)
155 |
156 | unique_labels = np.unique(markers_clean)
157 | final_seg = np.zeros_like(markers_clean, dtype=np.int32)
158 | cells = []
159 | nucleus_size_min, nucleus_size_max = nucleus_size
160 | for lbl in unique_labels:
161 | if lbl <= 1: # Skip background (-1) and unknown (1)
162 | continue
163 | mask = markers_clean == lbl
164 | x, y = np.where(mask)
165 | area = len(x)
166 |
167 | if nucleus_size_min <= area <= nucleus_size_max:
168 | probs = type_prob_map[:, x, y].mean(1)
169 | class_ix = np.argmax(probs)
170 | class_prob = type_prob_map[class_ix, x, y].mean()
171 | m = Mask.from_array(mask.astype(np.uint8))
172 | poly = m.to_polygons()[0]
173 | cells.append([CLASS_MAPPING[class_ix + 1], class_prob, poly])
174 | final_seg[markers_clean == lbl] = lbl
175 | return gpd.GeoDataFrame(cells, columns=["name", "prob", "geometry"])
176 |
--------------------------------------------------------------------------------
/src/lazyslide/models/segmentation/postprocess.py:
--------------------------------------------------------------------------------
1 | import cv2
2 | import geopandas as gpd
3 | import numpy as np
4 |
5 |
6 | def instanseg_postprocess(
7 | mask: np.ndarray,
8 | ):
9 | """
10 | Postprocess the mask to get the cell polygons.
11 |
12 | The feature of each cell is average-pooling the feature map within the cell's bounding box.
13 |
14 | Parameters
15 | ----------
16 | mask: np.ndarray
17 | The mask array.
18 |
19 | """
20 | from lazyslide.cv import MultiLabelMask
21 |
22 | mmask = MultiLabelMask(mask)
23 | polys = mmask.to_polygons(min_area=5, detect_holes=False)
24 | cells = []
25 | for k, vs in polys.items():
26 | if len(vs) == 0:
27 | continue
28 | elif len(vs) == 1:
29 | cell = vs[0]
30 | else:
31 | # Get the largest polygon
32 | svs = sorted(vs, key=lambda x: x.area)
33 | cell = svs[-1]
34 |
35 | cells.append(cell)
36 |
37 | container = {"geometry": cells}
38 | return gpd.GeoDataFrame(container)
39 |
40 |
41 | def semanticseg_postprocess(
42 | probs: np.ndarray,
43 | ignore_index: list[int] = None,
44 | min_area: int = 5,
45 | mapping: dict = None,
46 | ):
47 | from lazyslide.cv import MultiLabelMask
48 |
49 | mask = np.argmax(probs, axis=0).astype(np.uint8)
50 | mmask = MultiLabelMask(mask)
51 | polys = mmask.to_polygons(ignore_index=ignore_index, min_area=min_area)
52 | data = []
53 | for k, vs in polys.items():
54 | for v in vs:
55 | empty_mask = np.zeros_like(mask)
56 |
57 | cv2.drawContours( # noqa
58 | empty_mask,
59 | [np.array(v.exterior.coords).astype(np.int32)],
60 | -1,
61 | 1,
62 | thickness=cv2.FILLED,
63 | )
64 |
65 | prob = np.mean(probs[k][empty_mask == 1])
66 | class_name = k
67 | if mapping is not None:
68 | class_name = mapping[k]
69 | data.append([class_name, prob, v])
70 |
71 | return gpd.GeoDataFrame(data, columns=["class", "prob", "geometry"])
72 |
--------------------------------------------------------------------------------
/src/lazyslide/models/segmentation/sam.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | from ..base import SegmentationModel
4 |
5 |
6 | class SAM(SegmentationModel):
7 | SAM_VARIENTS = [
8 | "facebook/sam-vit-base",
9 | "facebook/sam-vit-large",
10 | "facebook/sam-vit-huge",
11 | ]
12 |
13 | SAM_HQ_VARIENTS = [
14 | "syscv-community/sam-hq-vit-base",
15 | "syscv-community/sam-hq-vit-large",
16 | "syscv-community/sam-hq-vit-huge",
17 | ]
18 |
19 | def __init__(self, variant="facebook/sam-vit-base", model_path=None, token=None):
20 | self.variant = variant
21 | if variant in self.SAM_VARIENTS:
22 | from transformers import SamModel, SamProcessor
23 |
24 | self.model = SamModel.from_pretrained(variant, use_auth_token=token)
25 | self.processor = SamProcessor.from_pretrained(variant, use_auth_token=token)
26 | self._is_hq = False
27 |
28 | elif variant in self.SAM_HQ_VARIENTS:
29 | from transformers import SamHQModel, SamHQProcessor
30 |
31 | self.model = SamHQModel.from_pretrained(variant, use_auth_token=token)
32 | self.processor = SamHQProcessor.from_pretrained(
33 | variant, use_auth_token=token
34 | )
35 | self._is_hq = True
36 | else:
37 | raise ValueError(
38 | f"Unsupported SAM variant: {variant}. "
39 | f"Choose from {self.SAM_VARIENTS + self.SAM_HQ_VARIENTS}."
40 | )
41 |
42 | def get_transform(self):
43 | return self.processor.image_processor
44 |
45 | @torch.inference_mode()
46 | def get_image_embedding(self, image) -> torch.Tensor:
47 | """
48 | Get the image embedding from the SAM model.
49 |
50 | Returns:
51 | torch.Tensor: Image embedding tensor of shape (1, C, H, W).
52 |
53 | """
54 | img_inputs = self.processor(image, return_tensors="pt").to(self.model.device)
55 |
56 | with torch.inference_mode():
57 | embeddings = self.model.get_image_embeddings(img_inputs["pixel_values"])
58 | if self._is_hq:
59 | embeddings = embeddings[0]
60 | return embeddings.detach().cpu()
61 |
62 | @torch.inference_mode()
63 | def segment(
64 | self,
65 | image,
66 | image_embedding=None,
67 | input_points=None,
68 | input_labels=None,
69 | input_boxes=None,
70 | segmentation_maps=None,
71 | multimask_output=False,
72 | ) -> torch.Tensor:
73 | """
74 | Segment the input image using the SAM model.
75 |
76 | Args:
77 | image (torch.Tensor): Input image tensor of shape (C, H, W).
78 |
79 | Returns:
80 | torch.Tensor: Segmentation mask tensor of shape (H, W).
81 | """
82 | inputs = self.processor(
83 | image,
84 | input_points=input_points,
85 | input_labels=input_labels,
86 | input_boxes=input_boxes,
87 | segmentation_maps=segmentation_maps,
88 | return_tensors="pt",
89 | )
90 | if image_embedding is not None:
91 | del inputs["pixel_values"]
92 | inputs["image_embeddings"] = image_embedding
93 |
94 | for k, v in inputs.items():
95 | if isinstance(v, torch.Tensor) and v.dtype == torch.float64:
96 | inputs[k] = v.to(dtype=torch.float32)
97 |
98 | inputs = inputs.to(self.model.device)
99 | outputs = self.model(**inputs, multimask_output=multimask_output)
100 | masks = self.processor.image_processor.post_process_masks(
101 | outputs.pred_masks.cpu(),
102 | inputs["original_sizes"].cpu(),
103 | inputs["reshaped_input_sizes"].cpu(),
104 | mask_threshold=0,
105 | )
106 | return masks[0]
107 |
--------------------------------------------------------------------------------
/src/lazyslide/models/segmentation/smp.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | from typing import Callable
4 |
5 | import torch
6 |
7 | from lazyslide.models.base import SegmentationModel
8 | from lazyslide.models.segmentation.postprocess import semanticseg_postprocess
9 |
10 |
11 | class SMPBase(SegmentationModel):
12 | """This is a base class for any models from segmentation models pytorch"""
13 |
14 | def __init__(
15 | self,
16 | arch: str = "unetplusplus",
17 | encoder_name: str = "timm-efficientnet-b0",
18 | encoder_weights: str = "imagenet",
19 | in_channels: int = 3,
20 | classes: int = 3,
21 | **kwargs,
22 | ):
23 | try:
24 | import segmentation_models_pytorch as smp
25 | except ModuleNotFoundError:
26 | raise ModuleNotFoundError(
27 | "Please install segmentation_models_pytorch to use this model."
28 | )
29 |
30 | self.encoder_name = encoder_name
31 | self.encoder_weights = encoder_weights
32 |
33 | self.model = smp.create_model(
34 | arch=arch,
35 | encoder_name=encoder_name,
36 | encoder_weights=encoder_weights,
37 | in_channels=in_channels,
38 | classes=classes,
39 | **kwargs,
40 | )
41 |
42 | def get_transform(self):
43 | from torchvision.transforms.v2 import Compose, ToImage, ToDtype, Normalize
44 |
45 | # default_fn = smp.encoders.get_preprocessing_fn(
46 | # self.encoder_name, self.encoder_weights
47 | # )
48 |
49 | return Compose(
50 | [
51 | ToImage(),
52 | ToDtype(torch.float32, scale=True),
53 | Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
54 | # default_fn
55 | ]
56 | )
57 |
58 | def get_postprocess(self) -> Callable:
59 | return semanticseg_postprocess
60 |
--------------------------------------------------------------------------------
/src/lazyslide/models/vision/__init__.py:
--------------------------------------------------------------------------------
1 | from .conch import CONCHVision
2 | from .gigapath import GigaPath, GigaPathSlideEncoder
3 | from .plip import PLIPVision
4 | from .uni import UNI, UNI2
5 | from .virchow import Virchow, Virchow2
6 | from .phikon import Phikon, PhikonV2
7 | from .h_optimus import HOptimus0, HOptimus1, H0Mini
8 | from .midnight import Midnight
9 | from .hibou import HibouB, HibouL
10 |
--------------------------------------------------------------------------------
/src/lazyslide/models/vision/base.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rendeirolab/LazySlide/f39634cc994b3098b0933075b9d25ecd99b9014e/src/lazyslide/models/vision/base.py
--------------------------------------------------------------------------------
/src/lazyslide/models/vision/conch.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | from lazyslide.models._utils import hf_access
4 | from lazyslide.models.base import ImageModel
5 |
6 |
7 | class CONCHVision(ImageModel):
8 | def __init__(self, model_path=None, token=None):
9 | try:
10 | from conch.open_clip_custom import create_model_from_pretrained
11 | except ImportError:
12 | raise ImportError(
13 | "Conch is not installed. You can install it using "
14 | "`pip install git+https://github.com/mahmoodlab/CONCH.git`."
15 | )
16 |
17 | with hf_access("conch_ViT-B-16"):
18 | self.model, self.processor = create_model_from_pretrained(
19 | "conch_ViT-B-16", model_path, hf_auth_token=token
20 | )
21 |
22 | def get_transform(self):
23 | return None
24 |
25 | @torch.inference_mode()
26 | def encode_image(self, image):
27 | if not isinstance(image, torch.Tensor):
28 | image = self.processor(image)
29 | if image.dim() == 3:
30 | image = image.unsqueeze(0)
31 |
32 | image_feature = self.model.encode_image(
33 | image, normalize=False, proj_contrast=False
34 | )
35 | return image_feature
36 |
--------------------------------------------------------------------------------
/src/lazyslide/models/vision/gigapath.py:
--------------------------------------------------------------------------------
1 | from platformdirs import user_cache_path
2 |
3 | from lazyslide.models.base import SlideEncoderModel, TimmModel
4 |
5 |
6 | class GigaPath(TimmModel):
7 | name = "GigaPath"
8 |
9 | def __init__(self, model_path=None, token=None):
10 | # Version check
11 | import timm
12 |
13 | try:
14 | from packaging import version
15 |
16 | timm_version = version.parse(timm.__version__)
17 | minimum_version = version.parse("1.0.3")
18 | if timm_version < minimum_version:
19 | raise ImportError(
20 | f"Gigapath needs timm >= 1.0.3. You have version {timm_version}."
21 | f"Run `pip install --upgrade timm` to install the latest version."
22 | )
23 | # If packaging is not installed, skip the version check
24 | except ModuleNotFoundError:
25 | pass
26 |
27 | super().__init__("hf_hub:prov-gigapath/prov-gigapath", token=token)
28 |
29 |
30 | class GigaPathSlideEncoder(SlideEncoderModel):
31 | def __init__(self, model_path=None, token=None):
32 | from huggingface_hub import login
33 |
34 | super().__init__()
35 |
36 | if token is not None:
37 | login(token)
38 |
39 | from gigapath.slide_encoder import create_model
40 |
41 | model = create_model(
42 | "hf_hub:prov-gigapath/prov-gigapath",
43 | "gigapath_slide_enc12l768d",
44 | 1536,
45 | local_dir=str(user_cache_path("lazyslide")),
46 | )
47 | self.model = model
48 |
49 | def encode_slide(self, tile_embed, coordinates):
50 | return self.model(tile_embed, coordinates).squeeze()
51 |
--------------------------------------------------------------------------------
/src/lazyslide/models/vision/h_optimus.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | from lazyslide.models.base import TimmModel
4 |
5 |
6 | def get_hoptimus_transform():
7 | from torchvision.transforms.v2 import (
8 | Compose,
9 | ToImage,
10 | Resize,
11 | CenterCrop,
12 | ToDtype,
13 | Normalize,
14 | )
15 | from torchvision.transforms import InterpolationMode
16 |
17 | return Compose(
18 | [
19 | ToImage(),
20 | Resize(
21 | size=(224, 224),
22 | interpolation=InterpolationMode.BICUBIC,
23 | max_size=None,
24 | antialias=True,
25 | ),
26 | CenterCrop(224),
27 | ToDtype(dtype=torch.float32, scale=True),
28 | Normalize(
29 | mean=(0.707223, 0.578729, 0.703617), std=(0.211883, 0.230117, 0.177517)
30 | ),
31 | ]
32 | )
33 |
34 |
35 | class HOptimus0(TimmModel):
36 | name = "H-optimus-0"
37 |
38 | def __init__(self, model_path=None, token=None):
39 | super().__init__(
40 | "hf-hub:bioptimus/H-optimus-0",
41 | pretrained=True,
42 | init_values=1e-5,
43 | dynamic_img_size=False,
44 | token=token,
45 | )
46 |
47 | def get_transform(self):
48 | return get_hoptimus_transform()
49 |
50 |
51 | class HOptimus1(TimmModel):
52 | name = "H-optimus-1"
53 |
54 | def __init__(self, model_path=None, token=None):
55 | super().__init__(
56 | "hf-hub:bioptimus/H-optimus-1",
57 | pretrained=True,
58 | init_values=1e-5,
59 | dynamic_img_size=False,
60 | token=token,
61 | )
62 |
63 | def get_transform(self):
64 | return get_hoptimus_transform()
65 |
66 |
67 | class H0Mini(TimmModel):
68 | name = "H0-mini"
69 |
70 | def __init__(self, model_path=None, token=None):
71 | import timm
72 |
73 | super().__init__(
74 | "hf-hub:bioptimus/H0-mini",
75 | pretrained=True,
76 | mlp_layer=timm.layers.SwiGLUPacked,
77 | act_layer=torch.nn.SiLU,
78 | token=token,
79 | )
80 |
81 | def get_transform(self):
82 | return get_hoptimus_transform()
83 |
84 | @torch.inference_mode()
85 | def encode_image(self, image):
86 | output = self.model(image)
87 | # CLS token features (1, 768):
88 | cls_features = output[:, 0]
89 | # Patch token features (1, 256, 768):
90 | patch_token_features = output[:, self.model.num_prefix_tokens :]
91 | # Concatenate the CLS token features with the mean of the patch token
92 | # features (1, 1536):
93 | concatenated_features = torch.cat(
94 | [cls_features, patch_token_features.mean(1)], dim=-1
95 | )
96 | return concatenated_features.cpu().detach().numpy()
97 |
--------------------------------------------------------------------------------
/src/lazyslide/models/vision/hibou.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | from lazyslide.models._utils import hf_access
4 | from lazyslide.models.base import ImageModel
5 |
6 |
7 | class Hibou(ImageModel):
8 | def __init__(self, hibou_version: str, model_path=None, token=None):
9 | try:
10 | from transformers import AutoModel
11 | except ImportError:
12 | raise ImportError(
13 | "transformers is not installed. You can install it using "
14 | "`pip install transformers`."
15 | )
16 |
17 | self.version = hibou_version
18 |
19 | with hf_access(f"histai/{self.version}"):
20 | self.model = AutoModel.from_pretrained(
21 | f"histai/{self.version}", trust_remote_code=True
22 | )
23 |
24 | def get_transform(self):
25 | from torchvision.transforms.v2 import (
26 | Compose,
27 | ToImage,
28 | Resize,
29 | CenterCrop,
30 | ToDtype,
31 | Normalize,
32 | )
33 | from torchvision.transforms import InterpolationMode
34 |
35 | return Compose(
36 | [
37 | ToImage(),
38 | Resize(
39 | size=(224, 224),
40 | interpolation=InterpolationMode.BICUBIC,
41 | max_size=None,
42 | antialias=True,
43 | ),
44 | CenterCrop(224),
45 | ToDtype(dtype=torch.float32, scale=True),
46 | Normalize(mean=(0.7068, 0.5755, 0.722), std=(0.195, 0.2316, 0.1816)),
47 | ]
48 | )
49 |
50 | @torch.inference_mode()
51 | def encode_image(self, image):
52 | image_features = self.model(pixel_values=image)
53 | return image_features.pooler_output
54 |
55 |
56 | class HibouB(Hibou):
57 | def __init__(self, token=None, model_path=None):
58 | super().__init__(hibou_version="hibou-b", token=token)
59 |
60 |
61 | class HibouL(Hibou):
62 | def __init__(self, token=None, model_path=None):
63 | super().__init__(hibou_version="hibou-l", token=token)
64 |
--------------------------------------------------------------------------------
/src/lazyslide/models/vision/midnight.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | from lazyslide.models._utils import hf_access
4 | from lazyslide.models.base import ImageModel
5 |
6 |
7 | class Midnight(ImageModel):
8 | def __init__(self, model_path=None, token=None):
9 | try:
10 | from transformers import AutoImageProcessor, AutoModel
11 | except ImportError:
12 | raise ImportError(
13 | "transformers is not installed. You can install it using "
14 | "`pip install transformers`."
15 | )
16 |
17 | with hf_access("kaiko-ai/midnight"):
18 | self.model = AutoModel.from_pretrained("kaiko-ai/midnight")
19 |
20 | def get_transform(self):
21 | from torchvision.transforms import v2
22 |
23 | return v2.Compose(
24 | [
25 | v2.ToImage(),
26 | v2.Resize(224),
27 | v2.CenterCrop(224),
28 | v2.ToDtype(dtype=torch.float32, scale=True),
29 | v2.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)),
30 | ]
31 | )
32 |
33 | @staticmethod
34 | def extract_classification_embedding(tensor):
35 | cls_embedding = tensor[:, 0, :]
36 | patch_embedding = tensor[:, 1:, :].mean(dim=1)
37 | return torch.cat([cls_embedding, patch_embedding], dim=-1)
38 |
39 | @torch.inference_mode()
40 | def encode_image(self, image):
41 | output = self.model(image).last_hidden_state
42 | image_feature = self.extract_classification_embedding(output)
43 | return image_feature
44 |
--------------------------------------------------------------------------------
/src/lazyslide/models/vision/phikon.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 |
4 | from lazyslide.models._utils import hf_access
5 | from lazyslide.models.base import ImageModel
6 |
7 |
8 | class Phikon(ImageModel):
9 | name = "phikon"
10 |
11 | def __init__(self, model_path=None, token=None):
12 | from transformers import AutoImageProcessor, ViTModel
13 |
14 | with hf_access("owkin/phikon"):
15 | self.model = ViTModel.from_pretrained(
16 | "owkin/phikon",
17 | add_pooling_layer=False,
18 | use_auth_token=token,
19 | )
20 | self.img_processor = AutoImageProcessor.from_pretrained(
21 | "owkin/phikon", use_fast=True
22 | )
23 |
24 | def get_transform(self):
25 | return None
26 |
27 | @torch.inference_mode()
28 | def encode_image(self, image) -> np.ndarray[np.float32]:
29 | inputs = self.img_processor(images=image, return_tensors="pt")
30 | inputs = {k: v.to(self.model.device) for k, v in inputs.items()}
31 | return self.model(**inputs).last_hidden_state[:, 0, :].cpu().detach().numpy()
32 |
33 |
34 | class PhikonV2(ImageModel):
35 | name = "phikon-v2"
36 |
37 | def __init__(self, model_path=None, token=None):
38 | from transformers import AutoImageProcessor, AutoModel
39 |
40 | with hf_access("owkin/phikon-v2"):
41 | self.model = AutoModel.from_pretrained(
42 | "owkin/phikon-v2",
43 | add_pooling_layer=False,
44 | use_auth_token=token,
45 | )
46 | self.img_processor = AutoImageProcessor.from_pretrained(
47 | "owkin/phikon-v2", use_fast=True
48 | )
49 |
50 | def get_transform(self):
51 | return None
52 |
53 | @torch.inference_mode()
54 | def encode_image(self, image) -> np.ndarray[np.float32]:
55 | inputs = self.img_processor(images=image, return_tensors="pt")
56 | inputs = {k: v.to(self.model.device) for k, v in inputs.items()}
57 | return self.model(**inputs).last_hidden_state[:, 0, :].cpu().detach().numpy()
58 |
--------------------------------------------------------------------------------
/src/lazyslide/models/vision/plip.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | from lazyslide.models._utils import hf_access
4 | from lazyslide.models.base import ImageModel
5 |
6 |
7 | class PLIPVision(ImageModel):
8 | def __init__(self, model_path=None, token=None):
9 | try:
10 | from transformers import CLIPVisionModelWithProjection, CLIPProcessor
11 | except ImportError:
12 | raise ImportError(
13 | "Please install the 'transformers' package to use the PLIP model"
14 | )
15 |
16 | super().__init__()
17 |
18 | if model_path is None:
19 | model_path = "vinid/plip"
20 |
21 | with hf_access(model_path):
22 | self.model = CLIPVisionModelWithProjection.from_pretrained(
23 | model_path, use_auth_token=token
24 | )
25 | self.processor = CLIPProcessor.from_pretrained(
26 | model_path, use_auth_token=token
27 | )
28 |
29 | def get_transform(self):
30 | return None
31 |
32 | @torch.inference_mode()
33 | def encode_image(self, image):
34 | inputs = self.processor(images=image, return_tensors="pt")
35 | inputs = {k: v.to(self.model.device) for k, v in inputs.items()}
36 | image_features = self.model.get_image_features(**inputs)
37 | return image_features
38 |
--------------------------------------------------------------------------------
/src/lazyslide/models/vision/uni.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | from lazyslide.models.base import TimmModel
4 |
5 |
6 | class UNI(TimmModel):
7 | def __init__(self, model_path=None, token=None):
8 | # from huggingface_hub import hf_hub_download
9 | # model_path = hf_hub_download("MahmoodLab/UNI", filename="pytorch_model.bin")
10 |
11 | if model_path is not None:
12 | super().__init__(
13 | "vit_large_patch16_224",
14 | token=token,
15 | img_size=224,
16 | patch_size=16,
17 | init_values=1e-5,
18 | num_classes=0,
19 | dynamic_img_size=True,
20 | pretrained=False,
21 | )
22 | self.model.load_state_dict(torch.load(model_path, map_location="cpu"))
23 | else:
24 | super().__init__(
25 | "hf-hub:MahmoodLab/uni",
26 | token=token,
27 | init_values=1e-5,
28 | dynamic_img_size=True,
29 | )
30 |
31 |
32 | class UNI2(TimmModel):
33 | def __init__(self, model_path=None, token=None):
34 | import timm
35 |
36 | timm_kwargs = {
37 | "img_size": 224,
38 | "patch_size": 14,
39 | "depth": 24,
40 | "num_heads": 24,
41 | "init_values": 1e-5,
42 | "embed_dim": 1536,
43 | "mlp_ratio": 2.66667 * 2,
44 | "num_classes": 0,
45 | "no_embed_class": True,
46 | "mlp_layer": timm.layers.SwiGLUPacked,
47 | "act_layer": torch.nn.SiLU,
48 | "reg_tokens": 8,
49 | "dynamic_img_size": True,
50 | }
51 |
52 | # from huggingface_hub import hf_hub_download
53 | # model_path = hf_hub_download("MahmoodLab/UNI2-h", filename="pytorch_model.bin")
54 |
55 | if model_path is not None:
56 | super().__init__(
57 | "vit_giant_patch14_224", token=token, pretrained=False, **timm_kwargs
58 | )
59 | self.model.load_state_dict(
60 | torch.load(model_path, map_location="cpu"), strict=True
61 | )
62 | else:
63 | super().__init__("hf-hub:MahmoodLab/UNI2-h", **timm_kwargs)
64 |
--------------------------------------------------------------------------------
/src/lazyslide/models/vision/virchow.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | from lazyslide.models.base import TimmModel
4 |
5 |
6 | class Virchow(TimmModel):
7 | _hf_hub_id = "paige-ai/Virchow"
8 |
9 | def __init__(self, model_path=None, token=None):
10 | from timm.layers import SwiGLUPacked
11 |
12 | super().__init__(
13 | f"hf-hub:{self._hf_hub_id}",
14 | pretrained=True,
15 | mlp_layer=SwiGLUPacked,
16 | act_layer=torch.nn.SiLU,
17 | token=token,
18 | )
19 |
20 | @torch.inference_mode()
21 | def encode_image(self, img):
22 | output = self.model(img)
23 | # CLS token features (1, 768):
24 | cls_features = output[:, 0]
25 | # Patch token features (1, 256, 768):
26 | patch_features = output[:, self.model.num_prefix_tokens :]
27 | return torch.cat((cls_features, patch_features.mean(1)), dim=-1)
28 |
29 |
30 | class Virchow2(Virchow):
31 | _hf_hub_id = "paige-ai/Virchow2"
32 |
--------------------------------------------------------------------------------
/src/lazyslide/plotting/__init__.py:
--------------------------------------------------------------------------------
1 | from ._api import tissue, tiles, annotations
2 | from ._wsi_viewer import WSIViewer
3 |
--------------------------------------------------------------------------------
/src/lazyslide/preprocess/__init__.py:
--------------------------------------------------------------------------------
1 | __all__ = ["find_tissues", "score_tissues", "tile_tissues", "score_tiles"]
2 |
3 | from ._graph import tile_graph
4 | from ._tiles import tile_tissues, score_tiles
5 | from ._tissue import find_tissues, score_tissues
6 |
--------------------------------------------------------------------------------
/src/lazyslide/preprocess/_graph.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | import warnings
4 | from itertools import chain
5 |
6 | import numpy as np
7 | import pandas as pd
8 | from anndata import AnnData
9 | from numba import njit
10 | from scipy.sparse import csr_matrix, spmatrix, isspmatrix_csr, SparseEfficiencyWarning
11 | from scipy.spatial import Delaunay
12 | from wsidata import WSIData
13 | from wsidata.io import add_table
14 |
15 | from lazyslide._const import Key
16 |
17 |
18 | def tile_graph(
19 | wsi: WSIData,
20 | n_neighs: int = 6,
21 | n_rings: int = 1,
22 | delaunay=False,
23 | transform: str = None,
24 | set_diag: bool = False,
25 | tile_key: str = Key.tiles,
26 | table_key: str = None,
27 | ):
28 | """
29 | Compute the spatial graph of the tiles.
30 |
31 | Parameters
32 | ----------
33 | wsi : :class:`WSIData `
34 | The WSIData object to work on.
35 | n_neighs : int, default: 6
36 | The number of neighbors to consider.
37 | n_rings : int, default: 1
38 | The number of rings to consider.
39 | delaunay : bool, default: False
40 | Whether to use Delaunay triangulation.
41 | transform : str, default: None
42 | The transformation to apply to the graph.
43 | set_diag : bool, default: False
44 | Whether to set the diagonal to 1.
45 | tile_key : str, default: 'tiles'
46 | The tile key.
47 | table_key : str, default: None
48 | The table key to store the graph.
49 |
50 | Returns
51 | -------
52 | The tiles with spatial connectivities and distances in an anndata format.
53 |
54 | - The feature spatial connectivities and distances will be added to :bdg-danger:`tables` slot of the spatial data object.
55 |
56 | Examples
57 | --------
58 | .. code-block:: python
59 |
60 | >>> import lazyslide as zs
61 | >>> wsi = zs.datasets.sample()
62 | >>> zs.pp.find_tissues(wsi)
63 | >>> zs.pp.tile_graph(wsi)
64 | >>> wsi['tile_graph']
65 |
66 |
67 | """
68 | coords = wsi[tile_key].bounds[["minx", "miny"]].values
69 | Adj, Dst = _spatial_neighbor(
70 | coords, n_neighs, delaunay, n_rings, transform, set_diag
71 | )
72 |
73 | conns_key = "spatial_connectivities"
74 | dists_key = "spatial_distances"
75 | neighbors_dict = {
76 | "connectivities_key": conns_key,
77 | "distances_key": dists_key,
78 | "params": {
79 | "n_neighbors": n_neighs,
80 | "transform": transform,
81 | },
82 | }
83 | # TODO: Store in a anndata object
84 | if table_key is None:
85 | table_key = Key.tile_graph(tile_key)
86 | if table_key not in wsi:
87 | table = AnnData(
88 | obs=pd.DataFrame(index=np.arange(coords.shape[0], dtype=int).astype(str)),
89 | obsp={conns_key: Adj, dists_key: Dst},
90 | uns={"spatial": neighbors_dict},
91 | )
92 | add_table(wsi, table_key, table)
93 | else:
94 | table = wsi[table_key]
95 | table.obsp[conns_key] = Adj
96 | table.obsp[dists_key] = Dst
97 | table.uns["spatial"] = neighbors_dict
98 |
99 |
100 | def _spatial_neighbor(
101 | coords,
102 | n_neighs: int = 6,
103 | delaunay: bool = False,
104 | n_rings: int = 1,
105 | transform: str = None,
106 | set_diag: bool = False,
107 | ) -> tuple[csr_matrix, csr_matrix]:
108 | with warnings.catch_warnings():
109 | warnings.simplefilter("ignore", SparseEfficiencyWarning)
110 | Adj, Dst = _build_grid(
111 | coords,
112 | n_neighs=n_neighs,
113 | n_rings=n_rings,
114 | delaunay=delaunay,
115 | set_diag=set_diag,
116 | )
117 |
118 | Adj.eliminate_zeros()
119 | Dst.eliminate_zeros()
120 |
121 | # check transform
122 | if transform == "spectral":
123 | Adj = _transform_a_spectral(Adj)
124 | elif transform == "cosine":
125 | Adj = _transform_a_cosine(Adj)
126 | elif transform == "none" or transform is None:
127 | pass
128 | else:
129 | raise NotImplementedError(f"Transform `{transform}` is not yet implemented.")
130 |
131 | return Adj, Dst
132 |
133 |
134 | def _build_grid(
135 | coords,
136 | n_neighs: int,
137 | n_rings: int,
138 | delaunay: bool = False,
139 | set_diag: bool = False,
140 | ) -> tuple[csr_matrix, csr_matrix]:
141 | if n_rings > 1:
142 | Adj: csr_matrix = _build_connectivity(
143 | coords,
144 | n_neighs=n_neighs,
145 | neigh_correct=True,
146 | set_diag=True,
147 | delaunay=delaunay,
148 | return_distance=False,
149 | )
150 | Res, Walk = Adj, Adj
151 | for i in range(n_rings - 1):
152 | Walk = Walk @ Adj
153 | Walk[Res.nonzero()] = 0.0
154 | Walk.eliminate_zeros()
155 | Walk.data[:] = i + 2.0
156 | Res = Res + Walk
157 | Adj = Res
158 | Adj.setdiag(float(set_diag))
159 | Adj.eliminate_zeros()
160 |
161 | Dst = Adj.copy()
162 | Adj.data[:] = 1.0
163 | else:
164 | Adj = _build_connectivity(
165 | coords,
166 | n_neighs=n_neighs,
167 | neigh_correct=True,
168 | delaunay=delaunay,
169 | set_diag=set_diag,
170 | )
171 | Dst = Adj.copy()
172 |
173 | Dst.setdiag(0.0)
174 |
175 | return Adj, Dst
176 |
177 |
178 | def _build_connectivity(
179 | coords,
180 | n_neighs: int,
181 | radius: float | tuple[float, float] | None = None,
182 | delaunay: bool = False,
183 | neigh_correct: bool = False,
184 | set_diag: bool = False,
185 | return_distance: bool = False,
186 | ) -> csr_matrix | tuple[csr_matrix, csr_matrix]:
187 | from sklearn.metrics import euclidean_distances
188 | from sklearn.neighbors import NearestNeighbors
189 |
190 | N = coords.shape[0]
191 | if delaunay:
192 | tri = Delaunay(coords)
193 | indptr, indices = tri.vertex_neighbor_vertices
194 | Adj = csr_matrix(
195 | (np.ones_like(indices, dtype=np.float64), indices, indptr), shape=(N, N)
196 | )
197 |
198 | if return_distance:
199 | # fmt: off
200 | dists = np.array(list(chain(*(
201 | euclidean_distances(coords[indices[indptr[i]: indptr[i + 1]], :], coords[np.newaxis, i, :])
202 | for i in range(N)
203 | if len(indices[indptr[i]: indptr[i + 1]])
204 | )))).squeeze()
205 | Dst = csr_matrix((dists, indices, indptr), shape=(N, N))
206 | # fmt: on
207 | else:
208 | r = (
209 | 1
210 | if radius is None
211 | else radius
212 | if isinstance(radius, (int, float))
213 | else max(radius)
214 | )
215 | tree = NearestNeighbors(n_neighbors=n_neighs, radius=r, metric="euclidean")
216 | tree.fit(coords)
217 |
218 | if radius is None:
219 | dists, col_indices = tree.kneighbors()
220 | dists, col_indices = dists.reshape(-1), col_indices.reshape(-1)
221 | row_indices = np.repeat(np.arange(N), n_neighs)
222 | if neigh_correct:
223 | dist_cutoff = np.median(dists) * 1.3 # there's a small amount of sway
224 | mask = dists < dist_cutoff
225 | row_indices, col_indices, dists = (
226 | row_indices[mask],
227 | col_indices[mask],
228 | dists[mask],
229 | )
230 | else:
231 | dists, col_indices = tree.radius_neighbors()
232 | row_indices = np.repeat(np.arange(N), [len(x) for x in col_indices])
233 | dists = np.concatenate(dists)
234 | col_indices = np.concatenate(col_indices)
235 |
236 | Adj = csr_matrix(
237 | (np.ones_like(row_indices, dtype=np.float64), (row_indices, col_indices)),
238 | shape=(N, N),
239 | )
240 | if return_distance:
241 | Dst = csr_matrix((dists, (row_indices, col_indices)), shape=(N, N))
242 |
243 | # radius-based filtering needs same indices/indptr: do not remove 0s
244 | Adj.setdiag(1.0 if set_diag else Adj.diagonal())
245 | if return_distance:
246 | Dst.setdiag(0.0)
247 | return Adj, Dst
248 |
249 | return Adj
250 |
251 |
252 | @njit
253 | def outer(indices, indptr, degrees):
254 | res = np.empty_like(indices, dtype=np.float64)
255 | start = 0
256 | for i in range(len(indptr) - 1):
257 | ixs = indices[indptr[i] : indptr[i + 1]]
258 | res[start : start + len(ixs)] = degrees[i] * degrees[ixs]
259 | start += len(ixs)
260 |
261 | return res
262 |
263 |
264 | def _transform_a_spectral(a: spmatrix) -> spmatrix:
265 | if not isspmatrix_csr(a):
266 | a = a.tocsr()
267 | if not a.nnz:
268 | return a
269 |
270 | degrees = np.squeeze(np.array(np.sqrt(1.0 / a.sum(axis=0))))
271 | a = a.multiply(outer(a.indices, a.indptr, degrees))
272 | a.eliminate_zeros()
273 |
274 | return a
275 |
276 |
277 | def _transform_a_cosine(a: spmatrix) -> spmatrix:
278 | from sklearn.metrics.pairwise import cosine_similarity
279 |
280 | return cosine_similarity(a, dense_output=False)
281 |
--------------------------------------------------------------------------------
/src/lazyslide/preprocess/_utils.py:
--------------------------------------------------------------------------------
1 | from typing import Union
2 |
3 | from lazyslide.cv.scorer import ScorerBase
4 |
5 | Scorer = Union[ScorerBase, str]
6 |
7 |
8 | def get_scorer(scorers):
9 | from lazyslide.cv.scorer import (
10 | ScorerBase,
11 | ComposeScorer,
12 | FocusLite,
13 | Contrast,
14 | Brightness,
15 | Redness,
16 | )
17 |
18 | scorer_mapper = {
19 | "focus": FocusLite,
20 | "contrast": Contrast,
21 | "brightness": Brightness,
22 | "redness": Redness,
23 | }
24 |
25 | scorer_list = []
26 | for s in scorers:
27 | if isinstance(s, ScorerBase):
28 | scorer_list.append(s)
29 | elif isinstance(s, str):
30 | scorer = scorer_mapper.get(s)
31 | if scorer is None:
32 | raise ValueError(
33 | f"Unknown scorer {s}, "
34 | f"available scorers are {'.'.join(scorer_mapper.keys())}"
35 | )
36 | # The scorer should be initialized when used
37 | scorer_list.append(scorer())
38 | else:
39 | raise TypeError(f"Unknown scorer type {type(s)}")
40 | compose_scorer = ComposeScorer(scorer_list)
41 | return compose_scorer
42 |
--------------------------------------------------------------------------------
/src/lazyslide/py.typed:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rendeirolab/LazySlide/f39634cc994b3098b0933075b9d25ecd99b9014e/src/lazyslide/py.typed
--------------------------------------------------------------------------------
/src/lazyslide/segmentation/__init__.py:
--------------------------------------------------------------------------------
1 | from ._seg_runner import SegmentationRunner, semantic
2 | from ._cell import cells, nulite
3 | from ._artifact import artifact
4 | from ._tissue import tissue
5 | from ._zero_shot import zero_shot
6 |
--------------------------------------------------------------------------------
/src/lazyslide/segmentation/_artifact.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | from typing import Literal
4 |
5 | from wsidata import WSIData
6 | from wsidata.io import add_shapes
7 |
8 | from lazyslide._const import Key
9 | from lazyslide._utils import get_torch_device
10 | from ._seg_runner import SegmentationRunner
11 | from ..models.segmentation import GrandQCArtifact
12 |
13 | # Define class mapping
14 | CLASS_MAPPING = {
15 | 1: "Normal Tissue",
16 | 2: "Fold",
17 | 3: "Dark spot & Foreign Object",
18 | 4: "PenMarking",
19 | 5: "Edge & Air Bubble",
20 | 6: "Out of Focus",
21 | 7: "Background",
22 | }
23 |
24 |
25 | def artifact(
26 | wsi: WSIData,
27 | tile_key: str,
28 | variants: Literal["grandqc_5x", "grandqc_7x", "grandqc_10x"] = "grandqc_7x",
29 | tissue_key: str = Key.tissue,
30 | batch_size: int = 4,
31 | num_workers: int = 0,
32 | device: str | None = None,
33 | key_added: str = "artifacts",
34 | ):
35 | """
36 | Artifact segmentation for the whole slide image.
37 |
38 | Run GrandQC artifact segmentation model on the whole slide image.
39 | The model is trained on 512x512 tiles with mpp=1.5, 2, or 1.
40 |
41 | It can detect the following artifacts:
42 | - Fold
43 | - Darkspot & Foreign Object
44 | - Pen Marking
45 | - Edge & Air Bubble
46 | - Out of Focus
47 |
48 | Parameters
49 | ----------
50 | wsi : WSIData
51 | The whole slide image data.
52 | tile_key : str
53 | The key of the tile table.
54 | variants : {"grandqc_5x", "grandqc_7x", "grandqc_10x"}, default: "grandqc_7x"
55 | The model variant to use for segmentation.
56 | tissue_key : str, default: Key.tissue
57 | The key of the tissue table.
58 | batch_size : int, default: 4
59 | The batch size for segmentation.
60 | num_workers : int, default: 0
61 | The number of workers for data loading.
62 | device : str, default: None
63 | The device for the model.
64 | key_added : str, default: "artifacts"
65 | The key for the added artifact shapes.
66 |
67 | """
68 | if tissue_key not in wsi:
69 | raise ValueError(
70 | "Tissue segmentation is required before artifact segmentation."
71 | "Please run `pp.find_tissues` first."
72 | )
73 |
74 | if device is None:
75 | device = get_torch_device()
76 |
77 | model_mpp = {
78 | "grandqc_5x": 2,
79 | "grandqc_7x": 1.5,
80 | "grandqc_10x": 1,
81 | }
82 |
83 | mpp = model_mpp[variants]
84 |
85 | if tile_key is not None:
86 | # Check if the tile spec is compatible with the model
87 | spec = wsi.tile_spec(tile_key)
88 | if spec is None:
89 | raise ValueError(f"Tiles or tile spec for {tile_key} not found.")
90 | if spec.mpp != mpp:
91 | raise ValueError(
92 | f"Tile spec mpp {spec.mpp} is not "
93 | f"compatible with the model mpp {mpp}"
94 | )
95 | if spec.width != 512 or spec.height != 512:
96 | raise ValueError("Tile should be 512x512.")
97 |
98 | model = GrandQCArtifact(model=variants.lstrip("grandqc_"))
99 |
100 | runner = SegmentationRunner(
101 | wsi,
102 | model,
103 | tile_key,
104 | transform=None,
105 | batch_size=batch_size,
106 | num_workers=num_workers,
107 | device=device,
108 | class_col="class",
109 | postprocess_kws={
110 | "ignore_index": [0, 1, 7], # Ignore background, normal tissue
111 | "mapping": CLASS_MAPPING,
112 | },
113 | )
114 | arts = runner.run()
115 | add_shapes(wsi, key=key_added, shapes=arts)
116 |
--------------------------------------------------------------------------------
/src/lazyslide/segmentation/_cell.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | import warnings
4 |
5 | from wsidata import WSIData
6 | from wsidata.io import add_shapes
7 |
8 | from lazyslide.models import SegmentationModel
9 | from lazyslide.models.segmentation import Instanseg, NuLite
10 | from ._seg_runner import SegmentationRunner
11 | from .._const import Key
12 |
13 |
14 | def cells(
15 | wsi: WSIData,
16 | model: str | SegmentationModel = "instanseg",
17 | tile_key=Key.tiles,
18 | transform=None,
19 | batch_size=4,
20 | num_workers=0,
21 | device=None,
22 | key_added="cells",
23 | ):
24 | """Cell segmentation for the whole slide image.
25 |
26 | Tiles should be prepared before segmentation.
27 |
28 | Recommended tile setting:
29 | - **instanseg**: 512x512, mpp=0.5
30 |
31 | Parameters
32 | ----------
33 | wsi : WSIData
34 | The whole slide image data.
35 | model : str | SegmentationModel, default: "instanseg"
36 | The cell segmentation model.
37 | tile_key : str, default: "tiles"
38 | The key of the tile table.
39 | transform : callable, default: None
40 | The transformation for the input tiles.
41 | batch_size : int, default: 4
42 | The batch size for segmentation.
43 | num_workers : int, default: 0
44 | The number of workers for data loading.
45 | device : str, default: None
46 | The device for the model.
47 | key_added : str, default: "cells"
48 | The key for the added cell shapes.
49 |
50 | """
51 | if model == "instanseg":
52 | model = Instanseg()
53 | # Run tile check
54 | tile_spec = wsi.tile_spec(tile_key)
55 | check_mpp = tile_spec.mpp == 0.5
56 | check_size = tile_spec.height == 512 and tile_spec.width == 512
57 | if not check_mpp or not check_size:
58 | warnings.warn(
59 | f"To optimize the performance of Instanseg model, "
60 | f"the tile size should be 512x512 and the mpp should be 0.5. "
61 | f"Current tile size is {tile_spec.width}x{tile_spec.height} with {tile_spec.mpp} mpp."
62 | )
63 |
64 | runner = SegmentationRunner(
65 | wsi,
66 | model,
67 | tile_key,
68 | transform=transform,
69 | batch_size=batch_size,
70 | num_workers=num_workers,
71 | device=device,
72 | )
73 | cells = runner.run()
74 | # Add cells to the WSIData
75 | add_shapes(wsi, key=key_added, shapes=cells)
76 |
77 |
78 | def nulite(
79 | wsi: WSIData,
80 | tile_key="tiles",
81 | transform=None,
82 | batch_size=4,
83 | num_workers=0,
84 | device=None,
85 | key_added="cell_types",
86 | ):
87 | """Cell type segmentation for the whole slide image.
88 |
89 | Tiles should be prepared before segmentation.
90 |
91 | Recommended tile setting:
92 | - **nulite**: 512x512, mpp=0.5
93 |
94 | Parameters
95 | ----------
96 | wsi : WSIData
97 | The whole slide image data.
98 | tile_key : str, default: "tiles"
99 | The key of the tile table.
100 | transform : callable, default: None
101 | The transformation for the input tiles.
102 | batch_size : int, default: 4
103 | The batch size for segmentation.
104 | num_workers : int, default: 0
105 | The number of workers for data loading.
106 | device : str, default: None
107 | The device for the model.
108 | key_added : str, default: "cell_types"
109 | The key for the added cell type shapes.
110 |
111 | """
112 |
113 | model = NuLite()
114 |
115 | runner = SegmentationRunner(
116 | wsi,
117 | model,
118 | tile_key,
119 | transform=transform,
120 | batch_size=batch_size,
121 | num_workers=num_workers,
122 | device=device,
123 | )
124 | cells = runner.run()
125 | # Add cells to the WSIData
126 | add_shapes(wsi, key=key_added, shapes=cells)
127 |
--------------------------------------------------------------------------------
/src/lazyslide/segmentation/_seg_runner.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | from functools import partial
4 | from typing import Literal, Callable, Mapping
5 |
6 | import geopandas as gpd
7 | import numpy as np
8 | import pandas as pd
9 | import torch
10 | from shapely.affinity import scale, translate
11 | from torch.utils.data import DataLoader
12 | from wsidata import WSIData
13 | from wsidata.io import add_shapes
14 |
15 | from lazyslide._const import Key
16 | from lazyslide._utils import default_pbar, get_torch_device
17 | from lazyslide.cv import merge_polygons
18 | from lazyslide.models.base import SegmentationModel
19 |
20 |
21 | def semantic(
22 | wsi: WSIData,
23 | model: SegmentationModel,
24 | tile_key=Key.tiles,
25 | transform=None,
26 | batch_size=4,
27 | num_workers=0,
28 | device=None,
29 | key_added="anatomical_structures",
30 | ):
31 | """
32 | Semantic segmentation for the whole slide image.
33 |
34 | Parameters
35 | ----------
36 | wsi : WSIData
37 | The whole slide image data.
38 | model : SegmentationModel
39 | The segmentation model.
40 | tile_key : str, default: "tiles"
41 | The key of the tile table.
42 | transform : callable, default: None
43 | The transformation for the input tiles.
44 | batch_size : int, default: 4
45 | The batch size for segmentation.
46 | num_workers : int, default: 0
47 | The number of workers for data loading.
48 | device : str, default: None
49 | The device for the model.
50 | key_added : str, default: "anatomical_structures"
51 | The key for the added instance shapes.
52 |
53 | """
54 | runner = SegmentationRunner(
55 | wsi=wsi,
56 | model=model,
57 | tile_key=tile_key,
58 | transform=transform,
59 | batch_size=batch_size,
60 | num_workers=num_workers,
61 | device=device,
62 | )
63 | shapes = runner.run()
64 | # Add the segmentation results to the WSIData
65 | add_shapes(wsi, key=key_added, shapes=shapes)
66 |
67 |
68 | class SegmentationRunner:
69 | """
70 | Segmentation runner for the whole slide image.
71 |
72 | Parameters
73 | ----------
74 | wsi : :class:`WSIData `
75 | The whole slide image data.
76 | model : :class:`SegmentationModel `
77 | The segmentation model.
78 | tile_key : str
79 | The key of the tile table.
80 | transform : callable, default: None
81 | The transformation for the input tiles.
82 | batch_size : int, default: 4
83 | The batch size for segmentation.
84 | num_workers : int, default: 0
85 | The number of workers for data loading.
86 | device : str, default: None
87 | The device for the model.
88 | postprocess_kws : dict, default: None
89 | The keyword arguments for the postprocess function defined in the model class
90 | dataloader_kws : dict, default: None
91 | The keyword arguments for the DataLoader.
92 | class_col : str, default: None
93 | The column name for the class in the output GeoDataFrame.
94 | prob_col : str, default: None
95 | The column name for the probability in the output GeoDataFrame.
96 | buffer_px : int, default: 0
97 | The buffer size in pixels for the polygons.
98 | drop_overlap : float, default: 0.9
99 | The overlap threshold for dropping polygons.
100 | pbar : bool, default: True
101 | Whether to show the progress bar.
102 |
103 | """
104 |
105 | def __init__(
106 | self,
107 | wsi: WSIData,
108 | model: SegmentationModel,
109 | tile_key: str,
110 | transform: Callable = None,
111 | batch_size: int = 4,
112 | num_workers: int = 0,
113 | device: str = None,
114 | postprocess_kws: dict = None,
115 | dataloader_kws: dict = None,
116 | class_col: str = None,
117 | prob_col: str = None,
118 | buffer_px: int = 0,
119 | drop_overlap: float = 0.9,
120 | pbar: bool = True,
121 | ):
122 | self.wsi = wsi
123 | self.model = model
124 | if device is None:
125 | device = get_torch_device()
126 | self.device = device
127 | self.tile_key = tile_key
128 | self.downsample = wsi.tile_spec(tile_key).base_downsample
129 |
130 | if transform is None:
131 | transform = model.get_transform()
132 | self.transform = transform
133 |
134 | if postprocess_kws is None:
135 | postprocess_kws = {}
136 | postprocess_fn = model.get_postprocess()
137 | self.postprocess_fn = partial(postprocess_fn, **postprocess_kws)
138 |
139 | if dataloader_kws is None:
140 | dataloader_kws = {}
141 | dataloader_kws.setdefault("num_workers", num_workers)
142 | dataloader_kws.setdefault("batch_size", batch_size)
143 | self.dataloader_kws = dataloader_kws
144 | self.merge_kws = dict(
145 | class_col=class_col,
146 | prob_col=prob_col,
147 | buffer_px=buffer_px,
148 | drop_overlap=drop_overlap,
149 | )
150 |
151 | self.pbar = pbar
152 |
153 | def _batch_postprocess(self, output, xs, ys):
154 | results = []
155 |
156 | if isinstance(output, (torch.Tensor, np.ndarray)):
157 | batches = zip(output, xs, ys)
158 | elif isinstance(output, tuple):
159 | batches = zip(list(zip(*output)), xs, ys)
160 | elif isinstance(output, Mapping):
161 | flattened = [
162 | dict(zip(output.keys(), values)) for values in zip(*output.values())
163 | ]
164 | batches = zip(flattened, xs, ys)
165 | else:
166 | raise NotImplementedError(f"Unsupported model output type {type(output)}")
167 |
168 | for batch, x, y in batches:
169 | result = self.postprocess_fn(batch)
170 | # The output of postprocess_fn is a gpd.GeoDataFrame
171 | # transform the polygons to the global coordinate
172 | polys = []
173 | for poly in result["geometry"]:
174 | poly = scale(
175 | poly, xfact=self.downsample, yfact=self.downsample, origin=(0, 0)
176 | )
177 | poly = translate(poly, xoff=x, yoff=y)
178 | polys.append(poly)
179 | result["geometry"] = polys
180 | if len(result) > 0:
181 | results.append(result)
182 |
183 | return results
184 |
185 | def __call__(self):
186 | dataset = self.wsi.ds.tile_images(
187 | tile_key=self.tile_key, transform=self.transform
188 | )
189 | dl = DataLoader(dataset, **self.dataloader_kws)
190 |
191 | # Move model to device
192 | if self.device is not None:
193 | self.model.to(self.device)
194 |
195 | with default_pbar(disable=not self.pbar) as progress_bar:
196 | task = progress_bar.add_task("Segmentation", total=len(dataset))
197 |
198 | results = []
199 | for chunk in dl:
200 | images = chunk["image"]
201 | xs, ys = np.asarray(chunk["x"]), np.asarray(chunk["y"])
202 | if self.device is not None:
203 | images = images.to(self.device)
204 | output = self.model.segment(images)
205 |
206 | rs = self._batch_postprocess(output, xs, ys)
207 | # Update only if the output is not empty
208 | results.extend(rs)
209 | progress_bar.update(task, advance=len(xs))
210 | polys_df = gpd.GeoDataFrame(pd.concat(results).reset_index(drop=True))
211 | progress_bar.update(task, description="Merging tiles...")
212 | # === Merge the polygons ===
213 | polys_df = merge_polygons(polys_df, **self.merge_kws)
214 | # === Refresh the progress bar ===
215 | progress_bar.update(task, description="Segmentation")
216 | progress_bar.refresh()
217 |
218 | polys_df = polys_df.explode().reset_index(drop=True)
219 | return polys_df
220 |
221 | def run(self):
222 | """
223 | Run the segmentation.
224 | """
225 | return self.__call__()
226 |
--------------------------------------------------------------------------------
/src/lazyslide/segmentation/_tissue.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | import cv2
4 | import numpy as np
5 | import torch
6 | from shapely.affinity import scale
7 | from wsidata import WSIData
8 | from wsidata.io import add_tissues
9 |
10 | from lazyslide._const import Key
11 | from lazyslide._utils import get_torch_device
12 | from lazyslide.cv import BinaryMask
13 | from lazyslide.models.segmentation import GrandQCTissue
14 |
15 |
16 | def tissue(
17 | wsi: WSIData,
18 | level: int = None,
19 | device: str | None = None,
20 | key_added: str = Key.tissue,
21 | ):
22 | """
23 | Return a dataset for tissue segmentation.
24 |
25 | Parameters
26 | ----------
27 | wsi: :class:`wsidata.WSIData`
28 | The whole slide image.
29 | level : int, default: None
30 | The level to segment the tissue.
31 | device : str, default: None
32 | The device to run the model.
33 | key_added : str, default: 'tissues'
34 | The key to add the tissue polygons.
35 |
36 | """
37 |
38 | if device is None:
39 | device = get_torch_device()
40 |
41 | props = wsi.properties
42 | if level is None:
43 | level_mpp = np.array(props.level_downsample) * props.mpp
44 | # Get the nearest level that towards mpp=10
45 | level = np.argmin(np.abs(level_mpp - 10))
46 | shape = props.level_shape[level]
47 |
48 | model = GrandQCTissue()
49 | transform = model.get_transform()
50 |
51 | model.to(device)
52 |
53 | # Ensure the image size is multiple of 32
54 | # Calculate the nearest multiples of 32
55 | height, width = shape
56 | new_height = (height + 31) // 32 * 32
57 | new_width = (width + 31) // 32 * 32
58 | img = wsi.reader.get_region(0, 0, width, height, level=level)
59 | downsample = props.level_downsample[level]
60 |
61 | # We cannot read the image directly from the reader.
62 | # The padding from image reader will introduce padding at only two sides
63 | # We need to pad the image on all four sides
64 | # without shifting the image equilibrium
65 | # Otherwise, this will introduce artifacts in the segmentation
66 |
67 | # # Compute padding amounts
68 | top_pad = (new_height - height) // 2
69 | bottom_pad = new_height - height - top_pad
70 | left_pad = (new_width - width) // 2
71 | right_pad = new_width - width - left_pad
72 |
73 | # Apply padding
74 | img = np.pad(
75 | img,
76 | pad_width=((top_pad, bottom_pad), (left_pad, right_pad), (0, 0)),
77 | mode="constant",
78 | constant_values=0, # Pad with black pixels
79 | )
80 |
81 | # Simulate JPEG compression
82 | encode_param = [int(cv2.IMWRITE_JPEG_QUALITY), 80]
83 | result, img = cv2.imencode(".jpg", img, encode_param)
84 | img = cv2.imdecode(img, 1)
85 |
86 | img = torch.tensor(img).permute(2, 0, 1)
87 |
88 | img_t = transform(img).unsqueeze(0)
89 | img_t = img_t.to(device)
90 | pred = model.segment(img_t)
91 |
92 | pred = pred.squeeze().detach().cpu().numpy()
93 | mask = np.argmax(pred, axis=0).astype(np.uint8)
94 | # Flip the mask
95 | mask = 1 - mask
96 | polygons = BinaryMask(mask).to_polygons(
97 | min_area=1e-3,
98 | min_hole_area=1e-5,
99 | detect_holes=True,
100 | )
101 | polygons = [
102 | scale(p, xfact=downsample, yfact=downsample, origin=(0, 0)) for p in polygons
103 | ]
104 | add_tissues(wsi, key_added, polygons)
105 |
--------------------------------------------------------------------------------
/src/lazyslide/tools/__init__.py:
--------------------------------------------------------------------------------
1 | from ._domain import spatial_domain, tile_shaper
2 | from ._features import feature_extraction, feature_aggregation
3 | from ._signatures import RNALinker
4 | from ._text_annotate import text_embedding, text_image_similarity
5 | from ._tissue_props import tissue_props
6 | from ._spatial_features import spatial_features, feature_utag
7 | from ._zero_shot import zero_shot_score, slide_caption
8 |
--------------------------------------------------------------------------------
/src/lazyslide/tools/_domain.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from wsidata import WSIData
3 | from wsidata.io import update_shapes_data, add_shapes
4 |
5 | from lazyslide._const import Key
6 |
7 |
8 | def spatial_domain(
9 | wsi: WSIData,
10 | feature_key: str,
11 | tile_key: str = Key.tiles,
12 | layer: str = None,
13 | resolution: float = 0.1,
14 | key_added: str = "domain",
15 | ):
16 | """Return the unsupervised domain of the WSI"""
17 | try:
18 | import scanpy as sc
19 | except ImportError:
20 | raise ImportError(
21 | "Please install scanpy to use this function, try `pip install scanpy`."
22 | )
23 | feature_key = wsi._check_feature_key(feature_key, tile_key)
24 | adata = wsi.fetch.features_anndata(feature_key, tile_key, tile_graph=False)
25 | sc.pp.scale(adata, layer=layer)
26 | sc.pp.pca(adata, layer=layer)
27 | sc.pp.neighbors(adata)
28 | sc.tl.leiden(adata, flavor="igraph", key_added=key_added, resolution=resolution)
29 | # Add to tile table
30 | update_shapes_data(wsi, tile_key, {key_added: adata.obs[key_added].to_numpy()})
31 |
32 |
33 | def tile_shaper(
34 | wsi: WSIData,
35 | groupby: str = "domain",
36 | tile_key: str = Key.tiles,
37 | key_added: str = "domain_shapes",
38 | ):
39 | # """Return the domain shapes of the WSI
40 | # Parameters
41 | # ----------
42 | # wsi: :class:`WSIData `
43 | # The WSIData object.
44 | # groupby: str
45 | # The groupby key.
46 | # tile_key: str
47 | # The tile key.
48 | # key_added: str
49 | # The key to add the shapes to.
50 | #
51 | # Returns
52 | # -------
53 | # None
54 | # The shapes will be added to the WSIData object.
55 | # - The shapes will be added to the `domain_shapes` layer of the tile table.
56 | #
57 | # Examples
58 | # --------
59 | # .. code-block:: python
60 | #
61 | # >>> import lazyslide as zs
62 | # >>> wsi = zs.datasets.sample()
63 | # >>> zs.pp.find_tissues(wsi)
64 | # >>> zs.pp.tile_tissues(wsi, 256, mpp=0.5)
65 | # >>> zs.tl.feature_extraction(wsi, "resnet50")
66 | # >>> zs.pp.tile_graph(wsi)
67 | # >>> zs.tl.spatial_domain(wsi, layer="utag", feature_key="resnet50", resolution=0.3)
68 | # >>> zs.tl.tile_shaper(wsi)
69 | #
70 | # """
71 | import geopandas as gpd
72 | from lazyslide.cv import BinaryMask
73 | from shapely.affinity import scale, translate
74 |
75 | result = []
76 |
77 | tile_table = wsi[tile_key]
78 |
79 | spec = wsi.tile_spec(tile_key)
80 |
81 | # To avoid large memory allocation of mask, get domain in each tissue
82 | for _, tissue_group in tile_table.groupby("tissue_id"):
83 | for name, group in tissue_group.groupby(groupby):
84 | bounds = (group.bounds / spec.base_height).astype(int)
85 | minx, miny, maxx, maxy = (
86 | bounds["minx"].min(),
87 | bounds["miny"].min(),
88 | bounds["maxx"].max(),
89 | bounds["maxy"].max(),
90 | )
91 | w, h = int(maxx - minx), int(maxy - miny)
92 | mask = np.zeros((h, w), dtype=np.uint8)
93 | for _, row in bounds.iterrows():
94 | mask[row["miny"] - miny, row["minx"] - minx] = 1
95 | polys = BinaryMask(mask).to_polygons()
96 | # scale back
97 | polys = [
98 | scale(
99 | poly, xfact=spec.base_height, yfact=spec.base_height, origin=(0, 0)
100 | )
101 | for poly in polys
102 | ]
103 | # translate
104 | polys = [
105 | translate(
106 | poly, xoff=minx * spec.base_height, yoff=miny * spec.base_height
107 | )
108 | for poly in polys
109 | ]
110 | for poly in polys:
111 | result.append([name, poly])
112 |
113 | domain_shapes = gpd.GeoDataFrame(data=result, columns=[groupby, "geometry"])
114 | add_shapes(wsi, key_added, domain_shapes)
115 | # return domain_shapes
116 |
--------------------------------------------------------------------------------
/src/lazyslide/tools/_spatial_features.py:
--------------------------------------------------------------------------------
1 | import warnings
2 |
3 | import numpy as np
4 | from wsidata import WSIData
5 |
6 | from lazyslide._const import Key
7 | from lazyslide._utils import find_stack_level
8 |
9 |
10 | def spatial_features(
11 | wsi: WSIData,
12 | feature_key: str,
13 | method: str = "smoothing",
14 | tile_key: str = Key.tiles,
15 | graph_key: str = None,
16 | layer_key: str = "spatial_features",
17 | ):
18 | """
19 | Integrate spatial tile context with vision features using spatial feature smoothing.
20 |
21 | Parameters
22 | ----------
23 | wsi : :class:`WSIData `
24 | The WSIData object.
25 | feature_key : str
26 | The feature key.
27 | method : str, default: 'smoothing'
28 | The method used for spatial feature smoothing. Currently only 'smoothing' is supported.
29 | tile_key : str, default: 'tiles'
30 | The tile key.
31 | graph_key : str, optional
32 | The graph key. If None, defaults to '{tile_key}_graph'.
33 | layer_key : str, default: 'spatial_features'
34 | The key for the output layer in the feature table.
35 |
36 | Returns
37 | -------
38 | None. The transformed feature will be added to the `spatial_features` layer of the feature table.
39 |
40 | Examples
41 | --------
42 | .. code-block:: python
43 |
44 | >>> import lazyslide as zs
45 | >>> wsi = zs.datasets.sample()
46 | >>> zs.pp.find_tissues(wsi)
47 | >>> zs.pp.tile_tissues(wsi, 256, mpp=0.5)
48 | >>> zs.tl.feature_extraction(wsi, "resnet50")
49 | >>> zs.pp.tile_graph(wsi)
50 | >>> zs.tl.spatial_features(wsi, "resnet50")
51 | >>> wsi["resnet50"].layers["spatial_features"]
52 |
53 | """
54 | if method != "smoothing":
55 | raise ValueError(f"Unknown method '{method}'. Only 'smoothing' is currently supported.")
56 |
57 | # Get the spatial connectivity
58 | try:
59 | if graph_key is None:
60 | graph_key = f"{tile_key}_graph"
61 | A = wsi.tables[graph_key].obsp["spatial_connectivities"]
62 | except KeyError:
63 | raise ValueError(
64 | "The tile graph is needed to transform feature with spatial smoothing. Please run `pp.tile_graph` first."
65 | )
66 | A = A + np.eye(A.shape[0])
67 | # L1 norm for each row
68 | norms = np.sum(np.abs(A), axis=1)
69 | # Normalize the array
70 | A_norm = A / norms
71 |
72 | feature_key = wsi._check_feature_key(feature_key, tile_key)
73 | feature_X = wsi.tables[feature_key].X
74 | A_spatial = np.transpose(feature_X) @ A_norm
75 | A_spatial = np.transpose(A_spatial)
76 | wsi.tables[feature_key].layers[layer_key] = np.asarray(A_spatial)
77 |
78 |
79 | def feature_utag(
80 | wsi: WSIData,
81 | feature_key: str,
82 | tile_key: str = Key.tiles,
83 | graph_key: str = None,
84 | ):
85 | """
86 | Deprecated. Use :func:`spatial_features` instead.
87 | """
88 | warnings.warn(
89 | "`tl.feature_utag` is deprecated and will be removed after 0.8.0, "
90 | "please use `tl.spatial_features` instead.",
91 | stacklevel=find_stack_level(),
92 | )
93 | return spatial_features(wsi, feature_key, method="smoothing", tile_key=tile_key, graph_key=graph_key, layer_key="spatial_features")
94 |
--------------------------------------------------------------------------------
/src/lazyslide/tools/_text_annotate.py:
--------------------------------------------------------------------------------
1 | from typing import List, Literal
2 |
3 | import numpy as np
4 | import pandas as pd
5 | from wsidata import WSIData
6 | from wsidata.io import add_features
7 |
8 | from lazyslide._const import Key
9 |
10 |
11 | def text_embedding(
12 | texts: List[str],
13 | model: Literal["plip", "conch"] = "plip",
14 | ):
15 | """Embed the text into a vector in the text-vision co-embedding using
16 | `PLIP `_ or
17 | `CONCH `_.
18 |
19 | Parameters
20 | ----------
21 | texts : List[str]
22 | The list of texts.
23 | model : Literal["plip", "conch"], default: "plip"
24 | The text embedding model, either PLIP or CONCH
25 |
26 | Returns
27 | -------
28 | pd.DataFrame
29 | The embeddings of the texts, with texts as index.
30 |
31 | Examples
32 | --------
33 | .. code-block:: python
34 |
35 | >>> import lazyslide as zs
36 | >>> wsi = zs.datasets.sample()
37 | >>> zs.pp.find_tissues(wsi)
38 | >>> zs.pp.tile_tissues(wsi, 256, mpp=0.5, key_added="text_tiles")
39 | >>> zs.tl.feature_extraction(wsi, "plip", tile_key="text_tiles")
40 | >>> terms = ["mucosa", "submucosa", "musclaris", "lymphocyte"]
41 | >>> zs.tl.text_embedding(terms, model="plip")
42 |
43 | """
44 | import torch
45 |
46 | if model == "plip":
47 | from lazyslide.models.multimodal import PLIP
48 |
49 | model_ins = PLIP()
50 | elif model == "conch":
51 | from lazyslide.models.multimodal import CONCH
52 |
53 | model_ins = CONCH()
54 | else:
55 | raise ValueError(f"Invalid model: {model}")
56 |
57 | # use numpy record array to store the embeddings
58 | with torch.inference_mode():
59 | embeddings = model_ins.encode_text(texts).detach().cpu().numpy()
60 | return pd.DataFrame(embeddings, index=texts)
61 |
62 |
63 | def text_image_similarity(
64 | wsi: WSIData,
65 | text_embeddings: pd.DataFrame,
66 | model: Literal["plip", "conch"] = "plip",
67 | tile_key: str = Key.tiles,
68 | feature_key: str = None,
69 | key_added: str = None,
70 | ):
71 | """
72 | Compute the similarity between text and image.
73 |
74 | Parameters
75 | ----------
76 | wsi : WSIData
77 | The WSIData object.
78 | text_embeddings : pd.DataFrame
79 | The embeddings of the texts, with texts as index.
80 | You can use :func:`zs.tl.text_embedding ` to get the embeddings.
81 | model : Literal["plip", "conch"], default: "plip"
82 | The text embedding model.
83 | tile_key : str, default: 'tiles'
84 | The tile key.
85 | feature_key : str
86 | The feature key.
87 | key_added : str
88 |
89 | Returns
90 | -------
91 | None
92 |
93 | - The similarity scores will be added to :bdg-danger:`tables` slot of the spatial data object.
94 |
95 | Examples
96 | --------
97 | .. code-block:: python
98 | >>> import lazyslide as zs
99 | >>> wsi = zs.datasets.sample()
100 | >>> zs.pp.find_tissues(wsi)
101 | >>> zs.pp.tile_tissues(wsi, 256, mpp=0.5, key_added="text_tiles")
102 | >>> zs.tl.feature_extraction(wsi, "plip", tile_key="text_tiles")
103 | >>> terms = ["mucosa", "submucosa", "musclaris", "lymphocyte"]
104 | >>> embeddings = zs.tl.text_embedding(terms, model="plip")
105 | >>> zs.tl.text_image_similarity(wsi, embeddings, model="plip", tile_key="text_tiles")
106 |
107 | """
108 |
109 | if feature_key is None:
110 | feature_key = model
111 | feature_key = wsi._check_feature_key(feature_key, tile_key)
112 | key_added = f"{feature_key}_text_similarity" or key_added
113 |
114 | feature_X = wsi.tables[feature_key].X
115 | similarity_score = np.dot(text_embeddings.values, feature_X.T).T
116 |
117 | add_features(
118 | wsi,
119 | key_added,
120 | tile_key,
121 | similarity_score,
122 | var=pd.DataFrame(index=text_embeddings.index),
123 | )
124 |
--------------------------------------------------------------------------------
/src/lazyslide/tools/_tissue_props.py:
--------------------------------------------------------------------------------
1 | from functools import cached_property
2 |
3 | import cv2
4 | import numpy as np
5 | import pandas as pd
6 | from wsidata import WSIData
7 | from wsidata.io import update_shapes_data
8 |
9 | from lazyslide._const import Key
10 |
11 |
12 | def point2shape(
13 | wsi: WSIData,
14 | key: str = "tiles",
15 | groupby: str = None,
16 | ):
17 | pass
18 |
19 |
20 | def tissue_props(
21 | wsi: WSIData,
22 | key: str = Key.tissue,
23 | ):
24 | """Compute a series of geometrical properties of tissue piecies
25 |
26 | - "area"
27 | - "area_filled"
28 | - "convex_area"
29 | - "solidity"
30 | - "convexity"
31 | - "axis_major_length"
32 | - "axis_minor_length"
33 | - "eccentricity"
34 | - "orientation"
35 | - "extent"
36 | - "perimeter"
37 | - "circularity"
38 |
39 | Parameters
40 | ----------
41 | wsi : :class:`WSIData `
42 | The WSIData object.
43 | key : str
44 | The tissue key.
45 |
46 | Returns
47 | -------
48 | None
49 |
50 | - The tissue properties will be added to the same table as the tissue shapes.
51 |
52 | Examples
53 | --------
54 | .. code-block:: python
55 |
56 | >>> import lazyslide as zs
57 | >>> wsi = zs.datasets.sample()
58 | >>> zs.pp.find_tissues(wsi)
59 | >>> zs.tl.tissue_props(wsi)
60 | >>> wsi['tissues']
61 |
62 | """
63 |
64 | props = []
65 | cnts = []
66 | for tissue_contour in wsi.iter.tissue_contours(key):
67 | cnt = tissue_contour.contour
68 | holes = tissue_contour.holes
69 |
70 | cnt_array = np.asarray(cnt.exterior.coords.xy, dtype=np.int32).T
71 | holes_array = [
72 | np.asarray(h.exterior.coords.xy, dtype=np.int32).T for h in holes
73 | ]
74 |
75 | _props = contour_props(cnt_array, holes_array)
76 | cnts.append(cnt)
77 | props.append(_props)
78 |
79 | props = pd.DataFrame(props).to_dict(orient="list")
80 | update_shapes_data(wsi, key, props)
81 |
82 |
83 | class ContourProps:
84 | def __init__(self, cnt, holes=None):
85 | self.cnt = cnt
86 | self.holes = holes
87 |
88 | @cached_property
89 | def area_filled(self):
90 | return cv2.contourArea(self.cnt)
91 |
92 | @cached_property
93 | def area(self):
94 | """Area without holes."""
95 | if self.holes is None:
96 | return self.area_filled
97 | else:
98 | area = self.area_filled
99 | for hole in self.holes:
100 | area -= cv2.contourArea(hole)
101 | return area
102 |
103 | @cached_property
104 | def bbox(self):
105 | x, y, w, h = cv2.boundingRect(self.cnt)
106 | return x, y, w, h
107 |
108 | @cached_property
109 | def centroid(self):
110 | M = self.moments
111 | cX = int(M["m10"] / M["m00"])
112 | cY = int(M["m01"] / M["m00"])
113 | return cX, cY
114 |
115 | @cached_property
116 | def convex_hull(self):
117 | return cv2.convexHull(self.cnt)
118 |
119 | @cached_property
120 | def convex_area(self):
121 | return cv2.contourArea(self.convex_hull)
122 |
123 | @cached_property
124 | def solidity(self):
125 | """Solidity is the ratio of the contour area to the convex area."""
126 | if self.convex_area == 0:
127 | return 0
128 | return self.area / self.convex_area
129 |
130 | @cached_property
131 | def convexity(self):
132 | """Convexity is the ratio of the convex area to the contour area."""
133 | if self.area == 0:
134 | return 0
135 | return self.convex_area / self.area
136 |
137 | @cached_property
138 | def ellipse(self):
139 | return cv2.fitEllipse(self.cnt)
140 |
141 | @cached_property
142 | def axis_major_length(self):
143 | x1, x2 = self.ellipse[1]
144 | if x1 < x2:
145 | return x2
146 | return x1
147 |
148 | @cached_property
149 | def axis_minor_length(self):
150 | x1, x2 = self.ellipse[1]
151 | if x1 < x2:
152 | return x1
153 | return x2
154 |
155 | @cached_property
156 | def eccentricity(self):
157 | if self.axis_major_length == 0:
158 | return 0
159 | return np.sqrt(1 - (self.axis_minor_length**2) / (self.axis_major_length**2))
160 |
161 | @cached_property
162 | def orientation(self):
163 | return self.ellipse[2]
164 |
165 | @cached_property
166 | def extent(self):
167 | if self.area == 0:
168 | return 0
169 | return self.area / (self.bbox[2] * self.bbox[3])
170 |
171 | @cached_property
172 | def perimeter(self):
173 | return cv2.arcLength(self.cnt, True)
174 |
175 | @cached_property
176 | def circularity(self):
177 | if self.perimeter == 0:
178 | return 0
179 | return 4 * np.pi * self.area / (self.perimeter**2)
180 |
181 | @cached_property
182 | def moments(self):
183 | return cv2.moments(self.cnt)
184 |
185 | @cached_property
186 | def moments_hu(self):
187 | return cv2.HuMoments(self.moments)
188 |
189 | def __call__(self):
190 | props = {
191 | "area": self.area,
192 | "area_filled": self.area_filled,
193 | "convex_area": self.convex_area,
194 | "solidity": self.solidity,
195 | "convexity": self.convexity,
196 | "axis_major_length": self.axis_major_length,
197 | "axis_minor_length": self.axis_minor_length,
198 | "eccentricity": self.eccentricity,
199 | "orientation": self.orientation,
200 | "extent": self.extent,
201 | "perimeter": self.perimeter,
202 | "circularity": self.circularity,
203 | }
204 |
205 | for ix, box in enumerate(self.bbox):
206 | props[f"bbox-{ix}"] = box
207 |
208 | for ix, c in enumerate(self.centroid):
209 | props[f"centroid-{ix}"] = c
210 |
211 | for i, hu in enumerate(self.moments_hu):
212 | props[f"hu-{i}"] = hu[0]
213 |
214 | for key, value in self.moments.items():
215 | props[f"moment-{key}"] = value
216 |
217 | return props
218 |
219 |
220 | def contour_props(cnt: np.ndarray, holes=None):
221 | """Calculate the properties of a contour."""
222 | return ContourProps(cnt, holes)()
223 |
--------------------------------------------------------------------------------
/src/lazyslide/tools/_zero_shot.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | from typing import Sequence, List, Iterable
4 |
5 | import numpy as np
6 | import pandas as pd
7 | import torch
8 | from wsidata import WSIData
9 |
10 | from lazyslide._utils import get_torch_device
11 |
12 |
13 | def _preprocess_prompts(prompts: List[str | List[str]]) -> List[List[str]]:
14 | """
15 | Preprocess the prompts to ensure they are in the correct format.
16 | """
17 | processed_prompts = []
18 | for prompt in prompts:
19 | if isinstance(prompt, str):
20 | processed_prompts.append([prompt])
21 | elif isinstance(prompt, Iterable):
22 | processed_prompts.append(list(prompt))
23 | else:
24 | raise ValueError(f"Invalid prompt type: {type(prompt)}")
25 | return processed_prompts
26 |
27 |
28 | def _get_agg_info(
29 | wsi: WSIData,
30 | feature_key,
31 | agg_key: str = None,
32 | agg_by: str | Sequence[str] = None,
33 | ):
34 | if agg_key is None:
35 | if agg_by is None:
36 | agg_key = "agg_slide"
37 | else:
38 | if isinstance(agg_by, str):
39 | agg_by = [agg_by]
40 | agg_key = f"agg_{'_'.join(agg_by)}"
41 | agg_info = wsi[feature_key].uns["agg_ops"][agg_key]
42 | annos = None
43 | if "keys" in agg_info:
44 | annos = pd.DataFrame(
45 | data=agg_info["values"],
46 | columns=agg_info["keys"],
47 | )
48 | return agg_info, annos
49 |
50 |
51 | def zero_shot_score(
52 | wsi: WSIData,
53 | prompts: list[list[str]],
54 | feature_key,
55 | *,
56 | agg_key: str = None,
57 | agg_by: str | Sequence[str] = None,
58 | model: str = "prism",
59 | device: str = None,
60 | ):
61 | """
62 | Perform zero-shot classification on the WSI
63 |
64 | Supported models:
65 | - prism: `Prism model `_.
66 | - titan: `Titan model `_.
67 |
68 | Corresponding slide-level features are required for the model.
69 |
70 |
71 | Parameters
72 | ----------
73 | wsi : :class:`wsidata.WSIData`
74 | The WSI data object.
75 | prompts : array of str
76 | The text labels to classify. You can use a list of strings to
77 | add more information to one class.
78 | feature_key : str
79 | The tile features to be used.
80 | agg_key : str
81 | The aggregation key
82 | agg_by : str or list of str
83 | The aggregation keys that were used to create the slide features.
84 | model: {"prism", "titan"}
85 | The model to use for zero-shot classification.
86 | device : str
87 | The device to use for inference. If None, the default device will be used.
88 |
89 | Returns
90 | -------
91 | pd.DataFrame
92 | The classification results (probability). The columns are the text labels and the
93 | rows are the slide features.
94 |
95 | # - The classification results will be added to :bdg-danger:`tables` slot of the spatial data object.
96 |
97 | Examples
98 | --------
99 | .. code-block:: python
100 |
101 | >>> import lazyslide as zs
102 | >>> wsi = zs.datasets.lung_carcinoma(with_data=False)
103 | >>> zs.pp.find_tissues(wsi)
104 | >>> zs.pp.tile_tissues(wsi, 512, background_fraction=0.95, mpp=0.5)
105 | >>> zs.tl.feature_extraction(wsi, "virchow")
106 | >>> zs.tl.feature_aggregation(wsi, feature_key="virchow", encoder="prism")
107 | >>> print(zs.tl.zero_shot_score(wsi, classes, feature_key="virchow_tiles"))
108 |
109 | """
110 | if device is None:
111 | device = get_torch_device()
112 |
113 | prompts = _preprocess_prompts(prompts)
114 |
115 | if model == "prism":
116 | from lazyslide.models.multimodal import Prism
117 |
118 | model = Prism()
119 | elif model == "titan":
120 | from lazyslide.models.multimodal import Titan
121 |
122 | model = Titan()
123 | model.to(device)
124 | # Get the embeddings from the WSI
125 | agg_info, annos = _get_agg_info(
126 | wsi,
127 | feature_key,
128 | agg_key=agg_key,
129 | agg_by=agg_by,
130 | )
131 |
132 | all_probs = []
133 | for ix, f in enumerate(agg_info["features"]):
134 | f = torch.tensor(f).unsqueeze(0).to(device)
135 | probs = model.score(f, prompts=prompts)
136 | all_probs.append(probs)
137 |
138 | all_probs = np.vstack(all_probs)
139 |
140 | named_prompts = [", ".join(p) for p in prompts]
141 | results = pd.DataFrame(
142 | data=all_probs,
143 | columns=named_prompts,
144 | )
145 | if annos is not None:
146 | results = pd.concat([annos, results], axis=1)
147 | return results
148 |
149 |
150 | def slide_caption(
151 | wsi: WSIData,
152 | prompt: list[str],
153 | feature_key,
154 | *,
155 | agg_key: str = None,
156 | agg_by: str | Sequence[str] = None,
157 | max_length: int = 100,
158 | model: str = "prism",
159 | device: str = None,
160 | ):
161 | """
162 | Generate captions for the slide.
163 |
164 | Parameters
165 | ----------
166 | wsi : :class:`wsidata.WSIData`
167 | The WSI data object.
168 | prompt : list of str
169 | The text instruction to generate the caption.
170 | feature_key : str
171 | The slide features to be used.
172 | agg_key : str
173 | The aggregation key
174 | agg_by : str or list of str
175 | The aggregation keys that were used to create the slide features.
176 | max_length : int
177 | The maximum length of the generated caption.
178 | model : {"prism"}
179 | The caption generation model to use.
180 | device : str
181 | The device to use for inference. If None, the default device will be used.
182 |
183 | """
184 |
185 | if device is None:
186 | device = get_torch_device()
187 |
188 | from lazyslide.models.multimodal import Prism
189 |
190 | model = Prism()
191 | model.to(device)
192 |
193 | agg_info, annos = _get_agg_info(
194 | wsi,
195 | feature_key,
196 | agg_key=agg_key,
197 | agg_by=agg_by,
198 | )
199 |
200 | captions = []
201 |
202 | for ix, lat in enumerate(agg_info["latents"]):
203 | lat = torch.tensor(lat).unsqueeze(0).to(device)
204 | caption = model.caption(
205 | lat,
206 | prompt=prompt,
207 | max_length=max_length,
208 | )
209 | captions.append(caption)
210 |
211 | results = pd.DataFrame(
212 | {
213 | "caption": captions,
214 | }
215 | )
216 | if annos is not None:
217 | results = pd.concat([annos, results], axis=1)
218 | return results
219 |
--------------------------------------------------------------------------------
/tests/conftest.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import pytest
4 | import torch
5 |
6 |
7 | class MockNet(torch.nn.Module):
8 | def __init__(self):
9 | super().__init__()
10 |
11 | def forward(self, x):
12 | return torch.zeros(x.shape[0], 1000)
13 |
14 |
15 | @pytest.fixture(scope="session", autouse=True)
16 | def wsi():
17 | import lazyslide as zs
18 |
19 | return zs.datasets.gtex_artery()
20 |
21 |
22 | @pytest.fixture(scope="session")
23 | def tmp_path_session(tmp_path_factory):
24 | return tmp_path_factory.mktemp("session_tmp")
25 |
26 |
27 | @pytest.fixture(scope="session", autouse=True)
28 | def torch_model_file(tmp_path_session):
29 | model = MockNet()
30 | torch.save(model, tmp_path_session / "model.pt")
31 | return tmp_path_session / "model.pt"
32 |
33 |
34 | @pytest.fixture(scope="session", autouse=True)
35 | def torch_jit_file(tmp_path_session):
36 | model = MockNet()
37 | torch.jit.script(model).save(tmp_path_session / "jit_model.pt")
38 | return tmp_path_session / "jit_model.pt"
39 |
40 |
41 | def pytest_collection_modifyitems(config, items):
42 | if os.getenv("GITHUB_ACTIONS") == "true":
43 | skip_on_ci = pytest.mark.skip(reason="Skipped on GitHub CI")
44 | for item in items:
45 | if "skip_on_ci" in item.keywords:
46 | item.add_marker(skip_on_ci)
47 |
--------------------------------------------------------------------------------
/tests/data/CMU-1-Small-Region.svs:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rendeirolab/LazySlide/f39634cc994b3098b0933075b9d25ecd99b9014e/tests/data/CMU-1-Small-Region.svs
--------------------------------------------------------------------------------
/tests/test_cv.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | import numpy as np
3 |
4 |
5 | np.random.seed(42)
6 |
7 | H, W = 100, 100
8 | N_CLASS = 5
9 | binary_mask = np.random.randint(0, 2, (H, W), dtype=np.uint8)
10 | multilabel_mask = np.random.randint(0, N_CLASS, (H, W), dtype=np.uint8)
11 | multiclass_mask = np.random.randint(
12 | 0,
13 | 2,
14 | (
15 | N_CLASS,
16 | H,
17 | W,
18 | ),
19 | dtype=np.uint8,
20 | )
21 |
22 |
23 | class TestMask:
24 | @pytest.mark.parametrize("mask", [binary_mask, multilabel_mask, multiclass_mask])
25 | def test_mask_to_polygon(self, mask):
26 | from lazyslide.cv.mask import Mask
27 |
28 | mask = Mask.from_array(mask)
29 | mask.to_polygons()
30 |
--------------------------------------------------------------------------------
/tests/test_datasets.py:
--------------------------------------------------------------------------------
1 | import lazyslide as zs
2 |
3 |
4 | def test_load_sample():
5 | wsi = zs.datasets.sample()
6 | assert wsi is not None
7 |
8 |
9 | def test_load_gtex_artery():
10 | wsi = zs.datasets.gtex_artery()
11 | assert wsi is not None
12 |
13 |
14 | def test_load_lung_carcinoma():
15 | wsi = zs.datasets.lung_carcinoma()
16 | assert wsi is not None
17 |
--------------------------------------------------------------------------------
/tests/test_pp.py:
--------------------------------------------------------------------------------
1 | import pytest
2 |
3 | import lazyslide as zs
4 |
5 |
6 | @pytest.mark.parametrize("detect_holes", [True, False])
7 | @pytest.mark.parametrize("key_added", ["tissue", "tissue2"])
8 | def test_pp_find_tissues(wsi, detect_holes, key_added):
9 | zs.pp.find_tissues(wsi, detect_holes=detect_holes, key_added=key_added)
10 |
11 | assert key_added in wsi.shapes
12 | if not detect_holes:
13 | tissue = wsi[key_added].geometry[0]
14 | assert len(tissue.interiors) == 0
15 |
16 |
17 | class TestPPTileTissues:
18 | def test_tile_px(self, wsi):
19 | zs.pp.find_tissues(wsi)
20 | zs.pp.tile_tissues(wsi, 256, key_added="tiles")
21 |
22 | def test_mpp(self, wsi):
23 | zs.pp.tile_tissues(wsi, 256, mpp=1, key_added="tiles1")
24 |
25 | @pytest.mark.xfail(raises=ValueError)
26 | def test_slide_mpp(self, wsi):
27 | zs.pp.tile_tissues(wsi, 256, slide_mpp=1, key_added="tiles2")
28 |
29 | def test_assert(self, wsi):
30 | s0 = len(wsi["tiles"])
31 | s1 = len(wsi["tiles1"])
32 |
33 | assert s0 > 0
34 | assert s1 < s0
35 |
--------------------------------------------------------------------------------
/tests/test_tl.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | import lazyslide as zs
3 |
4 | TIMM_MODEL = "mobilenetv3_small_050"
5 |
6 |
7 | class TestFeatureExtraction:
8 | def test_load_model(self, wsi, torch_model_file):
9 | zs.pp.find_tissues(wsi)
10 | zs.pp.tile_tissues(wsi, 512)
11 | zs.tl.feature_extraction(wsi, model_path=torch_model_file)
12 | # Test feature aggregation
13 | zs.tl.feature_aggregation(wsi, feature_key="MockNet")
14 |
15 | def test_load_jit_model(self, wsi, torch_jit_file):
16 | zs.tl.feature_extraction(wsi, model_path=torch_jit_file)
17 |
18 | @pytest.mark.skip_on_ci
19 | def test_timm_model(self, wsi):
20 | zs.tl.feature_extraction(wsi, model=TIMM_MODEL)
21 |
--------------------------------------------------------------------------------
/workflow/main.nf:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env nextflow
2 | nextflow.enable.dsl = 2
3 |
4 | params.slide_table = null
5 | params.tile_px = 256
6 | params.report_dir = "reports"
7 | params.models = "resnet50"
8 |
9 | process PREPROCESS {
10 | publishDir params.report_dir, mode: 'move'
11 | // conda "${projectDir}/env.yaml"
12 |
13 | input:
14 | tuple val(wsi), val(storage)
15 | val tile_px
16 |
17 | output:
18 | path '*_report.txt', emit: report
19 | tuple val(wsi), val(storage), emit: slide
20 |
21 | script:
22 |
23 | def wsi_base = wsi.baseName
24 |
25 | """
26 | lazyslide preprocess ${wsi} ${tile_px} --output ${storage}
27 | touch ${wsi_base}_report.txt
28 | """
29 | }
30 |
31 | process FEATURE {
32 | // conda "${projectDir}/env.yaml"
33 |
34 | input:
35 | tuple val(wsi), val(storage)
36 | each model
37 |
38 | script:
39 | """
40 | lazyslide feature ${wsi} ${model} --output ${storage}
41 | """
42 | }
43 |
44 |
45 |
46 | workflow {
47 |
48 | log.info """
49 | ██ █████ ███████ ██ ██ ███████ ██ ██ ██████ ███████
50 | ██ ██ ██ ███ ██ ██ ██ ██ ██ ██ ██ ██
51 | ██ ███████ ███ ████ ███████ ██ ██ ██ ██ █████
52 | ██ ██ ██ ███ ██ ██ ██ ██ ██ ██ ██
53 | ███████ ██ ██ ███████ ██ ███████ ███████ ██ ██████ ███████
54 |
55 | ===================================================================
56 |
57 | Workflow information:
58 | Workflow: ${workflow.projectDir}
59 |
60 | Input parameters:
61 | Slide table: ${file(params.slide_table)}
62 |
63 | """
64 |
65 | slides_ch = Channel
66 | .fromPath( params.slide_table, checkIfExists: true )
67 | .splitCsv( header: true )
68 | .map { row ->
69 | def slide_file = file(row.file, checkIfExists: true)
70 | def slide_storage = row.storage
71 | if (row.storage == null) { slide_storage = slide_file.parent / slide_file.baseName + ".zarr" }
72 | return tuple(slide_file, slide_storage)
73 | }
74 |
75 | // slides_ch.view()
76 |
77 | out_ch = PREPROCESS(slides_ch, params.tile_px)
78 |
79 | // println "Ouput of PREPROCESS: "
80 | // out_ch.slide.view()
81 |
82 | models = Channel.of(params.models?.split(','))
83 |
84 | FEATURE(out_ch.slide, models)
85 |
86 | }
--------------------------------------------------------------------------------
/workflow/modules/qc/main.nf:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env nextflow
2 | nextflow.enable.dsl = 2
3 |
4 | process SlideQC {
5 |
6 | input:
7 | val mpp
8 | val
9 | path slide
10 |
11 | output:
12 | path("*.qc.csv") into qc_ch
13 |
14 | script:
15 | """
16 | lazyslide qc $slide
17 | """
18 | }
--------------------------------------------------------------------------------