├── notebooks ├── .gitkeep └── visualization.ipynb ├── .gitattributes ├── tests ├── __init__.py ├── metrics │ ├── test_retrieval.py │ └── test_cosine_similarity.py └── models │ └── test_extractors.py ├── plismbench ├── __init__.py ├── engine │ ├── extract │ │ ├── __init__.py │ │ ├── core.py │ │ ├── utils.py │ │ ├── extract_from_h5.py │ │ └── extract_from_png.py │ ├── __init__.py │ ├── cli.py │ └── evaluate.py ├── utils │ ├── __init__.py │ ├── core.py │ ├── aggregate.py │ ├── evaluate.py │ ├── metrics.py │ └── viz.py ├── metrics │ ├── __init__.py │ ├── base.py │ ├── cosine_similarity.py │ └── retrieval.py └── models │ ├── meta.py │ ├── microsoft.py │ ├── standford.py │ ├── extractor.py │ ├── lunit.py │ ├── hkust.py │ ├── utils.py │ ├── owkin.py │ ├── histai.py │ ├── paige_ai.py │ ├── __init__.py │ ├── kaiko_ai.py │ ├── bioptimus.py │ └── mahmood_lab.py ├── mypy.ini ├── assets ├── figure.png ├── tiles_subset_2713.npy ├── tiles_subset_460.npy ├── tiles_subset_5426.npy └── tiles_subset_8139.npy ├── .pydocstyle ├── doc8.ini ├── docs ├── _static │ ├── logo-dark.png │ └── logo-light.png ├── index.rst ├── Makefile ├── make.bat ├── contributing.rst └── conf.py ├── pytest.ini ├── CODEOWNERS ├── .github └── workflows │ ├── publish.yml │ ├── page.yml │ └── python-app.yml ├── scripts └── evaluate.sh ├── .pre-commit-config.yaml ├── .gitignore ├── Makefile ├── pyproject.toml └── README.md /notebooks/.gitkeep: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /.gitattributes: -------------------------------------------------------------------------------- 1 | *.pth filter=lfs diff=lfs merge=lfs -text 2 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- 1 | """Unit tests for :mod:`plismbench`.""" 2 | -------------------------------------------------------------------------------- /plismbench/__init__.py: -------------------------------------------------------------------------------- 1 | """Top level package for :mod:`plismbench`.""" 2 | -------------------------------------------------------------------------------- /plismbench/engine/extract/__init__.py: -------------------------------------------------------------------------------- 1 | """Feature extraction module.""" 2 | -------------------------------------------------------------------------------- /plismbench/utils/__init__.py: -------------------------------------------------------------------------------- 1 | """A module for utility functionalities.""" 2 | -------------------------------------------------------------------------------- /mypy.ini: -------------------------------------------------------------------------------- 1 | [mypy] 2 | ignore_errors = False 3 | ignore_missing_imports = True 4 | -------------------------------------------------------------------------------- /assets/figure.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/owkin/plism-benchmark/HEAD/assets/figure.png -------------------------------------------------------------------------------- /.pydocstyle: -------------------------------------------------------------------------------- 1 | [pydocstyle] 2 | convention = numpy 3 | match = (?!_).*\.py 4 | add_ignore = D105 5 | -------------------------------------------------------------------------------- /plismbench/engine/__init__.py: -------------------------------------------------------------------------------- 1 | """Top-level module for training and inference functionalities.""" 2 | -------------------------------------------------------------------------------- /doc8.ini: -------------------------------------------------------------------------------- 1 | [doc8] 2 | path = docs/ 3 | ignore-path = docs/_build 4 | max-line-length = 88 5 | verbose = 1 6 | -------------------------------------------------------------------------------- /assets/tiles_subset_2713.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/owkin/plism-benchmark/HEAD/assets/tiles_subset_2713.npy -------------------------------------------------------------------------------- /assets/tiles_subset_460.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/owkin/plism-benchmark/HEAD/assets/tiles_subset_460.npy -------------------------------------------------------------------------------- /assets/tiles_subset_5426.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/owkin/plism-benchmark/HEAD/assets/tiles_subset_5426.npy -------------------------------------------------------------------------------- /assets/tiles_subset_8139.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/owkin/plism-benchmark/HEAD/assets/tiles_subset_8139.npy -------------------------------------------------------------------------------- /docs/_static/logo-dark.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/owkin/plism-benchmark/HEAD/docs/_static/logo-dark.png -------------------------------------------------------------------------------- /docs/_static/logo-light.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/owkin/plism-benchmark/HEAD/docs/_static/logo-light.png -------------------------------------------------------------------------------- /pytest.ini: -------------------------------------------------------------------------------- 1 | [pytest] 2 | testpaths = tests 3 | junit_family = xunit2 4 | markers = 5 | local: mark a test as a local test. 6 | -------------------------------------------------------------------------------- /plismbench/metrics/__init__.py: -------------------------------------------------------------------------------- 1 | """Module for metrics in plismbench.""" 2 | 3 | from .cosine_similarity import CosineSimilarity 4 | from .retrieval import TopkAccuracy 5 | 6 | 7 | __all__ = ["CosineSimilarity", "TopkAccuracy"] 8 | -------------------------------------------------------------------------------- /docs/index.rst: -------------------------------------------------------------------------------- 1 | PLISM robustness benchmark 2 | =========================================================== 3 | 4 | Repository hosting PLIM robustness benchmark 5 | 6 | :Release: |version| 7 | :Date: |today| 8 | 9 | .. toctree:: 10 | :maxdepth: 2 11 | :caption: Contents: 12 | 13 | Getting Started 14 | User Guide 15 | API Reference 16 | 17 | 18 | Indices and tables 19 | ================== 20 | * :ref:`genindex` 21 | * :ref:`modindex` 22 | * :ref:`search` 23 | -------------------------------------------------------------------------------- /CODEOWNERS: -------------------------------------------------------------------------------- 1 | # Lines starting with '#' are comments. 2 | # Each line is a file pattern followed by one or more owners. 3 | 4 | # These owners will be the default owners for everything in the repo. 5 | * @afilt 6 | 7 | # Order is important. The last matching pattern has the most precedence. 8 | # So if a pull request only touches javascript files, only these owners 9 | # will be requested to review. 10 | # *.js @octocat @github/js 11 | 12 | # You can also use email addresses if you prefer. 13 | # docs/* docs@example.com 14 | -------------------------------------------------------------------------------- /.github/workflows/publish.yml: -------------------------------------------------------------------------------- 1 | name: Publish 2 | on: 3 | release: 4 | types: [published] 5 | jobs: 6 | publish-on-pypi-servers: 7 | runs-on: ubuntu-latest 8 | steps: 9 | - name: Checkout repo and submodule 10 | uses: actions/checkout@v3 11 | 12 | - name: Setup Python 13 | uses: actions/setup-python@v4 14 | with: 15 | python-version: "3.10" 16 | 17 | - name: Install Poetry 18 | run: make install-poetry 19 | 20 | - name: Install dependencies 21 | run: poetry install --all-extras 22 | -------------------------------------------------------------------------------- /docs/Makefile: -------------------------------------------------------------------------------- 1 | # Minimal makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line. 5 | SPHINXOPTS = 6 | SPHINXBUILD = python -msphinx 7 | SPHINXPROJ = plismbench 8 | SOURCEDIR = . 9 | BUILDDIR = _build 10 | 11 | # Put it first so that "make" without argument is like "make help". 12 | help: 13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 14 | 15 | .PHONY: help Makefile 16 | 17 | # Catch-all target: route all unknown targets to Sphinx using the new 18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 19 | %: Makefile 20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 21 | -------------------------------------------------------------------------------- /docs/make.bat: -------------------------------------------------------------------------------- 1 | @ECHO OFF 2 | 3 | pushd %~dp0 4 | 5 | REM Command file for Sphinx documentation 6 | 7 | if "%SPHINXBUILD%" == "" ( 8 | set SPHINXBUILD=python -msphinx 9 | ) 10 | set SOURCEDIR=. 11 | set BUILDDIR=_build 12 | set SPHINXPROJ=plismbench 13 | 14 | if "%1" == "" goto help 15 | 16 | %SPHINXBUILD% >NUL 2>NUL 17 | if errorlevel 9009 ( 18 | echo. 19 | echo.The Sphinx module was not found. Make sure you have Sphinx installed, 20 | echo.then set the SPHINXBUILD environment variable to point to the full 21 | echo.path of the 'sphinx-build' executable. Alternatively you may add the 22 | echo.Sphinx directory to PATH. 23 | echo. 24 | echo.If you don't have Sphinx installed, grab it from 25 | echo.http://sphinx-doc.org/ 26 | exit /b 1 27 | ) 28 | 29 | %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% 30 | goto end 31 | 32 | :help 33 | %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% 34 | 35 | :end 36 | popd 37 | -------------------------------------------------------------------------------- /plismbench/metrics/base.py: -------------------------------------------------------------------------------- 1 | """Module for base metric object.""" 2 | 3 | from abc import abstractmethod 4 | 5 | from loguru import logger 6 | 7 | 8 | try: 9 | import cupy as cp 10 | except ImportError as error: 11 | logger.error( 12 | f"cupy is not installed. Please run `make install-cupy`.\nError: {error}." 13 | ) 14 | import numpy as np 15 | 16 | 17 | class BasePlismMetric: 18 | """Base class for metrics. 19 | 20 | Attributes 21 | ---------- 22 | device: str: Literal["cpu", "gpu"] 23 | Device to use for computation. 24 | """ 25 | 26 | def __init__(self, device: str, use_mixed_precision: bool = True): 27 | self.device = device 28 | self.ncp = cp if device == "gpu" else np 29 | self.use_mixed_precision = use_mixed_precision 30 | 31 | @abstractmethod 32 | def compute_metric(self, matrix_a: np.ndarray, matrix_b: np.ndarray): 33 | """Compute metric between feature matrices A and B.""" 34 | raise NotImplementedError 35 | -------------------------------------------------------------------------------- /scripts/evaluate.sh: -------------------------------------------------------------------------------- 1 | # List of feature extractors to compute metrics on 2 | extractors=( 3 | virchow2 4 | h0_mini 5 | phikon 6 | phikon_v2 7 | uni 8 | uni2h 9 | hoptimus0 10 | provgigapath 11 | hibou_vit_large 12 | kaiko_vit_base_8 13 | kaiko_vit_large_14 14 | plip 15 | conch 16 | ) 17 | 18 | # List of number of tiles to compute metrics on: the reference for leaderboard is 8,139 ! 19 | n_tiles=( 20 | 8139 21 | 5426 22 | 2713 23 | ) 24 | 25 | # Set features and metrics directories 26 | features_dir=/home/owkin/project/plism_features/ 27 | metrics_dir=/home/owkin/project/plism_metrics/ 28 | 29 | # Iterate over number of tiles 30 | for _n_tiles in ${n_tiles[*]} 31 | do 32 | # Iterate over extractors 33 | for extractor in ${extractors[*]} 34 | do 35 | plismbench evaluate \ 36 | --extractor "${extractor}" \ 37 | --features-dir $features_dir \ 38 | --metrics-dir $metrics_dir \ 39 | --n-tiles $_n_tiles \ 40 | --device "gpu" 41 | done 42 | done 43 | -------------------------------------------------------------------------------- /plismbench/utils/core.py: -------------------------------------------------------------------------------- 1 | """Generic utilty functions.""" 2 | 3 | import os 4 | import pickle 5 | from pathlib import Path 6 | from typing import Any 7 | 8 | import requests 9 | 10 | 11 | def load_pickle(file_path: str | Path) -> Any: 12 | """Load pickle.""" 13 | with open(file_path, "rb") as handle: 14 | return pickle.load(handle) 15 | 16 | 17 | def write_pickle(data: Any, file_path: str | Path) -> None: 18 | """Write data into a pickle file.""" 19 | with open(file_path, "wb") as handle: 20 | pickle.dump(data, handle) 21 | 22 | 23 | def download_state_dict(url: str, name: str) -> str: 24 | """Download checkpoint from a given URL and store it to disk.""" 25 | output_path = os.path.join(os.environ["HOME"], name) 26 | if os.path.exists(output_path): 27 | pass 28 | else: 29 | response = requests.get(url, stream=True) 30 | response.raise_for_status() # Raise error if download fails 31 | with open(output_path, "wb") as f: 32 | for chunk in response.iter_content(chunk_size=8192): 33 | f.write(chunk) 34 | return output_path 35 | -------------------------------------------------------------------------------- /tests/metrics/test_retrieval.py: -------------------------------------------------------------------------------- 1 | """Test module for retrieval metrics.""" 2 | 3 | import numpy as np 4 | import pytest 5 | from loguru import logger 6 | 7 | from plismbench.metrics.retrieval import TopkAccuracy 8 | 9 | 10 | # Test data 11 | test_data = [ 12 | ( 13 | np.array([[0, 1], [1, 0], [0, -1], [-1, 0]]), 14 | np.array([[0.1, 1.1], [0.2, 1.1], [0.3, 1.1], [0.4, 1.1]]), 15 | [1, 7], 16 | np.array([0.125, 1]), 17 | ), 18 | (np.random.rand(5, 100), np.random.rand(5, 100), [9], np.array([1])), 19 | ] 20 | 21 | 22 | @pytest.mark.parametrize(("matrix_a", "matrix_b", "k", "expected"), test_data) 23 | def test_topk_accuracy(matrix_a, matrix_b, k, expected): 24 | """Test top-k accuracy metric.""" 25 | # Test cpu 26 | metric = TopkAccuracy(device="cpu", k=k) 27 | result = metric.compute_metric(matrix_a, matrix_b) 28 | 29 | # Check np array equality 30 | assert result == pytest.approx(expected) 31 | 32 | 33 | @pytest.mark.local 34 | @pytest.mark.parametrize(("matrix_a", "matrix_b", "k", "expected"), test_data) 35 | def test_topk_accuracy_gpu(matrix_a, matrix_b, k, expected): 36 | """Test top-k accuracy metric on GPU.""" 37 | import cupy as cp 38 | 39 | # Check first if a GPU is available 40 | if cp.cuda.is_available(): 41 | # Test gpu 42 | metric = TopkAccuracy(device="gpu", k=k) 43 | result = metric.compute_metric(matrix_a, matrix_b) 44 | assert result == pytest.approx(expected) 45 | else: 46 | logger.info("No GPU available. Skipping GPU test.") 47 | -------------------------------------------------------------------------------- /tests/metrics/test_cosine_similarity.py: -------------------------------------------------------------------------------- 1 | """Test module for cosine similarity metrics.""" 2 | 3 | import numpy as np 4 | import pytest 5 | from loguru import logger 6 | 7 | from plismbench.metrics.cosine_similarity import CosineSimilarity 8 | 9 | 10 | # Test data 11 | test_data = [ 12 | ( 13 | np.array([[0, 1], [1, 0], [1, 1], [1, 1]]), 14 | np.array([[0, 2], [2, 0], [-1, -1], [0, 1]]), 15 | 0.426776695, # Mean cosine similarity should be (1 + 1 - 1 + 1/sqrt(2)) / 4 \approx 0.426776695 16 | ) 17 | ] 18 | 19 | 20 | @pytest.mark.parametrize(("matrix_a", "matrix_b", "expected"), test_data) 21 | def test_cosine_similarity(matrix_a, matrix_b, expected): 22 | """Test cosine similarity metric.""" 23 | # Test cpu 24 | metric = CosineSimilarity(device="cpu", use_mixed_precision=False) 25 | result = metric.compute_metric(matrix_a, matrix_b) 26 | assert result == pytest.approx(expected) 27 | 28 | 29 | @pytest.mark.local 30 | @pytest.mark.parametrize(("matrix_a", "matrix_b", "expected"), test_data) 31 | def test_cosine_similarity_gpu(matrix_a, matrix_b, expected): 32 | """Test cosine similarity metric on GPU.""" 33 | import cupy as cp # Do the import here to avoid CI issues. 34 | 35 | # Check first if a GPU is available 36 | if cp.cuda.is_available(): 37 | # Test gpu 38 | metric = CosineSimilarity(device="gpu", use_mixed_precision=False) 39 | result = metric.compute_metric(matrix_a, matrix_b) 40 | assert result == pytest.approx(expected) 41 | else: 42 | logger.info("No GPU available. Skipping GPU test.") 43 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | - repo: https://github.com/sirosen/check-jsonschema 3 | rev: 0.28.0 4 | hooks: 5 | - id: check-github-actions 6 | - id: check-github-workflows 7 | - repo: https://github.com/pre-commit/pre-commit-hooks 8 | rev: v4.5.0 9 | hooks: 10 | - id: trailing-whitespace 11 | name: Trim trailing whitespace 12 | - id: end-of-file-fixer 13 | name: Fix end of files 14 | exclude: \.ipynb$ 15 | - repo: https://github.com/python-poetry/poetry 16 | rev: 1.8.4 17 | hooks: 18 | - id: poetry-check 19 | name: Run poetry check to validate configuration 20 | - repo: local 21 | hooks: 22 | - id: ruff 23 | name: Check linting with `ruff` 24 | language: system 25 | types: [python] 26 | require_serial: true 27 | entry: poetry run ruff check 28 | args: [--fix] 29 | files: ^(plismbench|tests|scripts)/ 30 | - id: ruff-format 31 | name: Format files with `ruff` 32 | language: system 33 | types: [python] 34 | require_serial: true 35 | entry: poetry run ruff format 36 | files: ^(plismbench|tests|scripts)/ 37 | - repo: local 38 | hooks: 39 | - id: mypy 40 | name: Test typing with `mypy` 41 | language: system 42 | types: [python] 43 | require_serial: true 44 | entry: poetry run mypy 45 | files: ^plismbench/ 46 | - repo: local 47 | hooks: 48 | - id: jupyter-nb-clear-output 49 | name: Clear Jupyter notebook outputs 50 | files: \.ipynb$ 51 | language: system 52 | entry: poetry run jupyter nbconvert 53 | args: [--ClearOutputPreprocessor.enabled=True, --inplace] 54 | -------------------------------------------------------------------------------- /.github/workflows/page.yml: -------------------------------------------------------------------------------- 1 | name: Deploy doc 2 | 3 | on: 4 | # Runs on pushes targeting the main branch (uncomment to next 3 lines to enable) 5 | push: 6 | branches: 7 | - main 8 | 9 | # Allows you to run this workflow manually from the Actions tab 10 | workflow_dispatch: 11 | 12 | # Sets permissions of the GITHUB_TOKEN to allow deployment to GitHub Pages 13 | permissions: 14 | contents: read 15 | pages: write 16 | id-token: write 17 | 18 | # Allow one concurrent deployment 19 | concurrency: 20 | group: "pages" 21 | cancel-in-progress: true 22 | 23 | jobs: 24 | # Build job 25 | build: 26 | runs-on: ubuntu-latest 27 | steps: 28 | - uses: actions/checkout@v4 29 | 30 | - name: Free Disk Space (Ubuntu) 31 | uses: jlumbroso/free-disk-space@main 32 | with: 33 | tool-cache: true 34 | android: true 35 | dotnet: true 36 | haskell: true 37 | large-packages: true 38 | docker-images: true 39 | swap-storage: true 40 | 41 | - name: Set up Python 42 | uses: actions/setup-python@v5 43 | with: 44 | python-version: "3.10" 45 | 46 | - name: Install Poetry 47 | run: make install-poetry 48 | 49 | - name: Install dependencies 50 | run: poetry install --all-extras --with=docs 51 | 52 | - name: Build docs 53 | run: poetry run make docs 54 | 55 | - name: Upload artifact 56 | uses: actions/upload-pages-artifact@v3 57 | with: 58 | path: docs/_build/html 59 | 60 | # Deployment job 61 | deploy: 62 | environment: 63 | name: github-pages 64 | url: ${{ steps.deployment.outputs.page_url }} 65 | runs-on: ubuntu-latest 66 | needs: build 67 | steps: 68 | - name: Deploy to GitHub Pages 69 | id: deployment 70 | uses: actions/deploy-pages@v4 71 | -------------------------------------------------------------------------------- /plismbench/metrics/cosine_similarity.py: -------------------------------------------------------------------------------- 1 | """Module for cosine similarity metric.""" 2 | 3 | from plismbench.metrics.base import BasePlismMetric 4 | 5 | 6 | class CosineSimilarity(BasePlismMetric): 7 | """Cosine similarity metric.""" 8 | 9 | def __init__(self, device: str, use_mixed_precision: bool = True): 10 | super().__init__(device, use_mixed_precision) 11 | 12 | def compute_metric(self, matrix_a, matrix_b): 13 | """Compute cosine similarity metric.""" 14 | # Compute cosine simlilarity for each pair of tiles between features 15 | # matrix a and b. 16 | if matrix_a.shape[0] != matrix_b.shape[0]: 17 | raise ValueError( 18 | f"Number of tiles must match. Got {matrix_a.shape[0]} and {matrix_b.shape[0]}." 19 | ) 20 | 21 | # Put matrix_a and matrix_b on the gpu if needed 22 | matrix_a = self.ncp.asarray(matrix_a) # shape (n_tiles, n_features) 23 | matrix_b = self.ncp.asarray(matrix_b) # shape (n_tiles, n_features) 24 | 25 | if self.use_mixed_precision: 26 | matrix_a = matrix_a.astype(self.ncp.float16) 27 | matrix_b = matrix_b.astype(self.ncp.float16) 28 | 29 | # Compute cosine similarity 30 | dot_product_ab = self.ncp.matmul( 31 | matrix_a, matrix_b.T 32 | ) # shape (n_tiles, n_tiles) 33 | norm_a = self.ncp.linalg.norm( 34 | matrix_a, axis=1, keepdims=True 35 | ) # shape (n_tiles, ) 36 | norm_b = self.ncp.linalg.norm( 37 | matrix_b, axis=1, keepdims=True 38 | ) # shape (n_tiles, ) 39 | 40 | cosine_ab = dot_product_ab / (norm_a * norm_b.T) # shape (n_tiles, n_tiles) 41 | 42 | _mean_cosine_ab = self.ncp.diag(cosine_ab).mean() 43 | mean_cosine_ab = ( 44 | float(_mean_cosine_ab.get()) 45 | if self.device == "gpu" 46 | else float(_mean_cosine_ab) 47 | ) 48 | 49 | return mean_cosine_ab 50 | -------------------------------------------------------------------------------- /plismbench/engine/extract/core.py: -------------------------------------------------------------------------------- 1 | """Perform features extraction from PLISM dataset.""" 2 | 3 | from __future__ import annotations 4 | 5 | from pathlib import Path 6 | 7 | from plismbench.engine.extract.extract_from_h5 import run_extract_h5 8 | from plismbench.engine.extract.extract_from_png import run_extract_streaming 9 | 10 | 11 | def run_extract( 12 | feature_extractor_name: str, 13 | batch_size: int, 14 | device: int, 15 | export_dir: Path, 16 | download_dir: Path | None = None, 17 | streaming: bool = False, 18 | overwrite: bool = False, 19 | workers: int = 8, 20 | ) -> None: 21 | """Run features extraction. 22 | 23 | If ``stream==False``, data will be downloaded and stored to disk from 24 | https://huggingface.co/datasets/owkin/plism-dataset. This dataset 25 | contains 91 .h5 files each containing 16,278 images converted 26 | into numpy arrays. In this scenario, 300Gb storage are necessary. 27 | 28 | If ``stream==True``, data will be downloaded on the fly from 29 | https://huggingface.co/datasets/owkin/plism-dataset-tiles but not 30 | stored to disk. This dataset contains 91x16278 images stored as .png 31 | files. Streaming is enable using the ``datasets`` library and 32 | `datasets.load_dataset(..., streaming=True)`. Note that this comes 33 | with the limitation to use ``IterableDataset`` meaning that no easy 34 | resume can be performed if the features extraction fails. 35 | """ 36 | if streaming: 37 | run_extract_streaming( 38 | feature_extractor_name=feature_extractor_name, 39 | batch_size=batch_size, 40 | device=device, 41 | export_dir=export_dir, 42 | overwrite=overwrite, 43 | ) 44 | else: 45 | assert isinstance(download_dir, str), "Download directory should be specified." 46 | run_extract_h5( 47 | feature_extractor_name=feature_extractor_name, 48 | batch_size=batch_size, 49 | device=device, 50 | export_dir=export_dir, 51 | download_dir=download_dir, 52 | overwrite=overwrite, 53 | workers=workers, 54 | ) 55 | -------------------------------------------------------------------------------- /.github/workflows/python-app.yml: -------------------------------------------------------------------------------- 1 | name: Python dev 2 | 3 | on: 4 | pull_request: 5 | push: 6 | branches: 7 | - main 8 | 9 | jobs: 10 | testing: 11 | runs-on: ubuntu-latest 12 | strategy: 13 | matrix: 14 | python: ["3.10"] 15 | name: Testing Python ${{ matrix.python }} 16 | steps: 17 | - uses: actions/checkout@v4 18 | 19 | - name: Free Disk Space (Ubuntu) 20 | uses: jlumbroso/free-disk-space@main 21 | with: 22 | tool-cache: true 23 | android: true 24 | dotnet: true 25 | haskell: true 26 | large-packages: true 27 | docker-images: true 28 | swap-storage: true 29 | 30 | - name: Set up Python 31 | uses: actions/setup-python@v5 32 | with: 33 | python-version: ${{ matrix.python }} 34 | 35 | - name: Install Poetry 36 | run: make install-poetry 37 | 38 | - name: Install Hugging Face CLI 39 | run: pip install huggingface_hub 40 | 41 | - name: Log in to Hugging Face 42 | run: | 43 | python -c "from huggingface_hub import login; login(token='${{ secrets.HF_TOKEN }}', new_session=False)" 44 | 45 | - name: Install git lfs 46 | run: | 47 | git lfs install 48 | git lfs pull 49 | 50 | - name: Install dependencies 51 | run: make install-all 52 | 53 | - name: Testing 54 | run: poetry run make test 55 | 56 | linting: 57 | runs-on: ubuntu-latest 58 | name: Test Linting 59 | steps: 60 | - uses: actions/checkout@v4 61 | - name: Set up Python 62 | uses: actions/setup-python@v5 63 | with: 64 | python-version: "3.10" 65 | 66 | - name: Install Poetry 67 | run: make install-poetry 68 | 69 | - name: Install dependencies 70 | run: poetry install --all-extras --with=linting,docs 71 | 72 | - name: Test pre-commit checks 73 | run: poetry run make pre-commit-checks 74 | 75 | - name: Test linting 76 | run: poetry run make lint 77 | 78 | - name: Test typing 79 | run: poetry run make typing 80 | 81 | - name: Test docs 82 | run: poetry run make test-docs 83 | -------------------------------------------------------------------------------- /plismbench/engine/extract/utils.py: -------------------------------------------------------------------------------- 1 | """Utility functionalities for the extraction pipeline.""" 2 | 3 | from pathlib import Path 4 | 5 | import numpy as np 6 | import pandas as pd 7 | import torch 8 | 9 | from plismbench.models.extractor import Extractor 10 | 11 | 12 | # Do not touch those values as PLISM dataset contains 91 slides x 16278 tiles 13 | NUM_SLIDES: int = 91 14 | NUM_TILES_PER_SLIDE: int = 16_278 15 | 16 | 17 | def sort_coords(slide_features: np.ndarray) -> np.ndarray: 18 | """Sort slide features by coordinates.""" 19 | slide_coords = pd.DataFrame(slide_features[:, 1:3], columns=["x", "y"]) 20 | slide_coords.sort_values(["x", "y"], inplace=True) 21 | new_index = slide_coords.index.values 22 | return slide_features[new_index] 23 | 24 | 25 | def save_features( 26 | slide_features: list[np.ndarray], 27 | slide_id: str, 28 | export_path: Path, 29 | ) -> None: 30 | """Save features to disk. 31 | 32 | Parameters 33 | ---------- 34 | slide_features: list[np.ndarray] 35 | Current slide features. 36 | slide_id: str 37 | Current slide id. 38 | export_path: Path 39 | Export path for slide features. 40 | """ 41 | _output_slide_features = np.concatenate(slide_features, axis=0).astype(np.float32) 42 | output_slide_features = sort_coords(_output_slide_features) 43 | slide_num_tiles = output_slide_features.shape[0] 44 | assert slide_num_tiles == NUM_TILES_PER_SLIDE, ( 45 | f"Output features for slide {slide_id} contains {slide_num_tiles} < {NUM_TILES_PER_SLIDE}." 46 | ) 47 | np.save(export_path, output_slide_features) 48 | 49 | 50 | def process_imgs( 51 | imgs: torch.Tensor, tile_ids: list[str], model: Extractor 52 | ) -> np.ndarray: 53 | """Perform inference on input (already transformed) images. 54 | 55 | Parameters 56 | ---------- 57 | imgs: torch.Tensor 58 | Transformed images (e.g. normalized, cropped, etc.). 59 | tile_ids: list[str]: 60 | List of tile ids. 61 | model: Extractor 62 | Feature extractor. 63 | """ 64 | with torch.inference_mode(): 65 | batch_features = model(imgs).squeeze() # (N_tiles, d) numpy array 66 | batch_tiles_coordinates = np.array( 67 | [tile_id.split("_")[1:] for tile_id in tile_ids] 68 | ).astype(int) # (N_tiles, 3) numpy array 69 | batch_stack = np.concatenate([batch_tiles_coordinates, batch_features], axis=1) 70 | return batch_stack 71 | -------------------------------------------------------------------------------- /plismbench/models/meta.py: -------------------------------------------------------------------------------- 1 | """Models from Meta company.""" 2 | 3 | from __future__ import annotations 4 | 5 | import numpy as np 6 | import torch 7 | from torchvision import transforms 8 | 9 | from plismbench.models.extractor import Extractor 10 | from plismbench.models.utils import DEFAULT_DEVICE, prepare_module 11 | 12 | 13 | class Dinov2ViTGiant(Extractor): 14 | """ViT-giant model trained with DINOv2 with 4 registers on ImageNet (1). 15 | 16 | .. note:: 17 | (1) https://github.com/facebookresearch/dinov2?tab=readme-ov-file 18 | 19 | Parameters 20 | ---------- 21 | device: int | list[int] | None = DEFAULT_DEVICE, 22 | Compute resources to use. 23 | If None, will use all available GPUs. 24 | If -1, extraction will run on CPU. 25 | mixed_precision: bool = True 26 | Whether to use mixed_precision. 27 | 28 | """ 29 | 30 | def __init__( 31 | self, 32 | device: int | list[int] | None = DEFAULT_DEVICE, 33 | mixed_precision: bool = False, 34 | ): 35 | super().__init__() 36 | self.output_dim = 1536 37 | self.mixed_precision = mixed_precision 38 | 39 | feature_extractor = torch.hub.load( 40 | "facebookresearch/dinov2", "dinov2_vitg14_reg", verbose=True 41 | ) 42 | 43 | self.feature_extractor, self.device = prepare_module( 44 | feature_extractor, 45 | device, 46 | self.mixed_precision, 47 | ) 48 | if self.device is None: 49 | self.feature_extractor = self.feature_extractor.module 50 | 51 | @property # type: ignore 52 | def transform(self) -> transforms.Compose: 53 | """Transform method to apply element wise.""" 54 | return transforms.Compose( 55 | [ 56 | transforms.ToTensor(), # swap axes and normalize 57 | transforms.Normalize( 58 | mean=(0.485, 0.456, 0.406), 59 | std=(0.229, 0.224, 0.225), 60 | ), 61 | ] 62 | ) 63 | 64 | def __call__(self, images: torch.Tensor) -> np.ndarray: 65 | """Compute and return features. 66 | 67 | Parameters 68 | ---------- 69 | images: torch.Tensor 70 | Input of size (n_tiles, n_channels, dim_x, dim_y). 71 | 72 | Returns 73 | ------- 74 | torch.Tensor: Tensor of size (n_tiles, features_dim). 75 | """ 76 | features = self.feature_extractor(images.to(self.device)) 77 | return features.cpu().numpy() 78 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # IDE settings 2 | .vscode/ 3 | .idea/ 4 | 5 | # Byte-compiled / optimized / DLL files 6 | __pycache__/ 7 | *.py[cod] 8 | *$py.class 9 | 10 | # C extensions 11 | *.so 12 | 13 | # Distribution / packaging 14 | .Python 15 | build/ 16 | develop-eggs/ 17 | dist/ 18 | downloads/ 19 | eggs/ 20 | .eggs/ 21 | lib/ 22 | lib64/ 23 | parts/ 24 | sdist/ 25 | var/ 26 | wheels/ 27 | pip-wheel-metadata/ 28 | share/python-wheels/ 29 | *.egg-info/ 30 | .installed.cfg 31 | *.egg 32 | MANIFEST 33 | 34 | # PyInstaller 35 | # Usually these files are written by a python script from a template 36 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 37 | *.manifest 38 | *.spec 39 | 40 | # Installer logs 41 | pip-log.txt 42 | pip-delete-this-directory.txt 43 | 44 | # Unit test / coverage reports 45 | htmlcov/ 46 | .tox/ 47 | .nox/ 48 | .coverage 49 | .coverage.* 50 | .cache 51 | nosetests.xml 52 | coverage.xml 53 | *.cover 54 | *.py,cover 55 | .hypothesis/ 56 | .pytest_cache/ 57 | 58 | # Translations 59 | *.mo 60 | *.pot 61 | 62 | # Django stuff: 63 | *.log 64 | local_settings.py 65 | db.sqlite3 66 | db.sqlite3-journal 67 | 68 | # Flask stuff: 69 | instance/ 70 | .webassets-cache 71 | 72 | # Scrapy stuff: 73 | .scrapy 74 | 75 | # Sphinx documentation 76 | docs/_build/ 77 | docs/api/ 78 | 79 | # PyBuilder 80 | target/ 81 | 82 | # Jupyter Notebook 83 | .ipynb_checkpoints 84 | 85 | # IPython 86 | profile_default/ 87 | ipython_config.py 88 | 89 | # pyenv 90 | .python-version 91 | 92 | # pipenv 93 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 94 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 95 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 96 | # install all needed dependencies. 97 | #Pipfile.lock 98 | 99 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 100 | __pypackages__/ 101 | 102 | # Celery stuff 103 | celerybeat-schedule 104 | celerybeat.pid 105 | 106 | # SageMath parsed files 107 | *.sage.py 108 | 109 | # Environments 110 | .env 111 | .venv 112 | env/ 113 | venv/ 114 | ENV/ 115 | env.bak/ 116 | venv.bak/ 117 | 118 | # Spyder project settings 119 | .spyderproject 120 | .spyproject 121 | 122 | # Rope project settings 123 | .ropeproject 124 | 125 | # mkdocs documentation 126 | /site 127 | 128 | # mypy 129 | .mypy_cache/ 130 | .dmypy.json 131 | dmypy.json 132 | 133 | # Pyre type checker 134 | .pyre/ 135 | 136 | # OS 137 | .DS_Store 138 | -------------------------------------------------------------------------------- /plismbench/utils/aggregate.py: -------------------------------------------------------------------------------- 1 | """Get aggregated metrics.""" 2 | 3 | import numpy as np 4 | import pandas as pd 5 | 6 | 7 | def iqr(x: pd.Series) -> float: 8 | """Get interquartile range.""" 9 | return np.quantile(x, 0.75) - np.quantile(x, 0.25) 10 | 11 | 12 | def aggregate_metrics(dataframe: pd.DataFrame) -> pd.DataFrame: 13 | """Aggregate metrics accross all possible pairs.""" 14 | agg_metrics = ( 15 | dataframe.apply(lambda x: (np.mean(x), np.std(x), np.median(x), iqr(x))) 16 | .set_index(np.array(["mean", "std", "median", "iqr"])) 17 | .round(3) 18 | ) 19 | return agg_metrics 20 | 21 | 22 | def pad(x: pd.Series | pd.DataFrame) -> pd.Series: 23 | """Pad values to third digit.""" 24 | return x.astype(str).str.pad(5, side="right", fillchar="0") 25 | 26 | 27 | def format_results(results: pd.DataFrame) -> dict[str, str]: 28 | """Format metrics.""" 29 | _mean_std = pad(results.loc["mean", :]) + " (" + pad(results.loc["std", :]) + ")" 30 | _median_iqr = ( 31 | pad(results.loc["median", :]) + " (" + pad(results.loc["iqr", :]) + ")" 32 | ) 33 | 34 | mean_std = _mean_std.to_dict() 35 | median_iqr = _median_iqr.to_dict() 36 | output = {} 37 | for key in mean_std.keys(): 38 | output[key] = mean_std[key] + " ; " + median_iqr[key] 39 | return output 40 | 41 | 42 | def get_results(metrics: pd.DataFrame, top_k: list[int]) -> pd.DataFrame: 43 | """Get aggregated robustness results.""" 44 | metric_names = ["cosine_similarity"] + [f"top_{k}_accuracy" for k in top_k] 45 | all_results = aggregate_metrics(metrics[metric_names]) 46 | inter_scanner_results = aggregate_metrics( 47 | metrics.loc[metrics["staining_a"] == metrics["staining_b"], metric_names] 48 | ) 49 | inter_staining_results = aggregate_metrics( 50 | metrics.loc[metrics["scanner_a"] == metrics["scanner_b"], metric_names] 51 | ) 52 | inter_scanner_inter_staining_results = aggregate_metrics( 53 | metrics.loc[ 54 | (metrics["scanner_a"] != metrics["scanner_b"]) 55 | & (metrics["staining_a"] != metrics["staining_b"]), 56 | metric_names, 57 | ] 58 | ) 59 | output_dict = {} 60 | for robustness_type, results in zip( 61 | ["inter-scanner", "inter-staining", "inter-scanner, inter-staining", "all"], 62 | [ 63 | inter_scanner_results, 64 | inter_staining_results, 65 | inter_scanner_inter_staining_results, 66 | all_results, 67 | ], 68 | ): 69 | _output = format_results(results) 70 | output_dict[robustness_type] = _output 71 | return pd.DataFrame(output_dict).T 72 | -------------------------------------------------------------------------------- /plismbench/models/microsoft.py: -------------------------------------------------------------------------------- 1 | """Models from Microsoft company.""" 2 | 3 | from __future__ import annotations 4 | 5 | import numpy as np 6 | import timm 7 | import torch 8 | from torchvision import transforms 9 | 10 | from plismbench.models.extractor import Extractor 11 | from plismbench.models.utils import DEFAULT_DEVICE, prepare_module 12 | 13 | 14 | class ProvGigaPath(Extractor): 15 | """ProvGigaPath model developped by Microsoft available on Hugging-Face (1). 16 | 17 | .. note:: 18 | (1) https://huggingface.co/prov-gigapath/prov-gigapath 19 | 20 | Parameters 21 | ---------- 22 | device: int | list[int] | None = DEFAULT_DEVICE, 23 | Compute resources to use. 24 | If None, will use all available GPUs. 25 | If -1, extraction will run on CPU. 26 | mixed_precision: bool = True 27 | Whether to use mixed_precision. 28 | 29 | """ 30 | 31 | def __init__( 32 | self, 33 | device: int | list[int] | None = DEFAULT_DEVICE, 34 | mixed_precision: bool = False, 35 | ): 36 | super().__init__() 37 | self.output_dim = 1536 38 | self.mixed_precision = mixed_precision 39 | 40 | feature_extractor = timm.create_model( 41 | "hf_hub:prov-gigapath/prov-gigapath", pretrained=True 42 | ) 43 | 44 | self.feature_extractor, self.device = prepare_module( 45 | feature_extractor, 46 | device, 47 | self.mixed_precision, 48 | ) 49 | if self.device is None: 50 | self.feature_extractor = self.feature_extractor.module 51 | 52 | @property # type: ignore 53 | def transform(self) -> transforms.Compose: 54 | """Transform method to apply element wise.""" 55 | return transforms.Compose( 56 | [ 57 | transforms.Resize( 58 | 256, interpolation=transforms.InterpolationMode.BICUBIC 59 | ), 60 | transforms.CenterCrop(224), 61 | transforms.ToTensor(), 62 | transforms.Normalize( 63 | mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225) 64 | ), 65 | ] 66 | ) 67 | 68 | def __call__(self, images: torch.Tensor) -> np.ndarray: 69 | """Compute and return features. 70 | 71 | Parameters 72 | ---------- 73 | images: torch.Tensor 74 | Input of size (n_tiles, n_channels, dim_x, dim_y). 75 | 76 | Returns 77 | ------- 78 | torch.Tensor: Tensor of size (n_tiles, features_dim). 79 | """ 80 | features = self.feature_extractor(images.to(self.device)) 81 | return features.cpu().numpy() 82 | -------------------------------------------------------------------------------- /plismbench/models/standford.py: -------------------------------------------------------------------------------- 1 | """Models from Stanford University School of Medicine.""" 2 | 3 | from __future__ import annotations 4 | 5 | import numpy as np 6 | import torch 7 | from torchvision import transforms 8 | from transformers import AutoModelForZeroShotImageClassification, AutoProcessor 9 | 10 | from plismbench.models.extractor import Extractor 11 | from plismbench.models.utils import DEFAULT_DEVICE, prepare_module 12 | 13 | 14 | class PLIP(Extractor): 15 | """Plip model developped by Stanford University School of Medicine, Stanford, CA (1). 16 | 17 | .. note:: 18 | (1) https://huggingface.co/vinid/plip 19 | 20 | Parameters 21 | ---------- 22 | device: int | list[int] | None = DEFAULT_DEVICE, 23 | Compute resources to use. 24 | If None, will use all available GPUs. 25 | If -1, extraction will run on CPU. 26 | mixed_precision: bool = True 27 | Whether to use mixed_precision. 28 | 29 | """ 30 | 31 | def __init__( 32 | self, 33 | device: int | list[int] | None = DEFAULT_DEVICE, 34 | mixed_precision: bool = False, 35 | ): 36 | super().__init__() 37 | self.output_dim = 512 38 | self.mixed_precision = mixed_precision 39 | 40 | self.processor = AutoProcessor.from_pretrained("vinid/plip") 41 | feature_extractor = AutoModelForZeroShotImageClassification.from_pretrained( 42 | "vinid/plip" 43 | ) 44 | 45 | self.feature_extractor, self.device = prepare_module( 46 | feature_extractor, 47 | device, 48 | self.mixed_precision, 49 | ) 50 | if self.device is None: 51 | self.feature_extractor = self.feature_extractor.module 52 | 53 | def process(self, image) -> torch.Tensor: 54 | """Process input images.""" 55 | plip_input = self.processor(images=image, return_tensors="pt") 56 | return plip_input["pixel_values"][0] 57 | 58 | @property # type: ignore 59 | def transform(self) -> transforms.Lambda: 60 | """Transform method to apply element wise.""" 61 | return transforms.Lambda(self.process) 62 | 63 | def __call__(self, images: torch.Tensor) -> np.ndarray: 64 | """Compute and return features. 65 | 66 | Parameters 67 | ---------- 68 | images: torch.Tensor 69 | Input of size (n_tiles, n_channels, dim_x, dim_y). 70 | 71 | Returns 72 | ------- 73 | torch.Tensor: Tensor of size (n_tiles, features_dim). 74 | """ 75 | features = self.feature_extractor.module.get_image_features( # type: ignore 76 | images.to(self.device) 77 | ) 78 | return features.cpu().numpy() 79 | -------------------------------------------------------------------------------- /plismbench/utils/evaluate.py: -------------------------------------------------------------------------------- 1 | """Utility functions for metrics evaluation.""" 2 | 3 | import itertools 4 | from functools import lru_cache 5 | from pathlib import Path 6 | 7 | import numpy as np 8 | import pandas as pd 9 | 10 | 11 | NUM_TILES_PER_SLIDE: int = 16_278 12 | NUM_SLIDES: int = 91 13 | 14 | 15 | def get_tiles_subset_idx(n_tiles: int) -> np.ndarray: 16 | """Get tiles subset from the original 16_278.""" 17 | if n_tiles == NUM_TILES_PER_SLIDE: 18 | tiles_subset_idx = np.arange(0, NUM_TILES_PER_SLIDE) 19 | else: 20 | tiles_subset_idx = np.load( 21 | Path(__file__).parents[2] / "assets" / f"tiles_subset_{n_tiles}.npy" 22 | ) 23 | assert len(set(tiles_subset_idx)) == n_tiles 24 | return tiles_subset_idx 25 | 26 | 27 | @lru_cache() 28 | def load_features(fpath: Path) -> np.ndarray: 29 | """Load features from path using caching and convert to float32.""" 30 | feats = np.load(fpath) 31 | return feats.astype(np.float32) # will be converted to float16 later on ! 32 | 33 | 34 | def prepare_features_dataframe(features_dir: Path, extractor: str) -> pd.DataFrame: 35 | """Prepare unique WSI features dataframe with features paths and metadata.""" 36 | # Get {slide_id: features paths} dictionary 37 | features_paths = { 38 | fp: fp.parent.name 39 | for fp in (features_dir / extractor).glob("*/features.npy") 40 | if "_to_GMH_S60" in str(fp) 41 | } 42 | # Prepare list of slide names, staining, and scanner directly 43 | slide_data = [] 44 | for features_path, slide_name in features_paths.items(): 45 | staining, scanner = slide_name.split("_")[:2] 46 | slide_data.append([slide_name, features_path, staining, scanner]) 47 | 48 | # Build output dataset 49 | slide_features = pd.DataFrame( 50 | slide_data, columns=["slide", "features_path", "staining", "scanner"] 51 | ) 52 | return slide_features 53 | 54 | 55 | def prepare_pairs_dataframe(features_dir: Path, extractor: str) -> pd.DataFrame: 56 | """Prepare all pairs dataframe with features paths and metadata.""" 57 | slide_features = prepare_features_dataframe( 58 | features_dir=features_dir, extractor=extractor 59 | ) 60 | assert slide_features.shape == ( 61 | NUM_SLIDES, 62 | 4, 63 | ), "Slide features dataframe should be of shape (91, 4)." 64 | 65 | pairs = slide_features.merge(slide_features, how="cross", suffixes=("_a", "_b")) 66 | pairs.set_index(pairs["slide_a"] + "---" + pairs["slide_b"], inplace=True) 67 | unique_pairs = [ 68 | "---".join([a, b]) 69 | for (a, b) in set(itertools.combinations(slide_features["slide"], 2)) 70 | ] 71 | pairs = ( 72 | pairs.loc[unique_pairs] # type: ignore 73 | .sort_values(["features_path_a", "features_path_b"]) 74 | .reset_index(drop=True) 75 | ) 76 | 77 | assert pairs.shape[0] == int(NUM_SLIDES * (NUM_SLIDES - 1) / 2), ( 78 | "There should be 4,095 unique pairs of slides." 79 | ) 80 | return pairs 81 | -------------------------------------------------------------------------------- /plismbench/models/extractor.py: -------------------------------------------------------------------------------- 1 | """Core abstract method for feature extractors.""" 2 | 3 | from abc import ABC, abstractmethod 4 | from typing import Callable 5 | 6 | import numpy as np 7 | import torch 8 | 9 | 10 | class Extractor(ABC): 11 | """A base class for :mod:`plismbench` extractors.""" 12 | 13 | _feature_extractor: torch.nn.Module 14 | device: str | torch.device 15 | 16 | def __init__(self, *args, **kwargs): 17 | super().__init__(*args, **kwargs) 18 | 19 | self._transform = lambda x: x 20 | 21 | @property 22 | def feature_extractor(self) -> torch.nn.Module: 23 | """ 24 | Feature extractor module. 25 | 26 | Returns 27 | ------- 28 | feature_extractor: torch.nn.Module 29 | """ 30 | return self._feature_extractor 31 | 32 | @feature_extractor.setter 33 | def feature_extractor(self, feature_extractor_module: torch.nn.Module): 34 | """Set a new feature extractor module. 35 | 36 | Parameters 37 | ---------- 38 | feature_extractor_module: feature_extractor_module 39 | """ 40 | self._feature_extractor = feature_extractor_module 41 | 42 | @property 43 | def transform(self) -> Callable[[np.ndarray], torch.Tensor]: 44 | """ 45 | Transform method to apply element wise. Inputs should be np.ndarray. 46 | 47 | This function is applied on ``np.ndarray`` and not ``PIL.Image.Image`` 48 | as HuggingFace data is stored as numpy arrays for pickle checking purposes. 49 | If your model needs image resizing, then you will need to add a first 50 | ``transforms.ToPILImage()`` operation, then resizing and finally 51 | ``transforms.ToTensor()``. 52 | If your model is best working on images of shape 224x224, then no need 53 | for rescaling as PLISM tiles have 224x224 shapes. 54 | 55 | Default is identity. 56 | 57 | Returns 58 | ------- 59 | transform: Callable[[np.ndarray], torch.Tensor] 60 | """ 61 | return self._transform 62 | 63 | @transform.setter 64 | def transform(self, transform_function: Callable[[np.ndarray], torch.Tensor]): 65 | """Set a new transform function to the extractor. 66 | 67 | Parameters 68 | ---------- 69 | transform_function: Callable[[np.ndarray], Transformed] 70 | The transform function to be set for the extractor. 71 | """ 72 | self._transform = transform_function 73 | 74 | @abstractmethod 75 | def __call__(self, images: torch.Tensor) -> np.ndarray: 76 | """ 77 | Compute and return the MAP features. 78 | 79 | Parameters 80 | ---------- 81 | images: torch.Tensor 82 | Input of size (N_TILES, 3, DIM_X, DIM_Y). N_TILES=1 for an image, 83 | usually DIM_X = DIM_Y = 224. 84 | 85 | Returns 86 | ------- 87 | features : numpy.ndarray 88 | arrays of size (N_TILES, N_FEATURES) for an image 89 | """ 90 | raise NotImplementedError 91 | -------------------------------------------------------------------------------- /tests/models/test_extractors.py: -------------------------------------------------------------------------------- 1 | """Tests feature extractors available in `plismbench`.""" 2 | 3 | from __future__ import annotations 4 | 5 | import numpy as np 6 | import pytest 7 | import torch 8 | from PIL import Image 9 | 10 | from plismbench.models import FeatureExtractorsEnum 11 | 12 | 13 | @pytest.mark.parametrize( 14 | "extractor", 15 | FeatureExtractorsEnum.choices(), 16 | ) 17 | def test_extract_cpu( 18 | extractor: str, 19 | ) -> None: 20 | """Test feature extraction on CPU for all available models.""" 21 | model = FeatureExtractorsEnum[extractor.upper()].init(device=-1) 22 | # Set a random image and apply transform 23 | x = np.random.rand(224, 224, 3) * 255 24 | x = Image.fromarray(x.astype("uint8")).convert("RGB") 25 | transformed_x = model.transform(x) 26 | assert isinstance(transformed_x, torch.Tensor) 27 | assert transformed_x.shape == (3, 224, 224) 28 | # Get features 29 | features = model(transformed_x.unsqueeze(0)) 30 | expected_output_dim = model.output_dim 31 | assert isinstance(features, np.ndarray) 32 | assert features.shape == (1, expected_output_dim) 33 | 34 | 35 | @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available.") 36 | @pytest.mark.parametrize( 37 | "extractor", 38 | FeatureExtractorsEnum.choices(), 39 | ) 40 | def test_extract_gpu_w_mixed_precision( 41 | extractor: str, 42 | ) -> None: 43 | """Test feature extraction on GPU for all available models.""" 44 | model = FeatureExtractorsEnum[extractor.upper()].init(device=0) 45 | # Set a random image and apply transform 46 | x = np.random.rand(224, 224, 3) * 255 47 | x = Image.fromarray(x.astype("uint8")).convert("RGB") 48 | transformed_x = model.transform(x) 49 | assert isinstance(transformed_x, torch.Tensor) 50 | assert transformed_x.shape == (3, 224, 224) 51 | # Get features 52 | features = model(transformed_x.unsqueeze(0)) 53 | expected_output_dim = model.output_dim 54 | assert isinstance(features, np.ndarray) 55 | assert features.shape == (1, expected_output_dim) 56 | 57 | 58 | @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available.") 59 | @pytest.mark.parametrize( 60 | "extractor", 61 | FeatureExtractorsEnum.choices(), 62 | ) 63 | def test_extract_gpu_wo_mixed_precision( 64 | extractor: str, 65 | ) -> None: 66 | """Test feature extraction on GPU for all available models.""" 67 | model = FeatureExtractorsEnum[extractor.upper()].init( 68 | device=0, mixed_precision=False 69 | ) 70 | # Set a random image and apply transform 71 | x = np.random.rand(224, 224, 3) * 255 72 | x = Image.fromarray(x.astype("uint8")).convert("RGB") 73 | transformed_x = model.transform(x) 74 | assert isinstance(transformed_x, torch.Tensor) 75 | assert transformed_x.shape == (3, 224, 224) 76 | # Get features 77 | features = model(transformed_x.unsqueeze(0)) 78 | expected_output_dim = model.output_dim 79 | assert isinstance(features, np.ndarray) 80 | assert features.shape == (1, expected_output_dim) 81 | -------------------------------------------------------------------------------- /plismbench/models/lunit.py: -------------------------------------------------------------------------------- 1 | """Models from Lunit company.""" 2 | 3 | from __future__ import annotations 4 | 5 | import numpy as np 6 | import torch 7 | from timm.models.vision_transformer import VisionTransformer 8 | from torchvision import transforms 9 | 10 | from plismbench.models.extractor import Extractor 11 | from plismbench.models.utils import DEFAULT_DEVICE, prepare_module 12 | from plismbench.utils.core import download_state_dict 13 | 14 | 15 | class LunitViTS8(Extractor): 16 | """ViT-S/8 from Lunit available at (1). 17 | 18 | .. note:: 19 | (1) https://github.com/lunit-io/benchmark-ssl-pathology/releases/tag/pretrained-weights 20 | 21 | Parameters 22 | ---------- 23 | device: int | list[int] | None = DEFAULT_DEVICE, 24 | Compute resources to use. 25 | If None, will use all available GPUs. 26 | If -1, extraction will run on CPU. 27 | mixed_precision: bool = True 28 | Whether to use mixed_precision. 29 | 30 | """ 31 | 32 | def __init__( 33 | self, 34 | device: int | list[int] | None = DEFAULT_DEVICE, 35 | mixed_precision: bool = False, 36 | ): 37 | super().__init__() 38 | self.output_dim = 384 39 | self.mixed_precision = mixed_precision 40 | 41 | feature_extractor = VisionTransformer( 42 | img_size=224, 43 | patch_size=8, 44 | embed_dim=384, 45 | num_heads=6, 46 | num_classes=0, 47 | ) 48 | state_dict_path = download_state_dict( 49 | url="https://github.com/lunit-io/benchmark-ssl-pathology/releases/download/pretrained-weights/dino_vit_small_patch8_ep200.torch", 50 | name="lunit_vit_s8.pth", 51 | ) 52 | state_dict = torch.load(state_dict_path, map_location="cpu") 53 | feature_extractor.load_state_dict(state_dict, strict=False) 54 | 55 | self.feature_extractor, self.device = prepare_module( 56 | feature_extractor, 57 | device, 58 | self.mixed_precision, 59 | ) 60 | if self.device is None: 61 | self.feature_extractor = self.feature_extractor.module 62 | 63 | @property # type: ignore 64 | def transform(self) -> transforms.Compose: 65 | """Transform method to apply element wise.""" 66 | return transforms.Compose( 67 | [ 68 | transforms.ToTensor(), 69 | transforms.Normalize( 70 | mean=(0.70322989, 0.53606487, 0.66096631), 71 | std=(0.21716536, 0.26081574, 0.20723464), 72 | ), 73 | ] 74 | ) 75 | 76 | def __call__(self, images: torch.Tensor) -> np.ndarray: 77 | """Compute and return features. 78 | 79 | Parameters 80 | ---------- 81 | images: torch.Tensor 82 | Input of size (n_tiles, n_channels, dim_x, dim_y). 83 | 84 | Returns 85 | ------- 86 | torch.Tensor: Tensor of size (n_tiles, features_dim). 87 | """ 88 | features = self.feature_extractor(images.to(self.device)) 89 | return features.cpu().numpy() 90 | -------------------------------------------------------------------------------- /plismbench/models/hkust.py: -------------------------------------------------------------------------------- 1 | """Models from Hong Kong University of Science and Technology.""" 2 | 3 | from __future__ import annotations 4 | 5 | import re 6 | 7 | import numpy as np 8 | import timm 9 | import torch 10 | from torchvision import transforms 11 | 12 | from plismbench.models.extractor import Extractor 13 | from plismbench.models.utils import DEFAULT_DEVICE, prepare_module 14 | from plismbench.utils.core import download_state_dict 15 | 16 | 17 | def _convert_state_dict(state_dict: dict) -> dict: 18 | """Rename state dict keys to match timm's format.""" 19 | state_dict = { 20 | re.sub(r"blocks\.\d+\.(\d+)", r"blocks.\1", key.replace("backbone.", "")): value 21 | for key, value in state_dict.items() 22 | } 23 | remove_keys = ["mask_token"] + [ 24 | key for key in state_dict.keys() if "dino_head" in key 25 | ] 26 | for key in remove_keys: 27 | state_dict.pop(key) 28 | return state_dict 29 | 30 | 31 | class GPFM(Extractor): 32 | """GPFM model developped by HKUST (1). 33 | 34 | .. note:: 35 | (1) Ma, J., Guo, Z., Zhou, F., Wang, Y., Xu, Y., et al. (2024). 36 | Towards a generalizable pathology foundation model via unified knowledge 37 | distillation (arXiv No. 2407.18449). arXiv. https://arxiv.org/abs/2407.18449 38 | 39 | Parameters 40 | ---------- 41 | device: int | list[int] | None = DEFAULT_DEVICE, 42 | Compute resources to use. 43 | If None, will use all available GPUs. 44 | If -1, extraction will run on CPU. 45 | mixed_precision: bool = True 46 | Whether to use mixed_precision. 47 | 48 | """ 49 | 50 | def __init__( 51 | self, 52 | device: int | list[int] | None = DEFAULT_DEVICE, 53 | mixed_precision: bool = False, 54 | ): 55 | super().__init__() 56 | self.output_dim = 1024 57 | self.mixed_precision = mixed_precision 58 | 59 | _state_dict_path = download_state_dict( 60 | url="https://github.com/birkhoffkiki/GPFM/releases/download/ckpt/GPFM.pth", 61 | name="GPFM.pth", 62 | ) 63 | _state_dict = torch.load(_state_dict_path, map_location="cpu") 64 | state_dict = _convert_state_dict(_state_dict["teacher"]) 65 | 66 | feature_extractor = timm.create_model( 67 | model_name="vit_large_patch14_dinov2", 68 | pretrained=True, 69 | pretrained_cfg={ 70 | "state_dict": state_dict, 71 | "num_classes": 0, 72 | }, 73 | img_size=224, 74 | patch_size=14, 75 | init_values=1e-5, 76 | qkv_bias=True, 77 | dynamic_img_size=True, 78 | ) 79 | 80 | self.feature_extractor, self.device = prepare_module( 81 | feature_extractor, 82 | device, 83 | self.mixed_precision, 84 | ) 85 | if self.device is None: 86 | self.feature_extractor = self.feature_extractor.module 87 | 88 | @property # type: ignore 89 | def transform(self) -> transforms.Compose: 90 | """Transform method to apply element wise.""" 91 | return transforms.Compose( 92 | [ 93 | transforms.ToTensor(), 94 | transforms.Normalize( 95 | mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225) 96 | ), 97 | ] 98 | ) 99 | 100 | def __call__(self, images: torch.Tensor) -> np.ndarray: 101 | """Compute and return features. 102 | 103 | Parameters 104 | ---------- 105 | images: torch.Tensor 106 | Input of size (n_tiles, n_channels, dim_x, dim_y). 107 | 108 | Returns 109 | ------- 110 | torch.Tensor: Tensor of size (n_tiles, features_dim). 111 | """ 112 | features = self.feature_extractor(images.to(self.device)) 113 | return features.cpu().numpy() 114 | -------------------------------------------------------------------------------- /plismbench/metrics/retrieval.py: -------------------------------------------------------------------------------- 1 | """Module for retrieval metrics.""" 2 | 3 | import numpy as np 4 | 5 | from plismbench.metrics.base import BasePlismMetric 6 | 7 | 8 | class TopkAccuracy(BasePlismMetric): 9 | """Top-k accuracy.""" 10 | 11 | def __init__( 12 | self, 13 | device: str, 14 | use_mixed_precision: bool = True, 15 | k: list[int] | None = None, 16 | ): 17 | super().__init__(device, use_mixed_precision) 18 | self.k = [1, 3, 5, 10] if k is None else k 19 | 20 | def compute_metric(self, matrix_a, matrix_b): 21 | """Compute top-k accuracy metric.""" 22 | if matrix_a.shape[0] != matrix_b.shape[0]: 23 | raise ValueError( 24 | f"Number of tiles must match. Got {matrix_a.shape[0]} and {matrix_b.shape[0]}." 25 | ) 26 | 27 | matrix_ab = np.concatenate([matrix_a, matrix_b], axis=0) 28 | 29 | n_tiles = matrix_ab.shape[0] // 2 30 | 31 | if self.use_mixed_precision: 32 | matrix_ab = matrix_ab.astype(np.float16) 33 | 34 | matrix_ab = self.ncp.asarray(matrix_ab) # put concatenated matrix on the gpu 35 | # ``dot_product_ab`` is a block matrix of shape (2*n_tiles, 2*n_tiles) 36 | # [ 37 | # [, ], 38 | # [, ] 39 | # ] 40 | dot_product_ab = self.ncp.matmul( 41 | matrix_ab, matrix_ab.T 42 | ) # shape (2*n_tiles, 2*n_tiles) 43 | norm_ab = self.ncp.linalg.norm( 44 | matrix_ab, axis=1, keepdims=True 45 | ) # shape (2*n_tiles, ) 46 | cosine_ab = dot_product_ab / ( 47 | norm_ab * norm_ab.T 48 | ) # shape (2*n_tiles, 2*n_tiles) 49 | 50 | # Compute top-k indices for each row of cosine_ab using argpartition. 51 | # We use argpartition to efficiently find the top-k elements (excluding self-matches) 52 | kmax = max(self.k) 53 | # ``top_kmax_indices_ab`` has shape (2*n_tiles, kmax), for instance 54 | # ``top_kmax_indices_ab[i, 0]`` represents the closest tile index ``ci`` accross 55 | # slide a and slide b to the tile at index ``i`` (row index), hence ``ci`` 56 | # is spanning between 0 and 2*n_tiles but excludes the index ``i`` of the tile 57 | # itself 58 | top_kmax_indices_ab = self.ncp.argpartition( 59 | -cosine_ab, range(1, kmax + 1), axis=1 60 | )[:, 1 : kmax + 1] 61 | # Compute top-k accuracies by iterating over k values 62 | top_k_accuracies = [] 63 | for k in self.k: 64 | top_k_indices_ab = top_kmax_indices_ab[:, :k] # shape (2*n_tiles, k) 65 | top_k_indices_a = top_k_indices_ab[:n_tiles] # shape (n_tiles, k) 66 | top_k_indices_b = top_k_indices_ab[n_tiles:] # shape (n_tiles, k) 67 | 68 | top_k_accs = [] 69 | for i, top_k_indices in enumerate([top_k_indices_a, top_k_indices_b]): 70 | # If ``i==0``, we look at the closest tiles of each tile of matrix a that 71 | # are present in matrix b, hence ``(n_tiles, 2 * n_tiles)``. See matrix 72 | # block decomposition above. 73 | other_slide_indices = ( 74 | self.ncp.arange(n_tiles, 2 * n_tiles) 75 | if i == 0 76 | else self.ncp.arange(0, n_tiles) 77 | ) 78 | # We now count the number of times one of the top-k closest tiles to 79 | # tile ``i`` for slide a (resp. b) is the same tile but in slide b (resp. a) 80 | correct_matches = self.ncp.sum( 81 | self.ncp.any(top_k_indices == other_slide_indices[:, None], axis=1) 82 | ) 83 | _top_k_acc = correct_matches / n_tiles 84 | top_k_acc = ( 85 | float(_top_k_acc.get()) 86 | if self.device == "gpu" 87 | else float(_top_k_acc) 88 | ) 89 | top_k_accs.append(top_k_acc) 90 | 91 | # Average over the two directions 92 | top_k_accuracies.append(sum(top_k_accs) / 2) 93 | 94 | return np.array(top_k_accuracies) 95 | -------------------------------------------------------------------------------- /plismbench/models/utils.py: -------------------------------------------------------------------------------- 1 | """Utility functions to load and prepare feature extractors.""" 2 | 3 | from __future__ import annotations 4 | 5 | import torch 6 | from transformers.modeling_outputs import BaseModelOutputWithPooling 7 | 8 | 9 | DEFAULT_DEVICE = ( 10 | 0 if (torch.cuda.is_available() or torch.backends.mps.is_available()) else -1 11 | ) 12 | 13 | 14 | class PrecisionModule(torch.nn.Module): 15 | """Precision Module wrapper. 16 | 17 | Parameters 18 | ---------- 19 | module: torch.nn.Module 20 | device_type: str 21 | """ 22 | 23 | def __init__( 24 | self, module: torch.nn.Module, device_type: str, mixed_precision: bool 25 | ): 26 | super(PrecisionModule, self).__init__() 27 | self.module = module 28 | self.device_type = device_type 29 | self.mixed_precision = mixed_precision 30 | 31 | def forward(self, *args, **kwargs): 32 | """Forward pass w/ or w/o ``autocast``.""" 33 | # Mixed precision forward 34 | if self.mixed_precision: 35 | with torch.amp.autocast(device_type=self.device_type): 36 | output = self.module(*args, **kwargs) 37 | # Full precision forward 38 | else: 39 | output = self.module(*args, **kwargs) 40 | if isinstance(output, BaseModelOutputWithPooling): 41 | if "last_hidden_state" in output.keys(): 42 | output = output.last_hidden_state 43 | else: 44 | raise ValueError( 45 | "Model output has class `BaseModelOutputWithPooling` " 46 | "but no `'last_hidden_state'` attribute." 47 | ) 48 | # Back to float32 49 | return output.to(torch.float32) 50 | 51 | 52 | def prepare_module( 53 | module: torch.nn.Module, 54 | device: int | list[int] | None = None, 55 | mixed_precision: bool = True, 56 | ) -> tuple[torch.nn.Module, str | torch.device]: 57 | """ 58 | Prepare torch.nn.Module. 59 | 60 | By: 61 | - setting it to eval mode 62 | - disabling gradients 63 | - moving it to the correct device(s) 64 | 65 | Parameters 66 | ---------- 67 | module: torch.nn.Module 68 | device: Union[None, int, list[int]] = None 69 | Compute resources to use. 70 | If None, will use all available GPUs. 71 | If -1, extraction will run on CPU. 72 | mixed_precision: bool = True 73 | Whether to use mixed_precision (improved throughput on modern GPU cards). 74 | 75 | Returns 76 | ------- 77 | torch.nn.Module, str | torch.device 78 | """ 79 | if mixed_precision: 80 | if not (torch.cuda.is_available() or device == -1): 81 | raise ValueError("Mixed precision in only available for CUDA GPUs and CPU.") 82 | module = PrecisionModule( 83 | module, 84 | device_type="cpu" if not torch.cuda.is_available() else "cuda", 85 | mixed_precision=mixed_precision, 86 | ) 87 | 88 | device_: str | torch.device 89 | 90 | if device == -1 or not ( 91 | torch.cuda.is_available() or torch.backends.mps.is_available() 92 | ): 93 | device_ = "cpu" 94 | elif torch.backends.mps.is_available(): 95 | device_ = torch.device("mps") 96 | elif isinstance(device, int): 97 | device_ = f"cuda:{device}" 98 | else: 99 | # Use DataParallel to distribute the module on all GPUs 100 | device_ = "cuda:0" if device is None else f"cuda:{device[0]}" 101 | module = torch.nn.DataParallel(module, device_) # type: ignore 102 | 103 | module.to(device_) 104 | module.eval() 105 | module.requires_grad_(False) 106 | 107 | return module, device_ 108 | 109 | 110 | def prepare_device(gpu: None | int | list[int] = None) -> str: 111 | """Prepare device, copied from `tilingtool.utils.parallel::prepare_module`.""" 112 | if gpu == -1 or not ( 113 | torch.cuda.is_available() or torch.backends.mps.is_available() 114 | ): 115 | device = "cpu" 116 | elif torch.backends.mps.is_available(): 117 | device = str(torch.device("mps")) 118 | elif isinstance(gpu, int): 119 | device = f"cuda:{gpu}" 120 | else: 121 | # Use DataParallel to distribute the module on all GPUs 122 | device = "cuda:0" if gpu is None else f"cuda:{gpu[0]}" 123 | return device 124 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | .PHONY: clean clean-test clean-docs clean-pyc clean-build docs help 2 | .DEFAULT_GOAL := help 3 | 4 | define BROWSER_PYSCRIPT 5 | import os, webbrowser, sys 6 | 7 | from urllib.request import pathname2url 8 | 9 | webbrowser.open("file://" + pathname2url(os.path.abspath(sys.argv[1]))) 10 | endef 11 | export BROWSER_PYSCRIPT 12 | 13 | define PRINT_HELP_PYSCRIPT 14 | import re, sys 15 | 16 | for line in sys.stdin: 17 | match = re.match(r'^([a-zA-Z_-]+):.*?## (.*)$$', line) 18 | if match: 19 | target, help = match.groups() 20 | print("%-20s %s" % (target, help)) 21 | endef 22 | export PRINT_HELP_PYSCRIPT 23 | 24 | BROWSER := python -c "$$BROWSER_PYSCRIPT" 25 | 26 | ifeq (, $(shell which snakeviz)) 27 | PROFILE = pytest --profile-svg 28 | PROFILE_RESULT = prof/combined.svg 29 | PROFILE_VIEWER = $(BROWSER) 30 | else 31 | PROFILE = pytest --profile 32 | PROFILE_RESULT = prof/combined.prof 33 | PROFILE_VIEWER = snakeviz 34 | endif 35 | 36 | help: 37 | @python -c "$$PRINT_HELP_PYSCRIPT" < $(MAKEFILE_LIST) 38 | 39 | config: ## Configure poetry 40 | poetry config virtualenvs.in-project true 41 | 42 | lock: ## Generate a new poetry.lock file (To be done after adding new requirements to pyproject.toml) 43 | poetry lock 44 | 45 | install-cupy: ## Install GPU accelerated numpy 46 | conda install -c conda-forge cupy 47 | 48 | install-poetry: ## Install poetry package 49 | pip install poetry==1.7.1 50 | 51 | install: clean ## Install all package and development dependencies for testing to the active Python's site-packages 52 | poetry install --all-extras --with=testing,linting,docs,dev 53 | 54 | install-all: install-poetry install install-cupy ## Install poetry along with all package and development dependencies 55 | 56 | clean: clean-build clean-pyc clean-test clean-docs ## Remove all build, test, coverage and Python artifacts 57 | 58 | clean-build: ## Remove build artifacts 59 | rm -fr build/ 60 | rm -fr dist/ 61 | rm -fr .eggs/ 62 | find . -path ./.venv -prune -false -o -name '*.egg-info' -exec rm -fr {} + 63 | find . -path ./.venv -prune -false -o -name '*.egg' -exec rm -f {} + 64 | 65 | clean-pyc: ## Remove Python file artifacts 66 | find . -path ./.venv -prune -false -o -name '*.pyc' -exec rm -f {} + 67 | find . -path ./.venv -prune -false -o -name '*.pyo' -exec rm -f {} + 68 | find . -path ./.venv -prune -false -o -name '*~' -exec rm -f {} + 69 | find . -path ./.venv -prune -false -o -name '__pycache__' -exec rm -fr {} + 70 | 71 | clean-test: ## Remove test and coverage artifacts 72 | rm -f .coverage 73 | rm -f coverage.xml 74 | rm -fr htmlcov/ 75 | rm -fr .pytest_cache 76 | rm -fr .mypy_cache 77 | rm -fr prof/ 78 | rm -fr .ruff_cache 79 | 80 | clean-docs: ## Remove docs artifacts 81 | rm -fr docs/_build 82 | rm -fr docs/api 83 | 84 | format: ## format code ruff formatter 85 | ruff format plismbench tests 86 | 87 | lint: ## Check style with ruff linter 88 | ruff check --fix plismbench tests 89 | 90 | typing: ## Check static typing with mypy 91 | mypy plismbench 92 | 93 | pre-commit-checks: ## Run pre-commit checks on all files 94 | pre-commit run --hook-stage manual --all-files 95 | 96 | lint-all: pre-commit-checks lint typing ## Run all linting checks. 97 | 98 | test-all: ## Run CI tests quickly with the default Python 99 | pytest 100 | 101 | test: ## Run CI tests quickly with the default Python 102 | pytest -m "not local" 103 | 104 | test-docs: docs-api ## Check docs using doc8 105 | pydocstyle plismbench 106 | doc8 docs 107 | $(MAKE) -C docs doctest 108 | 109 | coverage: ## Check code coverage quickly with the default Python 110 | coverage run --source plismbench -m pytest 111 | coverage report -m 112 | coverage html 113 | $(BROWSER) htmlcov/index.html 114 | 115 | profile: ## Create a profile from test cases 116 | $(PROFILE) $(TARGET) 117 | $(PROFILE_VIEWER) $(PROFILE_RESULT) 118 | 119 | docs-api: ## Generate the API documentation for Sphinx 120 | rm -rf docs/api 121 | sphinx-apidoc -e -M -o docs/api plismbench 122 | 123 | docs: docs-api ## Generate Sphinx HTML documentation, including API docs 124 | $(MAKE) -C docs clean 125 | $(MAKE) -C docs html 126 | $(MAKE) open-docs 127 | 128 | .SILENT: 129 | open-docs: ## Open the generated Sphinx HTML documentation 130 | @if [ "$$port" != "" ]; then\ 131 | python3 -m http.server --directory docs/_build/html/ "$$port";\ 132 | else\ 133 | echo "No port was specified as a make argument. Trying a local run...";\ 134 | $(BROWSER) docs/_build/html/index.html;\ 135 | fi 136 | 137 | servedocs: docs ## Compile the docs watching for changes 138 | watchmedo shell-command -p '*.rst' -c '$(MAKE) -C docs html' -R -D . 139 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["poetry_core>=1.0.0"] 3 | build-backend = "poetry.core.masonry.api" 4 | 5 | [tool.poetry] 6 | name = "owkin-plismbench" 7 | version = "0.0.0" 8 | description = "Repository hosting PLISM robustness benchmark" 9 | authors = [ 10 | "Alexandre Filiot ", 11 | "Antoine Olivier ", 12 | ] 13 | 14 | readme = "README.md" 15 | homepage = "https://github.com/owkin/plism-benchmark" 16 | repository = "https://github.com/owkin/plism-benchmark" 17 | documentation = "https://owkin.github.io/plism-benchmark" 18 | packages = [{ include = "plismbench" }] 19 | 20 | classifiers = [ 21 | "Development Status :: 2 - Pre-Alpha", 22 | "Natural Language :: English", 23 | "Intended Audience :: Science/Research", 24 | "Topic :: Scientific/Engineering", 25 | "Programming Language :: Python :: 3", 26 | "Programming Language :: Python :: 3.9", 27 | "Programming Language :: Python :: 3.10", 28 | "Programming Language :: Python :: 3.11", 29 | ] 30 | 31 | [[tool.poetry.source]] 32 | name = "PyPI" 33 | priority = "primary" 34 | 35 | [tool.poetry.dependencies] 36 | python = ">=3.10,<4.0" 37 | torch = "^2.6.0" 38 | torchvision = "^0.21.0" 39 | transformers = "^4.49.0" 40 | datasets = "^3.3.1" 41 | loguru = "^0.7.3" 42 | timm = "1.0.12" 43 | pre-commit = "^4.1.0" 44 | typer = "^0.15.1" 45 | h5py = "^3.12.1" 46 | p-tqdm = "^1.4.2" 47 | rprint = "^0.0.8" 48 | matplotlib = "^3.10.0" 49 | seaborn = "^0.13.2" 50 | tabulate = "^0.9.0" 51 | conch = {git = "https://github.com/Mahmoodlab/CONCH.git"} 52 | einops = "^0.8.1" 53 | einops-exts = "^0.0.4" 54 | types-requests = "^2.32.0.20250515" 55 | 56 | [tool.poetry.group.dev] 57 | optional = true 58 | [tool.poetry.group.dev.dependencies] 59 | nbconvert = "*" 60 | notebook = "*" 61 | 62 | [tool.poetry.group.docs] 63 | optional = true 64 | [tool.poetry.group.docs.dependencies] 65 | sphinx = "^7.0.0" 66 | sphinx-gallery = "*" 67 | Jinja2 = "*" 68 | doc8 = "*" 69 | recommonmark = "*" 70 | m2r = "*" 71 | mistune = "*" 72 | nbsphinx = "*" 73 | pandoc = "*" 74 | docutils = "*" 75 | pydocstyle = "*" 76 | sphinxcontrib-fulltoc = "*" 77 | sphinxcontrib-mockautodoc = "*" 78 | sphinx-autodoc-typehints = "*" 79 | sphinx-paramlinks = "*" 80 | pydata_sphinx_theme = "*" 81 | sphinxcontrib-mermaid = "*" 82 | watchdog = "^2.1.8" 83 | sphinx-tabs = "^3.4.1" 84 | 85 | [tool.poetry.group.linting] 86 | optional = true 87 | [tool.poetry.group.linting.dependencies] 88 | mypy = ">=1.4" 89 | pre-commit = ">=2.20.0" 90 | ruff = ">=0.1.2" 91 | pandas-stubs = "*" 92 | types-docutils = "*" 93 | types-python-dateutil = "*" 94 | types-setuptools = "*" 95 | types-Jinja2 = "*" 96 | types-MarkupSafe = "*" 97 | types-PyYAML = "*" 98 | typing_extensions = "*" 99 | 100 | [tool.poetry.group.testing] 101 | optional = true 102 | [tool.poetry.group.testing.dependencies] 103 | pytest = "*" 104 | coverage = "*" 105 | pytest-cov = "*" 106 | pytest-sphinx = "*" 107 | pytest-runner = "*" 108 | pytest-profiling = "*" 109 | 110 | [tool.ruff] 111 | exclude = [ 112 | ".git", 113 | ".github", 114 | ".dvc", 115 | "__pycache__", 116 | ".venv", 117 | ".mypy_cache", 118 | ".ruff_cache", 119 | ".pytest_cache", 120 | "conf.py", 121 | ] 122 | lint.ignore = [ 123 | "B008", # do not perform function calls in argument defaults 124 | "C901", # too complex 125 | "D105", # undocumented magic method 126 | "E501", # line too long, handled by black 127 | "E731", # lambda-assignment 128 | "PLR0904", # too many public methods 129 | "PLR0913", # too many arguments 130 | "PLR2004", # magic value comparison 131 | "B019", # use of cache methods 132 | "B009", # getattr with constant value 133 | "N812", # lowercase imported as non constant 134 | "PLW2901", # loop variable overwritten by assignment 135 | "PT011", # broad pytest errors 136 | ] 137 | lint.select = [ 138 | "D", # pydocstyle 139 | "E", # pycodestyle errors 140 | "W", # pycodestyle warnings 141 | "F", # pyflakes 142 | "I", # isort 143 | "N", # pep8-naming conventions 144 | "C", # flake8-comprehensions 145 | "B", # flake8-bugbear 146 | "PL", # pylint 147 | "PT", # flake8-pytest-style 148 | "C90", # mccabe included in flake8 149 | "ASYNC", # flake8-async 150 | ] 151 | line-length = 88 # Must be consistent with black parameter 152 | target-version = "py39" # Must be aligned with the Python lower bound 153 | 154 | [tool.ruff.lint.isort] 155 | known-first-party = ["plismbench"] 156 | lines-after-imports = 2 157 | 158 | [tool.ruff.lint.pydocstyle] 159 | convention = "numpy" 160 | 161 | [tool.ruff.lint.per-file-ignores] 162 | "__init__.py" = ["F401"] 163 | 164 | [tool.poetry.scripts] 165 | plismbench = "plismbench.engine.cli:app" 166 | -------------------------------------------------------------------------------- /plismbench/models/owkin.py: -------------------------------------------------------------------------------- 1 | """Models from Owkin, Inc. company.""" 2 | 3 | from __future__ import annotations 4 | 5 | import numpy as np 6 | import torch 7 | from torchvision import transforms 8 | from transformers import AutoImageProcessor, AutoModel 9 | 10 | from plismbench.models.extractor import Extractor 11 | from plismbench.models.utils import DEFAULT_DEVICE, prepare_module 12 | 13 | 14 | class Phikon(Extractor): 15 | """Phikon model developped by Owkin available on Hugging-Face (1). 16 | 17 | .. note:: 18 | (1) https://huggingface.co/owkin/phikon 19 | 20 | Parameters 21 | ---------- 22 | device: int | list[int] | None = DEFAULT_DEVICE, 23 | Compute resources to use. 24 | If None, will use all available GPUs. 25 | If -1, extraction will run on CPU. 26 | mixed_precision: bool = True 27 | Whether to use mixed_precision. 28 | 29 | """ 30 | 31 | def __init__( 32 | self, 33 | device: int | list[int] | None = DEFAULT_DEVICE, 34 | mixed_precision: bool = False, 35 | ): 36 | super().__init__() 37 | self.output_dim = 768 38 | self.mixed_precision = mixed_precision 39 | 40 | self.processor = AutoImageProcessor.from_pretrained("owkin/phikon") 41 | # feature_extractor = ViTModel.from_pretrained( 42 | # "owkin/phikon", add_pooling_layer=False 43 | # ) 44 | feature_extractor = AutoModel.from_pretrained("owkin/phikon") 45 | 46 | self.feature_extractor, self.device = prepare_module( 47 | feature_extractor, 48 | device, 49 | self.mixed_precision, 50 | ) 51 | if self.device is None: 52 | self.feature_extractor = self.feature_extractor.module 53 | 54 | def process(self, image) -> torch.Tensor: 55 | """Process input images.""" 56 | phikon_input = self.processor(images=image, return_tensors="pt") 57 | return phikon_input["pixel_values"][0] 58 | 59 | @property # type: ignore 60 | def transform(self) -> transforms.Lambda: 61 | """Transform method to apply element wise.""" 62 | return transforms.Lambda(self.process) 63 | 64 | def __call__(self, images: torch.Tensor) -> np.ndarray: 65 | """Compute and return features. 66 | 67 | Parameters 68 | ---------- 69 | images: torch.Tensor 70 | Input of size (n_tiles, n_channels, dim_x, dim_y). 71 | 72 | Returns 73 | ------- 74 | torch.Tensor: Tensor of size (n_tiles, features_dim). 75 | """ 76 | last_hidden_state = self.feature_extractor(images.to(self.device)) 77 | features = last_hidden_state[:, 0] 78 | return features.cpu().numpy() 79 | 80 | 81 | class PhikonV2(Extractor): 82 | """Phikon V2 model developped by Owkin available on Hugging-Face (1). 83 | 84 | You will need to be granted access to be able to use this model. 85 | 86 | .. note:: 87 | (1) https://huggingface.co/owkin/phikon-v2 88 | 89 | Parameters 90 | ---------- 91 | device: int | list[int] | None = DEFAULT_DEVICE, 92 | Compute resources to use. 93 | If None, will use all available GPUs. 94 | If -1, extraction will run on CPU. 95 | mixed_precision: bool = True 96 | Whether to use mixed_precision. 97 | 98 | """ 99 | 100 | def __init__( 101 | self, 102 | device: int | list[int] | None = DEFAULT_DEVICE, 103 | mixed_precision: bool = False, 104 | ): 105 | super().__init__() 106 | self.output_dim = 1024 107 | self.mixed_precision = mixed_precision 108 | 109 | self.processor = AutoImageProcessor.from_pretrained("owkin/phikon-v2") 110 | feature_extractor = AutoModel.from_pretrained("owkin/phikon-v2") 111 | 112 | self.feature_extractor, self.device = prepare_module( 113 | feature_extractor, 114 | device, 115 | self.mixed_precision, 116 | ) 117 | if self.device is None: 118 | self.feature_extractor = self.feature_extractor.module 119 | 120 | def process(self, image) -> torch.Tensor: 121 | """Process input images.""" 122 | phikon_input = self.processor(images=image, return_tensors="pt") 123 | return phikon_input["pixel_values"][0] 124 | 125 | @property # type: ignore 126 | def transform(self) -> transforms.Lambda: 127 | """Transform method to apply element wise.""" 128 | return transforms.Lambda(self.process) 129 | 130 | def __call__(self, images: torch.Tensor) -> np.ndarray: 131 | """Compute and return features. 132 | 133 | Parameters 134 | ---------- 135 | images: torch.Tensor 136 | Input of size (n_tiles, n_channels, dim_x, dim_y). 137 | 138 | Returns 139 | ------- 140 | torch.Tensor: Tensor of size (n_tiles, features_dim). 141 | """ 142 | last_hidden_state = self.feature_extractor(images.to(self.device)) 143 | features = last_hidden_state[:, 0] 144 | return features.cpu().numpy() 145 | -------------------------------------------------------------------------------- /plismbench/models/histai.py: -------------------------------------------------------------------------------- 1 | """Models from HistAI company.""" 2 | 3 | from __future__ import annotations 4 | 5 | import numpy as np 6 | import torch 7 | from torchvision import transforms 8 | from transformers import AutoImageProcessor, AutoModel 9 | 10 | from plismbench.models.extractor import Extractor 11 | from plismbench.models.utils import DEFAULT_DEVICE, prepare_module 12 | 13 | 14 | class HibouBase(Extractor): 15 | """Hibou-Base model developped by HistAI available on Hugging-Face (1). 16 | 17 | .. note:: 18 | (1) https://huggingface.co/histai/hibou-b 19 | 20 | Parameters 21 | ---------- 22 | device: int | list[int] | None = DEFAULT_DEVICE, 23 | Compute resources to use. 24 | If None, will use all available GPUs. 25 | If -1, extraction will run on CPU. 26 | mixed_precision: bool = True 27 | Whether to use mixed_precision. 28 | """ 29 | 30 | def __init__( 31 | self, 32 | device: int | list[int] | None = DEFAULT_DEVICE, 33 | mixed_precision: bool = False, 34 | ): 35 | super().__init__() 36 | self.output_dim = 768 37 | self.mixed_precision = mixed_precision 38 | 39 | self.processor = AutoImageProcessor.from_pretrained( 40 | "histai/hibou-b", trust_remote_code=True 41 | ) 42 | feature_extractor = AutoModel.from_pretrained( 43 | "histai/hibou-b", trust_remote_code=True 44 | ) 45 | 46 | self.feature_extractor, self.device = prepare_module( 47 | feature_extractor, 48 | device, 49 | self.mixed_precision, 50 | ) 51 | if self.device is None: 52 | self.feature_extractor = self.feature_extractor.module 53 | 54 | def process(self, image) -> torch.Tensor: 55 | """Process input images.""" 56 | hibou_input = self.processor(images=image, return_tensors="pt") 57 | return hibou_input["pixel_values"][0] 58 | 59 | @property # type: ignore 60 | def transform(self) -> transforms.Lambda: 61 | """Transform method to apply element wise.""" 62 | return transforms.Lambda(self.process) 63 | 64 | def __call__(self, images: torch.Tensor) -> np.ndarray: 65 | """Compute and return features. 66 | 67 | Parameters 68 | ---------- 69 | images: torch.Tensor 70 | Input of size (n_tiles, n_channels, dim_x, dim_y). 71 | 72 | Returns 73 | ------- 74 | torch.Tensor: Tensor of size (n_tiles, features_dim). 75 | """ 76 | last_hidden_state = self.feature_extractor(images.to(self.device)) 77 | features = last_hidden_state[:, 0] 78 | return features.cpu().numpy() 79 | 80 | 81 | class HibouLarge(Extractor): 82 | """Hibou-Large model developped by HistAI available on Hugging-Face (1). 83 | 84 | .. note:: 85 | (1) https://huggingface.co/histai/hibou-l 86 | 87 | Parameters 88 | ---------- 89 | device: int | list[int] | None = DEFAULT_DEVICE, 90 | Compute resources to use. 91 | If None, will use all available GPUs. 92 | If -1, extraction will run on CPU. 93 | mixed_precision: bool = True 94 | Whether to use mixed_precision. 95 | """ 96 | 97 | def __init__( 98 | self, 99 | device: int | list[int] | None = DEFAULT_DEVICE, 100 | mixed_precision: bool = False, 101 | ): 102 | super().__init__() 103 | self.output_dim = 1024 104 | self.mixed_precision = mixed_precision 105 | 106 | self.processor = AutoImageProcessor.from_pretrained( 107 | "histai/hibou-L", trust_remote_code=True 108 | ) 109 | feature_extractor = AutoModel.from_pretrained( 110 | "histai/hibou-L", trust_remote_code=True 111 | ) 112 | 113 | self.feature_extractor, self.device = prepare_module( 114 | feature_extractor, 115 | device, 116 | self.mixed_precision, 117 | ) 118 | if self.device is None: 119 | self.feature_extractor = self.feature_extractor.module 120 | 121 | def process(self, image) -> torch.Tensor: 122 | """Process input images.""" 123 | hibou_input = self.processor(images=image, return_tensors="pt") 124 | return hibou_input["pixel_values"][0] 125 | 126 | @property # type: ignore 127 | def transform(self) -> transforms.Lambda: 128 | """Transform method to apply element wise.""" 129 | return transforms.Lambda(self.process) 130 | 131 | def __call__(self, images: torch.Tensor) -> np.ndarray: 132 | """Compute and return features. 133 | 134 | Parameters 135 | ---------- 136 | images: torch.Tensor 137 | Input of size (n_tiles, n_channels, dim_x, dim_y). 138 | 139 | Returns 140 | ------- 141 | torch.Tensor: Tensor of size (n_tiles, features_dim). 142 | """ 143 | last_hidden_state = self.feature_extractor(images.to(self.device)) 144 | features = last_hidden_state[:, 0] 145 | return features.cpu().numpy() 146 | -------------------------------------------------------------------------------- /plismbench/models/paige_ai.py: -------------------------------------------------------------------------------- 1 | """Models from Paige AI company.""" 2 | 3 | from __future__ import annotations 4 | 5 | from typing import Any 6 | 7 | import numpy as np 8 | import timm 9 | import torch 10 | from torchvision import transforms 11 | 12 | from plismbench.models.extractor import Extractor 13 | from plismbench.models.utils import DEFAULT_DEVICE, prepare_module 14 | 15 | 16 | class Virchow(Extractor): 17 | """Virchow model developped by Paige AI available on Hugging-Face (1). 18 | 19 | .. note:: 20 | (1) https://huggingface.co/paige-ai/Virchow 21 | 22 | Parameters 23 | ---------- 24 | device: int | list[int] | None = DEFAULT_DEVICE, 25 | Compute resources to use. 26 | If None, will use all available GPUs. 27 | If -1, extraction will run on CPU. 28 | mixed_precision: bool = True 29 | Whether to use mixed_precision. 30 | 31 | """ 32 | 33 | def __init__( 34 | self, 35 | device: int | list[int] | None = DEFAULT_DEVICE, 36 | mixed_precision: bool = False, 37 | ): 38 | super().__init__() 39 | self.output_dim = 2560 40 | self.mixed_precision = mixed_precision 41 | 42 | timm_kwargs: dict[str, Any] = { 43 | "mlp_layer": timm.layers.SwiGLUPacked, 44 | "act_layer": torch.nn.SiLU, 45 | } 46 | feature_extractor = timm.create_model( 47 | "hf-hub:paige-ai/Virchow", pretrained=True, **timm_kwargs 48 | ) 49 | 50 | self.feature_extractor, self.device = prepare_module( 51 | feature_extractor, 52 | device, 53 | self.mixed_precision, 54 | ) 55 | if self.device is None: 56 | self.feature_extractor = self.feature_extractor.module 57 | 58 | @property # type: ignore 59 | def transform(self) -> transforms.Compose: 60 | """Transform method to apply element wise.""" 61 | return transforms.Compose( 62 | [ 63 | transforms.ToTensor(), # swap axes and normalize 64 | transforms.Normalize( 65 | mean=(0.485, 0.456, 0.406), 66 | std=(0.229, 0.224, 0.225), 67 | ), 68 | ] 69 | ) 70 | 71 | def __call__(self, images: torch.Tensor) -> np.ndarray: 72 | """Compute and return features. 73 | 74 | Parameters 75 | ---------- 76 | images: torch.Tensor 77 | Input of size (n_tiles, n_channels, dim_x, dim_y). 78 | 79 | Returns 80 | ------- 81 | torch.Tensor: Tensor of size (n_tiles, features_dim). 82 | """ 83 | last_hidden_state = self.feature_extractor(images.to(self.device)) 84 | class_token = last_hidden_state[:, 0] 85 | patch_tokens = last_hidden_state[:, 1:] 86 | features = torch.cat([class_token, patch_tokens.mean(1)], dim=-1) 87 | return features.cpu().numpy() 88 | 89 | 90 | class Virchow2(Extractor): 91 | """Virchow2 model developped by Paige AI available on Hugging-Face (1). 92 | 93 | You will need to be granted access to be able to use this model. 94 | 95 | .. note:: 96 | (1) https://huggingface.co/paige-ai/Virchow2 97 | 98 | Parameters 99 | ---------- 100 | device: int | list[int] | None = DEFAULT_DEVICE, 101 | Compute resources to use. 102 | If None, will use all available GPUs. 103 | If -1, extraction will run on CPU. 104 | mixed_precision: bool = True 105 | Whether to use mixed_precision. 106 | 107 | """ 108 | 109 | def __init__( 110 | self, 111 | device: int | list[int] | None = DEFAULT_DEVICE, 112 | mixed_precision: bool = False, 113 | ): 114 | super().__init__() 115 | self.output_dim = 2560 116 | self.mixed_precision = mixed_precision 117 | 118 | timm_kwargs: dict[str, Any] = { 119 | "mlp_layer": timm.layers.SwiGLUPacked, 120 | "act_layer": torch.nn.SiLU, 121 | } 122 | feature_extractor = timm.create_model( 123 | "hf-hub:paige-ai/Virchow2", pretrained=True, **timm_kwargs 124 | ) 125 | 126 | self.feature_extractor, self.device = prepare_module( 127 | feature_extractor, 128 | device, 129 | self.mixed_precision, 130 | ) 131 | if self.device is None: 132 | self.feature_extractor = self.feature_extractor.module 133 | 134 | @property # type: ignore 135 | def transform(self) -> transforms.Compose: 136 | """Transform method to apply element wise.""" 137 | return transforms.Compose( 138 | [ 139 | transforms.ToTensor(), # swap axes and normalize 140 | transforms.Normalize( 141 | mean=(0.485, 0.456, 0.406), 142 | std=(0.229, 0.224, 0.225), 143 | ), 144 | ] 145 | ) 146 | 147 | def __call__(self, images: torch.Tensor) -> np.ndarray: 148 | """Compute and return features. 149 | 150 | Parameters 151 | ---------- 152 | images: torch.Tensor 153 | Input of size (n_tiles, n_channels, dim_x, dim_y). 154 | 155 | Returns 156 | ------- 157 | torch.Tensor: Tensor of size (n_tiles, features_dim). 158 | """ 159 | last_hidden_state = self.feature_extractor(images.to(self.device)) 160 | class_token = last_hidden_state[:, 0] 161 | patch_tokens = last_hidden_state[:, 5:] 162 | features = torch.cat([class_token, patch_tokens.mean(1)], dim=-1) 163 | return features.cpu().numpy() 164 | -------------------------------------------------------------------------------- /docs/contributing.rst: -------------------------------------------------------------------------------- 1 | .. highlight:: shell 2 | 3 | Contributing 4 | ------------ 5 | 6 | Thank you for considering contributing to the PLISM robustness benchmark project! 7 | This section provides instructions on how to set up the project locally and how to 8 | contribute to the codebase. 9 | 10 | 11 | Installation 12 | ~~~~~~~~~~~~ 13 | 14 | To get started, you can download the source code from the `Github repository`_ by 15 | cloning it: 16 | 17 | You can clone the repository: 18 | 19 | .. code-block:: console 20 | 21 | $ git clone git@github.com:owkin/plism-benchmark.git 22 | 23 | Once you have a copy of the source code, we recommend installing the latest version of 24 | `poetry`_: 25 | 26 | .. code-block:: console 27 | 28 | $ make install-poetry 29 | 30 | Next, create a Python environment using your preferred management system (``conda``, 31 | ``pip``, ...). If you don't have a preferred system, ``poetry`` will automatically 32 | create a new environment in ``.venv/`` when you run ``make config``. If you're using 33 | your own environment, make sure to activate it. To activate ``poetry``'s environment, 34 | you can run: ``poetry shell``. Please configure ``poetry`` by running: 35 | 36 | .. code-block:: console 37 | 38 | $ make config 39 | 40 | 41 | Once that's done, and if it has not been generated yet, 42 | you must generate the ``poetry.lock`` file by running: 43 | 44 | .. code-block:: console 45 | 46 | $ make lock 47 | 48 | To install all required dependencies, you can run the following command: 49 | 50 | .. code-block:: console 51 | 52 | $ make install-all 53 | 54 | .. _Github repository: https://github.com/owkin/plism-benchmark 55 | .. _poetry: https://python-poetry.org/docs/ 56 | 57 | 58 | Pre-commit 59 | ~~~~~~~~~~ 60 | 61 | You can run all the aforementioned styling checks manually as described. 62 | However, we encourage you to use `pre-commit hooks `_ 63 | instead to automatically run ``ruff`` and ``mypy``. 64 | This can be done by running : 65 | 66 | .. code-block:: console 67 | 68 | $ pre-commit install 69 | 70 | from the root of the ``plism-benchmark`` repository. Now all of 71 | the styling checks will be run each time you commit changes without your 72 | needing to run each one manually. In addition, using ``pre-commit`` will also 73 | allow you to more easily remain up-to-date with code checks as they evolve. 74 | 75 | If you don’t want to use ``pre-commit`` as part of your workflow, you can 76 | still use it to run its checks with: 77 | 78 | .. code-block:: console 79 | 80 | $ make pre-commit-checks 81 | 82 | without needing to have done ``pre-commit install`` beforehand. 83 | 84 | 85 | Guidelines 86 | ~~~~~~~~~~ 87 | 88 | To contribute to the PLISM robustness benchmark project, follow these steps: 89 | 90 | 1. Fork the `Github repository `_. 91 | 2. Create a new branch for your changes. 92 | 3. Implement them and commit them with clear commit messages. 93 | 4. Push your changes to your branch. 94 | 5. Open a pull request in the ``main`` branch of the `Github repository `_. 95 | 96 | When opening a pull request, make sure to include a clear description of your changes 97 | and why they are necessary. 98 | 99 | 100 | Testing 101 | ~~~~~~~ 102 | 103 | The PLISM robustness benchmark project uses `pytest `_ 104 | for testing. To run the tests, simply run: 105 | 106 | .. code-block:: console 107 | 108 | $ make test-all 109 | 110 | 111 | Make sure that all tests pass before submitting a pull request. 112 | 113 | 114 | Documentation 115 | ~~~~~~~~~~~~~ 116 | 117 | The PLISM robustness benchmark project uses `Sphinx `_ 118 | for documentation. To build the documentation, run: 119 | 120 | .. code-block:: console 121 | 122 | $ make docs 123 | 124 | The documentation will be built in the ``docs/_build/`` directory. 125 | 126 | 127 | New dependencies 128 | ~~~~~~~~~~~~~~~~ 129 | 130 | If or when you add additional dependencies to your project, you can use ``poetry`` 131 | in the following manner: 132 | 133 | .. code-block:: console 134 | 135 | $ poetry add xformers 136 | 137 | 138 | If you already have a ``requirements.txt`` file with your dependencies, you can inject 139 | them using ``poetry`` with the command: 140 | 141 | .. code-block:: console 142 | 143 | $ cat requirements.txt | xargs poetry add 144 | 145 | 146 | If your project requires dependencies that can't be installed using pip, make sure to 147 | add the corresponding installation commands to the ``Makefile`` under the 148 | ``make install`` section like this: 149 | 150 | .. code-block:: Makefile 151 | 152 | install: clean 153 | conda install # Example of dependency only installed with conda 154 | curl | sh # Example of dependency only installed with bash 155 | poetry install 156 | 157 | You can also add a library located in a git repository, the minimum information you 158 | need to specify is the location of the repository with the git key, and if necessary 159 | the branch from which the library is to be installed. By default ``poetry`` will revert 160 | to the master branch. You can do using the following command: 161 | 162 | .. code-block:: console 163 | 164 | $ poetry add "https://github.com/org/mypackage.git#branch=my_branch" 165 | 166 | 167 | Useful tip 168 | ~~~~~~~~~~ 169 | 170 | The repository comes with a preconfigured ``Makefile`` encapsulating numerous 171 | useful commands. To check them out, run the command: 172 | 173 | .. code-block:: console 174 | 175 | $ make help 176 | -------------------------------------------------------------------------------- /notebooks/visualization.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "from pathlib import Path\n", 10 | "\n", 11 | "from plismbench.utils.metrics import (\n", 12 | " format_results,\n", 13 | " get_aggregated_results,\n", 14 | " get_leaderboard_results,\n", 15 | ")\n", 16 | "from plismbench.utils.viz import EXTRACTOR_LABELS_DICT, display_plism_metrics" 17 | ] 18 | }, 19 | { 20 | "cell_type": "markdown", 21 | "metadata": {}, 22 | "source": [ 23 | "`metrics_root_dir` should have this architecture (as produced by `plismbench evaluate`).\n", 24 | "By default, metrics are only computed for 8139 tiles. As as example, you might have different folders corresponding to different extractors (here `h0_mini` and `conch`).\n", 25 | "\n", 26 | "```bash\n", 27 | ".\n", 28 | "├── 2713_tiles\n", 29 | "│   ├── conch\n", 30 | "│   │   ├── metrics.csv\n", 31 | "│   │   ├── pickles\n", 32 | "│   │   └── results.csv\n", 33 | "│   ├── h0_mini\n", 34 | "│   │   ├── metrics.csv\n", 35 | "│   │   ├── pickles\n", 36 | "│   │   └── results.csv\n", 37 | "...\n", 38 | "├── 5426_tiles\n", 39 | "│   ├── conch\n", 40 | "│   │   ├── metrics.csv\n", 41 | "│   │   ├── pickles\n", 42 | "│   │   └── results.csv\n", 43 | "│   ├── h0_mini\n", 44 | "│   │   ├── metrics.csv\n", 45 | "│   │   ├── pickles\n", 46 | "│   │   └── results.csv\n", 47 | "...\n", 48 | "└── 8139_tiles\n", 49 | " ├── conch\n", 50 | " │   ├── metrics.csv\n", 51 | " │   ├── pickles\n", 52 | " │   └── results.csv\n", 53 | " ├── h0_mini\n", 54 | " │   ├── metrics.csv\n", 55 | " │   ├── pickles\n", 56 | " └── └── results.csv\n", 57 | "```" 58 | ] 59 | }, 60 | { 61 | "cell_type": "code", 62 | "execution_count": null, 63 | "metadata": {}, 64 | "outputs": [], 65 | "source": [ 66 | "# Set metrics root directory\n", 67 | "metrics_root_dir = Path(\"/home/owkin/project/plism_metrics/\")\n", 68 | "\n", 69 | "# Retrieve a more detailed version of the results\n", 70 | "agg_type = \"median\" # choose between \"median\" or \"mean\"\n", 71 | "n_tiles = 8139 # default number of tiles\n", 72 | "raw_results = format_results(\n", 73 | " metrics_root_dir,\n", 74 | " agg_type=agg_type,\n", 75 | " n_tiles=n_tiles,\n", 76 | ")\n", 77 | "\n", 78 | "raw_results.head(10)" 79 | ] 80 | }, 81 | { 82 | "cell_type": "code", 83 | "execution_count": null, 84 | "metadata": {}, 85 | "outputs": [], 86 | "source": [ 87 | "# Aggregate and rank results for a specific metric, aggregation over pairs and robustness type\n", 88 | "metric = \"cosine_similarity\" # choose between \"cosine_similarity\", \"top_1_accuracy\", \"top_3_accuracy\", \"top_5_accuracy\", \"top_10_accuracy\"\n", 89 | "robustness_type = \"all\" # choose between \"all\", \"inter-scanner\", \"inter-staining\", \"inter-scanner, inter-staining\"\n", 90 | "results = get_aggregated_results(\n", 91 | " results=raw_results,\n", 92 | " metric_name=metric,\n", 93 | " agg_type=agg_type,\n", 94 | " robustness_type=robustness_type,\n", 95 | ")\n", 96 | "\n", 97 | "results.head(15)" 98 | ] 99 | }, 100 | { 101 | "cell_type": "code", 102 | "execution_count": null, 103 | "metadata": {}, 104 | "outputs": [], 105 | "source": [ 106 | "# Visualize the results\n", 107 | "display_plism_metrics(\n", 108 | " raw_results,\n", 109 | " xlim=(-0.005, 0.25), # may depend on the metric displayed on x-axis\n", 110 | " ylim=(0.4, 0.9), # may depend on the metric displayed on y-axis\n", 111 | " metric_x=\"top_10_accuracy_median\", # should be in ``raw_results``\n", 112 | " metric_y=\"cosine_similarity_median\", # should be in ``raw_results``\n", 113 | " robustness_x=\"all\", # should be in ``raw_results``\n", 114 | " robustness_y=\"all\", # should be in ``raw_results``\n", 115 | " label_x=\"Top-10 accuracy (all pairs)\",\n", 116 | " label_y=\"Cosine similarity (all pairs)\",\n", 117 | " fig_save_path=None, # can be None, a string or Path. You can export to .svg then use Inkscape to move the overlapping labels apart.\n", 118 | ")" 119 | ] 120 | }, 121 | { 122 | "cell_type": "code", 123 | "execution_count": null, 124 | "metadata": {}, 125 | "outputs": [], 126 | "source": [ 127 | "# Get leaderboard results\n", 128 | "leaderboard = get_leaderboard_results(metrics_root_dir=metrics_root_dir)\n", 129 | "leaderboard" 130 | ] 131 | }, 132 | { 133 | "cell_type": "code", 134 | "execution_count": null, 135 | "metadata": {}, 136 | "outputs": [], 137 | "source": [ 138 | "# Export to markdown\n", 139 | "leaderboard.index = leaderboard.index.map(EXTRACTOR_LABELS_DICT)\n", 140 | "print(leaderboard.astype(str).to_markdown(floatfmt=\".3f\"))" 141 | ] 142 | }, 143 | { 144 | "cell_type": "code", 145 | "execution_count": null, 146 | "metadata": {}, 147 | "outputs": [], 148 | "source": [] 149 | } 150 | ], 151 | "metadata": { 152 | "kernelspec": { 153 | "display_name": "Python (vesibench)", 154 | "language": "python", 155 | "name": "python3" 156 | }, 157 | "language_info": { 158 | "codemirror_mode": { 159 | "name": "ipython", 160 | "version": 3 161 | }, 162 | "file_extension": ".py", 163 | "mimetype": "text/x-python", 164 | "name": "python", 165 | "nbconvert_exporter": "python", 166 | "pygments_lexer": "ipython3", 167 | "version": "3.10.16" 168 | }, 169 | "vscode": { 170 | "interpreter": { 171 | "hash": "bdedc4665527ff43f21f83597a20c9857360c358c9dc57bfea9e7d2a253a1bcc" 172 | } 173 | } 174 | }, 175 | "nbformat": 4, 176 | "nbformat_minor": 4 177 | } 178 | -------------------------------------------------------------------------------- /plismbench/utils/metrics.py: -------------------------------------------------------------------------------- 1 | """Aggregation of robustness metrics across different extractors.""" 2 | 3 | from pathlib import Path 4 | 5 | import pandas as pd 6 | 7 | 8 | pd.set_option("future.no_silent_downcasting", True) 9 | 10 | 11 | def get_extractor_results(results_path: Path) -> pd.DataFrame: 12 | """Get robustness results for a given extractor.""" 13 | extractor = results_path.parent.name 14 | results = pd.read_csv(results_path, index_col=0) 15 | results.insert(0, "extractor", extractor) 16 | results.insert(1, "robustness_type", results.index.values) 17 | return results 18 | 19 | 20 | def get_results(metrics_root_dir: Path, n_tiles: int = 8139) -> pd.DataFrame: 21 | """Get robustness results for all extractors and a given number of tiles.""" 22 | results_paths = list((metrics_root_dir / f"{n_tiles}_tiles").glob("*/results.csv")) 23 | results = pd.concat( 24 | [get_extractor_results(results_path) for results_path in results_paths] 25 | ).reset_index(drop=True) 26 | return results 27 | 28 | 29 | def format_results( 30 | metrics_root_dir: Path, 31 | agg_type: str = "median", 32 | n_tiles: int = 8139, 33 | top_k: list[int] | None = None, 34 | ) -> pd.DataFrame: 35 | """Add float columns with parsed metrics wrt an aggregation type ("mean" or "median").""" 36 | if top_k is None: 37 | top_k = [1, 3, 5, 10] 38 | metric_names = ["cosine_similarity"] + [f"top_{k}_accuracy" for k in top_k] 39 | results = get_results(metrics_root_dir, n_tiles=n_tiles) 40 | metric_idx = 0 if agg_type == "mean" else 1 41 | agg_cols = ["_mean", "_std"] if agg_type == "mean" else ["_median", "_iqr"] 42 | output_results = results.map( 43 | lambda x: x.split(" ; ")[metric_idx] if ";" in x else x 44 | ) 45 | for metric_name in metric_names: 46 | metric_agg_cols = [f"{metric_name}{agg_col}" for agg_col in agg_cols] 47 | output_results[metric_agg_cols] = output_results[metric_name].str.extract( 48 | r"([0-9.]+)\s?\(([^)]+)\)" 49 | ) 50 | return output_results 51 | 52 | 53 | def rank_results( 54 | results: pd.DataFrame, 55 | robustness_type: str = "all", 56 | metric_name: str = "top_1_accuracy_median", 57 | ) -> pd.DataFrame: 58 | """Rank results according to a robustness type and metric name.""" 59 | output = pd.pivot( 60 | results, columns="robustness_type", index="extractor", values=metric_name 61 | ) 62 | return output.sort_values(robustness_type, ascending=False) 63 | 64 | 65 | def get_aggregated_results( 66 | results: pd.DataFrame, 67 | metric_name: str = "top_1_accuracy", 68 | robustness_type: str = "all", 69 | agg_type: str = "median", 70 | top_k: list[int] | None = None, 71 | ) -> pd.DataFrame: 72 | """Retrieve results from .csv and rank by a given metric.""" 73 | if top_k is None: 74 | top_k = [1, 3, 5, 10] 75 | supported_metric_names = ["cosine_similarity"] + [ 76 | f"top_{k}_accuracy" for k in top_k 77 | ] 78 | if metric_name not in supported_metric_names: 79 | raise ValueError( 80 | f"{metric_name} metric not supported. Supported: {supported_metric_names}." 81 | ) 82 | if agg_type not in (supported_agg_types := ["mean", "median"]): 83 | raise ValueError( 84 | f"{agg_type} aggregation not supported. Supported: {supported_agg_types}." 85 | ) 86 | if robustness_type not in ( 87 | supported_robustness_types := [ 88 | "all", 89 | "inter-scanner", 90 | "inter-scanner, inter-staining", 91 | "inter-staining", 92 | ] 93 | ): 94 | raise ValueError( 95 | f"{robustness_type} robustness type not supported. Supported: {supported_robustness_types}." 96 | ) 97 | ranked_results = rank_results( 98 | results, 99 | metric_name=f"{metric_name}_{agg_type}", 100 | robustness_type=robustness_type, 101 | ) 102 | ranked_results.insert(0, "extractor", ranked_results.index.values) 103 | return ranked_results 104 | 105 | 106 | def get_leaderboard_results( 107 | metrics_root_dir: Path, 108 | ) -> pd.DataFrame: 109 | """Generate leaderboard results.""" 110 | # Get all results 111 | raw_results = format_results( 112 | metrics_root_dir=metrics_root_dir, agg_type="median", n_tiles=8139, top_k=None 113 | ) 114 | # Get aggregated results for each type of robustness for cosine similarity and top 10 accuracy 115 | cosine_sim_results = get_aggregated_results( 116 | results=raw_results, metric_name="cosine_similarity", agg_type="median" 117 | ) 118 | top_10_acc_results = get_aggregated_results( 119 | results=raw_results, metric_name="top_10_accuracy", agg_type="median" 120 | ) 121 | # Merge the 2 dataframes into one 122 | leaderboard_cols = [ 123 | "all_cosine_similarity", 124 | "inter-scanner_top_10_accuracy", 125 | "inter-staining_top_10_accuracy", 126 | "inter-scanner, inter-staining_top_10_accuracy", 127 | ] 128 | leaderboard_cols_labels = [ 129 | "Cosine similarity (all)", 130 | "Top-10 accuracy (cross-scanner)", 131 | "Top-10 accuracy (cross-staining)", 132 | "Top-10 accuracy (cross-scanner, cross-staining)", 133 | ] 134 | leaderboard_results = ( 135 | cosine_sim_results.sort_index() 136 | .merge( 137 | top_10_acc_results.iloc[:, 1:], 138 | left_index=True, 139 | right_index=True, 140 | suffixes=("_cosine_similarity", "_top_10_accuracy"), 141 | )[leaderboard_cols] 142 | .astype(float) 143 | ) 144 | leaderboard_results.columns = leaderboard_cols_labels # type: ignore 145 | leaderboard_results.insert( 146 | 4, "Leaderboard metric", leaderboard_results.mean(axis=1).round(3) 147 | ) 148 | leaderboard_results = leaderboard_results.sort_values( 149 | "Leaderboard metric", ascending=False 150 | ) 151 | leaderboard_results["Rank"] = [ 152 | f"#{i}" for i in range(1, leaderboard_results.shape[0] + 1) 153 | ] 154 | return leaderboard_results 155 | -------------------------------------------------------------------------------- /plismbench/engine/cli.py: -------------------------------------------------------------------------------- 1 | """A module containing CLI commands of the repository.""" 2 | 3 | from __future__ import annotations 4 | 5 | from pathlib import Path 6 | from typing import Annotated, Union 7 | 8 | import typer 9 | from huggingface_hub import login, snapshot_download 10 | from loguru import logger 11 | 12 | from plismbench.engine.evaluate import compute_metrics 13 | from plismbench.engine.extract.core import run_extract 14 | from plismbench.models import FeatureExtractorsEnum 15 | from plismbench.models.utils import DEFAULT_DEVICE 16 | 17 | 18 | app = typer.Typer(name="plismbench") 19 | 20 | 21 | @app.command() 22 | def extract( 23 | extractor: Annotated[ 24 | str, 25 | typer.Option( 26 | "--extractor", 27 | help="The name of the feature extractor as defined in ``plismbench.models.__init__.py``", 28 | ), 29 | ], 30 | export_dir: Annotated[ 31 | Path, 32 | typer.Option( 33 | "--export-dir", 34 | help=( 35 | "The root folder where features will be stored." 36 | " The final export directory is ``export_dir / extractor``" 37 | ), 38 | ), 39 | ], 40 | streaming: Annotated[ 41 | bool, 42 | typer.Option( 43 | "--streaming", 44 | help="Whether to stream images instead of storing to disk (300Go).", 45 | ), 46 | ] = False, 47 | download_dir: Annotated[ 48 | Union[Path, None], 49 | typer.Option( 50 | "--download-dir", 51 | help="Folder containing the .h5 files downloaded from Hugging Face.", 52 | ), 53 | ] = None, 54 | device: Annotated[ 55 | int, typer.Option("--device", help="The CUDA devnumber or -1 for CPU.") 56 | ] = DEFAULT_DEVICE, 57 | batch_size: Annotated[ 58 | int, typer.Option("--batch-size", help="Features extraction batch size.") 59 | ] = 32, 60 | workers: Annotated[ 61 | int, typer.Option("--workers", help="Number of workers for async loading.") 62 | ] = 8, 63 | overwrite: Annotated[ 64 | bool, 65 | typer.Option( 66 | "--overwrite", 67 | help="Whether to overwrite the previous features extraction run.", 68 | ), 69 | ] = False, 70 | ): 71 | """Perform features extraction on PLISM histology tiles dataset streamed from Hugging-Face. 72 | 73 | .. code-block:: console 74 | 75 | $ plismbench extract --extractor h0_mini --batch-size 32 --export-dir $HOME/tmp/features/ --download-dir $HOME/tmp/slides/ 76 | 77 | """ 78 | supported_feature_extractors = FeatureExtractorsEnum.choices() 79 | if extractor not in supported_feature_extractors: 80 | raise NotImplementedError( 81 | f"Extractor {extractor} not supported." 82 | f" Supported extractors are: {supported_feature_extractors}." 83 | ) 84 | run_extract( 85 | feature_extractor_name=extractor, 86 | export_dir=export_dir / extractor, 87 | download_dir=download_dir, 88 | device=device, 89 | batch_size=batch_size, 90 | workers=workers, 91 | overwrite=overwrite, 92 | streaming=streaming, 93 | ) 94 | 95 | 96 | @app.command() 97 | def download( 98 | download_dir: Annotated[ 99 | Path, 100 | typer.Option( 101 | "--download-dir", 102 | help="Folder containing the .h5 files downloaded from Hugging Face.", 103 | ), 104 | ], 105 | hf_token: Annotated[str, typer.Option("--token", help="Hugging Face token.")], 106 | workers: Annotated[ 107 | int, typer.Option("--workers", help="Number of workers for parallel download.") 108 | ] = 8, 109 | ): 110 | """Download PLISM dataset from Hugging Face.""" 111 | login(token=hf_token, new_session=False) 112 | _ = snapshot_download( 113 | repo_id="owkin/plism-dataset", 114 | repo_type="dataset", 115 | local_dir=download_dir, 116 | allow_patterns=["*_to_GMH_S60.tif.h5"], 117 | ignore_patterns=[".gitattribues"], 118 | max_workers=workers, 119 | ) 120 | 121 | 122 | @app.command() 123 | def evaluate( 124 | extractor: Annotated[ 125 | str, 126 | typer.Option( 127 | "--extractor", 128 | help="The name of the feature extractor as defined in ``plismbench.models.__init__.py``", 129 | ), 130 | ], 131 | features_dir: Annotated[ 132 | Path, 133 | typer.Option( 134 | "--features-dir", 135 | help=( 136 | "The root folder where features will be stored." 137 | " The final export directory is ``export_dir / extractor``." 138 | ), 139 | ), 140 | ], 141 | metrics_dir: Annotated[ 142 | Path, 143 | typer.Option( 144 | "--metrics-dir", 145 | help=( 146 | "Folder containing the output metrics." 147 | " The final export directory is ``metrics_dir / extractor``." 148 | ), 149 | ), 150 | ], 151 | n_tiles: Annotated[ 152 | Union[str, None], 153 | typer.Option( 154 | "--n-tiles", help="Number of tiles per slide for metrics computation." 155 | ), 156 | ] = None, 157 | top_k: Annotated[ 158 | Union[str, None], 159 | typer.Option("--top-k", help="Values of k for top-k accuracy computation."), 160 | ] = None, 161 | device: Annotated[ 162 | str, 163 | typer.Option( 164 | "--device", help="'cpu' (parallel computation) or 'gpu' (sequential)." 165 | ), 166 | ] = "gpu", 167 | workers: Annotated[ 168 | int, 169 | typer.Option( 170 | "--workers", help="Number of workers for cpu parallel computations." 171 | ), 172 | ] = 4, 173 | overwrite: Annotated[ 174 | bool, 175 | typer.Option( 176 | "--overwrite", 177 | help="Whether to overwrite existing metrics.", 178 | ), 179 | ] = False, 180 | ): 181 | """Compute robustness metrics for a list of feature extractors.""" 182 | logger.info(f"Computing metrics for extractor {extractor}.") 183 | _ = compute_metrics( 184 | features_root_dir=features_dir, 185 | metrics_save_dir=metrics_dir, 186 | extractor=extractor, 187 | top_k=top_k if top_k is None else [int(t) for t in top_k.split(" ")], 188 | n_tiles=int(n_tiles) if n_tiles is not None else n_tiles, 189 | device=device, 190 | overwrite=overwrite, 191 | workers=workers, 192 | ) 193 | 194 | 195 | if __name__ == "__main__": 196 | app() 197 | -------------------------------------------------------------------------------- /plismbench/engine/extract/extract_from_h5.py: -------------------------------------------------------------------------------- 1 | """Download PLISM tiles dataset as h5 files and extract features for a given model.""" 2 | 3 | from __future__ import annotations 4 | 5 | from collections.abc import Callable 6 | from functools import partial 7 | from pathlib import Path 8 | 9 | import h5py 10 | import numpy as np 11 | import torch 12 | from loguru import logger 13 | from torch.utils.data import DataLoader 14 | from tqdm import tqdm 15 | 16 | from plismbench.engine.extract.utils import ( 17 | NUM_SLIDES, 18 | NUM_TILES_PER_SLIDE, 19 | process_imgs, 20 | save_features, 21 | ) 22 | from plismbench.models import FeatureExtractorsEnum 23 | 24 | 25 | class H5Dataset(torch.utils.data.Dataset): 26 | """Dataset wrapper iterating over a .h5 file content. 27 | 28 | Parameters 29 | ---------- 30 | file_path: Path 31 | Path to the .h5 file. 32 | """ 33 | 34 | def __init__(self, file_path: Path): 35 | super().__init__() 36 | self.file_path = file_path 37 | self.data = h5py.File(self.file_path, "r", libver="latest", swmr=True) 38 | self.keys = list(self.data.keys()) 39 | 40 | def __len__(self): 41 | """Get length of dataset.""" 42 | length = len(self.keys) 43 | assert length == NUM_TILES_PER_SLIDE, ( 44 | f"H5 file for slide {self.file_path.stem} does not contain {NUM_TILES_PER_SLIDE} tiles!" 45 | ) 46 | return length 47 | 48 | def __getitem__(self, idx): 49 | """Get next item (``tile_id``, ``tile_array``).""" 50 | tile_id = self.keys[idx] 51 | tile_array = self.data[tile_id][:] 52 | return tile_id, tile_array 53 | 54 | 55 | def collate( 56 | batch: list[tuple[str, torch.Tensor]], 57 | transform: Callable[[np.ndarray], torch.Tensor], 58 | ) -> tuple[list[str], torch.Tensor]: 59 | """Return tile ids and transformed images. 60 | 61 | Parameters 62 | ---------- 63 | batch: list[dict[str, Any]] 64 | List of length ``batch_size`` made of tuples. 65 | Each tuple represents a tile_id and the corresponding image. 66 | The image is a torch.float32 tensor (between 0 and 1). 67 | transform: Callable[[np.ndarray], torch.Tensor] 68 | Transform function taking ``np.ndarray`` image as inputs. 69 | 70 | Returns 71 | ------- 72 | output: tuple[list[str], torch.Tensor] 73 | A tuple made of tiles ids and transformed input images. 74 | """ 75 | tile_ids: list[str] = [b[0] for b in batch] 76 | raw_imgs: list[np.ndarray] = [b[1] for b in batch] # type: ignore 77 | imgs = torch.stack([transform(img) for img in raw_imgs]) 78 | output = (tile_ids, imgs) 79 | return output 80 | 81 | 82 | def get_dataloader( 83 | slide_h5_path: Path, 84 | transform: Callable[[np.ndarray], torch.Tensor], 85 | batch_size: int = 32, 86 | workers: int = 8, 87 | ) -> DataLoader: 88 | """Get PLISM tiles dataset dataloader transformed with ``transform`` function. 89 | 90 | Parameters 91 | ---------- 92 | slide_h5_path: Path 93 | Path to the .h5 containing tiles for a given slide. 94 | transform: Callable[[np.ndarray], torch.Tensor] 95 | Transform function taking ``np.ndarray`` image as inputs. 96 | batch_size: int = 32 97 | Batch size for features extraction. 98 | workers: int = 8 99 | Number of workers to load images. 100 | 101 | Returns 102 | ------- 103 | dataloader: DataLoader 104 | DataLoader returning (tile_ids, images). 105 | See ``collate`` function for details. 106 | """ 107 | dataset = H5Dataset(file_path=slide_h5_path) 108 | dataloader = DataLoader( 109 | dataset, 110 | batch_size=batch_size, 111 | collate_fn=partial(collate, transform=transform), 112 | num_workers=workers, 113 | pin_memory=True, 114 | shuffle=False, 115 | ) 116 | return dataloader 117 | 118 | 119 | def run_extract_h5( 120 | feature_extractor_name: str, 121 | batch_size: int, 122 | device: int, 123 | export_dir: Path, 124 | download_dir: Path, 125 | overwrite: bool = False, 126 | workers: int = 8, 127 | ) -> None: 128 | """Run features extraction.""" 129 | if overwrite: 130 | logger.warning("You are about to overwrite existing features.") 131 | logger.info(f"Download directory set to {str(download_dir)}.") 132 | logger.info(f"Export directory set to {str(export_dir)}.") 133 | 134 | # Create export directory if it doesn't exist 135 | export_dir.mkdir(exist_ok=True, parents=True) 136 | 137 | # Initialize the feature extractor 138 | feature_extractor = FeatureExtractorsEnum[feature_extractor_name.upper()].init( 139 | device=device 140 | ) 141 | image_transform = feature_extractor.transform 142 | 143 | slide_h5_paths = list(download_dir.glob("*.tif.h5")) 144 | assert (n_slides := len(slide_h5_paths)) == NUM_SLIDES, ( 145 | f"Download uncomplete: found {n_slides}/{NUM_SLIDES}" 146 | ) 147 | 148 | for slide_h5_path in tqdm(slide_h5_paths): 149 | # Get slide id 150 | slide_id = slide_h5_path.stem 151 | # Get output path for features 152 | slide_features_export_dir = Path(export_dir / slide_id) 153 | slide_features_export_path = slide_features_export_dir / "features.npy" 154 | if slide_features_export_path.exists(): 155 | if overwrite: 156 | logger.info( 157 | f"Features for slide {slide_id} already extracted. Overwriting..." 158 | ) 159 | else: 160 | logger.info( 161 | f"Features for slide {slide_id} already extracted. Skipping..." 162 | ) 163 | continue 164 | slide_features_export_dir.mkdir(exist_ok=True, parents=True) 165 | # Instanciate the dataloader 166 | dataloader = get_dataloader( 167 | slide_h5_path=slide_h5_path, 168 | transform=image_transform, 169 | batch_size=batch_size, 170 | workers=workers, 171 | ) 172 | # Iterate over the full dataset and store features each time 16,278 input images have been processed 173 | slide_features: list[np.ndarray] = [] 174 | for tile_ids, tile_images in tqdm( 175 | dataloader, total=len(dataloader), leave=False 176 | ): 177 | batch_stack = process_imgs(tile_images, tile_ids, model=feature_extractor) 178 | slide_features.append(batch_stack) 179 | save_features( 180 | slide_features, 181 | slide_id=slide_id, 182 | export_path=slide_features_export_path, 183 | ) 184 | logger.success(f"Successfully saved features for slide: {slide_id}") 185 | -------------------------------------------------------------------------------- /plismbench/models/__init__.py: -------------------------------------------------------------------------------- 1 | """Unit tests for :mod:`plismbench.models`.""" 2 | 3 | from __future__ import annotations 4 | 5 | from enum import Enum 6 | 7 | from plismbench.models.bioptimus import H0Mini, HOptimus0, HOptimus1 8 | from plismbench.models.extractor import Extractor 9 | from plismbench.models.histai import HibouBase, HibouLarge 10 | from plismbench.models.hkust import GPFM 11 | from plismbench.models.kaiko_ai import KaikoViTBase, KaikoViTLarge, Midnight12k 12 | from plismbench.models.lunit import LunitViTS8 13 | from plismbench.models.mahmood_lab import CONCH, UNI, CONCHv15, UNI2h 14 | from plismbench.models.meta import Dinov2ViTGiant 15 | from plismbench.models.microsoft import ProvGigaPath 16 | from plismbench.models.owkin import Phikon, PhikonV2 17 | from plismbench.models.paige_ai import Virchow, Virchow2 18 | from plismbench.models.standford import PLIP 19 | 20 | 21 | class StringEnum(Enum): 22 | """A base class string enumerator.""" 23 | 24 | def __str__(self) -> str: 25 | return str(self.value) 26 | 27 | @classmethod 28 | def choices(cls): 29 | """Get Enum names.""" 30 | return tuple(i.value for i in cls) 31 | 32 | 33 | class FeatureExtractorsEnum(StringEnum): 34 | """A class enumerator for feature extractors.""" 35 | 36 | # please follow the format "upper case = lower case" 37 | # this should map exactly the name in constants 38 | 39 | # Bioptimus 40 | H0_MINI = "h0_mini" 41 | HOPTIMUS0 = "hoptimus0" 42 | # HOPTIMUS1 = "hoptimus1" # access not granted for now 43 | # Kaiko AI 44 | KAIKO_VIT_BASE = "kaiko_vit_base" 45 | KAIKO_VIT_LARGE = "kaiko_vit_large" 46 | MIDNIGHT_12K = "midnight_12k" 47 | # Paige AI 48 | VIRCHOW = "virchow" 49 | VIRCHOW2 = "virchow2" 50 | # Microsoft 51 | PROVGIGAPATH = "provgigapath" 52 | # Mahmood Lab 53 | CONCH = "conch" 54 | CONCHV15 = "conchv15" 55 | UNI = "uni" 56 | UNI2H = "uni2h" 57 | # HistAI 58 | HIBOU_BASE = "hibou_base" 59 | HIBOU_LARGE = "hibou_large" 60 | # Owkin 61 | PHIKON = "phikon" 62 | PHIKONV2 = "phikonv2" 63 | # HKUST 64 | GPFM = "gpfm" 65 | # Standford 66 | PLIP = "plip" 67 | # Lunit 68 | LUNIT_VIT_SMALL_8 = "lunit_vit_small_8" 69 | # Meta 70 | DINOV2_VIT_GIANT_IMAGENET = "dinov2_vit_giant_imagenet" 71 | 72 | def init( # noqa: PLR0911, PLR0912 73 | self, 74 | device: int | list[int] | None, 75 | mixed_precision: bool = True, 76 | **kwargs, 77 | ) -> Extractor: 78 | """Initialize the feature extractor. Mixed precision is set by default.""" 79 | if self is self.H0_MINI: 80 | return H0Mini( 81 | device=device, 82 | mixed_precision=mixed_precision, 83 | **kwargs, 84 | ) 85 | elif self is self.HOPTIMUS0: 86 | return HOptimus0( 87 | device=device, 88 | mixed_precision=mixed_precision, 89 | **kwargs, 90 | ) 91 | # access not granted for now 92 | # elif self is self.HOPTIMUS1: 93 | # return HOptimus1( 94 | # device=device, 95 | # mixed_precision=mixed_precision, 96 | # **kwargs, 97 | # ) 98 | elif self is self.KAIKO_VIT_BASE: 99 | return KaikoViTBase( 100 | device=device, 101 | mixed_precision=mixed_precision, 102 | **kwargs, 103 | ) 104 | elif self is self.KAIKO_VIT_LARGE: 105 | return KaikoViTLarge( 106 | device=device, 107 | mixed_precision=mixed_precision, 108 | **kwargs, 109 | ) 110 | elif self is self.MIDNIGHT_12K: 111 | return Midnight12k( 112 | device=device, 113 | mixed_precision=mixed_precision, 114 | **kwargs, 115 | ) 116 | elif self is self.VIRCHOW: 117 | return Virchow( 118 | device=device, 119 | mixed_precision=mixed_precision, 120 | **kwargs, 121 | ) 122 | elif self is self.VIRCHOW2: 123 | return Virchow2( 124 | device=device, 125 | mixed_precision=mixed_precision, 126 | **kwargs, 127 | ) 128 | elif self is self.PROVGIGAPATH: 129 | return ProvGigaPath( 130 | device=device, 131 | mixed_precision=mixed_precision, 132 | **kwargs, 133 | ) 134 | elif self is self.CONCH: 135 | return CONCH( 136 | device=device, 137 | mixed_precision=mixed_precision, 138 | **kwargs, 139 | ) 140 | elif self is self.CONCHV15: 141 | return CONCHv15( 142 | device=device, 143 | mixed_precision=mixed_precision, 144 | **kwargs, 145 | ) 146 | elif self is self.UNI: 147 | return UNI( 148 | device=device, 149 | mixed_precision=mixed_precision, 150 | **kwargs, 151 | ) 152 | elif self is self.UNI2H: 153 | return UNI2h( 154 | device=device, 155 | mixed_precision=mixed_precision, 156 | **kwargs, 157 | ) 158 | elif self is self.HIBOU_BASE: 159 | return HibouBase( 160 | device=device, 161 | mixed_precision=mixed_precision, 162 | **kwargs, 163 | ) 164 | elif self is self.HIBOU_LARGE: 165 | return HibouLarge( 166 | device=device, 167 | mixed_precision=mixed_precision, 168 | **kwargs, 169 | ) 170 | elif self is self.PHIKON: 171 | return Phikon( 172 | device=device, 173 | mixed_precision=mixed_precision, 174 | **kwargs, 175 | ) 176 | elif self is self.PHIKONV2: 177 | return PhikonV2( 178 | device=device, 179 | mixed_precision=mixed_precision, 180 | **kwargs, 181 | ) 182 | elif self is self.GPFM: 183 | return GPFM( 184 | device=device, 185 | mixed_precision=mixed_precision, 186 | **kwargs, 187 | ) 188 | elif self is self.PLIP: 189 | return PLIP( 190 | device=device, 191 | mixed_precision=mixed_precision, 192 | **kwargs, 193 | ) 194 | elif self is self.DINOV2_VIT_GIANT_IMAGENET: 195 | return Dinov2ViTGiant( 196 | device=device, 197 | mixed_precision=mixed_precision, 198 | **kwargs, 199 | ) 200 | elif self is self.LUNIT_VIT_SMALL_8: 201 | return LunitViTS8( 202 | device=device, 203 | mixed_precision=mixed_precision, 204 | **kwargs, 205 | ) 206 | else: 207 | raise NotImplementedError(f"Extractor {self} is not supported.") 208 | -------------------------------------------------------------------------------- /plismbench/models/kaiko_ai.py: -------------------------------------------------------------------------------- 1 | """Models from Kaiko AI company.""" 2 | 3 | from __future__ import annotations 4 | 5 | import numpy as np 6 | import torch 7 | from torchvision import transforms 8 | from transformers import AutoModel 9 | 10 | from plismbench.models.extractor import Extractor 11 | from plismbench.models.utils import DEFAULT_DEVICE, prepare_module 12 | 13 | 14 | class KaikoViTBase(Extractor): 15 | """Kaiko ViT-Base model available on Pytorch Hub (1-2). 16 | 17 | .. note:: 18 | (1) kaiko. ai, Aben, N., de Jong, E. D., Gatopoulos, I., Känzig, N., Karasikov, M., Lagré, A., Moser, R., van Doorn, J., & Tang, F. (2024). Towards large-scale training of pathology foundation models. arXiv. https://arxiv.org/abs/2404.15217 19 | (2) https://github.com/kaiko-ai/towards_large_pathology_fms 20 | 21 | Parameters 22 | ---------- 23 | device: int | list[int] | None = DEFAULT_DEVICE, 24 | Compute resources to use. 25 | If None, will use all available GPUs. 26 | If -1, extraction will run on CPU. 27 | mixed_precision: bool = True 28 | Whether to use mixed_precision. 29 | """ 30 | 31 | def __init__( 32 | self, 33 | device: int | list[int] | None = DEFAULT_DEVICE, 34 | mixed_precision: bool = False, 35 | ): 36 | super().__init__() 37 | self.output_dim = 768 38 | self.mixed_precision = mixed_precision 39 | 40 | feature_extractor = torch.hub.load( 41 | "kaiko-ai/towards_large_pathology_fms", 42 | "vitb8", 43 | trust_repo=True, 44 | verbose=True, 45 | ) 46 | 47 | self.feature_extractor, self.device = prepare_module( 48 | feature_extractor, 49 | device, 50 | self.mixed_precision, 51 | ) 52 | if self.device is None: 53 | self.feature_extractor = self.feature_extractor.module 54 | 55 | @property # type: ignore 56 | def transform(self) -> transforms.Compose: 57 | """Transform method to apply element wise.""" 58 | return transforms.Compose( 59 | [ 60 | transforms.ToTensor(), # swap axes and normalize 61 | transforms.Normalize( 62 | mean=(0.5, 0.5, 0.5), 63 | std=(0.5, 0.5, 0.5), 64 | ), 65 | ] 66 | ) 67 | 68 | def __call__(self, images: torch.Tensor) -> np.ndarray: 69 | """Compute and return features. 70 | 71 | Parameters 72 | ---------- 73 | images: torch.Tensor 74 | Input of size (n_tiles, n_channels, dim_x, dim_y). 75 | 76 | Returns 77 | ------- 78 | torch.Tensor: Tensor of size (n_tiles, features_dim). 79 | """ 80 | features = self.feature_extractor(images.to(self.device)) 81 | return features.cpu().numpy() 82 | 83 | 84 | class KaikoViTLarge(Extractor): 85 | """Kaiko ViT-Large model available on Pytorch Hub (1-2). 86 | 87 | .. note:: 88 | (1) kaiko. ai, Aben, N., de Jong, E. D., Gatopoulos, I., Känzig, N., Karasikov, M., Lagré, A., Moser, R., van Doorn, J., & Tang, F. (2024). Towards large-scale training of pathology foundation models. arXiv. https://arxiv.org/abs/2404.15217 89 | (2) https://github.com/kaiko-ai/towards_large_pathology_fms 90 | 91 | Parameters 92 | ---------- 93 | device: int | list[int] | None = DEFAULT_DEVICE, 94 | Compute resources to use. 95 | If None, will use all available GPUs. 96 | If -1, extraction will run on CPU. 97 | mixed_precision: bool = True 98 | Whether to use mixed_precision. 99 | """ 100 | 101 | def __init__( 102 | self, 103 | device: int | list[int] | None = DEFAULT_DEVICE, 104 | mixed_precision: bool = False, 105 | ): 106 | super().__init__() 107 | self.output_dim = 1024 108 | self.mixed_precision = mixed_precision 109 | 110 | feature_extractor = torch.hub.load( 111 | "kaiko-ai/towards_large_pathology_fms", 112 | "vitl14", 113 | trust_repo=True, 114 | verbose=True, 115 | ) 116 | 117 | self.feature_extractor, self.device = prepare_module( 118 | feature_extractor, 119 | device, 120 | self.mixed_precision, 121 | ) 122 | if self.device is None: 123 | self.feature_extractor = self.feature_extractor.module 124 | 125 | @property # type: ignore 126 | def transform(self) -> transforms.Compose: 127 | """Transform method to apply element wise.""" 128 | return transforms.Compose( 129 | [ 130 | transforms.ToTensor(), # swap axes and normalize 131 | transforms.Normalize( 132 | mean=(0.5, 0.5, 0.5), 133 | std=(0.5, 0.5, 0.5), 134 | ), 135 | ] 136 | ) 137 | 138 | def __call__(self, images: torch.Tensor) -> np.ndarray: 139 | """Compute and return features. 140 | 141 | Parameters 142 | ---------- 143 | images: torch.Tensor 144 | Input of size (n_tiles, n_channels, dim_x, dim_y). 145 | 146 | Returns 147 | ------- 148 | torch.Tensor: Tensor of size (n_tiles, features_dim). 149 | """ 150 | features = self.feature_extractor(images.to(self.device)) 151 | return features.cpu().numpy() 152 | 153 | 154 | class Midnight12k(Extractor): 155 | """Midnight-12k model developped by Kaiko AI available on Hugging-Face (1). 156 | 157 | .. note:: 158 | (1) https://huggingface.co/kaiko-ai/midnight 159 | 160 | Parameters 161 | ---------- 162 | device: int | list[int] | None = DEFAULT_DEVICE, 163 | Compute resources to use. 164 | If None, will use all available GPUs. 165 | If -1, extraction will run on CPU. 166 | mixed_precision: bool = True 167 | Whether to use mixed_precision. 168 | 169 | """ 170 | 171 | def __init__( 172 | self, 173 | device: int | list[int] | None = DEFAULT_DEVICE, 174 | mixed_precision: bool = False, 175 | ): 176 | super().__init__() 177 | self.output_dim = 3072 178 | self.mixed_precision = mixed_precision 179 | 180 | feature_extractor = AutoModel.from_pretrained("kaiko-ai/midnight") 181 | 182 | self.feature_extractor, self.device = prepare_module( 183 | feature_extractor, 184 | device, 185 | self.mixed_precision, 186 | ) 187 | 188 | if self.device is None: 189 | self.feature_extractor = self.feature_extractor.module 190 | 191 | @property # type: ignore 192 | def transform(self) -> transforms.Compose: 193 | """Transform method to apply element wise.""" 194 | return transforms.Compose( 195 | [ 196 | transforms.ToTensor(), # swap axes and normalize 197 | transforms.Normalize( 198 | mean=(0.5, 0.5, 0.5), 199 | std=(0.5, 0.5, 0.5), 200 | ), 201 | ] 202 | ) 203 | 204 | def __call__(self, images: torch.Tensor) -> np.ndarray: 205 | """Compute and return features. 206 | 207 | Parameters 208 | ---------- 209 | images: torch.Tensor 210 | Input of size (n_tiles, n_channels, dim_x, dim_y). 211 | 212 | Returns 213 | ------- 214 | torch.Tensor: Tensor of size (n_tiles, features_dim). 215 | """ 216 | last_hidden_state = self.feature_extractor(images.to(self.device)) 217 | class_token = last_hidden_state[:, 0] 218 | patch_tokens = last_hidden_state[:, 1:] 219 | features = torch.cat([class_token, patch_tokens.mean(1)], dim=-1) 220 | return features.cpu().numpy() 221 | -------------------------------------------------------------------------------- /plismbench/models/bioptimus.py: -------------------------------------------------------------------------------- 1 | """Models from Bioptimus company.""" 2 | 3 | from __future__ import annotations 4 | 5 | from typing import Any 6 | 7 | import numpy as np 8 | import timm 9 | import torch 10 | from torchvision import transforms 11 | 12 | from plismbench.models.extractor import Extractor 13 | from plismbench.models.utils import DEFAULT_DEVICE, prepare_module 14 | 15 | 16 | class HOptimus0(Extractor): 17 | """H-Optimus-0 model developped by Bioptimus available on Hugging-Face (1). 18 | 19 | .. note:: 20 | (1) https://huggingface.co/bioptimus/H-optimus-0 21 | 22 | Parameters 23 | ---------- 24 | device: int | list[int] | None = DEFAULT_DEVICE, 25 | Compute resources to use. 26 | If None, will use all available GPUs. 27 | If -1, extraction will run on CPU. 28 | mixed_precision: bool = True 29 | Whether to use mixed_precision. 30 | 31 | """ 32 | 33 | def __init__( 34 | self, 35 | device: int | list[int] | None = DEFAULT_DEVICE, 36 | mixed_precision: bool = False, 37 | ): 38 | super().__init__() 39 | self.output_dim = 1536 40 | self.mixed_precision = mixed_precision 41 | 42 | timm_kwargs: dict[str, Any] = { 43 | "init_values": 1e-5, 44 | "dynamic_img_size": False, 45 | } 46 | feature_extractor = timm.create_model( 47 | "hf-hub:bioptimus/H-optimus-0", pretrained=True, **timm_kwargs 48 | ) 49 | 50 | self.feature_extractor, self.device = prepare_module( 51 | feature_extractor, 52 | device, 53 | self.mixed_precision, 54 | ) 55 | if self.device is None: 56 | self.feature_extractor = self.feature_extractor.module 57 | 58 | @property # type: ignore 59 | def transform(self) -> transforms.Compose: 60 | """Transform method to apply element wise.""" 61 | return transforms.Compose( 62 | [ 63 | transforms.ToTensor(), # swap axes and normalize 64 | transforms.Normalize( 65 | mean=(0.707223, 0.578729, 0.703617), 66 | std=(0.211883, 0.230117, 0.177517), 67 | ), 68 | ] 69 | ) 70 | 71 | def __call__(self, images: torch.Tensor) -> np.ndarray: 72 | """Compute and return features. 73 | 74 | Parameters 75 | ---------- 76 | images: torch.Tensor 77 | Input of size (n_tiles, n_channels, dim_x, dim_y). 78 | 79 | Returns 80 | ------- 81 | torch.Tensor: Tensor of size (n_tiles, features_dim). 82 | """ 83 | features = self.feature_extractor(images.to(self.device)) 84 | return features.cpu().numpy() 85 | 86 | 87 | class H0Mini(Extractor): 88 | """H0-mini model developped by Owkin & Bioptimus available on Hugging-Face (1). 89 | 90 | You will need to be granted access to be able to use this model. 91 | 92 | .. note:: 93 | (1) https://huggingface.co/bioptimus/H0-mini 94 | 95 | Parameters 96 | ---------- 97 | device: int | list[int] | None = DEFAULT_DEVICE, 98 | Compute resources to use. 99 | If None, will use all available GPUs. 100 | If -1, extraction will run on CPU. 101 | mixed_precision: bool = True 102 | Whether to use mixed_precision. 103 | 104 | """ 105 | 106 | def __init__( 107 | self, 108 | device: int | list[int] | None = DEFAULT_DEVICE, 109 | mixed_precision: bool = False, 110 | ): 111 | super().__init__() 112 | self.output_dim = 768 113 | self.mixed_precision = mixed_precision 114 | 115 | timm_kwargs: dict[str, Any] = { 116 | "mlp_layer": timm.layers.SwiGLUPacked, 117 | "act_layer": torch.nn.SiLU, 118 | } 119 | feature_extractor = timm.create_model( 120 | "hf-hub:bioptimus/H0-mini", pretrained=True, **timm_kwargs 121 | ) 122 | 123 | self.feature_extractor, self.device = prepare_module( 124 | feature_extractor, 125 | device, 126 | self.mixed_precision, 127 | ) 128 | if self.device is None: 129 | self.feature_extractor = self.feature_extractor.module 130 | 131 | @property # type: ignore 132 | def transform(self) -> transforms.Compose: 133 | """Transform method to apply element wise.""" 134 | return transforms.Compose( 135 | [ 136 | transforms.ToTensor(), # swap axes and normalize 137 | transforms.Normalize( 138 | mean=(0.707223, 0.578729, 0.703617), 139 | std=(0.211883, 0.230117, 0.177517), 140 | ), 141 | ] 142 | ) 143 | 144 | def __call__(self, images: torch.Tensor) -> np.ndarray: 145 | """Compute and return features. 146 | 147 | Parameters 148 | ---------- 149 | images: torch.Tensor 150 | Input of size (n_tiles, n_channels, dim_x, dim_y). 151 | 152 | Returns 153 | ------- 154 | torch.Tensor: Tensor of size (n_tiles, features_dim). 155 | """ 156 | last_hidden_state = self.feature_extractor(images.to(self.device)) 157 | features = last_hidden_state[:, 0] # only cls token 158 | return features.cpu().numpy() 159 | 160 | 161 | class HOptimus1(Extractor): 162 | """H-Optimus-1 model developped by Bioptimus available on Hugging-Face (1). 163 | 164 | You will need to be granted access to be able to use this model. 165 | 166 | .. note:: 167 | (1) https://huggingface.co/bioptimus/H-optimus-1 168 | 169 | Parameters 170 | ---------- 171 | device: int | list[int] | None = DEFAULT_DEVICE, 172 | Compute resources to use. 173 | If None, will use all available GPUs. 174 | If -1, extraction will run on CPU. 175 | mixed_precision: bool = True 176 | Whether to use mixed_precision. 177 | 178 | """ 179 | 180 | def __init__( 181 | self, 182 | device: int | list[int] | None = DEFAULT_DEVICE, 183 | mixed_precision: bool = False, 184 | ): 185 | super().__init__() 186 | self.output_dim = 1536 187 | self.mixed_precision = mixed_precision 188 | 189 | timm_kwargs: dict[str, Any] = { 190 | "init_values": 1e-5, 191 | "dynamic_img_size": False, 192 | } 193 | feature_extractor = timm.create_model( 194 | "hf-hub:bioptimus/H-optimus-1", pretrained=True, **timm_kwargs 195 | ) 196 | 197 | self.feature_extractor, self.device = prepare_module( 198 | feature_extractor, 199 | device, 200 | self.mixed_precision, 201 | ) 202 | if self.device is None: 203 | self.feature_extractor = self.feature_extractor.module 204 | 205 | @property # type: ignore 206 | def transform(self) -> transforms.Compose: 207 | """Transform method to apply element wise.""" 208 | return transforms.Compose( 209 | [ 210 | transforms.ToTensor(), # swap axes and normalize 211 | transforms.Normalize( 212 | mean=(0.707223, 0.578729, 0.703617), 213 | std=(0.211883, 0.230117, 0.177517), 214 | ), 215 | ] 216 | ) 217 | 218 | def __call__(self, images: torch.Tensor) -> np.ndarray: 219 | """Compute and return features. 220 | 221 | Parameters 222 | ---------- 223 | images: torch.Tensor 224 | Input of size (n_tiles, n_channels, dim_x, dim_y). 225 | 226 | Returns 227 | ------- 228 | torch.Tensor: Tensor of size (n_tiles, features_dim). 229 | """ 230 | features = self.feature_extractor(images.to(self.device)) 231 | return features.cpu().numpy() 232 | -------------------------------------------------------------------------------- /plismbench/utils/viz.py: -------------------------------------------------------------------------------- 1 | """Visualization of robustness results across different extractors.""" 2 | 3 | from pathlib import Path 4 | from typing import Any 5 | 6 | import matplotlib.pyplot as plt 7 | import pandas as pd 8 | import seaborn as sns 9 | 10 | 11 | sns.set_style("darkgrid") 12 | pd.set_option("future.no_silent_downcasting", True) 13 | 14 | # Please leave those 2 dictionnaries as is. 15 | EXTRACTOR_LABELS_DICT = { 16 | "conch": "CONCH", 17 | "gpfm": "GPFM", 18 | "hibou_vit_base": "Hibou Base", 19 | "hibou_vit_large": "Hibou Large", 20 | "h0_mini": "H0-Mini", 21 | "hoptimus0": "H-Optimus-0", 22 | "kaiko_vit_base_8": "Kaiko ViT-B/8", 23 | "kaiko_vit_large_14": "Kaiko ViT-L/14", 24 | "midnight_12k": "Midnight-12k", 25 | "phikon": "Phikon", 26 | "phikon_v2": "Phikon v2", 27 | "plip": "PLIP", 28 | "provgigapath": "Prov-GigaPath", 29 | "uni": "UNI", 30 | "uni2h": "UNI2-h", 31 | "virchow": "Virchow", 32 | "virchow2": "Virchow2", 33 | } 34 | EXTRACTOR_PARAMETERS_DICT = { 35 | "conch": 86_000_000, 36 | "gpfm": 307_000_000, 37 | "hibou_vit_base": 86_000_000, 38 | "hibou_vit_large": 307_000_000, 39 | "h0_mini": 86_000_000, 40 | "hoptimus0": 1_100_000_000, 41 | "kaiko_vit_base_8": 86_000_000, 42 | "kaiko_vit_large_14": 307_000_000, 43 | "midnight_12k": 1_100_000_000, 44 | "phikon": 86_000_000, 45 | "phikon_v2": 307_000_000, 46 | "plip": 86_000_000, 47 | "provgigapath": 1_100_000_000, 48 | "uni": 307_000_000, 49 | "uni2h": 681_000_000, 50 | "virchow": 632_000_000, 51 | "virchow2": 632_000_000, 52 | } 53 | 54 | 55 | def expand_columns(raw_results: pd.DataFrame) -> pd.DataFrame: 56 | """Expand columns so as to have one column per metric and robustness type.""" 57 | output = [] 58 | # Robustness types are "all", "inter-scanner", "inter-staining", 59 | # "inter-scanner, inter-staining" 60 | for robustness_type in raw_results["robustness_type"].unique(): 61 | subset = raw_results[ 62 | raw_results["robustness_type"] == robustness_type 63 | ].sort_values("extractor") 64 | subset = subset.set_index("extractor").iloc[:, 1:] 65 | subset.columns = [f"{c}__{robustness_type}" for c in subset.columns] 66 | output.append(subset) 67 | output_df = pd.concat(output, axis=1) 68 | output_df.insert( 69 | 0, "extractor", output_df.index.to_series().replace(EXTRACTOR_LABELS_DICT) 70 | ) 71 | output_df.insert( 72 | 1, "Parameters", output_df.index.to_series().replace(EXTRACTOR_PARAMETERS_DICT) 73 | ) 74 | return output_df 75 | 76 | 77 | def display_plism_metrics( 78 | raw_results: pd.DataFrame, 79 | metric_x: str = "cosine_similarity_median", 80 | metric_y: str = "top_1_accuracy_median", 81 | robustness_x: str = "all", 82 | robustness_y: str = "all", 83 | label_x: str = "Median Cosine Similarity", 84 | label_y: str = "Median Top-1 Accuracy", 85 | fig_save_path: str | Path | None = None, 86 | xlim: tuple[float, float] | None = None, 87 | ylim: tuple[float, float] | None = None, 88 | palette: Any | None = None, 89 | ): 90 | """Display PLISM robustness metrics. 91 | 92 | Parameters 93 | ---------- 94 | raw_results: pd.DataFrame 95 | Raw results as computed by ``plismbench.utils.metrics.format_results``. 96 | metric_x: str = "cosine_similarity_median" 97 | Metric to display for x-axis. Should be of type 'metric_aggregation'. 98 | Supported metrics depends on the columns of ``raw_results`` but are 99 | by "cosine_similarity", "top_1_accuracy", "top_3_accuracy", 100 | "top_5_accuracy" and "top_10_accuracy". Supported aggregation types 101 | are either "mean" or "median". 102 | metric_y: str = "top_1_accuracy_median" 103 | Metric to display for y-axis. 104 | robustness_x: str = "all" 105 | Type of robustness for ``metric_x``. 106 | Supported types are "all", "inter-scanner", "inter-staining", 107 | "inter-scanner" and "inter-staining". 108 | robustness_y: str = "all" 109 | Type of robustness for ``metric_y``. 110 | label_x: str = "Median Cosine Similarity" 111 | Label for x-axis (can be anything). 112 | label_y: str = "Median Top-1 Accuracy" 113 | Label for y-axis (can be anything). 114 | xlim: tuple[float, float] | None = None 115 | Limits for x-axis. 116 | ylim: tuple[float, float] | None = None 117 | Limits for y-axis. 118 | palette = None 119 | Color palette. 120 | fig_save_path: str | Path | None = None 121 | Figure save path. 122 | """ 123 | # Set figure 124 | fig, ax = plt.subplots(1, 1, figsize=(10, 6), dpi=200) 125 | 126 | # Default limits for axes 127 | xlim = (0, 1) if xlim is None else xlim 128 | ylim = (0, 1) if xlim is None else ylim 129 | 130 | # Define x and y column in ``results_df`` 131 | col_x = f"{metric_x}__{robustness_x}" 132 | col_y = f"{metric_y}__{robustness_y}" 133 | results_df = expand_columns(raw_results) 134 | results_df = results_df[["extractor", "Parameters", col_x, col_y]] 135 | results_df[[col_x, col_y]] = results_df[[col_x, col_y]].astype(float) 136 | 137 | # Default color palette 138 | if palette is None: 139 | palette = sns.color_palette("tab20")[: results_df.shape[0]] 140 | 141 | # Display metrics for each extractor 142 | sns.scatterplot( 143 | data=results_df, 144 | x=col_x, 145 | y=col_y, 146 | hue="extractor", 147 | size="Parameters", 148 | sizes=(50, 2000), 149 | palette=palette, 150 | edgecolor="black", 151 | alpha=0.9, 152 | legend=True, 153 | ax=ax, 154 | ) 155 | 156 | # Display number of parameters for each extractor 157 | sns.scatterplot( 158 | data=results_df, 159 | x=col_x, 160 | y=col_y, 161 | s=5, 162 | color="black", 163 | marker="+", 164 | facecolor="black", 165 | alpha=0.7, 166 | legend=True, 167 | ax=ax, 168 | ) 169 | 170 | # Set labels and limits 171 | ax.set_xlabel(label_x) 172 | ax.set_ylabel(label_y) 173 | ax.set_xlim(xlim) 174 | ax.set_ylim(ylim) 175 | 176 | # Please leave these values as-is 177 | for _, row in results_df.iterrows(): 178 | plt.text( 179 | row[col_x] + 0.001, 180 | row[col_y] + 0.001, 181 | row["extractor"], 182 | fontsize=10, 183 | ha="left", 184 | va="bottom", 185 | fontweight="bold", 186 | ) 187 | 188 | # Please leave these values as-is: adding circles in the legend 189 | # proportionnal to the model size. 190 | scatter_handles = [ 191 | plt.scatter( 192 | [], [], s=50, edgecolor="black", color="gray", alpha=0.5, label=" 22M" 193 | ), 194 | plt.scatter( 195 | [], [], s=155, edgecolor="black", color="gray", alpha=0.5, label=" 86M" 196 | ), 197 | plt.scatter( 198 | [], [], s=555, edgecolor="black", color="gray", alpha=0.5, label=" 307M" 199 | ), 200 | plt.scatter( 201 | [], 202 | [], 203 | s=1143, 204 | edgecolor="black", 205 | color="gray", 206 | alpha=0.5, 207 | label=" 632M", 208 | ), 209 | plt.scatter( 210 | [], 211 | [], 212 | s=2000, 213 | edgecolor="black", 214 | color="gray", 215 | alpha=0.5, 216 | label=" 1,100M", 217 | ), 218 | ] 219 | ax.legend( 220 | handles=scatter_handles, 221 | title="No. parameters", 222 | loc="center left", 223 | bbox_to_anchor=(1, 0.6), 224 | fontsize=12, 225 | labelspacing=0.2, 226 | handleheight=2, 227 | ) 228 | # Export figure 229 | if fig_save_path is not None: 230 | fig.savefig(fig_save_path, dpi=300, bbox_inches="tight") 231 | plt.show() 232 | -------------------------------------------------------------------------------- /docs/conf.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # 3 | # This file is execfile()d with the current directory set to its 4 | # containing dir. 5 | # 6 | # Note that not all possible configuration values are present in this 7 | # autogenerated file. 8 | # 9 | # All configuration values have a default; values that are commented out 10 | # serve to show the default. 11 | 12 | # Workaround https://github.com/mgaitan/sphinxcontrib-mermaid/issues/72 13 | import errno 14 | 15 | # If extensions (or modules to document with autodoc) are in another 16 | # directory, add these directories to sys.path here. If the directory is 17 | # relative to the documentation root, use os.path.abspath to make it 18 | # absolute, like shown here. 19 | # 20 | import os 21 | import sys 22 | import datetime 23 | import pkg_resources 24 | 25 | import sphinx.util.osutil 26 | 27 | sphinx.util.osutil.ENOENT = errno.ENOENT 28 | 29 | sys.path.insert(0, os.path.abspath('..')) 30 | 31 | # -- General configuration --------------------------------------------- 32 | 33 | # If your documentation needs a minimal Sphinx version, state it here. 34 | # 35 | # needs_sphinx = '1.0' 36 | 37 | # Add any Sphinx extension module names here, as strings. They can be 38 | # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom ones. 39 | extensions = [ 40 | 'sphinx.ext.autodoc', 41 | 'sphinx.ext.viewcode', 42 | 'sphinx.ext.doctest', 43 | 'sphinx.ext.napoleon', 44 | 'sphinx.ext.intersphinx', 45 | 'sphinxcontrib.mermaid', 46 | ] 47 | 48 | # Add any paths that contain templates here, relative to this directory. 49 | templates_path = ['_templates'] 50 | 51 | # The suffix(es) of source filenames. 52 | # You can specify multiple suffix as a list of string: 53 | # 54 | # source_suffix = ['.rst', '.md'] 55 | source_suffix = '.rst' 56 | 57 | # The master toctree document. 58 | master_doc = 'index' 59 | 60 | # package distribution 61 | package_distribution = pkg_resources.get_distribution('owkin-plismbench') 62 | 63 | # General information about the project. 64 | project = 'PLISM robustness benchmark' 65 | copyright = f'{datetime.date.today().year}, Owkin Inc.' 66 | author = 'Owkin Inc.' 67 | 68 | # The version info for the project you're documenting, acts as replacement 69 | # for |version| and |release|, also used in various other places throughout 70 | # the built documents. 71 | # 72 | # The short X.Y version. 73 | version = package_distribution.version 74 | # The full version, including alpha/beta/rc tags. 75 | release = package_distribution.version 76 | 77 | # The language for content autogenerated by Sphinx. Refer to documentation 78 | # for a list of supported languages. 79 | # 80 | # This is also used if you do content translation via gettext catalogs. 81 | # Usually you set "language" from the command line for these cases. 82 | language = 'en' 83 | 84 | # List of patterns, relative to source directory, that match files and 85 | # directories to ignore when looking for source files. 86 | # This patterns also effect to html_static_path and html_extra_path 87 | exclude_patterns = ['_build', 'Thumbs.db', '.DS_Store'] 88 | 89 | # The name of the Pygments (syntax highlighting) style to use. 90 | pygments_style = 'sphinx' 91 | 92 | # If true, `todo` and `todoList` produce output, else they produce nothing. 93 | todo_include_todos = False 94 | 95 | # Autodoc documentation 96 | autodoc_default_options = {'member-order': 'bysource', 'undoc-members': True} 97 | 98 | # -- Options for HTML output ------------------------------------------- 99 | 100 | # The theme to use for HTML and HTML Help pages. See the documentation for 101 | # a list of builtin themes. 102 | # 103 | html_theme = 'pydata_sphinx_theme' 104 | 105 | # Theme options are theme-specific and customize the look and feel of a 106 | # theme further. For a list of options available for each theme, see the 107 | # documentation. 108 | # 109 | # html_theme_options = {} 110 | 111 | # Add any paths that contain custom static files (such as style sheets) here, 112 | # relative to this directory. They are copied after the builtin static files, 113 | # so a file named "default.css" will overwrite the builtin "default.css". 114 | 115 | 116 | # -- Options for HTMLHelp output --------------------------------------- 117 | 118 | # Output file base name for HTML help builder. 119 | htmlhelp_basename = 'plismbenchdoc' 120 | 121 | # -- Options for LaTeX output ------------------------------------------ 122 | 123 | latex_elements: dict = { 124 | # The paper size ('letterpaper' or 'a4paper'). 125 | # 126 | # 'papersize': 'letterpaper', 127 | # The font size ('10pt', '11pt' or '12pt'). 128 | # 129 | # 'pointsize': '10pt', 130 | # Additional stuff for the LaTeX preamble. 131 | # 132 | # 'preamble': '', 133 | # Latex figure (float) alignment 134 | # 135 | # 'figure_align': 'htbp', 136 | } 137 | 138 | # Grouping the document tree into LaTeX files. List of tuples 139 | # (source start file, target name, title, author, documentclass 140 | # [howto, manual, or own class]). 141 | latex_documents = [ 142 | (master_doc, 'plismbench.tex', project, 'Owkin Inc.', 'manual'), 143 | ] 144 | 145 | # -- Options for manual page output ------------------------------------ 146 | 147 | # One entry per manual page. List of tuples 148 | # (source start file, name, description, authors, manual section). 149 | man_pages = [(master_doc, 'plismbench', project, [author], 1)] 150 | 151 | # -- Options for Texinfo output ---------------------------------------- 152 | 153 | # Grouping the document tree into Texinfo files. List of tuples 154 | # (source start file, target name, title, author, 155 | # dir menu entry, description, category) 156 | texinfo_documents = [ 157 | ( 158 | master_doc, 159 | 'plismbench', 160 | project, 161 | author, 162 | 'plismbench', 163 | 'Repository hosting PLIM robustness benchmark', 164 | 'Data Science', 165 | ), 166 | ] 167 | 168 | # -- Options for pydata-sphinx-theme ----------------------------------- 169 | 170 | html_static_path = ["_static"] 171 | html_theme_options = { 172 | "github_url": "https://github.com/owkin/plism-benchmark", 173 | "show_prev_next": False, 174 | "logo": { 175 | "image_light": "logo-light.png", 176 | "image_dark": "logo-dark.png", 177 | } 178 | } 179 | 180 | # -- Options for sphinx.ext.intersphinx -------------------------------- 181 | 182 | intersphinx_mapping = { 183 | 'python': ('https://docs.python.org/3', None), 184 | 'pandas': ('https://pandas.pydata.org/pandas-docs/dev', None), 185 | 'numpy': ('https://numpy.org/doc/stable/', None), 186 | 'scipy': ('https://docs.scipy.org/doc/scipy/', None), 187 | 'torch': ('https://pytorch.org/docs/master/', None), 188 | 'matplotlib': ('https://matplotlib.org/stable/', None), 189 | 'torchvision': ('https://pytorch.org/vision/stable/', None), 190 | 'loguru': ('https://loguru.readthedocs.io/en/stable/', None), 191 | 'openslide': ('https://openslide.org/api/python/', None), 192 | 'ml_collections': ('https://ml-collections.readthedocs.io/en/stable/', None), 193 | } 194 | pdf_documents = [ 195 | ('index'), 196 | ] 197 | 198 | # -- Options for checking cross-references ---------------------------- 199 | 200 | nitpicky = True 201 | 202 | # -- Options for redirecting references ------------------------------- 203 | 204 | autodoc_docstring_signature = True 205 | autodoc_type_aliases = { 206 | "np.ndarray": "numpy.ndarray", 207 | "pd.DataFrame": "pandas.DataFrame", 208 | "ConfigDict": "ml_collections.config_dict.ConfigDict", 209 | "Callable": "collections.abc.Callable", 210 | "Any": "typing.Any", 211 | } 212 | 213 | docstring_alias = { 214 | "np.ndarray": "numpy.ndarray", 215 | "pd.DataFrame": "pandas.DataFrame", 216 | "Path": "pathlib.Path", 217 | "ConfigDict": "ml_collections.config_dict.ConfigDict", 218 | "Callable": "collections.abc.Callable", 219 | "Any": "typing.Any", 220 | } 221 | 222 | 223 | def autodoc_process_docstring(app, what, name, obj, options, lines): 224 | for i in range(len(lines)): 225 | for old , new in docstring_alias.items(): 226 | lines[i] = lines[i].replace(old, new) 227 | 228 | 229 | def setup(app): 230 | app.connect("autodoc-process-docstring", autodoc_process_docstring) 231 | -------------------------------------------------------------------------------- /plismbench/engine/extract/extract_from_png.py: -------------------------------------------------------------------------------- 1 | """Stream PLISM tiles dataset and extract features on-the-fly for a given model.""" 2 | 3 | from __future__ import annotations 4 | 5 | from collections.abc import Callable 6 | from functools import partial 7 | from math import ceil 8 | from pathlib import Path 9 | 10 | import datasets 11 | import numpy as np 12 | import torch 13 | from loguru import logger 14 | from PIL import Image 15 | from torch.utils.data import DataLoader 16 | from tqdm import tqdm 17 | 18 | from plismbench.engine.extract.utils import ( 19 | NUM_SLIDES, 20 | NUM_TILES_PER_SLIDE, 21 | process_imgs, 22 | save_features, 23 | ) 24 | from plismbench.models import FeatureExtractorsEnum 25 | from plismbench.models.extractor import Extractor 26 | 27 | 28 | def collate( 29 | batch: list[dict[str, str | Image.Image]], 30 | transform: Callable[[np.ndarray], torch.Tensor], 31 | ) -> tuple[list[str], list[str], torch.Tensor]: 32 | """Return slide ids, tile ids and transformed images. 33 | 34 | Parameters 35 | ---------- 36 | batch: list[dict[str, str | Image.Image]], 37 | List of length ``batch_size`` made of dictionnaries. 38 | Each dictionnary is a single input with keys: 'slide_id', 39 | 'tile_id' and 'png'. The image is a ``PIL.Image.Image`` 40 | with type unit8 (0-255) 41 | transform: Callable[[np.ndarray], torch.Tensor] 42 | Transform function taking ``np.ndarray`` image as inputs. 43 | Prior to calling this transform function, conversion from a 44 | ``PIL.Image.Image`` to an array is performed. 45 | 46 | Returns 47 | ------- 48 | output: tuple[list[str], list[str], torch.Tensor] 49 | A tuple made of slides ids, tiles ids and transformed input images. 50 | """ 51 | slide_ids: list[str] = [b["slide_id"] for b in batch] # type: ignore 52 | tile_ids: list[str] = [b["tile_id"] for b in batch] # type: ignore 53 | imgs = torch.stack([transform(np.array(b["png"])) for b in batch]) 54 | output = (slide_ids, tile_ids, imgs) 55 | return output 56 | 57 | 58 | def resume_streaming( 59 | export_dir: Path, 60 | slide_features: list[np.ndarray], 61 | current_num_tiles: int, 62 | slide_features_export_path: Path, 63 | feature_extractor: Extractor, 64 | slide_ids: list[str], 65 | tile_ids: list[str], 66 | imgs: torch.Tensor, 67 | reference_slide_id: str, 68 | ) -> tuple[list[np.ndarray], bool, int]: 69 | """Resume streaming without re-extracting slides.""" 70 | set_continue = False 71 | # If the current slide has features already available 72 | if slide_features_export_path.exists(): 73 | set_continue = True 74 | # Check that the batch contains all tiles from the same slide 75 | next_slide_id = slide_ids[-1] 76 | # Otherwise enter a specific condition 77 | if next_slide_id != reference_slide_id: 78 | next_slide_features_export_dir = Path(export_dir / next_slide_id) 79 | next_slide_features_export_path = ( 80 | next_slide_features_export_dir / "features.npy" 81 | ) 82 | # If the next slide was also already extracted, skip 83 | if next_slide_features_export_path.exists(): 84 | logger.info( 85 | f"Features for slide {next_slide_id} already extracted, skipping..." 86 | ) 87 | # Otherwise, start the feature extraction by adding the tiles from 88 | # slide N+1 into `slide_features` 89 | else: 90 | # New slide without features detected is `next_slide_id`. 91 | # We retrieve the maximum index at which all tiles in the batch comes from slide N 92 | mask = np.array(slide_ids) != reference_slide_id 93 | idx = mask.argmax() 94 | # And only process the later, then export the slides features 95 | batch_stack = process_imgs( 96 | imgs[idx:], tile_ids[idx:], model=feature_extractor 97 | ) 98 | current_num_tiles += batch_stack.shape[0] 99 | slide_features.append(batch_stack) 100 | else: 101 | logger.info( 102 | f"Features for slide {reference_slide_id} already extracted, skipping..." 103 | ) 104 | return slide_features, set_continue, current_num_tiles 105 | 106 | 107 | def run_extract_streaming( 108 | feature_extractor_name: str, 109 | batch_size: int, 110 | device: int, 111 | export_dir: Path, 112 | overwrite: bool, 113 | ) -> None: 114 | """Run features extraction with streaming.""" 115 | logger.info(f"Export directory set to {str(export_dir)}.") 116 | if overwrite: 117 | logger.warning("You are about to overwrite existing features.") 118 | 119 | # Create export directory if it doesn't exist 120 | export_dir.mkdir(exist_ok=True, parents=True) 121 | 122 | # Initialize the feature extractor 123 | feature_extractor = FeatureExtractorsEnum[feature_extractor_name.upper()].init( 124 | device=device 125 | ) 126 | image_transform = feature_extractor.transform 127 | 128 | # Create the dataset and dataloader without actually loading the files to disk (`streaming=True`) 129 | # The dataset is sorted by slide_id, meaning that the first 16278 indexes belong to the same first slide, 130 | # then 16278:32556 to the second slide, etc. 131 | dataset = datasets.load_dataset( 132 | "owkin/plism-dataset-tiles", split="train", streaming=True 133 | ) 134 | collate_fn = partial(collate, transform=image_transform) 135 | dataloader = DataLoader( 136 | dataset, 137 | batch_size=batch_size, 138 | collate_fn=collate_fn, 139 | num_workers=0, 140 | pin_memory=True, 141 | shuffle=False, 142 | ) 143 | 144 | # Iterate over the full dataset and store features each time 16278 input images have been processed 145 | slide_features: list[np.ndarray] = [] 146 | current_num_tiles = 0 147 | 148 | for slide_ids, tile_ids, imgs in tqdm( 149 | dataloader, 150 | total=ceil(NUM_SLIDES * NUM_TILES_PER_SLIDE / batch_size), 151 | desc="Extracting features", 152 | ): 153 | reference_slide_id = slide_ids[0] 154 | 155 | # Get output path for features 156 | slide_features_export_dir = Path(export_dir / reference_slide_id) 157 | slide_features_export_path = slide_features_export_dir / "features.npy" 158 | slide_features_export_dir.mkdir(exist_ok=True, parents=True) 159 | if not overwrite: 160 | slide_features, continue_, current_num_tiles = resume_streaming( 161 | export_dir=export_dir, 162 | slide_features=slide_features, 163 | current_num_tiles=current_num_tiles, 164 | slide_features_export_path=slide_features_export_path, 165 | feature_extractor=feature_extractor, 166 | slide_ids=slide_ids, 167 | tile_ids=tile_ids, 168 | imgs=imgs, 169 | reference_slide_id=reference_slide_id, 170 | ) 171 | if continue_: 172 | continue 173 | 174 | # If we're on the same slide, we just add the batch features to the running list 175 | if all(slide_id == reference_slide_id for slide_id in slide_ids): 176 | batch_stack = process_imgs(imgs, tile_ids, model=feature_extractor) 177 | slide_features.append(batch_stack) 178 | # For the very last slide, the last batch may be of size < `batch_size` 179 | current_num_tiles += batch_stack.shape[0] 180 | # If the current batch contains exactly the last `batch_size` tile features for the slide, 181 | # export the slide features and reset `slide_features` and `current_num_tiles` 182 | if current_num_tiles == NUM_TILES_PER_SLIDE: 183 | save_features( 184 | slide_features, 185 | slide_id=reference_slide_id, 186 | export_path=slide_features_export_path, 187 | ) 188 | logger.success( 189 | f"Successfully saved features for slide: {reference_slide_id}" 190 | ) 191 | slide_features = [] 192 | current_num_tiles = 0 193 | # The current batch contains tiles from slide N (`reference_slide_id`) and slide N+1 194 | else: 195 | # We retrieve the maximum index at which all tiles in the batch comes from slide N 196 | mask = np.array(slide_ids) != reference_slide_id 197 | idx = mask.argmax() 198 | # And only process the later, then export the slides features 199 | batch_stack = process_imgs( 200 | imgs[:idx], tile_ids[:idx], model=feature_extractor 201 | ) 202 | current_num_tiles += batch_stack.shape[0] 203 | slide_features.append(batch_stack) 204 | save_features( 205 | slide_features, 206 | slide_id=reference_slide_id, 207 | export_path=slide_features_export_path, 208 | ) 209 | logger.success( 210 | f"Successfully saved features for slide: {reference_slide_id}" 211 | ) 212 | # We initialize `slide_features` and `current_num_tiles` with respectively 213 | # the tile features from slide N+1 214 | slide_features = [ 215 | process_imgs(imgs[idx:], tile_ids[idx:], model=feature_extractor) 216 | ] 217 | current_num_tiles = batch_size - idx 218 | -------------------------------------------------------------------------------- /plismbench/engine/evaluate.py: -------------------------------------------------------------------------------- 1 | """Compute robustness metrics: cosine similarity and top-k accuracies.""" 2 | 3 | import sys 4 | from functools import partial 5 | from pathlib import Path 6 | 7 | import cupy as cp 8 | import numpy as np 9 | import pandas as pd 10 | from loguru import logger 11 | from p_tqdm import p_map 12 | from rich import print as rprint 13 | from tqdm import tqdm 14 | 15 | from plismbench.metrics import CosineSimilarity, TopkAccuracy 16 | from plismbench.utils.aggregate import get_results 17 | from plismbench.utils.core import load_pickle, write_pickle 18 | from plismbench.utils.evaluate import ( 19 | get_tiles_subset_idx, 20 | load_features, 21 | prepare_pairs_dataframe, 22 | ) 23 | 24 | 25 | # Leave those two variables as-is 26 | STAININGS: list[str] = [ 27 | "GIV", 28 | "GIVH", 29 | "GM", 30 | "GMH", 31 | "GV", 32 | "GVH", 33 | "HR", 34 | "HRH", 35 | "KR", 36 | "KRH", 37 | "LM", 38 | "LMH", 39 | "MY", 40 | ] 41 | 42 | SCANNERS: list[str] = ["AT2", "GT450", "P", "S210", "S360", "S60", "SQ"] 43 | NUM_SLIDES: int = 91 44 | NUM_TILES_PER_SLIDE: int = 16_278 45 | DEFAULT_NUM_TILES_PER_SLIDE_METRICS: int = NUM_TILES_PER_SLIDE // 2 # 8_139 46 | SUPPORTED_NUM_TILES = [None, 460, 2_713, 5_426, 8_139, 16_278] 47 | 48 | 49 | def compute_metrics_ab( 50 | fp_a: Path, 51 | fp_b: Path, 52 | tiles_subset_idx: np.ndarray, 53 | top_k: list[int], 54 | device: str, 55 | pickles_save_dir: Path, 56 | overwrite: bool, 57 | ) -> list[float]: 58 | """Compute metrics between float16 features from slide a and slide b.""" 59 | # Check if a pickle has already been dumped to disk to avoid computing 60 | # the metrics twice for a given slides pair. 61 | pickle_key = "---".join([fp_a.parent.name, fp_b.parent.name]) 62 | if (pickle_path := pickles_save_dir / f"{pickle_key}.pkl").exists(): 63 | if overwrite: 64 | pass 65 | else: 66 | try: 67 | return load_pickle(pickle_path) 68 | except Exception as exc: # type: ignore 69 | logger.info(f"{str(pickle_path)} seems to be corrupted:\n{exc}.") 70 | 71 | matrix_a, matrix_b = ( 72 | load_features(fp_a), 73 | load_features(fp_b), 74 | ) 75 | # Coordinates should be equal for tiles location matching 76 | np.testing.assert_allclose(matrix_a[:, :3], matrix_b[:, :3]) 77 | # Concanenate features from slide a and b to compute 78 | # top-k accuracies. Note: top-k accuracy is computed 79 | # over a subset of tiles. 80 | # Warning: convert matrix to float16 ! 81 | features_a, features_b = ( 82 | matrix_a[tiles_subset_idx, 3:], 83 | matrix_b[tiles_subset_idx, 3:], 84 | ) 85 | 86 | if device == "gpu": 87 | mempool = cp.get_default_memory_pool() 88 | pinned_mempool = cp.get_default_pinned_memory_pool() 89 | 90 | # Compute cosine similarity 91 | cosine_metric = CosineSimilarity(device=device, use_mixed_precision=True) 92 | cosine_similarity = cosine_metric.compute_metric(features_a, features_b) 93 | 94 | # Compute top-k accuracies 95 | topk_metric = TopkAccuracy(device=device, k=top_k, use_mixed_precision=True) 96 | top_k_accuracies = topk_metric.compute_metric( 97 | matrix_a=features_a, 98 | matrix_b=features_b, 99 | ) 100 | 101 | if device == "gpu": 102 | mempool.free_all_blocks() 103 | pinned_mempool.free_all_blocks() 104 | 105 | metrics_ab = [cosine_similarity, *list(top_k_accuracies)] 106 | write_pickle(metrics_ab, pickle_path) 107 | return metrics_ab 108 | 109 | 110 | def compute_metrics( 111 | features_root_dir: Path, 112 | metrics_save_dir: Path, 113 | extractor: str, 114 | top_k: list[int] | None = None, 115 | n_tiles: int | None = None, 116 | device: str = "gpu", 117 | workers: int = 4, 118 | overwrite: bool = False, 119 | ): 120 | """Compute robustness metrics and save it to disk. 121 | 122 | Parameters 123 | ---------- 124 | features_root_dir: Path 125 | The root folder where features will be stored. 126 | The final export directory is ``features_root_dir / extractor`` 127 | metrics_save_dir: Path 128 | Folder containing the output metrics. 129 | The final export directory is ``metrics_save_dir / extractor``. 130 | extractor: str 131 | The name of the feature extractor as defined in ``plismbench.models.__init__.py`` 132 | top_k: list[int] | None = None 133 | Values of k for top-k accuracy computation. 134 | n_tiles: int | None = None 135 | Number of tiles per slide for metrics computation. 136 | device: str = "gpu" 137 | Device on which matrix operations will be performed. 138 | workers: int = 4 139 | Number of workers for cpu parallel computations if ``device='cpu'``. 140 | overwrite: bool = False 141 | Whether to overwrite existing metrics. 142 | """ 143 | # Supported number of tiles correspond to 144 | # None: DEFAULT_NUM_TILES_PER_SLIDE_METRICS = 8_139 145 | # 460: corresponds to 10 tiles per TMA - meant for debugging purposes 146 | # 2_713: NUM_TILES_PER_SLIDE / 6 147 | # 5_426: NUM_TILES_PER_SLIDE / 3 148 | # 8_139: NUM_TILES_PER_SLIDE / 2 149 | # 16_278: NUM_TILES_PER_SLIDE 150 | 151 | if n_tiles not in SUPPORTED_NUM_TILES: 152 | raise ValueError( 153 | f"n_tiles should take values in {SUPPORTED_NUM_TILES}. Got {n_tiles}." 154 | ) 155 | n_tiles = DEFAULT_NUM_TILES_PER_SLIDE_METRICS if n_tiles is None else n_tiles 156 | top_k = [1, 3, 5, 10] if top_k is None else top_k 157 | 158 | metrics_save_dir = metrics_save_dir / f"{n_tiles}_tiles" / extractor 159 | pickles_save_dir = metrics_save_dir / "pickles" 160 | metrics_export_path: Path = metrics_save_dir / "metrics.csv" 161 | if metrics_export_path.exists(): 162 | if overwrite: 163 | logger.info("Metrics already exist. Overwriting...") 164 | else: 165 | logger.info("Metrics already exist. Skipping...") 166 | sys.exit() 167 | pickles_save_dir.mkdir(exist_ok=True, parents=True) 168 | logger.info(f"Metrics will be saved at {str(metrics_export_path)}.") 169 | logger.info(f"Slide pairs pickles will be saved at {str(pickles_save_dir)}.") 170 | 171 | slide_pairs = prepare_pairs_dataframe( 172 | features_dir=features_root_dir, extractor=extractor 173 | ) 174 | n_pairs = slide_pairs.shape[0] 175 | features_paths_pairs = slide_pairs[["features_path_a", "features_path_b"]].values 176 | tiles_subset_idx = get_tiles_subset_idx(n_tiles=n_tiles) 177 | logger.warning( 178 | f"Will compute metrics on {n_tiles} tiles per slide. " 179 | f"Top-k accuracies will be computed for k in {top_k}." 180 | ) 181 | 182 | if device not in (supported_device := ["cpu", "gpu"]): 183 | raise ValueError( 184 | f"Device {device} not supported. Please choose among {supported_device}." 185 | ) 186 | logger.warning(f"Metrics will be computed on {device}.") 187 | 188 | if device == "gpu": 189 | logger.info("Running on gpu: sequential computation over pairs.") 190 | pairs_metrics = [] 191 | for fp_a, fp_b in tqdm(features_paths_pairs, total=n_pairs): 192 | metrics_ab = compute_metrics_ab( 193 | fp_a=fp_a, 194 | fp_b=fp_b, 195 | tiles_subset_idx=tiles_subset_idx, 196 | top_k=top_k, 197 | device=device, 198 | pickles_save_dir=pickles_save_dir, 199 | overwrite=overwrite, 200 | ) 201 | pairs_metrics.append((fp_a, fp_b, *metrics_ab)) 202 | else: 203 | logger.info("Running on cpu: parallel computation over pairs.") 204 | logger.warning( 205 | f"Number of workers: {workers}. Try reducing it if you have RAM issues." 206 | ) 207 | _compute_metrics_ab = partial( 208 | compute_metrics_ab, 209 | tiles_subset_idx=tiles_subset_idx, 210 | top_k=top_k, 211 | device=device, 212 | pickles_save_dir=pickles_save_dir, 213 | overwrite=overwrite, 214 | ) 215 | metrics = p_map( 216 | _compute_metrics_ab, 217 | features_paths_pairs[:, 0], 218 | features_paths_pairs[:, 1], 219 | num_cpus=workers, 220 | ) 221 | pairs_metrics = [ 222 | (fp_a, fp_b, *m) for ((fp_a, fp_b), m) in zip(features_paths_pairs, metrics) 223 | ] 224 | 225 | metrics = pd.DataFrame( 226 | pairs_metrics, 227 | columns=[ 228 | "features_path_a", 229 | "features_path_b", 230 | "cosine_similarity", 231 | ] 232 | + [f"top_{k}_accuracy" for k in top_k], 233 | ) 234 | output = slide_pairs.merge( 235 | metrics, 236 | how="inner", 237 | on=["features_path_a", "features_path_b"], 238 | ) 239 | assert (n_rows := output.shape[0]) == n_pairs, ( 240 | f"Output dataframe with metrics have n_rows: {n_rows} < {n_pairs}." 241 | ) 242 | # Export metrics for all pairs 243 | output.to_csv(metrics_export_path, index=None) # type: ignore 244 | robustness_results = get_results(metrics=output, top_k=top_k) 245 | # Get and export aggregated results 246 | results_export_path = metrics_save_dir / "results.csv" 247 | robustness_results.to_csv(results_export_path, index=True) # type: ignore 248 | # Only display median (IQR) 249 | logger.info("Robustness results [median (iqr)]:") 250 | rprint(robustness_results.map(lambda x: x.split(" ; ")[1])) 251 | logger.success("Successfully computed and stored metrics.") 252 | -------------------------------------------------------------------------------- /plismbench/models/mahmood_lab.py: -------------------------------------------------------------------------------- 1 | """Models from Mahmood Lab.""" 2 | 3 | from __future__ import annotations 4 | 5 | from pathlib import Path 6 | from typing import Any 7 | 8 | import numpy as np 9 | import timm 10 | import torch 11 | from conch.open_clip_custom import create_model_from_pretrained 12 | from huggingface_hub import snapshot_download 13 | from loguru import logger 14 | from torchvision import transforms 15 | from transformers import AutoModel 16 | 17 | from plismbench.models.extractor import Extractor 18 | from plismbench.models.utils import DEFAULT_DEVICE, prepare_device, prepare_module 19 | 20 | 21 | class UNI(Extractor): 22 | """UNI model developped by Mahmood Lab available on Hugging-Face (1). 23 | 24 | .. note:: 25 | (1) https://huggingface.co/MahmoodLab/UNI 26 | 27 | Parameters 28 | ---------- 29 | device: int | list[int] | None = DEFAULT_DEVICE, 30 | Compute resources to use. 31 | If None, will use all available GPUs. 32 | If -1, extraction will run on CPU. 33 | mixed_precision: bool = True 34 | Whether to use mixed_precision. 35 | 36 | """ 37 | 38 | def __init__( 39 | self, 40 | device: int | list[int] | None = DEFAULT_DEVICE, 41 | mixed_precision: bool = False, 42 | ): 43 | super().__init__() 44 | self.output_dim = 1024 45 | self.mixed_precision = mixed_precision 46 | 47 | timm_kwargs: dict[str, Any] = { 48 | "init_values": 1e-5, 49 | "dynamic_img_size": True, 50 | } 51 | feature_extractor = timm.create_model( 52 | "hf-hub:MahmoodLab/uni", pretrained=True, **timm_kwargs 53 | ) 54 | 55 | self.feature_extractor, self.device = prepare_module( 56 | feature_extractor, 57 | device, 58 | self.mixed_precision, 59 | ) 60 | if self.device is None: 61 | self.feature_extractor = self.feature_extractor.module 62 | 63 | @property # type: ignore 64 | def transform(self) -> transforms.Compose: 65 | """Transform method to apply element wise.""" 66 | return transforms.Compose( 67 | [ 68 | transforms.ToTensor(), # swap axes and normalize 69 | transforms.Normalize( 70 | mean=(0.485, 0.456, 0.406), 71 | std=(0.229, 0.224, 0.225), 72 | ), 73 | ] 74 | ) 75 | 76 | def __call__(self, images: torch.Tensor) -> np.ndarray: 77 | """Compute and return features. 78 | 79 | Parameters 80 | ---------- 81 | images: torch.Tensor 82 | Input of size (n_tiles, n_channels, dim_x, dim_y). 83 | 84 | Returns 85 | ------- 86 | torch.Tensor: Tensor of size (n_tiles, features_dim). 87 | """ 88 | features = self.feature_extractor(images.to(self.device)) 89 | return features.cpu().numpy() 90 | 91 | 92 | class UNI2h(Extractor): 93 | """UNI2-h model developped by Mahmood Lab available on Hugging-Face (1). 94 | 95 | .. note:: 96 | (1) https://huggingface.co/MahmoodLab/UNI2-h 97 | 98 | Parameters 99 | ---------- 100 | device: int | list[int] | None = DEFAULT_DEVICE, 101 | Compute resources to use. 102 | If None, will use all available GPUs. 103 | If -1, extraction will run on CPU. 104 | mixed_precision: bool = True 105 | Whether to use mixed_precision. 106 | 107 | """ 108 | 109 | def __init__( 110 | self, 111 | device: int | list[int] | None = DEFAULT_DEVICE, 112 | mixed_precision: bool = False, 113 | ): 114 | super().__init__() 115 | self.output_dim = 1536 116 | self.mixed_precision = mixed_precision 117 | 118 | timm_kwargs: dict[str, Any] = { 119 | "img_size": 224, 120 | "patch_size": 14, 121 | "depth": 24, 122 | "num_heads": 24, 123 | "init_values": 1e-5, 124 | "embed_dim": 1536, 125 | "mlp_ratio": 2.66667 * 2, 126 | "num_classes": 0, 127 | "no_embed_class": True, 128 | "mlp_layer": timm.layers.SwiGLUPacked, 129 | "act_layer": torch.nn.SiLU, 130 | "reg_tokens": 8, 131 | "dynamic_img_size": True, 132 | } 133 | feature_extractor = timm.create_model( 134 | "hf-hub:MahmoodLab/UNI2-h", pretrained=True, **timm_kwargs 135 | ) 136 | 137 | self.feature_extractor, self.device = prepare_module( 138 | feature_extractor, 139 | device, 140 | self.mixed_precision, 141 | ) 142 | if self.device is None: 143 | self.feature_extractor = self.feature_extractor.module 144 | 145 | @property # type: ignore 146 | def transform(self) -> transforms.Compose: 147 | """Transform method to apply element wise.""" 148 | return transforms.Compose( 149 | [ 150 | transforms.ToTensor(), # swap axes and normalize 151 | transforms.Normalize( 152 | mean=(0.485, 0.456, 0.406), 153 | std=(0.229, 0.224, 0.225), 154 | ), 155 | ] 156 | ) 157 | 158 | def __call__(self, images: torch.Tensor) -> np.ndarray: 159 | """Compute and return features. 160 | 161 | Parameters 162 | ---------- 163 | images: torch.Tensor 164 | Input of size (n_tiles, n_channels, dim_x, dim_y). 165 | 166 | Returns 167 | ------- 168 | torch.Tensor: Tensor of size (n_tiles, features_dim). 169 | """ 170 | features = self.feature_extractor(images.to(self.device)) 171 | return features.cpu().numpy() 172 | 173 | 174 | class CONCH(Extractor): 175 | """CONCH model developped by Mahmood Lab available on Hugging-Face (1). 176 | 177 | .. note:: 178 | (1) https://huggingface.co/MahmoodLab/CONCH 179 | 180 | Parameters 181 | ---------- 182 | device: int | list[int] | None = DEFAULT_DEVICE, 183 | Compute resources to use. 184 | If None, will use all available GPUs. 185 | If -1, extraction will run on CPU. 186 | mixed_precision: bool = True 187 | Whether to use mixed_precision. 188 | 189 | """ 190 | 191 | def __init__( 192 | self, 193 | device: int | list[int] | None = DEFAULT_DEVICE, 194 | mixed_precision: bool = False, 195 | ): 196 | super().__init__() 197 | self.output_dim = 512 198 | self.mixed_precision = mixed_precision 199 | 200 | checkpoint_dir = snapshot_download(repo_id="MahmoodLab/CONCH") 201 | checkpoint_path = Path(checkpoint_dir) / "pytorch_model.bin" 202 | 203 | feature_extractor, self.processor = create_model_from_pretrained( 204 | "conch_ViT-B-16", 205 | force_image_size=224, 206 | checkpoint_path=str(checkpoint_path), 207 | device=prepare_device(device), 208 | ) 209 | 210 | self.feature_extractor, self.device = prepare_module( 211 | feature_extractor, 212 | device, 213 | self.mixed_precision, 214 | ) 215 | if self.device is None: 216 | self.feature_extractor = self.feature_extractor.module 217 | 218 | def process(self, image) -> torch.Tensor: 219 | """Process input images.""" 220 | conch_input = self.processor(image) 221 | return conch_input 222 | 223 | @property # type: ignore 224 | def transform(self) -> transforms.Lambda: 225 | """Transform method to apply element wise.""" 226 | return transforms.Lambda(self.process) 227 | 228 | def __call__(self, images: torch.Tensor) -> np.ndarray: 229 | """Compute and return features. 230 | 231 | Parameters 232 | ---------- 233 | images: torch.Tensor 234 | Input of size (n_tiles, n_channels, dim_x, dim_y). 235 | 236 | Returns 237 | ------- 238 | torch.Tensor: Tensor of size (n_tiles, features_dim). 239 | """ 240 | features = self.feature_extractor.module.encode_image( # type: ignore 241 | images.to(self.device), proj_contrast=False, normalize=False 242 | ) 243 | return features.cpu().numpy() 244 | 245 | 246 | class CONCHv15(Extractor): 247 | """Conchv15 model available from TITAN on Hugging-Face (1). 248 | 249 | .. note:: 250 | (1) https://huggingface.co/MahmoodLab/conchv1_5 251 | """ 252 | 253 | def __init__( 254 | self, 255 | device: int | list[int] | None = DEFAULT_DEVICE, 256 | mixed_precision: bool = False, 257 | ): 258 | super().__init__() 259 | self.output_dim = 768 260 | self.mixed_precision = mixed_precision 261 | 262 | titan = AutoModel.from_pretrained("MahmoodLab/TITAN", trust_remote_code=True) 263 | feature_extractor, _ = titan.return_conch() 264 | 265 | self.feature_extractor, self.device = prepare_module( 266 | feature_extractor, 267 | device, 268 | self.mixed_precision, 269 | ) 270 | if self.device is None: 271 | self.feature_extractor = self.feature_extractor.module 272 | 273 | logger.info("This model is best performing on 448x448 images.") 274 | 275 | @property # type: ignore 276 | def transform(self) -> transforms.Lambda: 277 | """Transform method to apply element wise.""" 278 | return transforms.Compose( 279 | [ 280 | transforms.ToTensor(), 281 | transforms.Normalize( 282 | mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225) 283 | ), 284 | ] 285 | ) 286 | 287 | def __call__(self, images: torch.Tensor) -> np.ndarray: 288 | """Compute and return features. 289 | 290 | Args: 291 | images (torch.Tensor): Input of size (n_tiles, n_channels, dim_x, dim_y). 292 | 293 | Returns 294 | ------- 295 | torch.Tensor: Tensor of size (n_tiles, features_dim). 296 | """ 297 | features = self.feature_extractor(images.to(self.device)) 298 | return features.cpu().numpy() 299 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |
2 | 3 | ![header](https://capsule-render.vercel.app/api?type=waving&height=140&color=0:56b4e9,50:009e73,100:cc79a7&text=Plismbench:§ion=header&fontAlign=16&fontSize=45&textBg=false&descAlignY=45&fontAlignY=20&descSize=20&desc=A%20%20robustness%20%20benchmark%20%20of%20%20pathology%20%20foundation%20%20models&descAlign=52) 4 | 5 | 6 | 7 | [![Python dev](https://github.com/owkin/plism-benchmark/actions/workflows/python-app.yml/badge.svg)](https://github.com/owkin/plism-benchmark/actions/workflows/python-app.yml) [![Deploy doc](https://github.com/owkin/plism-benchmark/actions/workflows/page.yml/badge.svg)](https://github.com/owkin/plism-benchmark/actions/workflows/page.yml) [![Arxiv](https://img.shields.io/badge/Arxiv-2407.18449-red?style=flat-square)](https://arxiv.org/abs/2501.16239) 8 | [![Hugging face](https://img.shields.io/badge/%F0%9F%A4%97%20%20-PLISM-yellow)](https://huggingface.co/datasets/owkin/plism-dataset) 9 |
10 | 11 | 12 | ## Documentation 13 | 14 | The documentation can be found [here](https://owkin.github.io/plism-benchmark). 15 | Please refer to the Installation section to install this repository. 16 | 17 | ## Benchmark 18 | 19 | On 03/03/2025. 20 | 21 | | Extractor | Cosine similarity (all) | Top-10 accuracy (cross-scanner) | Top-10 accuracy (cross-staining) | Top-10 accuracy (cross-scanner, cross-staining) | Leaderboard metric | Rank | 22 | |:---------------|--------------------------:|----------------------------------:|-----------------------------------:|--------------------------------------------------:|---------------------:|:-------| 23 | | H0-Mini | 0.800 | 0.864 | 0.318 | 0.183 | 0.541 | #1 | 24 | | CONCH | 0.846 | 0.752 | 0.241 | 0.155 | 0.498 | #2 | 25 | | H-Optimus-0 | 0.685 | 0.744 | 0.327 | 0.166 | 0.480 | #3 | 26 | | Virchow2 | 0.777 | 0.609 | 0.306 | 0.163 | 0.464 | #4 | 27 | | Midnight-12k | 0.748 | 0.435 | 0.200 | 0.108 | 0.373 | #5 | 28 | | Prov-GigaPath | 0.570 | 0.592 | 0.118 | 0.054 | 0.333 | #6 | 29 | | UNI2-h | 0.591 | 0.501 | 0.190 | 0.046 | 0.332 | #7 | 30 | | UNI | 0.547 | 0.532 | 0.169 | 0.053 | 0.325 | #8 | 31 | | Kaiko ViT-B/8 | 0.764 | 0.346 | 0.147 | 0.045 | 0.325 | #9 | 32 | | GPFM | 0.594 | 0.356 | 0.092 | 0.017 | 0.265 | #10 | 33 | | PLIP | 0.878 | 0.054 | 0.040 | 0.004 | 0.244 | #11 | 34 | | Phikon | 0.622 | 0.125 | 0.021 | 0.004 | 0.193 | #12 | 35 | | Kaiko ViT-L/14 | 0.569 | 0.115 | 0.041 | 0.006 | 0.183 | #13 | 36 | | Phikon v2 | 0.557 | 0.064 | 0.030 | 0.003 | 0.164 | #14 | 37 | | Hibou Large | 0.490 | 0.061 | 0.030 | 0.008 | 0.147 | #15 | 38 | 39 | Our robustness benchmark is based on two different metrics: top-10 accuracy and cosine similarity. These metrics are computed over 4,095 unique slide pairs. Through our evaluation pipeline, robustness metrics are computed for all pairs but also cross-scanner (fixed staining), cross-staining (fixed scanner) or cross-scanner and cross-staining. Details are available in the `results.csv` file generated as the end of the evaluation. 40 | 41 | We plan to udpate this benchmark regularly with the latest extractors. Feel free to submit a PR sharing your results with your own feature extractor (see contribution guidelines). 42 | 43 | > [!IMPORTANT] 44 | > The leaderboard metric is the average of 4 metrics: median cosine similary for all pairs, median cross-scanner top-10 accuracy, median cross-staining top-10 accuracy, median cross-scanner and cross-staining top-10 accuracy. Median is computed over each corresponding slide pairs (e.g. for cross-scanner, slide pairs with different scanner each but common staining). 45 | 46 | 47 | 48 | ### Run PLISM benchmark with your model 49 | 50 | The following commands can be run through the cli command `plismbench`. 51 | You can find a detailed description of each subcommand by typing: 52 | 53 | ```bash 54 | plismbench --help 55 | ``` 56 | 57 | ### Hardware requirements 58 | 59 | This benchmark can be executed on cpu or gpu. We strongly advise to run it on gpu to benefit from `cupy` acceleration on graphical cards. From downloading to computing the results, running the benchmark takes approximately on our workstation (**32 CPUs, 1 Nvidia T4 (16Go) and 120Gb RAM**): 60 | 61 | _With storage do disk_ 62 | 63 | - 2h45 for a ViT-B: 15 minutes for download, 1h30 for features extraction, 1h for robustness metrics computation. 64 | - 4h45 for a ViT-g: 15 minutes for download, 3h for features extraction, 1h30 for robustness metrics computation. 65 | 66 | _Without storage do disk_ 67 | 68 | - 3h30 for a ViT-B: 2h30 for features extraction, 1h for robustness metrics computation. 69 | - 6h30 for a ViT-g: 5h for features extraction, 1h30 for robustness metrics computation. 70 | 71 | 72 | ### [Optional] Download 73 | 74 | **If you don't have 250Go available to store PLISM dataset to disk, we advise you to perform the features extraction by streaming images on the fly (see next section). In that case, you can skip this section.** 75 | 76 | First you will need to download [PLISM dataset](https://huggingface.co/datasets/owkin/plism-dataset) hosted on Hugging Face using the following command: 77 | 78 | ```bash 79 | plismbench download --download-dir /your/download/dir --token your_hf_token --workers 8 80 | ``` 81 | 82 | > [!NOTE] 83 | > 225 Go are required to store 91 WSI-level .h5 files, download approximately takes 10 minutes (32 workers) 84 | > 85 | 86 | ### Features extraction 87 | 88 | Please follow these next steps: 89 | 90 | 0. Let's set `org=your_company_or_group_name` 91 | 1. Implement your model in ``plismbench/models/org.py`` 92 | 2. Add it to the ``plismbench.models.__init__.py`` enum 93 | 3. Add related test in ``tests/models/test_org.py`` 94 | 4. Perform features extraction using the following script (example with `H0_mini`): 95 | 96 | ```bash 97 | plismbench extract \ 98 | --extractor h0_mini \ 99 | --batch-size 8 \ 100 | --export-dir /your/features/export/dir/ \ 101 | --download-dir /the/previous/download/dir/ \ 102 | --workers 8 103 | ``` 104 | 105 | The output features directory will automatically be set to `export_dir/extractor`. 106 | 107 | **Specify ``--streaming`` if you want to perform the download of images on the fly without storing to disk.** 108 | 109 | 110 | > [!NOTE] 111 | > 10 Gb storage and 1h30 are necessary to extract all features with a ViT-B model, 16 CPUs and 1 Nvidia T4 (16Go). 2h30 are necessary if streaming mode is enabled. 112 | > 113 | 114 | > [!IMPORTANT] 115 | > If your model aims to be integrated into `plismbench`, prior tests will be conducted on CI/CD which requires a login step to Hugging Face. This step will call `secrets.HF_TOKEN`, i.e. the HF token of the CODEOWNER of this repository. 116 | 117 | > ```yaml 118 | > - name: Log in to Hugging Face 119 | > run: python -c "from huggingface_hub import login; login(token='${{ secrets.HF_TOKEN }}', new_session=False)" 120 | >``` 121 | > Please make sure that 1) your model is public, 2) the CODEOWNER has access to it. For instance, if your model is publicly available on HF but under gated access, please check with the CODEOWNER to be granted access to it (you can ask it through your PR). **We only benchmark public models.** 122 | 123 | ### Compute metrics 124 | 125 | Simply run (example with `H0_mini`): 126 | 127 | ```bash 128 | plismbench evaluate \ 129 | --extractor h0_mini \ 130 | --features-dir /your/features/previous/export/dir/ \ 131 | --metrics-dir /your/metrics/export/dir/ 132 | ``` 133 | 134 | The input features directory will automatically be set to `export_dir/extractor`. 135 | 136 | > [!NOTE] 137 | > 1h is necessary to compute metrics for a ViT-B model, 16 CPUs and 1 Nvidia T4 (16Go). 138 | > 139 | 140 | 141 | Note that the `evaluate` pipeline runs regardless of the models registered in ``plismbench.models.__init__.py``. The only requirement is to store your model features inside `/your/features/previous/export/dir/` under the `extractor` folder (e.g. `your/features/previous/export/dir/h0_mini/`). 142 | 143 | The `evaluate` command can run on two different types of device: 144 | 145 | - `--device="cpu"` (uses `numpy`): in that case, please specify a number of `--workers`. Metrics computation will be parallelized over all possible slide pairs. **Depending on your RAM, setting a too high number of workers will cause memory errors**. Indeed, if `n_tiles=8139` and `workers=32` then 32 matrices of shape (8139, d) will be stored to RAM, then 32 matrices of shape (16278, 16278) to compute top-k accuracies as it requires to compute dot products between slide A and slide B. Please lower the number of workers if you encounter RAM issues. 146 | - `--device="gpu"` (uses `cupy`): in that case, no need to specify the number of workers. Matrix operations are done on the gpu directly in a sequential manner over all possible slide pairs. **Depending on your GPU RAM, you may encounter cuda memory errors**. We advise to switch to CPU in that case. As an example, we manage to run `evaluate` on GPU (1 T4 16 Go) with Virchow2 concatenated features (d=2563) and `n_tiles=8139` (1 hour). 147 | 148 | 149 | > [!IMPORTANT] 150 | > The `evaluate` command will compute metrics for each slide-pair (individual pickles and a final .csv with 1 row per pair) and metrics aggregated over pairs (.csv file). Metrics are cosine similarity and top-k accuracies (with k=[1, 3, 5, 10]) by default. We compute mean (std) and median (iqr) over all possible slides pairs, inter-scanners pairs, inter-stainings pairs and inter-scanners + inter-staining pairs. 151 | > 152 | > The number of tiles can be set to either 460 (debugging purposes), 2713 (1/6th of the total number of tiles per slide which is 16278), 5426 (1/3rd), 8139 (half) or 16278 (total number of tiles). **If `None`, the default number of tiles will be 8139 which is the reference for our benchmark.** 153 | 154 | ### Get your results 155 | 156 | By default, results are available at `/your/metrics/export/dir/8139_tiles/your_extractor/results.csv`. Here is an example with `H0_mini` on a subset of 460 tiles. For each type of robustness and metrics, we report `mean (std) ; median (iqr)`. 157 | 158 | | | cosine_similarity | top_1_accuracy | top_3_accuracy | top_5_accuracy | top_10_accuracy | 159 | |:------------------------------|:------------------------------|:------------------------------|:------------------------------|:------------------------------|:------------------------------| 160 | | inter-scanner | 0.914 (0.051) ; 0.923 (0.056) | 0.673 (0.255) ; 0.701 (0.433) | 0.781 (0.213) ; 0.835 (0.307) | 0.823 (0.193) ; 0.882 (0.251) | 0.875 (0.162) ; 0.931 (0.173) | 161 | | inter-staining | 0.769 (0.160) ; 0.830 (0.085) | 0.190 (0.167) ; 0.152 (0.213) | 0.309 (0.218) ; 0.292 (0.307) | 0.372 (0.240) ; 0.374 (0.336) | 0.467 (0.266) ; 0.501 (0.357) | 162 | | inter-scanner, inter-staining | 0.737 (0.156) ; 0.792 (0.106) | 0.104 (0.108) ; 0.072 (0.127) | 0.197 (0.163) ; 0.166 (0.227) | 0.253 (0.190) ; 0.231 (0.274) | 0.346 (0.226) ; 0.346 (0.335) | 163 | | all | 0.753 (0.158) ; 0.803 (0.106) | 0.153 (0.194) ; 0.089 (0.166) | 0.251 (0.229) ; 0.195 (0.273) | 0.307 (0.244) ; 0.266 (0.321) | 0.397 (0.265) ; 0.387 (0.373) | 164 | 165 | You can generate those results by executing: 166 | 167 | ```python 168 | from plismbench.utils.metrics import format_results 169 | results = format_results(metrics_root_dir="/path/to/metrics/root_dir/") 170 | ``` 171 | 172 | Please check `notebooks/visualization.ipynb` for details. 173 | 174 | ### Get leaderboard results 175 | 176 | 177 | You can generate the leaderboard results from a `metrics.csv` file by using 178 | 179 | ```python 180 | from plismbench.utils.metrics import get_leaderboard_results 181 | leaderboard_results = get_leaderboard_results(metrics_root_dir="/path/to/metrics/root_dir/") 182 | ``` 183 | 184 | Please check `notebooks/visualization.ipynb` for details. 185 | 186 | ## Contribute 187 | 188 | Please refer to our [documentation](https://owkin.github.io/plism-benchmark) to follow our contribution guidelines. 189 | To add a new feature extractor: 190 | - Create a `.py` file with your organization id (e.g. `owkin.py`) or re-use an existing one in `plismbench.models` 191 | - Add a class inheriting from `Extractor` (please refer to other models implementation for guidance) 192 | - Add your model in `FeatureExtractorsEnum` located in `plismbench.models.__init__` 193 | - Don't forget to test it in `tests/models/test_extractors.py` 194 | 195 | > [!IMPORTANT] 196 | > Please report the output of `get_leaderboard_results` in your PR description as illustrated above, along with the number of tiles used to compute the metrics. 197 | > 198 | 199 | ## License 200 | 201 | This repository is licensed under [CC BY 4.0 licence](https://creativecommons.org/licenses/by/4.0/deed.en). 202 | 203 | ## Acknowledgments 204 | 205 | We thank PLISM dataset's authors for their unique contribution. 206 | 207 | ## Third-party licenses 208 | 209 | - PLISM dataset (Ochi et al., 2024) is distributed under [CC BY 4.0 license](https://plus.figshare.com/collections/Pathology_Images_of_Scanners_and_Mobilephones_PLISM_Dataset/6773925). 210 | - Elastix (Klein et al., 2010; Shamonin et al., 2014) is distributed under [Apache 2.0 license](https://github.com/SuperElastix/elastix). 211 | 212 | ## How to cite 213 | 214 | If you are using this dataset, please cite the original article (Ochi et al., 2024) and our work as follows: 215 | 216 | Filiot, A., Dop, N., Tchita, O., Riou, A., Peeters, T., Valter, D., Scalbert, M., Saillard, C., Robin, G., & Olivier, A. (2025). Distilling foundation models for robust and efficient models in digital pathology. arXiv. https://arxiv.org/abs/2501.16239 217 | 218 | or 219 | 220 | ``` 221 | @misc{filiot2025distillingfoundationmodelsrobust, 222 | title={Distilling foundation models for robust and efficient models in digital pathology}, 223 | author={Alexandre Filiot and Nicolas Dop and Oussama Tchita and Auriane Riou and Thomas Peeters and Daria Valter and Marin Scalbert and Charlie Saillard and Geneviève Robin and Antoine Olivier}, 224 | year={2025}, 225 | eprint={2501.16239}, 226 | archivePrefix={arXiv}, 227 | primaryClass={cs.CV}, 228 | url={https://arxiv.org/abs/2501.16239}, 229 | } 230 | ``` 231 | 232 | ## References 233 | 234 | - (Ochi et al., 2024) Ochi, M., Komura, D., Onoyama, T. et al. Registered multi-device/staining histology image dataset for domain-agnostic machine learning models. Sci Data 11, 330 (2024). 235 | --------------------------------------------------------------------------------