├── .gitignore
├── LICENSE
├── README.md
├── docs
└── logos
│ └── logo.png
├── environment.yaml
├── examples
├── README.md
├── annotator_2d.py
├── annotator_2d_wsi.py
├── finetuning
│ ├── README.md
│ ├── annotator_with_finetuned_model.py
│ └── finetune_nuclick.py
└── train_semantic.py
├── experiments
├── README.md
├── benchmarking
│ ├── README.md
│ ├── cellvit
│ │ ├── cellvit_inference.py
│ │ ├── eval_util.py
│ │ └── pannuke_semantic.py
│ ├── hovernet
│ │ ├── eval_util.py
│ │ ├── hovernet_inference.py
│ │ └── semantic.py
│ ├── hovernext
│ │ ├── evaluate_ais_hover.py
│ │ └── hovernext_inference.py
│ ├── instanseg
│ │ └── instanseg_inference.py
│ ├── outdated
│ │ └── cellvitplusplus
│ │ │ ├── eval.py
│ │ │ └── infer_cvtplus.py
│ └── stardist
│ │ └── stardist_inference.py
├── data
│ ├── README.md
│ ├── cvtplus_preproc.py
│ ├── dataloaders.py
│ ├── get_paths.py
│ ├── interactive_data.py
│ ├── load_original_data.py
│ ├── load_padded_data.py
│ ├── semantic
│ │ └── get_semantic_paths.py
│ └── util.py
├── patho-sam
│ ├── README.md
│ ├── evaluate_ais.py
│ ├── evaluate_amg.py
│ ├── evaluate_iterative_prompting.py
│ ├── get_results.py
│ ├── per_image_eval.py
│ ├── push_inference.py
│ ├── run_ais.py
│ ├── run_amg.py
│ ├── run_iter_boxes.py
│ ├── run_iter_points.py
│ └── util.py
├── semantic_segmentation
│ ├── README.md
│ ├── benchmark_methods
│ │ ├── evaluate_other_methods.py
│ │ └── run_biomedparse.py
│ ├── generalists
│ │ ├── evaluate_pannuke.py
│ │ ├── submit_training.py
│ │ ├── train_generalist.py
│ │ └── train_pannuke.py
│ └── results
│ │ ├── get_bar_plots.py
│ │ └── get_qualitative_plots.py
└── training
│ ├── README.md
│ ├── generalists
│ ├── get_generalist_datasets.py
│ ├── run_all_finetuning.py
│ └── train_generalist_histopathology.py
│ └── specialists
│ ├── finetune_neutrophils.py
│ ├── finetune_specialist.py
│ ├── get_specialist_dataset.py
│ └── run_all_finetuning.py
├── patho_sam
├── __init__.py
├── __version__.py
├── automatic_segmentation.py
├── evaluation
│ ├── class_weights.py
│ └── evaluation.py
├── io
│ ├── __init__.py
│ └── util.py
├── semantic_segmentation.py
├── training
│ ├── __init__.py
│ ├── semantic_trainer.py
│ └── util.py
└── util.py
├── pyproject.toml
├── scripts
├── plotting
│ └── get_dataset_figure.py
└── test_pannuke.py
└── setup.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | share/python-wheels/
24 | *.egg-info/
25 | .installed.cfg
26 | *.egg
27 | MANIFEST
28 |
29 | # PyInstaller
30 | # Usually these files are written by a python script from a template
31 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
32 | *.manifest
33 | *.spec
34 |
35 | # Installer logs
36 | pip-log.txt
37 | pip-delete-this-directory.txt
38 |
39 | # Unit test / coverage reports
40 | htmlcov/
41 | .tox/
42 | .nox/
43 | .coverage
44 | .coverage.*
45 | .cache
46 | nosetests.xml
47 | coverage.xml
48 | *.cover
49 | *.py,cover
50 | .hypothesis/
51 | .pytest_cache/
52 | cover/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | .pybuilder/
76 | target/
77 |
78 | # Jupyter Notebook
79 | .ipynb_checkpoints
80 |
81 | # IPython
82 | profile_default/
83 | ipython_config.py
84 |
85 | # pyenv
86 | # For a library or package, you might want to ignore these files since the code is
87 | # intended to run in multiple environments; otherwise, check them in:
88 | # .python-version
89 |
90 | # pipenv
91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
94 | # install all needed dependencies.
95 | #Pipfile.lock
96 |
97 | # poetry
98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
99 | # This is especially recommended for binary packages to ensure reproducibility, and is more
100 | # commonly ignored for libraries.
101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
102 | #poetry.lock
103 |
104 | # pdm
105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
106 | #pdm.lock
107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
108 | # in version control.
109 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control
110 | .pdm.toml
111 | .pdm-python
112 | .pdm-build/
113 |
114 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
115 | __pypackages__/
116 |
117 | # Celery stuff
118 | celerybeat-schedule
119 | celerybeat.pid
120 |
121 | # SageMath parsed files
122 | *.sage.py
123 |
124 | # Environments
125 | .env
126 | .venv
127 | env/
128 | venv/
129 | ENV/
130 | env.bak/
131 | venv.bak/
132 |
133 | # Spyder project settings
134 | .spyderproject
135 | .spyproject
136 |
137 | # Rope project settings
138 | .ropeproject
139 |
140 | # mkdocs documentation
141 | /site
142 |
143 | # mypy
144 | .mypy_cache/
145 | .dmypy.json
146 | dmypy.json
147 |
148 | # Pyre type checker
149 | .pyre/
150 |
151 | # pytype static type analyzer
152 | .pytype/
153 |
154 | # Cython debug symbols
155 | cython_debug/
156 |
157 | # PyCharm
158 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
159 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
160 | # and can be added to the global gitignore or merged into this file. For a more nuclear
161 | # option (not recommended) you can uncomment the following to ignore the entire idea folder.
162 | #.idea/
163 |
164 | # Other stuff
165 | *.png
166 | gpu_jobs/
167 | checkpoints/
168 | logs/
169 | *.out
170 | *.zarr
171 | *.sh
172 | configs/
173 | inference_utils/
174 | *.tif
175 | *.svg
176 | *.csv
177 | *.tiff
178 | data/
179 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2025 Computational Cell Analytics (Research Group of Prof. Dr. Constantin Pape)
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Segment Anything for Histopathology
2 |
3 |
4 |
5 | PathoSAM implements interactive annotation and (automatic) instance and semantic segmentation for histopathology images. It is built on top of [Segment Anything](https://segment-anything.com/) by Meta AI and our prior work [Segment Anything for Microscopy](https://computational-cell-analytics.github.io/micro-sam/micro_sam.html). It specializes Segment Anything for nucleus segmentation in histopathology data. Its core components are:
6 | - The publicly available `patho_sam` models for interactive data annotation that were fine-tuned on openly available histopathology images.
7 | - The `patho_sam` library, which provides training functionality based on [Segment Anything for Microscopy](https://computational-cell-analytics.github.io/micro-sam/micro_sam.html), and supports:
8 | - Application of Segment Anything to histopathology images, including whole-slide images, and fine-tuning on your data.
9 | - Semantic segmentation.
10 |
11 | Based on these components, `patho_sam` enables fast interactive and automatic annotation for histopathology images, see [Usage](#usage) for details.
12 |
13 | ## Installation
14 |
15 | How to install `patho_sam` python library from source?
16 |
17 | To create the environment and install `patho_sam` into it follow these steps:
18 |
19 | 1. Clone the repository: `git clone https://github.com/computational-cell-analytics/patho-sam`
20 | 2. Enter it: `cd patho-sam`
21 | 3. Create the environment with the necessary requirements: `conda env create -f environment.yaml`
22 | 4. Activate the environment: `conda activate patho-sam`
23 | 5. Install `patho_sam`: `pip install -e .`
24 |
25 | ## Usage
26 |
27 | ### Using example scripts:
28 |
29 | See the [examples](./examples/) folder for more details.
30 |
31 | ### Using CLI:
32 |
33 | - Download the example whole-slide image by running the following via terminal: `patho_sam.example_data` (see `patho_sam.example_data -h` for more details about the CLI).
34 | - Run automatic segmentation on your own WSI or the example data by running the following via terminal:
35 | ```bash
36 | patho_sam.automatic_segmentation -i /home/anwai/.cache/micro_sam/sample_data/whole-slide-histopathology-example-image.svs -o segmentation.tif
37 | ```
38 |
39 | > NOTE 1: See `patho_sam.automatic_segmentation -h` for more details about the CLI.
40 |
41 | > NOTE 2: You can find your cache directory using: `python -c "from micro_sam.util import get_cache_directory; print(get_cache_directory())"`.
42 |
43 |
44 | ## Citation
45 |
46 | If you are using this repository in your research please cite:
47 | - [Our preprint](https://doi.org/10.48550/arXiv.2502.00408).
48 | - the [Segment Anything for Microscopy](https://www.nature.com/articles/s41592-024-02580-4) publication.
49 | - And the original [Segment Anything](https://arxiv.org/abs/2304.02643) publication.
50 |
--------------------------------------------------------------------------------
/docs/logos/logo.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/computational-cell-analytics/patho-sam/482e2e336b9e5b75028c013452a70aba8107207a/docs/logos/logo.png
--------------------------------------------------------------------------------
/environment.yaml:
--------------------------------------------------------------------------------
1 | name: patho-sam
2 | channels:
3 | - conda-forge
4 | dependencies:
5 | - micro_sam
6 | # Note: installing the pytorch package from conda-forge will generally
7 | # give you the most optmized version for your system, if you have a modern
8 | # enough OS and CUDA version (CUDA >= 12). For older versions, you can
9 | # specify the CUDA version by pinning libtorch.
10 | # For example, add this line for a CUDA 11 version:
11 | # - libtorch=*=cuda11*
12 | # or, to enforce a CPU installation, change to
13 | # - "pytorch=*=cpu*"
14 | - pytorch >=2.5
15 |
--------------------------------------------------------------------------------
/examples/README.md:
--------------------------------------------------------------------------------
1 | # Example Scripts
2 |
3 | Examples for using the [`micro_sam`](https://github.com/computational-cell-analytics/micro-sam) annotation tools:
4 |
5 | - `annotator_2d_wsi.py`: Run the interactive 2d annotation tool, suitable for whole-slide images (WSIs).
6 | - `annotator_2d.py`: Run the interactive 2d annotation tool.
7 | - `train_pannuke_semantic.py`: Train a model for semantic segmentation of nuclei in PanNuke histopathology images.
8 |
9 | There are Jupyter Notebooks available for using automatic segmentation and finetuning on some example data in the [notebooks](https://github.com/computational-cell-analytics/micro-sam/tree/master/notebooks) folder located in the `micro-sam` repository.
10 |
11 | The folder `finetuning` contains example scripts that show how a Segment Anything model can be fine-tuned on custom data with the `micro_sam.train` library, and how the finetuned models can then be used within the `micro_sam` annotation tools.
12 |
--------------------------------------------------------------------------------
/examples/annotator_2d.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | from micro_sam.util import get_cache_directory
4 | from micro_sam.sam_annotator import annotator_2d
5 | from micro_sam.sample_data import fetch_wholeslide_histopathology_example_data
6 |
7 |
8 | from patho_sam.io import read_wsi
9 |
10 |
11 | DATA_CACHE = os.path.join(get_cache_directory(), "sample_data")
12 |
13 |
14 | def whole_slide_image_annotator(use_finetuned_model):
15 | """Run the 2d annotator for an example histopathology image.
16 |
17 | See 'fetch_wholeslide_histopathology_example_data' for details on the data.
18 | """
19 | example_data = fetch_wholeslide_histopathology_example_data(DATA_CACHE)
20 |
21 | # Read a small ROI for the 2d annotator.
22 | image = read_wsi(example_data, image_size=(10000, 15000, 512, 512)) # spatial shape: (512, 512)
23 |
24 | if use_finetuned_model:
25 | model_type = "vit_b_histopathology"
26 | else:
27 | model_type = "vit_b"
28 |
29 | # Store embeddings in a desired shape.
30 | save_path = f"./embedding_{model_type}.zarr"
31 |
32 | annotator_2d(image=image, embedding_path=save_path, model_type=model_type)
33 |
34 |
35 | def main():
36 | # Whether to use the fine-tuned SAM model for WSIs.
37 | use_finetuned_model = True
38 |
39 | # 2d annotator for WSI data.
40 | whole_slide_image_annotator(use_finetuned_model)
41 |
42 |
43 | if __name__ == "__main__":
44 | main()
45 |
--------------------------------------------------------------------------------
/examples/annotator_2d_wsi.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | from micro_sam.util import get_cache_directory
4 | from micro_sam.sam_annotator import annotator_2d
5 | from micro_sam.sample_data import fetch_wholeslide_histopathology_example_data
6 |
7 | from patho_sam.io import read_wsi
8 |
9 |
10 | DATA_CACHE = os.path.join(get_cache_directory(), "sample_data")
11 |
12 |
13 | def whole_slide_image_annotator(use_finetuned_model):
14 | """Run the 2d annotator for an example whole-slide image for histopathology.
15 |
16 | See 'fetch_wholeslide_histopathology_example_data' for details on the data.
17 | """
18 | example_data = fetch_wholeslide_histopathology_example_data(DATA_CACHE)
19 |
20 | # Load the WSI image.
21 | image = read_wsi(example_data)
22 |
23 | if use_finetuned_model:
24 | model_type = "vit_b_histopathology"
25 | else:
26 | model_type = "vit_b"
27 |
28 | # Store embeddings in a desired shape.
29 | save_path = f"./embedding_{model_type}.zarr"
30 |
31 | annotator_2d(
32 | image=image,
33 | embedding_path=save_path,
34 | model_type=model_type,
35 | tile_shape=(384, 384),
36 | halo=(64, 64),
37 | precompute_amg_state=True,
38 | )
39 |
40 |
41 | def main():
42 | # Whether to use the fine-tuned SAM model for WSIs.
43 | use_finetuned_model = True
44 |
45 | # 2d annotator for WSI data.
46 | whole_slide_image_annotator(use_finetuned_model)
47 |
48 |
49 | if __name__ == "__main__":
50 | main()
51 |
--------------------------------------------------------------------------------
/examples/finetuning/README.md:
--------------------------------------------------------------------------------
1 | # Example for model finetuning
2 |
3 | This folder contains example scripts that show how to finetune a SAM model on your own data and how the finetuned model can be used:
4 |
5 | - `finetune_nuclick.py`: Shows how to finetune the model on new data. Set `train_instance_segmentation` to `True` in order to also train a decoder for automatic instance segmentation.
6 | - `annotator_with_finetuned_model.py`: Use the finetuned model in the 2d annotator.
7 |
--------------------------------------------------------------------------------
/examples/finetuning/annotator_with_finetuned_model.py:
--------------------------------------------------------------------------------
1 | import imageio.v3 as imageio
2 |
3 | from micro_sam.sam_annotator import annotator_2d
4 |
5 |
6 | def run_annotator_with_finetuned_model():
7 | """Run the 2d annotator with a custom (finetuned) model.
8 |
9 | Here, we use the model that is produced by `finetune_nuclick.py` and apply it
10 | for an image from the validation set.
11 | """
12 | # take the last frame, which is part of the val set, so the model was not directly trained on it
13 | im = imageio.imread("./data/IHC_nuclick/IHC/images/Validation/ROI_338_2.png")
14 |
15 | # set the checkpoint and the path for caching the embeddings
16 | checkpoint = "./finetuned_nuclick_model.pth"
17 | embedding_path = "./embeddings/embeddings-finetuned.zarr"
18 |
19 | # Adapt this if you finetune a different model type, e.g. vit_h.
20 | model_type = "vit_b" # We finetune a vit_b in the example script.
21 |
22 | # Run the 2d annotator with the custom model.
23 | annotator_2d(im, model_type=model_type, embedding_path=embedding_path, checkpoint=checkpoint)
24 |
25 |
26 | if __name__ == "__main__":
27 | run_annotator_with_finetuned_model()
28 |
--------------------------------------------------------------------------------
/examples/finetuning/finetune_nuclick.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import torch
4 |
5 | import torch_em
6 | from torch_em.data import MinInstanceSampler
7 | from torch_em.data.datasets import get_nuclick_dataset
8 | from torch_em.transform.label import PerObjectDistanceTransform
9 |
10 | import micro_sam.training as sam_training
11 | from micro_sam.util import export_custom_sam_model
12 |
13 | from patho_sam.training import get_train_val_split, histopathology_identity
14 |
15 |
16 | DATA_FOLDER = "data"
17 |
18 |
19 | def get_dataloaders(batch_size, patch_shape, train_instance_segmentation):
20 | """This returns the NuClick dataloaders implemented in `torch-em`.
21 | https://github.com/constantinpape/torch-em/blob/main/torch_em/data/datasets/histopathology/nuclick.py
22 | It will automatically download the NuClick data.
23 |
24 | NOTE: To replace this with another data loader, you need to return a torch data loader
25 | that returns `x, y` tensors, where `x` is the image data and `y` are corresponding labels.
26 | The labels have to be in a label mask instance segmentation format.
27 | i.e. a tensor of the same spatial shape as `x`, with instance labels for objects.
28 | Important: the ID 0 is reserved for background, and the IDS must be consecutive.
29 |
30 | See https://github.com/computational-cell-analytics/micro-sam/blob/master/examples/finetuning/finetune_hela.py
31 | for more details on how to create your custom dataloaders.
32 | """
33 | os.makedirs(DATA_FOLDER, exist_ok=True)
34 |
35 | # All relevant stuff for the dataset.
36 | raw_transform = histopathology_identity # Avoids normalizing the inputs, i.e. keeps the intensities b/w [0, 255].
37 | sampler = MinInstanceSampler(min_num_instances=2, min_size=10) # Ensures at least 2 foreground objects per sample.
38 | label_dtype = torch.float32 # Converts labels to expected dtype.
39 |
40 | if train_instance_segmentation:
41 | # Computes the distance transform for objects to perform end-to-end automatic instance segmentation.
42 | label_transform = PerObjectDistanceTransform(
43 | distances=True, boundary_distances=True, directed_distances=False, foreground=True, instances=True,
44 | )
45 | else:
46 | label_transform = torch_em.transform.label.connected_components
47 |
48 | dataset = get_nuclick_dataset(
49 | path=DATA_FOLDER,
50 | patch_shape=patch_shape,
51 | split="Train",
52 | sampler=sampler,
53 | label_dtype=label_dtype,
54 | label_transform=label_transform,
55 | raw_transform=raw_transform,
56 | download=True, # This will download the image and segmentation data for training.
57 | )
58 |
59 | # Get the datasets.
60 | train_ds, val_ds = get_train_val_split(dataset)
61 |
62 | # Get the dataloaders.
63 | train_loader = torch_em.get_data_loader(train_ds, batch_size=batch_size, shuffle=True)
64 | val_loader = torch_em.get_data_loader(val_ds, batch_size=batch_size, shuffle=True)
65 |
66 | return train_loader, val_loader
67 |
68 |
69 | def run_training(checkpoint_name, model_type, train_instance_segmentation):
70 | """Run the actual model training."""
71 |
72 | # All hyperparameters for training.
73 | batch_size = 1 # the training batch size
74 | patch_shape = (512, 512) # the size of patches for training
75 | n_objects_per_batch = 25 # the number of objects per batch that will be sampled
76 | device = "cuda" if torch.cuda.is_available() else "cpu" # the device/GPU used for training.
77 |
78 | # Get the dataloaders.
79 | train_loader, val_loader = get_dataloaders(batch_size, patch_shape, train_instance_segmentation)
80 |
81 | # Run training.
82 | sam_training.train_sam(
83 | name=checkpoint_name,
84 | model_type=model_type,
85 | train_loader=train_loader,
86 | val_loader=val_loader,
87 | n_epochs=100,
88 | n_objects_per_batch=n_objects_per_batch,
89 | with_segmentation_decoder=train_instance_segmentation,
90 | device=device,
91 | )
92 |
93 |
94 | def export_model(checkpoint_name, model_type):
95 | """Export the trained model."""
96 | export_path = "./finetuned_nuclick_model.pth"
97 | checkpoint_path = os.path.join("checkpoints", checkpoint_name, "best.pt")
98 | export_custom_sam_model(checkpoint_path=checkpoint_path, model_type=model_type, save_path=export_path)
99 |
100 |
101 | def main():
102 | """Finetune a Segment Anything model.
103 |
104 | This example uses image data and segmentations from the NuClick dataset for lymphocyte segmentation,
105 | but can be easily adapted for other data (including data you have annotated with 'micro_sam' beforehand).
106 | """
107 | # The 'model_type' determines which base model is used to initialize the weights that are finetuned.
108 | # We use 'vit_b' here becaise it can be trained faster. Note that 'vit_h' usually yields higher quality results.
109 | model_type = "vit_b"
110 |
111 | # The name of the checkpoint. The checkpoints will be stored in './checkpoints/'
112 | checkpoint_name = "sam_nuclick"
113 |
114 | # Train an additional convolutional decoer for end-to-end automatic instance segmentation.
115 | train_instance_segmentation = True
116 |
117 | run_training(checkpoint_name, model_type, train_instance_segmentation)
118 | export_model(checkpoint_name, model_type)
119 |
120 |
121 | if __name__ == "__main__":
122 | main()
123 |
--------------------------------------------------------------------------------
/examples/train_semantic.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import torch
4 |
5 | import torch_em
6 | from torch_em.data import MinTwoInstanceSampler
7 | from torch_em.data.datasets import get_pannuke_dataset
8 |
9 | import micro_sam.training as sam_training
10 | from micro_sam.instance_segmentation import get_unetr
11 |
12 | from patho_sam.training import SemanticInstanceTrainer, get_train_val_split
13 |
14 |
15 | DATA_FOLDER = "data"
16 |
17 |
18 | def get_dataloaders(data_path):
19 | """This returns the PanNuke dataloaders implemented in `torch-em`.
20 | https://github.com/constantinpape/torch-em/blob/main/torch_em/data/datasets/histopathology/pannuke.py
21 | It will automatically download the PanNuke data.
22 |
23 | NOTE: To replace this with another data loader, you need to return a torch data loader
24 | that returns `x, y` tensors, where `x` is the image data and `y` are corresponding labels.
25 | The labels have to be in a label mask semantic segmentation format.
26 | i.e. a tensor of the same spatial shape as `x`, with semantic labels for objects.
27 | Important: the ID 0 is reserved for background and ensure you have all semantic classes.
28 | """
29 | # All relevant stuff for the dataset.
30 | raw_transform = sam_training.identity # Avoids normalizing the inputs, i.e. keeps the intensities b/w [0, 255].
31 | sampler = MinTwoInstanceSampler() # Ensures that atleast one foreground class is obtained.
32 | label_dtype = torch.float32 # Converts labels to expected dtype.
33 |
34 | # Get the dataset
35 | dataset = get_pannuke_dataset(
36 | path=data_path,
37 | patch_shape=(1, 512, 512),
38 | ndim=2,
39 | folds=["fold_1", "fold_2"],
40 | custom_label_choice="semantic",
41 | sampler=sampler,
42 | label_dtype=label_dtype,
43 | raw_transform=raw_transform,
44 | download=True,
45 | )
46 |
47 | # Create custom splits.
48 | train_dataset, val_dataset = get_train_val_split(dataset)
49 |
50 | # Get the dataloaders.
51 | train_loader = torch_em.get_data_loader(dataset=train_dataset, batch_size=1, shuffle=True)
52 | val_loader = torch_em.get_data_loader(dataset=val_dataset, batch_size=1, shuffle=True)
53 |
54 | return train_loader, val_loader
55 |
56 |
57 | def train_pannuke_semantic_segmentation(checkpoint_name, model_type):
58 | """Script for semantic segmentation for PanNuke data."""
59 |
60 | # Parameters for training
61 | num_classes = 6 # available classes are [0, 1, 2, 3, 4, 5]
62 | device = "cuda" if torch.cuda.is_available() else "cpu"
63 | train_loader, val_loader = get_dataloaders(data_path=os.path.join(DATA_FOLDER, "pannuke"))
64 |
65 | # Get the trainable Segment Anything Model.
66 | model = sam_training.get_trainable_sam_model(
67 | model_type=model_type,
68 | device=device,
69 | checkpoint_path=None, # override to provide filepath for your trained SAM model.
70 | )
71 |
72 | # Get the UNETR model for semantic segmentation pipeline
73 | unetr = get_unetr(
74 | image_encoder=model.sam.image_encoder, device=device, out_channels=num_classes, flexible_load_checkpoint=True,
75 | )
76 |
77 | # All other stuff we need for training
78 | optimizer = torch.optim.AdamW(unetr.parameters(), lr=1e-4)
79 | scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor=0.9, patience=5)
80 |
81 | # Converts the per-batch inputs and corresponding labels to desired format required for training.
82 | convert_inputs = sam_training.util.ConvertToSemanticSamInputs()
83 |
84 | # Trainer for semantic segmentation (implemented using 'torch_em')
85 | trainer = SemanticInstanceTrainer(
86 | name=checkpoint_name,
87 | train_loader=train_loader,
88 | val_loader=val_loader,
89 | model=unetr,
90 | optimizer=optimizer,
91 | device=device,
92 | lr_scheduler=scheduler,
93 | log_image_interval=100,
94 | mixed_precision=True,
95 | compile_model=False,
96 | convert_inputs=convert_inputs,
97 | num_classes=num_classes,
98 | dice_weight=0, # override to use weighted dice-cross entropy loss. the trainer uses cross-entropy loss only.
99 | )
100 | trainer.fit(epochs=100)
101 |
102 |
103 | def main():
104 | """Finetune a Segment Anything model for semantic segmentation on the PanNuke dataset.
105 |
106 | This example uses image data and semantic segmentation labels for the PanNule dataset,
107 | but can easily be adapted for other data (including data you have annotated with patho_sam beforehand).
108 | NOTE: You must provide semantic class labels to train within this setup.
109 | """
110 | # The model_type determines which base model is used to initialize the weights that are finetuned.
111 | # We use 'vit_b' here because it can be trained faster. Note that 'vit_h' yields higher quality results.
112 | model_type = "vit_b"
113 |
114 | # The name of checkpoint. The checkpoints will be stored in './checkpoints/'.
115 | checkpoint_name = "pannuke_semantic"
116 |
117 | train_pannuke_semantic_segmentation(checkpoint_name, model_type)
118 |
119 |
120 | if __name__ == "__main__":
121 | main()
122 |
--------------------------------------------------------------------------------
/experiments/README.md:
--------------------------------------------------------------------------------
1 | # Segment Anything for Histopathology *Experiment Scripts*
2 |
3 | This directory contains all experiment scripts for building and benchmarking segment anything model for histopathology.
4 |
5 | Here is a brief description of relevant folders:
6 | - `benchmarking`: Contains scripts to run reference methods (eg. HoVerNet, HoVerNeXt, CellViT, InstanSeg, StarDist)
7 | - `data`: Contains scripts for preprocessing images and corresponding labels for all experiments.
8 | - `patho-sam`: Contains scripts for running SAM-models (eg. default SAM, `micro-sam` and `patho-sam` models for automatic and interactive instance segmentation).
9 | - `semantic_segmentation`: Contains scripts for training and evaluation semantic segmentation using `patho-sam` generalist models.
10 | - `training`: Contains scripts for training the specialist and generalist `patho-sam` models.
11 |
12 | > NOTE 1: There are scripts where the expected filepaths / directories might be hard-coded. Replace them with your respective paths where you would like to store / fetch files from.
13 |
14 | >NOTE 2: We provide [example scripts](../examples/) for convenience. Feel free to check them out for both interactive and automatic segmentation using `PathoSAM` models.
15 |
--------------------------------------------------------------------------------
/experiments/benchmarking/README.md:
--------------------------------------------------------------------------------
1 | # Benchmarking Methods for Automatic Instance (and Semantic) Segmentation
2 |
3 | This folder contains methods that are used to compare automatic segmentation of `patho-sam` with the following:
4 | - `cellvit`: Contains scripts to run CellViT for automatic instance segmentation and classification with several model variants.
5 | - `hovernet`: Contains scripts to run HoVerNet for automatic instance segmentation and classifciation with several model variants.
6 | - `hovernext`: Contains scripts to run HoVerNeXt for automatic instance segmentation and classification with several model variants.
7 | - `instanseg`: Contains scripts to run InstanSeg for automatic instance segmentation.
8 | - `stardist`: Contains scripts to run StarDist for automatic instance segmentation.
9 | - Other miscellanous folders:
10 | - `outdated`: Contains other scripts either outdated or under development.
11 |
--------------------------------------------------------------------------------
/experiments/benchmarking/cellvit/cellvit_inference.py:
--------------------------------------------------------------------------------
1 | import os
2 | import shutil
3 | import subprocess
4 | import argparse
5 |
6 | from eval_util import evaluate_cellvit
7 |
8 |
9 | def run_inference(input_dir, output_dir, dataset, checkpoint_path):
10 | if dataset not in ["pannuke", "nuclick", "srsanet", "lizard", "cpm15", "consep", "cpm17", "monuseg"]:
11 | data_dir = os.path.join(input_dir, "original_data", dataset, "eval_split")
12 | else:
13 | data_dir = os.path.join(input_dir, "vit_data", dataset, "eval_split")
14 |
15 | checkpoint = os.path.splitext(os.path.basename(checkpoint_path))[0].split("-", 1)[1]
16 |
17 | result_dir = os.path.join(output_dir, "results")
18 | output_path = os.path.join(output_dir, dataset, checkpoint)
19 | os.makedirs(output_path, exist_ok=True)
20 | args = [
21 | "--model", f"{checkpoint_path}",
22 | "--outdir", f"{output_path}",
23 | "--magnification", "40",
24 | "--data", f"{data_dir}",
25 | ]
26 |
27 | command = [
28 | "python", os.path.expanduser("~/CellViT/cell_segmentation/inference/inference_cellvit_experiment_monuseg.py"),
29 | ] + args
30 |
31 | print(f"Running inference with CellViT {checkpoint} model on {dataset} dataset...")
32 | subprocess.run(command)
33 |
34 | plot_dir = os.path.join(output_dir, dataset, checkpoint, dataset, "plots")
35 | if os.path.exists(plot_dir):
36 | shutil.rmtree(plot_dir)
37 |
38 | if os.path.exists(os.path.join(output_path, "inference_monuseg.log")):
39 | os.remove(os.path.join(output_path, "inference_monuseg.log"))
40 |
41 | evaluate_cellvit(output_path, checkpoint, dataset, data_dir, result_dir)
42 | print(f"Successfully ran inference with CellViT {checkpoint} model on {dataset} dataset")
43 |
44 |
45 | def get_cellvit_args():
46 | parser = argparse.ArgumentParser()
47 | parser.add_argument("-i", "--input", type=str, default=None, help="The path to the input data")
48 | parser.add_argument("-d", "--dataset", type=str, default=None, help="The datasets to infer on")
49 | parser.add_argument("-o", "--output_path", type=str, default=None, help="The path where the results are saved")
50 | parser.add_argument("-c", "--checkpoint_path", type=str, default=None, help="The checkpoints to use for inference")
51 | args = parser.parse_args()
52 | return args
53 |
54 |
55 | def main():
56 | args = get_cellvit_args()
57 | run_inference(
58 | input_dir=args.input,
59 | output_dir=args.output_path,
60 | datasets=[args.dataset],
61 | checkpoints=[args.checkpoint_path],
62 | )
63 |
64 |
65 | if __name__ == "__main__":
66 | main()
67 |
--------------------------------------------------------------------------------
/experiments/benchmarking/cellvit/eval_util.py:
--------------------------------------------------------------------------------
1 | import os
2 | import zipfile
3 | from glob import glob
4 | from tqdm import tqdm
5 | from natsort import natsorted
6 |
7 | import numpy as np
8 | import pandas as pd
9 | import imageio.v3 as imageio
10 | from skimage.measure import label
11 |
12 | from elf.evaluation import mean_segmentation_accuracy
13 |
14 |
15 | def zip_predictions(path, target_dir):
16 | print(f"Zipping {path}...")
17 | zip_name = os.path.basename(path) + ".zip"
18 | zip_path = os.path.join(target_dir, zip_name)
19 | with zipfile.ZipFile(zip_path, "w", zipfile.ZIP_DEFLATED) as zipf:
20 | for root, _dirs, files in os.walk(path):
21 | for file in files:
22 | file_path = os.path.join(root, file)
23 | arcname = os.path.relpath(file_path, start=path)
24 | zipf.write(file_path, arcname)
25 | print("Successfully zipped results")
26 |
27 |
28 | def _run_evaluation(gt_paths, prediction_paths, verbose=True):
29 | print(len(gt_paths), len(prediction_paths))
30 | assert len(gt_paths) == len(prediction_paths), \
31 | f"label / prediction mismatch: {len(gt_paths)} / {len(prediction_paths)}"
32 |
33 | msas, sa50s, sa75s = [], [], []
34 | for gt_path, pred_path in tqdm(
35 | zip(gt_paths, prediction_paths),
36 | desc="Evaluate predictions",
37 | total=len(gt_paths),
38 | disable=not verbose,
39 | ):
40 | assert os.path.exists(gt_path), gt_path
41 | assert os.path.exists(pred_path), pred_path
42 |
43 | gt = imageio.imread(gt_path)
44 | gt = label(gt)
45 | pred = imageio.imread(pred_path)
46 |
47 | msa, scores = mean_segmentation_accuracy(pred, gt, return_accuracies=True)
48 | sa50, sa75 = scores[0], scores[5]
49 | msas.append(msa), sa50s.append(sa50), sa75s.append(sa75)
50 |
51 | return msas, sa50s, sa75s
52 |
53 |
54 | def evaluate_cellvit(prediction_dir, checkpoint, dataset, label_dir, result_dir):
55 | save_path = os.path.join(dataset, checkpoint, f'{dataset}_cellvit_{checkpoint}_ais_result.csv')
56 | if os.path.exists(save_path):
57 | print(f"Results for {dataset} evaluation already exist")
58 | return
59 | prediction_paths = natsorted(glob(os.path.join(prediction_dir, "*")))
60 | gt_paths = natsorted(glob(os.path.join(label_dir, "test_labels", "*")))
61 | if len(prediction_paths) == 0:
62 | print(f"No predictions for {dataset} dataset on {checkpoint} checkpoint found")
63 | return
64 |
65 | msas, sa50s, sa75s = _run_evaluation(gt_paths=gt_paths, prediction_paths=prediction_paths)
66 | results = pd.DataFrame.from_dict(
67 | {
68 | "mSA": [np.mean(msas)], "SA50": [np.mean(sa50s)], "SA75": [np.mean(sa75s)],
69 | }
70 | )
71 | print(results.head(2))
72 | os.makedirs(os.path.join(result_dir, dataset, checkpoint), exist_ok=True)
73 | results.to_csv(save_path, index=False)
74 |
--------------------------------------------------------------------------------
/experiments/benchmarking/cellvit/pannuke_semantic.py:
--------------------------------------------------------------------------------
1 | import os
2 | import shutil
3 | import argparse
4 | import subprocess
5 |
6 |
7 | CVT_CP = ["256-x20", "256-x40", "SAM-H-x20", "SAM-H-x40"]
8 |
9 |
10 | def run_inference(checkpoint_path, input_dir, output_dir, dataset):
11 | data_dir = os.path.join(input_dir, dataset, "eval_split")
12 | checkpoint = os.path.splitext((os.path.basename(checkpoint_path)[8:]))[0]
13 | output_path = os.path.join(output_dir, dataset, checkpoint)
14 | os.makedirs(output_path, exist_ok=True)
15 | args = [
16 | "--model", f"{checkpoint_path}",
17 | "--outdir", f"{output_path}",
18 | "--magnification", "40",
19 | "--data", f"{data_dir}",
20 | ]
21 |
22 | command = [
23 | "python3",
24 | os.path.expanduser("~/CellViT/cell_segmentation/inference/inference_cellvit_experiment_pannuke.py"), # noqa
25 | ] + args
26 |
27 | print(f"Running inference with CellViT {checkpoint} model on {dataset} dataset...")
28 | subprocess.run(command)
29 | plot_dir = os.path.join(output_path, "plots")
30 | if os.path.exists(plot_dir):
31 | shutil.rmtree(plot_dir)
32 |
33 | print(f"Successfully ran inference with CellViT {checkpoint} model on {dataset} dataset")
34 |
35 |
36 | def get_cellvit_args():
37 | parser = argparse.ArgumentParser()
38 | parser.add_argument("-i", "--input", type=str, default=None, help="The path to the input data")
39 | parser.add_argument("-d", "--dataset", type=str, default=None, help="The datasets to infer on")
40 | parser.add_argument("-o", "--output_path", type=str, default=None, help="The path where the results are saved")
41 | parser.add_argument("-c", "--checkpoint_path", type=str, default=None, help="The checkpoints to use for inference")
42 | args = parser.parse_args()
43 | return args
44 |
45 |
46 | def main():
47 | args = get_cellvit_args()
48 | run_inference(
49 | input_dir=args.input,
50 | output_dir=args.output_path,
51 | datasets=args.dataset,
52 | checkpoints=args.checkpoint_path,
53 | )
54 |
55 |
56 | if __name__ == "__main__":
57 | main()
58 |
--------------------------------------------------------------------------------
/experiments/benchmarking/hovernet/eval_util.py:
--------------------------------------------------------------------------------
1 | import os
2 | from glob import glob
3 | from tqdm import tqdm
4 | from natsort import natsorted
5 |
6 | import numpy as np
7 | import pandas as pd
8 | import imageio.v3 as imageio
9 | from skimage.measure import label
10 |
11 | from elf.evaluation import mean_segmentation_accuracy
12 |
13 |
14 | DATASETS = [
15 | "consep",
16 | "cpm15",
17 | "cpm17",
18 | "cryonuseg",
19 | "lizard",
20 | "lynsec_he",
21 | "lynsec_ihc",
22 | "monuseg",
23 | "nuclick",
24 | "nuinsseg",
25 | "pannuke",
26 | "puma",
27 | "srsanet",
28 | "tnbc",
29 | ]
30 |
31 |
32 | def _run_evaluation(gt_paths, prediction_paths, verbose=True):
33 | assert len(gt_paths) == len(
34 | prediction_paths
35 | ), f"label / prediction mismatch: {len(gt_paths)} / {len(prediction_paths)}"
36 | msas, sa50s, sa75s = [], [], []
37 |
38 | for gt_path, pred_path in tqdm(
39 | zip(gt_paths, prediction_paths),
40 | desc="Evaluate predictions",
41 | total=len(gt_paths),
42 | disable=not verbose,
43 | ):
44 | assert os.path.exists(gt_path), gt_path
45 | assert os.path.exists(pred_path), pred_path
46 |
47 | gt = imageio.imread(gt_path)
48 | gt = label(gt)
49 | pred = imageio.imread(pred_path)
50 |
51 | msa, scores = mean_segmentation_accuracy(pred, gt, return_accuracies=True)
52 | sa50, sa75 = scores[0], scores[5]
53 | msas.append(msa), sa50s.append(sa50), sa75s.append(sa75)
54 |
55 | return msas, sa50s, sa75s
56 |
57 |
58 | def evaluate_hovernet(prediction_dir, label_dir, result_dir, dataset, checkpoint):
59 | gt_paths = natsorted(glob(os.path.join(label_dir, "*")))
60 | save_path = os.path.join(result_dir, dataset, checkpoint, f'{dataset}_hovernet_{checkpoint}_ais_result.csv')
61 | prediction_paths = natsorted(glob(os.path.join(prediction_dir, "*.tiff")))
62 | os.makedirs(os.path.join(result_dir, dataset, checkpoint), exist_ok=True)
63 | print(f"evaluation {dataset} dataset on checkpoint {checkpoint} ...")
64 | msas, sa50s, sa75s = _run_evaluation(gt_paths=gt_paths, prediction_paths=prediction_paths)
65 | results = pd.DataFrame.from_dict(
66 | {
67 | "mSA": [np.mean(msas)],
68 | "SA50": [np.mean(sa50s)],
69 | "SA75": [np.mean(sa75s)],
70 | }
71 | )
72 | print(results.head())
73 | results.to_csv(save_path, index=False)
74 |
--------------------------------------------------------------------------------
/experiments/benchmarking/hovernet/hovernet_inference.py:
--------------------------------------------------------------------------------
1 | import os
2 | import shutil
3 | import argparse
4 | import subprocess
5 | from glob import glob
6 | from tqdm import tqdm
7 | from natsort import natsorted
8 |
9 | import imageio as imageio
10 | from scipy.io import loadmat
11 |
12 | from eval_util import evaluate_hovernet
13 |
14 |
15 | def mat_to_tiff(path):
16 | pred_mat_paths = [p for p in natsorted(glob(os.path.join(path, "mat", "*.mat")))]
17 | for mpath in tqdm(pred_mat_paths, desc="Preprocessing labels"):
18 | pred_path = os.path.join(path, os.path.basename(mpath.replace(".mat", ".tiff")))
19 | pred = loadmat(mpath)["inst_map"]
20 | imageio.imwrite(pred_path, pred)
21 |
22 |
23 | def run_inference(input_dir, output_dir, dataset, checkpoint_path):
24 | checkpoint = os.path.basename(checkpoint_path).split("_")[2]
25 | output_path = os.path.join(output_dir, "inference", dataset, checkpoint)
26 | input_path = os.path.join(input_dir, dataset, "eval_split", "test_images")
27 |
28 | os.makedirs(output_path, exist_ok=True)
29 | if checkpoint in ["consep", "cpm17", "kumar"]:
30 | model_mode = "original"
31 | nr_types = 0
32 | type_info = ""
33 | else:
34 | model_mode = "fast"
35 | type_info = os.path.expanduser("~/hover_net/type_info.json")
36 | if checkpoint == "pannuke":
37 | nr_types = 6
38 | else:
39 | nr_types = 5
40 |
41 | args = [
42 | "--nr_types", f"{nr_types}",
43 | "--type_info_path", f"{type_info}",
44 | "--model_mode", f"{model_mode}",
45 | "--model_path", f"{checkpoint_path}",
46 | "--nr_inference_workers", "2",
47 | "--nr_post_proc_worker", "0",
48 | "tile",
49 | "--input_dir", f"{input_path}",
50 | "--output_dir", f"{output_path}",
51 | "--save_raw_map",
52 | ]
53 |
54 | command = ["python3", os.path.expanduser("~/hover_net/run_infer.py")] + args
55 | print(f"Running inference with HoVerNet {checkpoint} model on {dataset} dataset...")
56 |
57 | subprocess.run(command)
58 | mat_to_tiff(os.path.join(output_path))
59 | evaluate_hovernet(
60 | prediction_dir=output_path,
61 | label_dir=os.path.join(input_dir, dataset, "eval_split", "test_labels"),
62 | result_dir=os.path.join(output_dir, "results"),
63 | checkpoint=checkpoint,
64 | dataset=dataset,
65 | )
66 | shutil.rmtree(os.path.join(output_path, "json"))
67 | shutil.rmtree(os.path.join(output_path, "mat"))
68 | shutil.rmtree(os.path.join(output_path, "overlay"))
69 | print(f"Inference on {dataset} dataset with the HoVerNet {checkpoint} model successfully completed")
70 |
71 |
72 | def get_hovernet_args():
73 | parser = argparse.ArgumentParser()
74 | parser.add_argument("-i", "--input", type=str, default=None, help="The path to the input data")
75 | parser.add_argument("-d", "--dataset", type=str, default=None, help="The datasets to infer on")
76 | parser.add_argument("-o", "--output_dir", type=str, default=None, help="The path where the results are saved")
77 | parser.add_argument(
78 | "-c", "--checkpoint_path", type=str, default=None,
79 | help="The path to the HoVerNet checkpoint to use for inference."
80 | )
81 | args = parser.parse_args()
82 | return args
83 |
84 |
85 | def main():
86 | args = get_hovernet_args()
87 | run_inference(
88 | input_dir=args.input,
89 | output_dir=args.output_dir,
90 | dataset=args.dataset,
91 | checkpoint=args.checkpoint_path,
92 | )
93 |
94 |
95 | if __name__ == "__main__":
96 | main()
97 |
--------------------------------------------------------------------------------
/experiments/benchmarking/hovernet/semantic.py:
--------------------------------------------------------------------------------
1 | import os
2 | import shutil
3 | import subprocess
4 | from glob import glob
5 | from tqdm import tqdm
6 | from natsort import natsorted
7 |
8 | import cv2
9 | import json
10 | import numpy as np
11 | import imageio.v3 as imageio
12 |
13 |
14 | def json_to_tiff(path):
15 | label_json_paths = [p for p in natsorted(glob(os.path.join(path, "json", "*.json")))]
16 | img_shape = (256, 256)
17 | for mpath in tqdm(label_json_paths, desc="Postprocessing labels"):
18 | label_path = os.path.join(path, os.path.basename(mpath.replace(".json", ".tiff")))
19 | with open(mpath, 'r') as file:
20 | data = json.load(file)
21 | pred_class_map = np.zeros(img_shape, dtype=np.int32)
22 | for id, cell_data in enumerate(data['nuc'].items(), start=1):
23 | cell_data = cell_data[1]
24 | contour = np.array(cell_data["contour"])
25 | contour[:, 0] = np.clip(contour[:, 0], 0, img_shape[1])
26 | contour[:, 1] = np.clip(contour[:, 1], 0, img_shape[0])
27 | contour = contour.reshape((-1, 1, 2))
28 | cell_type = cell_data["type"]
29 | contour = np.vstack((contour, [contour[0]]))
30 | contour = contour.astype(np.int32)
31 | cv2.fillPoly(pred_class_map, [contour], cell_type)
32 |
33 | imageio.imwrite(label_path, pred_class_map)
34 |
35 |
36 | def run_inference(model_dir, input_dir, output_dir, type_info_path):
37 | for dataset in ["pannuke"]:
38 | for checkpoint in ["pannuke"]:
39 | output_path = os.path.join(output_dir, "inference", dataset, checkpoint)
40 | input_path = os.path.join(input_dir, dataset, "semantic_split", "test_images")
41 | if os.path.exists(os.path.join(output_dir, "results", dataset, checkpoint, 'ais_result.csv')):
42 | print(f"Inference with HoVerNet model (type: {checkpoint}) on {dataset} dataset already done")
43 | continue
44 |
45 | os.makedirs(output_path, exist_ok=True)
46 | if checkpoint in ["consep", "cpm17", "kumar"]:
47 | model_mode = "original"
48 | model_path = os.path.join(
49 | model_dir, "checkpoints", f"hovernet_original_{checkpoint}_notype_tf2pytorch.tar"
50 | )
51 | nr_types = 0
52 | else:
53 | model_mode = "fast"
54 |
55 | model_path = os.path.join(model_dir, "checkpoints", f"hovernet_fast_{checkpoint}_type_tf2pytorch.tar")
56 | if checkpoint == "pannuke":
57 | nr_types = 6
58 | else:
59 | nr_types = 5
60 |
61 | args = [
62 | "--nr_types", f"{nr_types}",
63 | "--type_info_path", "/user/titus.griebel/u12649/hover_net/type_info.json",
64 | "--model_mode", f"{model_mode}",
65 | "--model_path", f"{model_path}",
66 | "--nr_inference_workers", "2",
67 | "--nr_post_proc_worker", "0",
68 | "tile",
69 | "--input_dir", f"{input_path}",
70 | "--output_dir", f"{output_path}",
71 | ]
72 |
73 | command = ["python3", "/user/titus.griebel/u12649/hover_net/run_infer.py"] + args
74 | print(f"Running inference with HoVerNet {checkpoint} model on {dataset} dataset...")
75 |
76 | subprocess.run(command)
77 | json_to_tiff(output_path)
78 | shutil.rmtree(os.path.join(output_path, "json"))
79 | shutil.rmtree(os.path.join(output_path, "mat"))
80 | shutil.rmtree(os.path.join(output_path, "overlay"))
81 | print(f"Inference on {dataset} dataset with the HoVerNet {checkpoint} model successfully completed")
82 |
83 |
84 | run_inference(
85 | model_dir="/mnt/lustre-grete/usr/u12649/models/hovernet",
86 | input_dir="/mnt/lustre-grete/usr/u12649/data/original_data",
87 | output_dir="/mnt/lustre-grete/usr/u12649/models/hovernet_types",
88 | type_info_path="/user/titus.griebel/u12649/hover_net/type_info.json",
89 | )
90 |
--------------------------------------------------------------------------------
/experiments/benchmarking/hovernext/evaluate_ais_hover.py:
--------------------------------------------------------------------------------
1 | import os
2 | from glob import glob
3 | from tqdm import tqdm
4 | from natsort import natsorted
5 |
6 | import numpy as np
7 | import pandas as pd
8 | import imageio.v3 as imageio
9 | from skimage.measure import label
10 |
11 | from elf.evaluation import mean_segmentation_accuracy
12 |
13 |
14 | CHECKPOINTS = [
15 | "lizard_convnextv2_large",
16 | "lizard_convnextv2_base",
17 | "lizard_convnextv2_tiny",
18 | "pannuke_convnextv2_tiny_1",
19 | "pannuke_convnextv2_tiny_2",
20 | "pannuke_convnextv2_tiny_3",
21 | ]
22 |
23 | DATASETS = [
24 | "consep",
25 | "cpm15",
26 | "cpm17",
27 | "cryonuseg",
28 | "lizard",
29 | "lynsec_he",
30 | "lynsec_ihc",
31 | "monusac",
32 | "monuseg",
33 | "nuclick",
34 | "nuinsseg",
35 | "pannuke",
36 | "puma",
37 | "srsanet",
38 | "tnbc",
39 | ]
40 |
41 |
42 | def _run_evaluation(gt_paths, prediction_paths, verbose=True):
43 | assert len(gt_paths) == len(prediction_paths), \
44 | f"label / prediction mismatch: {len(gt_paths)} / {len(prediction_paths)}"
45 |
46 | msas, sa50s, sa75s = [], [], []
47 | for gt_path, pred_path in tqdm(
48 | zip(gt_paths, prediction_paths, strict=False),
49 | desc="Evaluate predictions",
50 | total=len(gt_paths),
51 | disable=not verbose,
52 | ):
53 | assert os.path.exists(gt_path), gt_path
54 | assert os.path.exists(pred_path), pred_path
55 |
56 | gt = imageio.imread(gt_path)
57 | gt = label(gt)
58 | pred = imageio.imread(pred_path)
59 |
60 | msa, scores = mean_segmentation_accuracy(pred, gt, return_accuracies=True)
61 | sa50, sa75 = scores[0], scores[5]
62 | msas.append(msa), sa50s.append(sa50), sa75s.append(sa75)
63 |
64 | return msas, sa50s, sa75s
65 |
66 |
67 | def evaluate_all_datasets_hovernet(prediction_dir, label_dir, result_dir):
68 | for dataset in DATASETS:
69 | gt_paths = natsorted(glob(os.path.join(label_dir, dataset, "eval_split", "test_labels", "*")))
70 | for checkpoint in CHECKPOINTS:
71 | save_path = os.path.join(
72 | result_dir, dataset, checkpoint, f'{dataset}_hovernext_{checkpoint}_ais_result.csv'
73 | )
74 | if os.path.exists(save_path):
75 | continue
76 |
77 | prediction_paths = natsorted(glob(os.path.join(prediction_dir, dataset, checkpoint, "*")))
78 | if len(prediction_paths) == 0:
79 | print(f"No predictions for {dataset} dataset on {checkpoint} checkpoint found")
80 | continue
81 |
82 | os.makedirs(os.path.join(result_dir, dataset, checkpoint), exist_ok=True)
83 | print(f"evaluation {dataset} dataset on checkpoint {checkpoint} ...")
84 |
85 | msas, sa50s, sa75s = _run_evaluation(gt_paths=gt_paths, prediction_paths=prediction_paths)
86 | results = pd.DataFrame.from_dict(
87 | {"mSA": [np.mean(msas)], "SA50": [np.mean(sa50s)], "SA75": [np.mean(sa75s)]}
88 | )
89 | print(results.head())
90 | results.to_csv(save_path, index=False)
91 |
92 |
93 | evaluate_all_datasets_hovernet(
94 | "/mnt/lustre-grete/usr/u12649/models/hovernext/inference",
95 | "/mnt/lustre-grete/usr/u12649/data/original_data",
96 | "/mnt/lustre-grete/usr/u12649/models/hovernext/results",
97 | )
98 |
--------------------------------------------------------------------------------
/experiments/benchmarking/hovernext/hovernext_inference.py:
--------------------------------------------------------------------------------
1 | import os
2 | import subprocess
3 |
4 | # This script does not work right now; use the same script inside the hovernext clone
5 |
6 |
7 | def run_inference(input_dir, output_dir):
8 | for dataset in [
9 | "cpm15",
10 | "cpm17",
11 | "cryonuseg",
12 | "janowczyk",
13 | "lizard",
14 | "lynsec",
15 | "monusac",
16 | "monuseg",
17 | "nuinsseg",
18 | "pannuke",
19 | "puma",
20 | "tnbc",
21 | ]:
22 | for model in ["pannuke_convnextv2_tiny_1", "pannuke_convnextv2_tiny_2", "pannuke_convnextv2_tiny_3"]:
23 | output_path = os.path.join(output_dir, dataset, model)
24 | input_path = os.path.join(input_dir, dataset, "loaded_dataset/complete_dataset/eval_split/test_images/*")
25 | if len(os.listdir(output_path)) > 0:
26 | continue
27 | os.makedirs(output_path, exist_ok=True)
28 | args = [
29 | "--input", f"{input_path}",
30 | "--cp", f"{model}",
31 | "--output_root", f"{output_path}",
32 | "--tile_size", "512",
33 | ]
34 |
35 | command = ["python3", "/user/titus.griebel/u12649/hover_next_inference/main.py"] + args
36 | print(f"Running inference with HoVerNeXt {model} model on {dataset} dataset...")
37 | subprocess.run(command)
38 | print(f"Inference on {dataset} dataset with the HoVerNeXt model {model} successfully completed")
39 |
40 |
41 | run_inference(
42 | input_dir="/mnt/lustre-grete/usr/u12649/data/test",
43 | output_dir="/mnt/lustre-grete/usr/u12649/models/hovernext/inference",
44 | )
45 |
--------------------------------------------------------------------------------
/experiments/benchmarking/instanseg/instanseg_inference.py:
--------------------------------------------------------------------------------
1 | import os
2 | import argparse
3 | from glob import glob
4 | from tqdm import tqdm
5 | from natsort import natsorted
6 |
7 | import numpy as np
8 | import pandas as pd
9 | import imageio.v3 as imageio
10 | from skimage.measure import label
11 |
12 | from tukra.io import read_image
13 | from tukra.inference import segment_using_instanseg
14 |
15 | from elf.evaluation import mean_segmentation_accuracy
16 |
17 |
18 | def _run_evaluation(gt_paths, prediction_paths, verbose=True):
19 | assert len(gt_paths) == len(prediction_paths), \
20 | f"label / prediction mismatch: {len(gt_paths)} / {len(prediction_paths)}"
21 |
22 | msas, sa50s, sa75s = [], [], []
23 | for gt_path, pred_path in tqdm(
24 | zip(gt_paths, prediction_paths, strict=False),
25 | desc="Evaluate predictions",
26 | total=len(gt_paths),
27 | disable=not verbose,
28 | ):
29 | assert os.path.exists(gt_path), gt_path
30 | assert os.path.exists(pred_path), pred_path
31 |
32 | gt = imageio.imread(gt_path)
33 | gt = label(gt)
34 | pred = imageio.imread(pred_path)
35 |
36 | msa, scores = mean_segmentation_accuracy(pred, gt, return_accuracies=True)
37 | sa50, sa75 = scores[0], scores[5]
38 | msas.append(msa), sa50s.append(sa50), sa75s.append(sa75)
39 |
40 | return msas, sa50s, sa75s
41 |
42 |
43 | def evaluate_instanseg(prediction_dir, label_dir, result_dir, dataset):
44 | gt_paths = natsorted(glob(os.path.join(label_dir, "eval_split", "test_labels", "*")))
45 | for checkpoint in ["instanseg"]:
46 | save_path = os.path.join(result_dir, dataset, checkpoint, f'{dataset}_instanseg_{checkpoint}_ais_result.csv')
47 | if os.path.exists(save_path):
48 | continue
49 | prediction_paths = natsorted(glob(os.path.join(prediction_dir, "*")))
50 | if len(prediction_paths) == 0:
51 | print(f"No predictions for {dataset} dataset on {checkpoint} checkpoint found")
52 | continue
53 | os.makedirs(os.path.join(result_dir, dataset, checkpoint), exist_ok=True)
54 | print(f"Evaluating {dataset} dataset on InstanSeg ...")
55 | msas, sa50s, sa75s = _run_evaluation(gt_paths=gt_paths, prediction_paths=prediction_paths)
56 | results = pd.DataFrame.from_dict(
57 | {"mSA": [np.mean(msas)], "SA50": [np.mean(sa50s)], "SA75": [np.mean(sa75s)]}
58 | )
59 |
60 | results.to_csv(save_path, index=False)
61 | print(results.head(2))
62 |
63 |
64 | def infer_instanseg(data_dir, output_path, dataset):
65 | image_paths = natsorted(glob(os.path.join(data_dir, "eval_split", "test_images", "*")))
66 | os.makedirs(output_path, exist_ok=True)
67 | for image_path in tqdm(image_paths, desc=f"Performing inference on {dataset}"):
68 | image = read_image(image_path)
69 | segmentation = segment_using_instanseg(
70 | image=image, model_type="brightfield_nuclei", target="nuclei", scale="small"
71 | )
72 | imageio.imwrite(os.path.join(output_path, os.path.basename(image_path).replace(".png", ".tiff")), segmentation)
73 |
74 |
75 | def run_inference(input_dir, output_dir, dataset):
76 | output_path = os.path.join(output_dir, "inference", dataset, "instanseg")
77 | input_path = os.path.join(input_dir, dataset)
78 | if os.path.exists(
79 | os.path.join(output_dir, "results", dataset, "instanseg", f'{dataset}_instanseg_instanseg_ais_result.csv')
80 | ):
81 | return
82 |
83 | os.makedirs(output_path, exist_ok=True)
84 | print(f"Running inference with InstanSeg model on {dataset} dataset... \n")
85 | infer_instanseg(input_path, output_path, dataset)
86 | print(f"Inference on {dataset} dataset with the InstanSeg model successfully completed \n")
87 |
88 | evaluate_instanseg(
89 | prediction_dir=output_path,
90 | label_dir=input_path,
91 | result_dir=os.path.join(output_dir, "results"),
92 | dataset=dataset,
93 | )
94 |
95 |
96 | def get_instanseg_args():
97 | parser = argparse.ArgumentParser()
98 | parser.add_argument("-i", "--input", type=str, default=None, help="The path to the input data")
99 | parser.add_argument("-d", "--dataset", type=str, default=None, help="The datasets to infer on")
100 | parser.add_argument("-o", "--output_dir", type=str, default=None, help="The path where the results are saved")
101 | args = parser.parse_args()
102 | return args
103 |
104 |
105 | def main():
106 | args = get_instanseg_args()
107 | run_inference(input_dir=args.input, output_dir=args.output_dir, dataset=args.dataset)
108 |
109 |
110 | if __name__ == "__main__":
111 | main()
112 |
--------------------------------------------------------------------------------
/experiments/benchmarking/outdated/cellvitplusplus/eval.py:
--------------------------------------------------------------------------------
1 | import os
2 | from glob import glob
3 | from tqdm import tqdm
4 | from natsort import natsorted
5 |
6 | import numpy as np
7 | import pandas as pd
8 | import imageio.v3 as imageio
9 | from skimage.measure import label
10 |
11 | from elf.evaluation import mean_segmentation_accuracy
12 |
13 |
14 | DATASETS = [
15 | "consep",
16 | "cpm15",
17 | "cpm17",
18 | "cryonuseg",
19 | "lizard",
20 | "lynsec_he",
21 | "lynsec_ihc",
22 | "monusac",
23 | "monuseg",
24 | "nuclick",
25 | "nuinsseg",
26 | "pannuke",
27 | "puma",
28 | "srsanet",
29 | "tnbc",
30 | ]
31 |
32 | CVTPP_CP = [
33 | # 'Virchow-x40-AMP',
34 | "SAM-H-x40-AMP",
35 | # '256-x40-AMP'
36 | ]
37 |
38 |
39 | def _run_evaluation(gt_paths, prediction_paths, verbose=True):
40 | print(len(gt_paths), len(prediction_paths))
41 | assert len(gt_paths) == len(prediction_paths), \
42 | f"label / prediction mismatch: {len(gt_paths)} / {len(prediction_paths)}"
43 |
44 | msas, sa50s, sa75s = [], [], []
45 | for gt_path, pred_path in tqdm(
46 | zip(gt_paths, prediction_paths), desc="Evaluate predictions", total=len(gt_paths), disable=not verbose,
47 | ):
48 | assert os.path.exists(gt_path), gt_path
49 | assert os.path.exists(pred_path), pred_path
50 |
51 | gt = imageio.imread(gt_path)
52 | gt = label(gt)
53 | pred = imageio.imread(pred_path)
54 |
55 | msa, scores = mean_segmentation_accuracy(pred, gt, return_accuracies=True)
56 | sa50, sa75 = scores[0], scores[5]
57 | msas.append(msa), sa50s.append(sa50), sa75s.append(sa75)
58 |
59 | return msas, sa50s, sa75s
60 |
61 |
62 | def evaluate_cellvit(prediction_dir, checkpoint, dataset, result_dir, label_dir):
63 | save_path = os.path.join(result_dir, dataset, checkpoint, "ais_result.csv")
64 | if os.path.exists(save_path):
65 | print(f"Results for {dataset} evaluation already exist")
66 | return
67 |
68 | prediction_paths = natsorted(glob(os.path.join(prediction_dir, dataset, checkpoint, "*.tiff")))
69 | gt_paths = natsorted(glob(os.path.join(label_dir, "*.tiff")))
70 | if len(prediction_paths) == 0:
71 | print(f"No predictions for {dataset} dataset on {checkpoint} checkpoint found")
72 | return
73 |
74 | msas, sa50s, sa75s = _run_evaluation(gt_paths=gt_paths, prediction_paths=prediction_paths)
75 | results = pd.DataFrame.from_dict(
76 | {"mSA": [np.mean(msas)], "SA50": [np.mean(sa50s)], "SA75": [np.mean(sa75s)]}
77 | )
78 | print(results.head())
79 | os.makedirs(os.path.join(result_dir, dataset, checkpoint), exist_ok=True)
80 | results.to_csv(save_path, index=False)
81 |
82 |
83 | def main():
84 | input_dir = "/mnt/lustre-grete/usr/u12649/data/final_test"
85 | prediction_dir = "/mnt/lustre-grete/usr/u12649/models/cellvitpp/inference/"
86 | result_dir = "/mnt/lustre-grete/usr/u12649/models/cellvitpp/results"
87 | for dataset in DATASETS:
88 | label_dir = os.path.join(input_dir, dataset, "loaded_testset", "eval_split", "test_labels")
89 | for checkpoint in CVTPP_CP:
90 | evaluate_cellvit(prediction_dir, checkpoint, dataset, result_dir, label_dir)
91 |
92 |
93 | main()
94 |
--------------------------------------------------------------------------------
/experiments/benchmarking/outdated/cellvitplusplus/infer_cvtplus.py:
--------------------------------------------------------------------------------
1 | import os
2 | import subprocess
3 | import pandas as pd
4 | from glob import glob
5 | from natsort import natsorted
6 |
7 | DATASETS = [
8 | "consep",
9 | "cpm15",
10 | "cpm17",
11 | "cryonuseg",
12 | "lizard",
13 | "lynsec_he",
14 | "lynsec_ihc",
15 | "monusac",
16 | "monuseg",
17 | "nuclick",
18 | "nuinsseg",
19 | "pannuke",
20 | "puma",
21 | "srsanet",
22 | "tnbc",
23 | ]
24 |
25 | CVTPP_CP = [
26 | # 'Virchow-x40-AMP',
27 | 'SAM-H-x40-AMP',
28 | # '256-x40-AMP'
29 | ]
30 |
31 |
32 | def run_inference(model_dir, input_dir, output_dir, result_dir, label_dir):
33 | for dataset in DATASETS:
34 | data_dir = os.path.join(input_dir, dataset)
35 | files = {
36 | "path": list(natsorted(glob(os.path.join(data_dir, '*.tiff')))),
37 | "slide_mpp": [0.25 for i in range(len(list(natsorted(glob(os.path.join(data_dir, '*.tiff'))))))],
38 | "magnification": [40 for i in range(len(list(natsorted(glob(os.path.join(data_dir, '*.tiff'))))))]
39 | }
40 | filelist_df = pd.DataFrame(files)
41 | os.makedirs(os.path.join(input_dir, "file_lists"), exist_ok=True)
42 | csv_filelist = os.path.join(input_dir, "file_lists", f"{dataset}_filelist.csv")
43 | filelist_df.to_csv(csv_filelist, index=False)
44 |
45 | for checkpoint in CVTPP_CP:
46 | checkpoint_path = os.path.join(model_dir, f"CellViT-{checkpoint}.pth")
47 | output_path = os.path.join(output_dir, dataset, checkpoint)
48 | if os.path.exists(output_path):
49 | if len(os.listdir(output_path)) > 1:
50 | continue
51 | os.makedirs(output_path, exist_ok=True)
52 | args = [
53 | "--model", f"{checkpoint_path}",
54 | "--outdir", f"{output_path}",
55 | "process_dataset",
56 | "--filelist", f"{csv_filelist}"
57 | ]
58 | command = [
59 | "python3",
60 | "/user/titus.griebel/u12649/CellViT-plus-plus/cellvit/detect_cells.py",
61 | ] + args
62 | print(f"Running inference with CellViT-plus-plus {checkpoint} model on {dataset} dataset...")
63 | subprocess.run(command)
64 |
65 | for file in glob(os.path.join(output_path, '*json')):
66 | os.remove(file)
67 | # evaluate_cellvit(output_path, checkpoint, dataset, result_dir, label_dir)
68 |
69 | print(f"Successfully ran inference with CellViT {checkpoint} model on {dataset} dataset")
70 |
71 |
72 | def main():
73 | run_inference(
74 | "/mnt/lustre-grete/usr/u12649/models/cellvit_plusplus/checkpoints",
75 | "/mnt/lustre-grete/usr/u12649/data/cvtplus/preprocessed",
76 | "/mnt/lustre-grete/usr/u12649/models/cellvit_plusplus/inference/",
77 | "/mnt/lustre-grete/usr/u12649/models/cellvit_plusplus/results",
78 | "/mnt/lustre-grete/usr/u12649/data/final_test"
79 | )
80 |
81 |
82 | if __name__ == "__main__":
83 | main()
84 |
--------------------------------------------------------------------------------
/experiments/benchmarking/stardist/stardist_inference.py:
--------------------------------------------------------------------------------
1 | import os
2 | import argparse
3 | from tqdm import tqdm
4 | from glob import glob
5 | from natsort import natsorted
6 |
7 | import numpy as np
8 | import pandas as pd
9 | import imageio.v3 as imageio
10 | from skimage.measure import label
11 |
12 | from tukra.io import read_image
13 | from tukra.inference import segment_using_stardist
14 |
15 | from elf.evaluation import mean_segmentation_accuracy
16 |
17 |
18 | def _run_evaluation(gt_paths, prediction_paths, verbose=True):
19 | assert len(gt_paths) == len(prediction_paths), \
20 | f"label / prediction mismatch: {len(gt_paths)} / {len(prediction_paths)}"
21 |
22 | msas, sa50s, sa75s = [], [], []
23 | for gt_path, pred_path in tqdm(
24 | zip(gt_paths, prediction_paths, strict=False),
25 | desc="Evaluate predictions",
26 | total=len(gt_paths),
27 | disable=not verbose,
28 | ):
29 | assert os.path.exists(gt_path), gt_path
30 | assert os.path.exists(pred_path), pred_path
31 |
32 | gt = imageio.imread(gt_path)
33 | gt = label(gt)
34 | pred = imageio.imread(pred_path)
35 |
36 | msa, scores = mean_segmentation_accuracy(pred, gt, return_accuracies=True)
37 | sa50, sa75 = scores[0], scores[5]
38 | msas.append(msa), sa50s.append(sa50), sa75s.append(sa75)
39 |
40 | return msas, sa50s, sa75s
41 |
42 |
43 | def evaluate_stardist(prediction_dir, label_dir, result_dir, dataset):
44 | gt_paths = natsorted(glob(os.path.join(label_dir, "eval_split", "test_labels", "*")))
45 | for checkpoint in ["stardist"]:
46 | save_path = os.path.join(result_dir, dataset, checkpoint, f'{dataset}_stardist_stardist_ais_result.csv')
47 | if os.path.exists(save_path):
48 | continue
49 |
50 | prediction_paths = natsorted(glob(os.path.join(prediction_dir, "*")))
51 | if len(prediction_paths) == 0:
52 | print(f"No predictions for {dataset} dataset on {checkpoint} checkpoint found")
53 | continue
54 |
55 | os.makedirs(os.path.join(result_dir, dataset, checkpoint), exist_ok=True)
56 | print(f"Evaluating {dataset} dataset on Stardist ...")
57 | msas, sa50s, sa75s = _run_evaluation(gt_paths=gt_paths, prediction_paths=prediction_paths)
58 | results = pd.DataFrame.from_dict(
59 | {"mSA": [np.mean(msas)], "SA50": [np.mean(sa50s)], "SA75": [np.mean(sa75s)]}
60 | )
61 |
62 | results.to_csv(save_path, index=False)
63 | print(results.head(2))
64 |
65 |
66 | def infer_stardist(data_dir, output_path):
67 | image_paths = natsorted(glob(os.path.join(data_dir, "eval_split", "test_images", "*")))
68 | os.makedirs(output_path, exist_ok=True)
69 | for image_path in image_paths:
70 | image = read_image(image_path)
71 | segmentation = segment_using_stardist(image=image, model_name="2D_versatile_he")
72 | imageio.imwrite(os.path.join(output_path, os.path.basename(image_path)), segmentation)
73 |
74 |
75 | def run_inference(input_dir, output_dir, dataset):
76 | output_path = os.path.join(output_dir, 'inference', dataset, "stardist")
77 | input_path = os.path.join(input_dir, dataset)
78 | if os.path.exists(
79 | os.path.join(output_dir, "results", dataset, "stardist", f'{dataset}_stardist_stardist_ais_result.csv')
80 | ):
81 | return
82 |
83 | os.makedirs(output_path, exist_ok=True)
84 | print(f"Running inference with StarDist model on {dataset} dataset... \n")
85 | infer_stardist(input_path, output_path)
86 | print(f"Inference on {dataset} dataset with the StarDist model successfully completed \n")
87 |
88 | evaluate_stardist(
89 | prediction_dir=output_path,
90 | label_dir=input_path,
91 | result_dir=os.path.join(output_dir, 'results'),
92 | dataset=dataset
93 | )
94 |
95 |
96 | def get_stardist_args():
97 | parser = argparse.ArgumentParser()
98 | parser.add_argument("-i", "--input", type=str, default=None, help="The path to the input data")
99 | parser.add_argument("-d", "--dataset", type=str, default=None, help="The datasets to infer on")
100 | parser.add_argument("-o", "--output_dir", type=str, default=None, help="The path where the results are saved")
101 | args = parser.parse_args()
102 | return args
103 |
104 |
105 | def main():
106 | args = get_stardist_args()
107 | run_inference(input_dir=args.input, output_dir=args.output_dir, dataset=args.dataset)
108 |
109 |
110 | if __name__ == "__main__":
111 | main()
112 |
--------------------------------------------------------------------------------
/experiments/data/README.md:
--------------------------------------------------------------------------------
1 | # Scripts to Preprocess Datasets for Evaluation
2 |
3 | This folder contains scripts for preprocessing datasets for all of our benchmarking (both automatic and interactive segmentation)
4 |
5 | > TLDR: The idea of this scripts are to use the input images corresponding to how they should be on a practical point of view. You can write your own scripts / provide path to your images in our [example scripts](../../examples/).
6 |
--------------------------------------------------------------------------------
/experiments/data/cvtplus_preproc.py:
--------------------------------------------------------------------------------
1 | import os
2 | import shutil
3 | from glob import glob
4 | from tqdm import tqdm
5 |
6 | import numpy as np
7 | import imageio.v3 as imageio
8 |
9 |
10 | DATASETS = [
11 | "consep",
12 | "cpm15",
13 | "cpm17",
14 | "cryonuseg",
15 | "lizard",
16 | "lynsec_he",
17 | "lynsec_ihc",
18 | "monusac",
19 | "monuseg",
20 | "nuclick",
21 | "nuinsseg",
22 | "pannuke",
23 | "puma",
24 | "srsanet",
25 | "tnbc",
26 | ]
27 |
28 |
29 | def preprocess_cvtplus(input_dir, output_dir):
30 | # create resized pyramid tiffs with vips
31 | import pyvips
32 |
33 | for dataset in DATASETS:
34 | data_dir = os.path.join(input_dir, dataset, "loaded_testset", "eval_split", "test_images")
35 | intermediate_folder = os.path.join(output_dir, "intermediate", dataset)
36 | output_folder = os.path.join(output_dir, "preprocessed", dataset)
37 | os.makedirs(intermediate_folder, exist_ok=True)
38 | os.makedirs(output_folder, exist_ok=True)
39 | for img_path in tqdm(glob(os.path.join(data_dir, "*.tiff"))):
40 | img = imageio.imread(img_path)
41 | img_uint = img.astype(np.uint8)
42 | intermediate_file = os.path.join(intermediate_folder, os.path.basename(img_path))
43 | imageio.imwrite(intermediate_file, img_uint)
44 | output_file = os.path.join(output_folder, os.path.basename(img_path))
45 | image = pyvips.Image.new_from_file(intermediate_file)
46 | image.tiffsave(output_file, tile=True, tile_width=512, tile_height=512, pyramid=True)
47 |
48 | shutil.rmtree(intermediate_folder)
49 |
50 |
51 | def main():
52 | preprocess_cvtplus(
53 | input_dir="/mnt/lustre-grete/usr/u12649/data/final_test", output_dir="/mnt/lustre-grete/usr/u12649/data/cvtplus"
54 | )
55 |
56 |
57 | if __name__ == "__main__":
58 | main()
59 |
--------------------------------------------------------------------------------
/experiments/data/dataloaders.py:
--------------------------------------------------------------------------------
1 | import micro_sam.training as sam_training
2 |
3 | from torch_em.data import datasets
4 | from torch_em.data import MinInstanceSampler
5 |
6 |
7 | def get_dataloaders(patch_shape, data_path, dataset):
8 | raw_transform = sam_training.identity
9 | sampler = MinInstanceSampler(min_num_instances=3)
10 |
11 | if dataset == "consep":
12 | loader = datasets.get_consep_loader(
13 | path=data_path,
14 | patch_shape=patch_shape,
15 | batch_size=1,
16 | download=True,
17 | split="test",
18 | raw_transform=raw_transform,
19 | sampler=sampler,
20 | )
21 |
22 | elif dataset == "cpm15":
23 | loader = datasets.get_cpm_loader(
24 | path=data_path,
25 | patch_shape=patch_shape,
26 | batch_size=1,
27 | download=False,
28 | split="test",
29 | raw_transform=raw_transform,
30 | sampler=sampler,
31 | data_choice="cpm15",
32 | )
33 |
34 | elif dataset == "cpm17":
35 | loader = datasets.get_cpm_loader(
36 | path=data_path,
37 | patch_shape=patch_shape,
38 | batch_size=1,
39 | download=False,
40 | split="test",
41 | raw_transform=raw_transform,
42 | sampler=sampler,
43 | data_choice="cpm17",
44 | )
45 |
46 | elif dataset == "cryonuseg":
47 | loader = datasets.get_cryonuseg_loader(
48 | path=data_path,
49 | patch_shape=(1,) + patch_shape,
50 | batch_size=1,
51 | rater="b1",
52 | split="test",
53 | download=True,
54 | raw_transform=raw_transform,
55 | sampler=sampler,
56 | )
57 |
58 | elif dataset == "glas":
59 | loader = datasets.get_glas_loader(
60 | path=data_path,
61 | patch_shape=patch_shape,
62 | batch_size=1,
63 | download=True,
64 | split="test",
65 | raw_transform=raw_transform,
66 | sampler=MinInstanceSampler(min_num_instances=2),
67 | )
68 |
69 | elif dataset == "lizard":
70 | loader = datasets.get_lizard_loader(
71 | path=data_path,
72 | patch_shape=patch_shape,
73 | batch_size=1,
74 | download=True,
75 | split="test",
76 | raw_transform=raw_transform,
77 | sampler=sampler,
78 | )
79 |
80 | elif dataset == "lynsec_he":
81 | loader = datasets.get_lynsec_loader(
82 | path=data_path,
83 | patch_shape=patch_shape,
84 | batch_size=1,
85 | choice="h&e",
86 | download=True,
87 | raw_transform=raw_transform,
88 | sampler=sampler,
89 | )
90 |
91 | elif dataset == "lynsec_ihc":
92 | loader = datasets.get_lynsec_loader(
93 | path=data_path,
94 | patch_shape=patch_shape,
95 | batch_size=1,
96 | choice="ihc",
97 | download=True,
98 | raw_transform=raw_transform,
99 | sampler=sampler,
100 | )
101 |
102 | elif dataset == "monusac":
103 | loader = datasets.get_monusac_loader(
104 | path=data_path,
105 | patch_shape=patch_shape,
106 | batch_size=1,
107 | split="test",
108 | download=True,
109 | raw_transform=raw_transform,
110 | sampler=sampler,
111 | )
112 |
113 | elif dataset == "monuseg":
114 | loader = datasets.get_monuseg_loader(
115 | path=data_path,
116 | patch_shape=patch_shape,
117 | batch_size=1,
118 | split="test",
119 | download=True,
120 | raw_transform=raw_transform,
121 | sampler=sampler,
122 | )
123 |
124 | elif dataset == "nuinsseg":
125 | loader = datasets.get_nuinsseg_loader(
126 | path=data_path,
127 | patch_shape=patch_shape,
128 | batch_size=1,
129 | download=True,
130 | raw_transform=raw_transform,
131 | sampler=sampler,
132 | )
133 | elif dataset == "nuclick":
134 | loader = datasets.get_nuclick_loader(
135 | path=data_path,
136 | patch_shape=patch_shape,
137 | batch_size=1,
138 | download=True,
139 | split="Validation",
140 | raw_transform=raw_transform,
141 | sampler=sampler,
142 | )
143 |
144 | elif dataset == "pannuke":
145 | loader = datasets.get_pannuke_loader(
146 | path=data_path,
147 | patch_shape=(1,) + patch_shape,
148 | batch_size=1,
149 | folds=["fold_3"],
150 | ndim=2,
151 | download=True,
152 | raw_transform=raw_transform,
153 | sampler=sampler,
154 | )
155 |
156 | elif dataset == "puma":
157 | loader = datasets.get_puma_loader(
158 | path=data_path,
159 | patch_shape=patch_shape,
160 | batch_size=1,
161 | annotations="nuclei",
162 | download=True,
163 | split="test",
164 | raw_transform=raw_transform,
165 | sampler=sampler,
166 | )
167 |
168 | elif dataset == "srsanet":
169 | loader = datasets.get_srsanet_loader(
170 | path=data_path,
171 | patch_shape=patch_shape,
172 | batch_size=1,
173 | download=True,
174 | split="test",
175 | raw_transform=raw_transform,
176 | sampler=sampler,
177 | )
178 |
179 | elif dataset == "tnbc":
180 | loader = datasets.get_tnbc_loader(
181 | path=data_path,
182 | patch_shape=patch_shape,
183 | batch_size=1,
184 | ndim=2,
185 | download=True,
186 | split="test",
187 | raw_transform=raw_transform,
188 | sampler=sampler,
189 | )
190 |
191 | return loader
192 |
--------------------------------------------------------------------------------
/experiments/data/get_paths.py:
--------------------------------------------------------------------------------
1 | import os
2 | from glob import glob
3 | from natsort import natsorted
4 |
5 | import h5py
6 | import imageio.v3 as imageio
7 |
8 | from torch_em.data.datasets import histopathology
9 |
10 |
11 | def get_dataset_paths(data_path, dataset) -> list:
12 |
13 | if dataset == "consep":
14 | data_paths = histopathology.consep.get_consep_paths(
15 | path=data_path, download=True, split="test",
16 | )
17 | label_key = 'labels'
18 | image_key = 'raw'
19 |
20 | elif dataset == "cpm15":
21 | image_paths, label_paths = histopathology.cpm.get_cpm_paths(
22 | path=data_path, download=False, split="test", data_choice="cpm15",
23 | )
24 |
25 | elif dataset == "cpm17":
26 | image_paths, label_paths = histopathology.cpm.get_cpm_paths(
27 | path=data_path, download=False, split="test", data_choice="cpm17",
28 | )
29 |
30 | elif dataset == "cryonuseg":
31 | image_paths, label_paths = histopathology.cryonuseg.get_cryonuseg_paths(
32 | path=data_path, rater_choice="b1", split="test", download=True,
33 | )
34 |
35 | elif dataset == "glas":
36 | data_paths = histopathology.glas.get_glas_paths(
37 | path=data_path, download=True, split="test",
38 | )
39 | label_key = 'labels'
40 | image_key = 'raw'
41 |
42 | elif dataset == "lizard":
43 | data_paths = histopathology.lizard.get_lizard_paths(
44 | path=data_path, download=True, split="test",
45 | )
46 | label_key = 'labels/segmentation'
47 | image_key = 'image'
48 |
49 | elif dataset == "lynsec_he":
50 | image_paths, label_paths = histopathology.lynsec.get_lynsec_paths(
51 | path=data_path, choice="h&e", download=True,
52 | )
53 |
54 | elif dataset == "lynsec_ihc":
55 | image_paths, label_paths = histopathology.lynsec.get_lynsec_paths(
56 | path=data_path, choice="ihc", download=True,
57 | )
58 |
59 | elif dataset == "monuseg":
60 | image_paths, label_paths = histopathology.monuseg.get_monuseg_paths(
61 | path=data_path, split="test", download=True,
62 | )
63 |
64 | elif dataset == "nuclick":
65 | image_paths, label_paths = histopathology.nuclick.get_nuclick_paths(
66 | path=data_path, download=True, split="Validation",
67 | )
68 |
69 | elif dataset == "nuinsseg":
70 | image_paths, label_paths = histopathology.nuinsseg.get_nuinsseg_paths(
71 | path=data_path, download=True,
72 | )
73 |
74 | elif dataset == "pannuke":
75 | data_paths = histopathology.pannuke.get_pannuke_paths(
76 | path=data_path, folds=["fold_3"], download=True,
77 | )
78 | cached_images = os.path.join(data_path, "loaded_images")
79 | cached_labels = os.path.join(data_path, "loaded_labels")
80 | os.makedirs(cached_images, exist_ok=True)
81 | os.makedirs(cached_labels, exist_ok=True)
82 |
83 | for h5_path in data_paths:
84 | with h5py.File(h5_path, 'r') as file:
85 | images = file['images']
86 | labels = file['labels/instances']
87 | images = images[:]
88 | labels = labels[:]
89 |
90 | # PanNuke is provided in an array of shape (C, B, H, W)
91 | images = images.transpose(1, 2, 3, 0) # --> (B, H, W, C)
92 |
93 | counter = 1
94 | for image, label in zip(images, labels):
95 | image_path = os.path.join(cached_images, f"{counter:04}.tiff")
96 | label_path = os.path.join(cached_labels, f"{counter:04}.tiff")
97 |
98 | assert image.shape == (256, 256, 3)
99 | imageio.imwrite(image_path, image)
100 | imageio.imwrite(label_path, label)
101 |
102 | counter += 1
103 |
104 | image_paths = glob(os.path.join(cached_images, "*.tiff"))
105 | label_paths = glob(os.path.join(cached_labels, "*.tiff"))
106 |
107 | elif dataset == "puma":
108 | data_paths = histopathology.puma.get_puma_paths(
109 | path=data_path, annotations="nuclei", download=True, split="test",
110 | )
111 | label_key = 'labels/nuclei'
112 | image_key = 'raw'
113 |
114 | elif dataset == "srsanet":
115 | image_paths, label_paths = histopathology.srsanet.get_srsanet_paths(
116 | path=data_path,
117 | download=True,
118 | split="test",
119 | )
120 |
121 | elif dataset == "tnbc":
122 | data_paths = histopathology.tnbc.get_tnbc_paths(
123 | path=data_path, download=True, split="test",
124 | )
125 | label_key = 'labels/instances'
126 | image_key = 'raw'
127 | label_key = 'labels/instances'
128 | image_key = 'raw'
129 |
130 | if dataset in ["consep", "lizard", "glas", "puma", "tnbc"]:
131 | cached_images = os.path.join(data_path, "loaded_images")
132 | cached_labels = os.path.join(data_path, "loaded_labels")
133 | os.makedirs(cached_images, exist_ok=True)
134 | os.makedirs(cached_labels, exist_ok=True)
135 | for h5_path in data_paths:
136 | with h5py.File(h5_path, 'r') as file:
137 | img = file[image_key]
138 | label = file[label_key]
139 | img = img[:]
140 | label = label[:]
141 | image = img.transpose(1, 2, 0)
142 |
143 | img_path = os.path.join(
144 | cached_images, os.path.basename(h5_path).replace(".h5", ".tiff"))
145 | label_path = os.path.join(
146 | cached_labels, os.path.basename(h5_path).replace(".h5", ".tiff"))
147 | assert image.shape[:2] == label.shape, f"{image.shape}, {label.shape}"
148 |
149 | imageio.imwrite(img_path, image)
150 | imageio.imwrite(label_path, label)
151 | image_paths = glob(os.path.join(cached_images, "*.tiff"))
152 | label_paths = glob(os.path.join(cached_labels, "*.tiff"))
153 |
154 | return natsorted(image_paths), natsorted(label_paths)
155 |
--------------------------------------------------------------------------------
/experiments/data/interactive_data.py:
--------------------------------------------------------------------------------
1 | import os
2 | import argparse
3 |
4 | import imageio.v3 as imageio
5 |
6 | from dataloaders import get_dataloaders
7 | from util import DATASETS, create_val_split, remove_empty_labels
8 |
9 |
10 | def load_testsets(path, dsets=DATASETS, patch_shape=(512, 512)) -> None:
11 | for dataset in dsets:
12 | if os.path.exists(os.path.join(path, dataset, "loaded_testset", "images")):
13 | if len(os.listdir(os.path.join(path, dataset, "loaded_testset", "images"))) > 1:
14 | print(f"Dataset {dataset} is loaded already.")
15 | continue
16 |
17 | print(f"Loading {dataset} dataset...")
18 | dpath = os.path.join(path, dataset)
19 | os.makedirs(dpath, exist_ok=True)
20 | loader = get_dataloaders(patch_shape=patch_shape, data_path=dpath, dataset=dataset)
21 |
22 | image_output_path = os.path.join(path, dataset, "loaded_testset", "images")
23 | label_output_path = os.path.join(path, dataset, "loaded_testset", "labels")
24 |
25 | os.makedirs(image_output_path, exist_ok=True)
26 | os.makedirs(label_output_path, exist_ok=True)
27 |
28 | for idx, (image, label) in enumerate(loader, start=1):
29 | image = image.squeeze().numpy()
30 | label = label.squeeze().numpy()
31 | image = image.transpose(1, 2, 0)
32 | if image.shape[-1] == 4: # deletes alpha channel if one exists
33 | image = image[..., :-1]
34 |
35 | imageio.imwrite(os.path.join(image_output_path, f"{idx:04}.tiff"), image)
36 | imageio.imwrite(os.path.join(label_output_path, f"{idx:04}.tiff"), label)
37 |
38 | remove_empty_labels(dpath)
39 | create_val_split(os.path.join(dpath, "loaded_testset"), custom_name="eval_split", dataset=dataset)
40 | print(f"{dataset} testset has successfully been loaded.")
41 |
42 |
43 | def dataloading_args():
44 | parser = argparse.ArgumentParser()
45 | parser.add_argument("-p", "--path", type=str, default=None)
46 | parser.add_argument("-d", "--datasets", type=str, default=None)
47 | parser.add_argument("--patch_shape", type=tuple, default=(512, 512))
48 |
49 | args = parser.parse_args()
50 | return args
51 |
52 |
53 | def main():
54 | args = dataloading_args()
55 | if args.path is not None:
56 | data_path = args.path
57 | else:
58 | data_path = "/mnt/lustre-grete/usr/u12649/data/final_test/"
59 |
60 | if args.datasets is not None:
61 | load_testsets(data_path, [args.datasets], args.patch_shape)
62 | else:
63 | load_testsets(data_path, patch_shape=args.patch_shape)
64 |
65 |
66 | if __name__ == "__main__":
67 | main()
68 |
--------------------------------------------------------------------------------
/experiments/data/load_original_data.py:
--------------------------------------------------------------------------------
1 | import os
2 | import shutil
3 | import argparse
4 | from tqdm import tqdm
5 |
6 | import imageio.v3 as imageio
7 |
8 | from util import DATASETS
9 | from get_paths import get_dataset_paths
10 |
11 |
12 | def load_datasets(path, datasets=DATASETS):
13 | for dataset in datasets:
14 | dataset_path = os.path.join(path, dataset)
15 | image_outpath = os.path.join(dataset_path, "loaded_images")
16 | label_outpath = os.path.join(dataset_path, "loaded_labels")
17 | image_outpath = os.path.join(dataset_path, "loaded_images")
18 | label_outpath = os.path.join(dataset_path, "loaded_labels")
19 | os.makedirs(image_outpath, exist_ok=True)
20 | os.makedirs(label_outpath, exist_ok=True)
21 | if len(os.listdir(image_outpath)) > 1:
22 | continue
23 |
24 | print(f"Loading {dataset}...")
25 | if len(os.listdir(image_outpath)) > 1:
26 | continue
27 |
28 | print(f"Loading {dataset}...")
29 | image_paths, label_paths = get_dataset_paths(dataset_path, dataset)
30 | assert len(image_paths) == len(label_paths)
31 |
32 | count = 1
33 | for image_path, label_path in tqdm(zip(image_paths, label_paths), desc="Moving files to new directory..."):
34 | img_ext = os.path.splitext(image_path)[1]
35 | label_ext = os.path.splitext(label_path)[1]
36 | image_dest = os.path.join(image_outpath, f"{count:04}{img_ext}")
37 | label_dest = os.path.join(label_outpath, f"{count:04}{label_ext}")
38 |
39 | img = imageio.imread(image_path)
40 | if img.shape[2] == 4: # checks for and removes alpha channel
41 | img = img[:, :, :-1]
42 | imageio.imwrite(image_path, img)
43 | print("Alpha channel successfully removed.")
44 |
45 | shutil.move(image_path, image_dest)
46 | shutil.move(label_path, label_dest)
47 |
48 | count += 1
49 |
50 |
51 | def dataloading_args():
52 | parser = argparse.ArgumentParser()
53 | parser.add_argument("-p", "--path", type=str, default=None)
54 | parser.add_argument("-d", "--datasets", type=str, default=None)
55 |
56 | args = parser.parse_args()
57 | return args
58 |
59 |
60 | def main():
61 | args = dataloading_args()
62 | if args.path is not None:
63 | data_path = args.path
64 | else:
65 | data_path = "/mnt/lustre-grete/usr/u12649/data/original_data"
66 |
67 | if args.datasets is not None:
68 | load_datasets(data_path, [args.datasets])
69 | else:
70 | load_datasets(data_path)
71 |
72 |
73 | if __name__ == "__main__":
74 | main()
75 |
--------------------------------------------------------------------------------
/experiments/data/load_padded_data.py:
--------------------------------------------------------------------------------
1 | import os
2 | import shutil
3 | import argparse
4 | from tqdm import tqdm
5 |
6 | import numpy as np
7 | import imageio.v3 as imageio
8 |
9 | from util import DATASETS, PADDING_DS
10 | from get_paths import get_dataset_paths
11 |
12 |
13 | def load_datasets(path, datasets=DATASETS):
14 | for dataset in datasets:
15 | dataset_path = os.path.join(path, dataset)
16 | image_outpath = os.path.join(dataset_path, "loaded_images")
17 | label_outpath = os.path.join(dataset_path, "loaded_labels")
18 | os.makedirs(image_outpath, exist_ok=True)
19 | os.makedirs(label_outpath, exist_ok=True)
20 | if os.path.exists(os.path.join(dataset_path, "eval_split", "test_images")):
21 | if len(os.listdir(os.path.join(dataset_path, "eval_split", "test_images"))) > 1:
22 | continue
23 |
24 | print(f"Loading {dataset}...")
25 | image_paths, label_paths = get_dataset_paths(dataset_path, dataset)
26 | assert len(image_paths) == len(label_paths)
27 |
28 | count = 1
29 | for image_path, label_path in tqdm(zip(image_paths, label_paths), desc="Moving files to new directory..."):
30 | img_ext = os.path.splitext(image_path)[1]
31 | label_ext = os.path.splitext(label_path)[1]
32 | image_dest = os.path.join(image_outpath, f"{count:04}{img_ext}")
33 | label_dest = os.path.join(label_outpath, f"{count:04}{label_ext}")
34 |
35 | img = imageio.imread(image_path)
36 | if img.shape[2] == 4: # checks for and removes alpha channel
37 | img = img[:, :, :-1]
38 | imageio.imwrite(image_path, img)
39 | print("Alpha channel successfully removed.")
40 |
41 | if dataset in PADDING_DS:
42 | padded_img = np.zeros((512, 512, 3), dtype=img.dtype)
43 | padded_img[:256, :256, :] = img
44 | assert padded_img.shape == (512, 512, 3), padded_img.shape
45 | imageio.imwrite(image_path, padded_img)
46 |
47 | label = imageio.imread(label_path)
48 | padded_label = np.zeros((512, 512), dtype=label.dtype)
49 | padded_label[:256, :256, :] = label
50 | assert padded_label.shape == (512, 512), padded_label.shape
51 | imageio.imwrite(label_path, padded_label)
52 |
53 | if dataset != 'lizard':
54 | if img.shape[0] != img.shape[1] or img.shape[0] % 16 != 0:
55 | shape = img.shape
56 | new_shape = max(shape[:2]) // 16
57 | new_dim = (new_shape + 1) * 16
58 | padded_img = np.zeros((new_dim, new_dim, 3), dtype=img.dtype)
59 | padded_img[:shape[0], :shape[1], :] = img
60 |
61 | label = imageio.imread(label_path)
62 | padded_label = np.zeros((new_dim, new_dim), dtype=label.dtype)
63 | padded_label[:shape[0], :shape[1]] = label
64 | imageio.imwrite(label_path, padded_label)
65 | assert padded_label.shape == padded_img.shape[:2]
66 |
67 | imageio.imwrite(image_path, padded_img)
68 |
69 | if dataset == 'lizard':
70 | shape = img.shape
71 | if img.shape[0] != img.shape[1] or img.shape[0] % 1024 != 0:
72 | new_shape = max(shape[:2]) // 1024
73 | new_dim = (new_shape + 1) * 1024
74 | padded_img = np.zeros((new_dim, new_dim, 3), dtype=img.dtype)
75 | padded_img[:shape[0], :shape[1], :] = img
76 | imageio.imwrite(image_path, padded_img)
77 |
78 | label = imageio.imread(label_path)
79 | padded_label = np.zeros((new_dim, new_dim), dtype=label.dtype)
80 | padded_label[:shape[0], :shape[1]] = label
81 | imageio.imwrite(label_path, padded_label)
82 | assert padded_label.shape == padded_img.shape[:2]
83 |
84 | shutil.move(image_path, image_dest)
85 | shutil.move(label_path, label_dest)
86 |
87 | count += 1
88 |
89 |
90 | def dataloading_args():
91 | parser = argparse.ArgumentParser()
92 | parser.add_argument("-p", "--path", type=str, default=None)
93 | parser.add_argument("-d", "--datasets", type=str, default=None)
94 |
95 | args = parser.parse_args()
96 | return args
97 |
98 |
99 | def main():
100 | args = dataloading_args()
101 | if args.path is not None:
102 | data_path = args.path
103 | else:
104 | data_path = "/mnt/lustre-grete/usr/u12649/data/vit_data"
105 |
106 | if args.datasets is not None:
107 | load_datasets(data_path, [args.datasets])
108 | else:
109 | load_datasets(data_path)
110 |
111 |
112 | if __name__ == "__main__":
113 | main()
114 |
--------------------------------------------------------------------------------
/experiments/data/semantic/get_semantic_paths.py:
--------------------------------------------------------------------------------
1 | import os
2 | from glob import glob
3 | from natsort import natsorted
4 |
5 | import h5py
6 | import imageio.v3 as imageio
7 |
8 | from torch_em.data.datasets import histopathology
9 |
10 | from patho_sam.training.util import remap_labels
11 |
12 |
13 | def get_dataset_paths(data_path, dataset):
14 | cached_images = os.path.join(data_path, "loaded_images")
15 | cached_labels = os.path.join(data_path, "loaded_labels")
16 | os.makedirs(cached_images, exist_ok=True)
17 | os.makedirs(cached_labels, exist_ok=True)
18 |
19 | if dataset == "pannuke":
20 | data_paths = histopathology.pannuke.get_pannuke_paths(
21 | path=data_path, folds=["fold_3"], download=True,
22 | )
23 |
24 | for h5_path in data_paths:
25 | with h5py.File(h5_path, 'r') as file:
26 | images = file['images']
27 | labels = file['labels/semantic']
28 | images = images[:]
29 | labels = labels[:]
30 |
31 | # PanNuke is provided in an array of shape (C, B, H, W)
32 | images = images.transpose(1, 2, 3, 0) # --> (B, H, W, C)
33 |
34 | counter = 1
35 | for image, label in zip(images, labels):
36 | image_path = os.path.join(cached_images, f"{counter:04}.tiff")
37 | label_path = os.path.join(cached_labels, f"{counter:04}.tiff")
38 |
39 | assert image.shape == (256, 256, 3)
40 | imageio.imwrite(image_path, image)
41 | imageio.imwrite(label_path, label)
42 |
43 | counter += 1
44 |
45 | elif dataset == "puma":
46 | data_paths = histopathology.puma.get_puma_paths(
47 | path=data_path, annotations="nuclei", download=True, split="test",
48 | )
49 |
50 | for h5_path in data_paths:
51 | with h5py.File(h5_path, 'r') as file:
52 | img = file['raw']
53 | label = file['labels/semantic/nuclei']
54 | img = img[:]
55 | label = label[:]
56 | image = img.transpose(1, 2, 0)
57 | label = remap_labels(label, name=None)
58 | img_path = os.path.join(
59 | cached_images, os.path.basename(h5_path).replace(".h5", ".tiff"))
60 | label_path = os.path.join(
61 | cached_labels, os.path.basename(h5_path).replace(".h5", ".tiff"))
62 | assert image.shape[:2] == label.shape, f"{image.shape}, {label.shape}"
63 |
64 | imageio.imwrite(img_path, image)
65 | imageio.imwrite(label_path, label)
66 |
67 | else:
68 | raise ValueError
69 |
70 | image_paths = glob(os.path.join(cached_images, "*.tiff"))
71 | label_paths = glob(os.path.join(cached_labels, "*.tiff"))
72 |
73 | return natsorted(image_paths), natsorted(label_paths)
74 |
--------------------------------------------------------------------------------
/experiments/patho-sam/README.md:
--------------------------------------------------------------------------------
1 | # Scripts for Evaluating Segment Anything Models for Histopathology
2 |
3 | This folder contains scripts for evaluating interactive and automatic instance segmentation of nucleus in histopathology images.
4 |
5 | > TLDR: The scripts expect path to input images and corresponding labels. You can write your own function to automate the processes in the scripts. The keywords mean: 1) AIS (Automatic Instance Segmentation), 2) AMG (Automatic Mask Generation), 3) Iterative Prompting (i.e. interactive segmentation, where prompts are derived from the ground-truth labels, simulating a user, and add additional points for 7 iterations). See our [preprint](https://doi.org/10.48550/arXiv.2502.00408) for more details on this.
--------------------------------------------------------------------------------
/experiments/patho-sam/evaluate_ais.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | from micro_sam.evaluation.evaluation import run_evaluation
4 | from micro_sam.evaluation.inference import run_instance_segmentation_with_decoder
5 |
6 | from util import VANILLA_MODELS, get_default_arguments, get_pred_paths, get_test_paths, get_val_paths
7 |
8 |
9 | def run_instance_segmentation_with_decoder_inference(
10 | model_type, checkpoint, experiment_folder, input_path, tiling_window_params
11 | ):
12 | val_image_paths, val_gt_paths = get_val_paths(input_path)
13 | test_image_paths, _ = get_test_paths(input_path)
14 | prediction_folder = run_instance_segmentation_with_decoder(
15 | checkpoint, model_type, experiment_folder, val_image_paths, val_gt_paths, test_image_paths,
16 | tiling_window_params=tiling_window_params
17 | )
18 | return prediction_folder
19 |
20 |
21 | def eval_instance_segmentation_with_decoder(prediction_folder, experiment_folder, input_path):
22 | print("Evaluating", prediction_folder)
23 | _, gt_paths = get_test_paths(input_path)
24 | pred_paths = get_pred_paths(prediction_folder)
25 | save_path = os.path.join(experiment_folder, "results", "instance_segmentation_with_decoder.csv")
26 | res = run_evaluation(gt_paths, pred_paths, save_path=save_path)
27 | print(res)
28 |
29 |
30 | def main():
31 | args = get_default_arguments()
32 |
33 | if args.checkpoint is None:
34 | ckpt = VANILLA_MODELS[args.model]
35 | else:
36 | ckpt = args.checkpoint
37 |
38 | if args.tiling_window:
39 | tiling_window_params = {"tile_shape": [384, 384], "halo": [64, 64]}
40 | else:
41 | tiling_window_params = None
42 |
43 | prediction_folder = run_instance_segmentation_with_decoder_inference(
44 | args.model, ckpt, args.experiment_folder, args.input_path, tiling_window_params
45 | )
46 | eval_instance_segmentation_with_decoder(prediction_folder, args.experiment_folder, args.input_path)
47 |
48 |
49 | if __name__ == "__main__":
50 | main()
51 |
--------------------------------------------------------------------------------
/experiments/patho-sam/evaluate_amg.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | from micro_sam.evaluation.evaluation import run_evaluation
4 | from micro_sam.evaluation.inference import run_amg
5 |
6 | from util import (
7 | VANILLA_MODELS, get_default_arguments, get_pred_paths, get_test_paths, get_val_paths,
8 | )
9 |
10 |
11 | def run_amg_inference(model_type, checkpoint, experiment_folder, input_path, tiling_window_params=None):
12 | val_image_paths, val_gt_paths = get_val_paths(input_path)
13 | test_image_paths, _ = get_test_paths(input_path)
14 | prediction_folder = run_amg(
15 | checkpoint,
16 | model_type,
17 | experiment_folder,
18 | val_image_paths,
19 | val_gt_paths,
20 | test_image_paths,
21 | tiling_window_params=tiling_window_params,
22 | )
23 | return prediction_folder
24 |
25 |
26 | def eval_amg(prediction_folder, experiment_folder, input_path):
27 | print("Evaluating", prediction_folder)
28 | _, gt_paths = get_test_paths(input_path)
29 | pred_paths = get_pred_paths(prediction_folder)
30 | save_path = os.path.join(experiment_folder, "results", "amg.csv")
31 | res = run_evaluation(gt_paths, pred_paths, save_path=save_path)
32 | print(res)
33 |
34 |
35 | def main():
36 | args = get_default_arguments()
37 | if args.checkpoint is None:
38 | ckpt = VANILLA_MODELS[args.model]
39 | else:
40 | ckpt = args.checkpoint
41 |
42 | if args.tiling_window:
43 | tiling_window_params = {"tile_shape": [384, 384], "halo": [64, 64]}
44 | else:
45 | tiling_window_params = None
46 |
47 | prediction_folder = run_amg_inference(
48 | args.model, ckpt, args.experiment_folder, args.input_path, tiling_window_params
49 | )
50 | eval_amg(prediction_folder, args.experiment_folder, args.input_path)
51 |
52 |
53 | if __name__ == "__main__":
54 | main()
55 |
--------------------------------------------------------------------------------
/experiments/patho-sam/evaluate_iterative_prompting.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | from micro_sam.evaluation import inference
4 | from micro_sam.evaluation.evaluation import run_evaluation_for_iterative_prompting
5 |
6 | from util import get_default_arguments, get_model, get_test_paths
7 |
8 |
9 | def _run_iterative_prompting(exp_folder, predictor, start_with_box_prompt, use_masks, input_path):
10 | prediction_root = os.path.join(exp_folder, "start_with_box" if start_with_box_prompt else "start_with_point")
11 | embedding_folder = os.path.join(exp_folder, "embeddings")
12 | image_paths, gt_paths = get_test_paths(input_path)
13 | inference.run_inference_with_iterative_prompting(
14 | predictor=predictor,
15 | image_paths=image_paths,
16 | gt_paths=gt_paths,
17 | embedding_dir=embedding_folder,
18 | prediction_dir=prediction_root,
19 | start_with_box_prompt=start_with_box_prompt,
20 | use_masks=use_masks,
21 | )
22 | return prediction_root
23 |
24 |
25 | def _evaluate_iterative_prompting(prediction_root, start_with_box_prompt, exp_folder, input_path):
26 | _, gt_paths = get_test_paths(input_path)
27 |
28 | run_evaluation_for_iterative_prompting(
29 | gt_paths=gt_paths,
30 | prediction_root=prediction_root,
31 | experiment_folder=exp_folder,
32 | start_with_box_prompt=start_with_box_prompt,
33 | )
34 |
35 |
36 | def main():
37 | args = get_default_arguments()
38 |
39 | start_with_box_prompt = args.box # overwrite to start first iters' prompt with box instead of single point
40 |
41 | # Get the predictor to perform inference
42 | predictor = get_model(model_type=args.model, ckpt=args.checkpoint)
43 |
44 | prediction_root = _run_iterative_prompting(
45 | args.experiment_folder, predictor, start_with_box_prompt, args.use_masks, args.input_path
46 | )
47 | _evaluate_iterative_prompting(prediction_root, start_with_box_prompt, args.experiment_folder, args.input_path)
48 |
49 |
50 | if __name__ == "__main__":
51 | main()
52 |
--------------------------------------------------------------------------------
/experiments/patho-sam/get_results.py:
--------------------------------------------------------------------------------
1 | import os
2 | from natsort import natsorted
3 |
4 | import pandas as pd
5 |
6 |
7 | SAM_TYPES = ["vit_b", "vit_l", "vit_h", "vit_b_lm"]
8 | SAM_MODELS = ["generalist_sam", "lm_sam", "pannuke_sam", "vanilla_sam", "nuclick_sam", "glas_sam", "cryonuseg_sam"]
9 | MODEL_NAMES = ["hovernet", "cellvit", "hovernext", "stardist", "cellvitpp", "instanseg"] + SAM_MODELS
10 |
11 | DATASETS = [
12 | "consep",
13 | "cpm15",
14 | "cpm17",
15 | "cryonuseg",
16 | "glas",
17 | "lizard",
18 | "lynsec_he",
19 | "lynsec_ihc",
20 | "monusac",
21 | "monuseg",
22 | "nuclick",
23 | "nuinsseg",
24 | "pannuke",
25 | "puma",
26 | "srsanet",
27 | "tnbc",
28 | ]
29 |
30 | HNXT_CP = [
31 | "lizard_convnextv2_large",
32 | "lizard_convnextv2_base",
33 | "lizard_convnextv2_tiny",
34 | "pannuke_convnextv2_tiny_1",
35 | "pannuke_convnextv2_tiny_2",
36 | "pannuke_convnextv2_tiny_3",
37 | ]
38 |
39 | CVT_CP = [
40 | "256-x20",
41 | "256-x40",
42 | "SAM-H-x20",
43 | "SAM-H-x40",
44 | ]
45 |
46 | HVNT_CP = [
47 | "consep",
48 | "cpm17",
49 | "kumar",
50 | "monusac",
51 | "pannuke",
52 | ]
53 |
54 | CVTPP_CP = ["SAM-H-x40-AMP"]
55 |
56 | CHECKPOINTS = {
57 | "hovernet": HVNT_CP,
58 | "hovernext": HNXT_CP,
59 | "cellvit": CVT_CP,
60 | "cellvitpp": CVTPP_CP,
61 | "generalist_sam": SAM_TYPES,
62 | "pannuke_sam": SAM_TYPES,
63 | "old_generalist_sam": SAM_TYPES,
64 | "nuclick_sam": SAM_TYPES,
65 | "vanilla_sam": SAM_TYPES,
66 | "glas_sam": SAM_TYPES,
67 | "cryonuseg_sam": SAM_TYPES,
68 | "lm_sam": ["vit_b_lm"],
69 | "stardist": ["stardist"],
70 | "instanseg": ["instanseg"],
71 | }
72 |
73 |
74 | def get_results(path, overwrite=False):
75 | os.makedirs(os.path.join(path, "sum_results", "concatenated_results"), exist_ok=True)
76 | for mode in ['ais', 'amg', 'boxes', 'points']:
77 | csv_concat = os.path.join(path, "sum_results", "concatenated_results", f"concatenated_{mode}_results.csv")
78 | concat_df = pd.DataFrame()
79 | for model in MODEL_NAMES:
80 | if model not in SAM_MODELS and mode != 'ais':
81 | continue
82 |
83 | if model == "vanilla_sam" and mode == 'ais':
84 | continue
85 |
86 | for checkpoint in CHECKPOINTS[model]:
87 | if mode in ['ais', 'amg']:
88 | result_dict = {"dataset": [], "msa": [], "sa50": [], "sa75": []}
89 | else:
90 | result_dict = {
91 | "dataset": [],
92 | "msa_1st": [],
93 | "msa_8th": [],
94 | "sa50_1st": [],
95 | "sa50_8th": [],
96 | "sa75_1st": [],
97 | "sa75_8th": [],
98 | }
99 |
100 | os.makedirs(os.path.join(path, "sum_results", model), exist_ok=True)
101 | csv_out = os.path.join(path, "sum_results", model, f"{mode}_{model}_{checkpoint}_results.csv")
102 | if os.path.exists(csv_out) and not overwrite:
103 | print(f"{csv_out} already exists.")
104 | continue
105 |
106 | for dataset in natsorted(DATASETS):
107 | if model in SAM_MODELS:
108 | csv_path = os.path.join(
109 | path, model, "results", dataset, mode, f"{dataset}_{model}_{checkpoint}_{mode}.csv"
110 | )
111 | else:
112 | csv_path = os.path.join(
113 | path, model, "results", dataset, checkpoint,
114 | f"{dataset}_{model}_{checkpoint}_{mode}_result.csv"
115 | )
116 |
117 | if not os.path.exists(csv_path):
118 | continue
119 |
120 | df = pd.read_csv(csv_path)
121 | if mode in ['ais', 'amg']:
122 | result_dict["msa"].append(df.loc[0, "mSA"])
123 | result_dict["sa50"].append(df.loc[0, "SA50"])
124 | result_dict["sa75"].append(df.loc[0, "SA75"])
125 | result_dict["dataset"].append(dataset)
126 | else:
127 | result_dict["msa_1st"].append(df.loc[0, "mSA"])
128 | result_dict["sa50_1st"].append(df.loc[0, "SA50"])
129 | result_dict["sa75_1st"].append(df.loc[0, "SA75"])
130 | result_dict["msa_8th"].append(df.loc[7, "mSA"])
131 | result_dict["sa50_8th"].append(df.loc[7, "SA50"])
132 | result_dict["sa75_8th"].append(df.loc[7, "SA75"])
133 | result_dict["dataset"].append(dataset)
134 |
135 | df = pd.DataFrame(result_dict)
136 | if df.empty:
137 | continue
138 |
139 | df.to_csv(csv_out, index=False)
140 | if mode in ["amg", "ais"]:
141 | df = df[["dataset", "msa"]]
142 | df.rename(columns={"msa": f"{model}_{checkpoint}_{mode}"}, inplace=True)
143 | else:
144 | df = df[["dataset", "msa_1st", "msa_8th"]]
145 | df.rename(
146 | columns={
147 | "msa_1st": f"{model}_{checkpoint}_{mode}",
148 | "msa_8th": f"{model}_{checkpoint}_I{mode[0]}"
149 | },
150 | inplace=True,
151 | )
152 |
153 | if concat_df.empty:
154 | concat_df = df
155 | else:
156 | concat_df = pd.merge(concat_df, df, on="dataset", how="outer")
157 | print(f"{model} (checkpoint: {checkpoint}) added to concat results")
158 |
159 | if not concat_df.empty:
160 | concat_df.to_csv(csv_concat, index=False)
161 |
162 |
163 | def main():
164 | get_results("/mnt/lustre-grete/usr/u12649/models/", overwrite=True)
165 |
166 |
167 | if __name__ == "__main__":
168 | main()
169 |
--------------------------------------------------------------------------------
/experiments/patho-sam/per_image_eval.py:
--------------------------------------------------------------------------------
1 | import os
2 | from glob import glob
3 | from tqdm import tqdm
4 | from natsort import natsorted
5 | from typing import List, Optional, Union
6 |
7 | import numpy as np
8 | import pandas as pd
9 | import imageio.v3 as imageio
10 | from skimage.measure import label
11 |
12 | from elf.evaluation import mean_segmentation_accuracy
13 |
14 |
15 | DATASETS = [
16 | "consep",
17 | "cpm15",
18 | "cpm17",
19 | "cryonuseg",
20 | "lizard",
21 | "lynsec_he",
22 | "lynsec_ihc",
23 | "monusac",
24 | "monuseg",
25 | "nuclick",
26 | "nuinsseg",
27 | "pannuke",
28 | "puma",
29 | "srsanet",
30 | "tnbc",
31 | ]
32 |
33 | SAM_MODELS = ['generalist_sam', 'pannuke_sam']
34 |
35 | MODELS = [
36 | 'stardist',
37 | 'hovernet',
38 | 'hovernext',
39 | 'instanseg',
40 | 'cellvit',
41 | # 'cellvitpp',
42 | 'generalist_sam',
43 | 'pannuke_sam',
44 | # 'old_generalist_sam'
45 | ]
46 |
47 | MODEL_NAMES = {
48 | 'stardist': "StarDist",
49 | 'hovernet': "HoVerNet",
50 | 'hovernext': "HoVerNeXt",
51 | 'instanseg': "InstanSeg",
52 | 'cellvit': "CellViT",
53 | 'pannuke_sam': "Patho-SAM (Specialist)",
54 | 'generalist_sam': "Patho-SAM (Generalist)",
55 | }
56 |
57 | HNXT_CP = [
58 | "lizard_convnextv2_large",
59 | "lizard_convnextv2_base",
60 | "lizard_convnextv2_tiny",
61 | "pannuke_convnextv2_tiny_1",
62 | "pannuke_convnextv2_tiny_2",
63 | "pannuke_convnextv2_tiny_3",
64 | ]
65 | CVT_CP = [
66 | "256-x20",
67 | "256-x40",
68 | "SAM-H-x20",
69 | "SAM-H-x40",
70 | ]
71 |
72 | CVTPP_CP = ["SAM-H-x40-AMP"]
73 |
74 | SAM_TYPES = ["vit_b", "vit_l", "vit_h"]
75 |
76 |
77 | HVNT_CP = [
78 | 'consep',
79 | 'cpm17',
80 | 'kumar',
81 | 'monusac',
82 | 'pannuke',
83 | ]
84 |
85 | CHECKPOINTS = {
86 | 'hovernet': HVNT_CP,
87 | 'hovernext': HNXT_CP,
88 | 'cellvit': CVT_CP,
89 | 'cellvitpp': CVTPP_CP,
90 | 'generalist_sam': SAM_TYPES,
91 | 'pannuke_sam': ['vit_b'],
92 | 'old_generalist_sam': ['vit_b'],
93 | 'stardist': ['stardist'],
94 | 'instanseg': ['instanseg'],
95 | }
96 |
97 |
98 | def _run_evaluation(gt_paths, prediction_paths, verbose=True):
99 | assert len(gt_paths) == len(prediction_paths), f"{len(gt_paths)}, {len(prediction_paths)}"
100 | msas, sa50s, sa75s = [], [], []
101 |
102 | for gt_path, pred_path in tqdm(
103 | zip(gt_paths, prediction_paths), desc="Evaluate predictions", total=len(gt_paths), disable=not verbose
104 | ):
105 | assert os.path.exists(gt_path), gt_path
106 | assert os.path.exists(pred_path), pred_path
107 |
108 | gt = imageio.imread(gt_path)
109 | gt = label(gt)
110 | pred = imageio.imread(pred_path)
111 |
112 | msa, scores = mean_segmentation_accuracy(pred, gt, return_accuracies=True)
113 | sa50, sa75 = scores[0], scores[5]
114 | msas.append(msa), sa50s.append(sa50), sa75s.append(sa75)
115 |
116 | return msas, sa50s, sa75s
117 |
118 |
119 | def run_evaluation(
120 | gt_paths: List[Union[os.PathLike, str]],
121 | prediction_paths: List[Union[os.PathLike, str]],
122 | save_path: Optional[Union[os.PathLike, str]] = None,
123 | verbose: bool = True,
124 | ) -> pd.DataFrame:
125 | """Run evaluation for instance segmentation predictions.
126 |
127 | Args:
128 | gt_paths: The list of paths to ground-truth images.
129 | prediction_paths: The list of paths with the instance segmentations to evaluate.
130 | save_path: Optional path for saving the results.
131 | verbose: Whether to print the progress.
132 |
133 | Returns:
134 | A DataFrame that contains the evaluation results.
135 | """
136 | assert len(gt_paths) == len(prediction_paths), f"{len(gt_paths)}, {len(prediction_paths)}"
137 | # if a save_path is given and it already exists then just load it instead of running the eval
138 | if save_path is not None and os.path.exists(save_path):
139 | return pd.read_csv(save_path)
140 |
141 | msas, sa50s, sa75s = _run_evaluation(gt_paths, prediction_paths, verbose=verbose)
142 |
143 | results = np.array(msas)
144 | np.save(save_path, results)
145 |
146 |
147 | def per_sample_eval(inf_path, data_path):
148 | for model in MODELS:
149 | for checkpoint in CHECKPOINTS[model]:
150 | for dataset in DATASETS:
151 | if model not in SAM_MODELS:
152 | inference_paths = natsorted(
153 | glob(os.path.join(inf_path, model, "inference", dataset, checkpoint, "*.tiff"))
154 | )
155 | else:
156 | inference_paths = natsorted(
157 | glob(
158 | os.path.join(
159 | inf_path, model, "inference", dataset, checkpoint, "instance",
160 | "instance_segmentation_with_decoder", "inference", "*.tiff"
161 | )
162 | )
163 | )
164 |
165 | label_paths = natsorted(
166 | glob(os.path.join(data_path, dataset, "loaded_testset", "eval_split", "test_labels", "*.tiff"))
167 | )
168 | save_path = os.path.join(
169 | inf_path, model, "results", "per_sample", f"{model}_{dataset}_{checkpoint}_ps.npy"
170 | )
171 | if os.path.exists(save_path):
172 | continue
173 |
174 | os.makedirs(os.path.dirname(save_path), exist_ok=True)
175 | print(f"evaluating {model}, checkpoint {checkpoint}")
176 | run_evaluation(gt_paths=label_paths, prediction_paths=inference_paths, save_path=save_path)
177 |
178 |
179 | per_sample_eval(
180 | data_path="/mnt/lustre-grete/usr/u12649/data/final_test",
181 | inf_path="/mnt/lustre-grete/usr/u12649/models",
182 | )
183 |
--------------------------------------------------------------------------------
/experiments/patho-sam/push_inference.py:
--------------------------------------------------------------------------------
1 | import os
2 | import subprocess
3 | from datetime import datetime
4 |
5 | from util import DATASETS, SAM_TYPES, MODEL_NAMES
6 |
7 |
8 | CHECKPOINTS = {
9 | "vanilla_sam": SAM_TYPES,
10 | "generalist_sam": SAM_TYPES,
11 | "pannuke_sam": ["vit_b"],
12 | "lm_sam": ["vit_b_lm"],
13 | "glas_sam": ["vit_b"],
14 | "nuclick_sam": ["vit_b"],
15 | }
16 |
17 |
18 | def write_batch_script(out_path, _name, model_type, dataset, dry, model_name):
19 | "Writing scripts for patho-sam inference."
20 | batch_script = """#!/bin/bash
21 | #SBATCH -t 2-00:00:00
22 | #SBATCH --nodes=1
23 | #SBATCH --ntasks=1
24 | #SBATCH -p grete:shared
25 | #SBATCH -G A100:1
26 | #SBATCH -A nim00007
27 | #SBATCH -x ggpu114
28 | #SBATCH -c 16
29 | #SBATCH --job-name=patho-sam-inference
30 |
31 | source ~/.bashrc
32 | conda activate sam2 \n"""
33 |
34 | # python script
35 | python_script = f"python {_name}.py "
36 | python_script += f"-d {dataset} " # dataset to infer on
37 | python_script += f"-m {model_type} " # name of the model configuration
38 | python_script += f"-n {model_name} " # name of the model
39 |
40 | # let's add the python script to the bash script
41 | batch_script += python_script
42 |
43 | _op = out_path[:-3] + f"_{os.path.split(_name)[-1]}.sh"
44 | with open(_op, "w") as f:
45 | f.write(batch_script)
46 |
47 | cmd = ["sbatch", _op]
48 | if not dry:
49 | subprocess.run(cmd)
50 |
51 |
52 | def get_batch_script_names(tmp_folder):
53 | tmp_folder = os.path.expanduser(tmp_folder)
54 | os.makedirs(tmp_folder, exist_ok=True)
55 |
56 | script_name = "patho-sam-inference"
57 |
58 | dt = datetime.now().strftime("%Y_%m_%d_%H_%M_%S_%f")
59 | tmp_name = script_name + dt
60 | batch_script = os.path.join(tmp_folder, f"{tmp_name}.sh")
61 |
62 | return batch_script
63 |
64 |
65 | def submit_slurm(args):
66 | "Submit python script that needs gpus with given inputs on a slurm node."
67 | tmp_folder = "./gpu_jobs"
68 | model_path = "/mnt/lustre-grete/usr/u12649/models"
69 | script_names = ["run_ais", "run_amg"]
70 | for script_name in script_names:
71 | print(f"Running for {script_name}")
72 | for dataset in DATASETS:
73 | for model_name in MODEL_NAMES:
74 | for model in CHECKPOINTS[model_name]:
75 | if model_name == 'glas_sam' and dataset != 'glas':
76 | continue
77 |
78 | if model_name == 'nuclick_sam' and dataset != 'nuclick':
79 | continue
80 |
81 | if script_name != 'run_amg' and model_name == 'vanilla_sam':
82 | continue
83 |
84 | result = os.path.join(
85 | model_path, model_name, "results", dataset,
86 | f"{script_name[4:]}", f"{dataset}_{model_name}_{model}_{script_name[4:]}.csv"
87 | )
88 | print(result)
89 | if os.path.exists(result):
90 | continue
91 |
92 | write_batch_script(
93 | out_path=get_batch_script_names(tmp_folder),
94 | _name=script_name,
95 | model_type=model,
96 | dataset=dataset,
97 | dry=args.dry,
98 | model_name=model_name,
99 | )
100 |
101 |
102 | def main(args):
103 | submit_slurm(args)
104 |
105 |
106 | if __name__ == "__main__":
107 | import argparse
108 | parser = argparse.ArgumentParser()
109 | parser.add_argument("-m", "--model_type", type=str, default=None)
110 | parser.add_argument("--dry", action="store_true")
111 | args = parser.parse_args()
112 | main(args)
113 |
--------------------------------------------------------------------------------
/experiments/patho-sam/run_ais.py:
--------------------------------------------------------------------------------
1 | import os
2 | import shutil
3 | import argparse
4 | import subprocess
5 |
6 | from util import TILING_WINDOW_DS, PADDING_DS
7 |
8 |
9 | def run_inference(input_dir, output_dir, model_type, dataset, model_name, checkpoint_path):
10 | output_path = os.path.join(output_dir, model_name, "inference", dataset, model_type, "instance")
11 | os.makedirs(output_path, exist_ok=True)
12 | if dataset not in PADDING_DS:
13 | input_path = os.path.join(input_dir, "original_data", dataset, "eval_split")
14 | else:
15 | input_path = os.path.join(input_dir, "vit_data", dataset, "eval_split")
16 |
17 | args = [
18 | "-m", f"{model_type}",
19 | "-c", f"{checkpoint_path}",
20 | "--experiment_folder", f"{output_path}",
21 | "-i", f"{input_path}",
22 | ]
23 |
24 | if dataset in TILING_WINDOW_DS:
25 | args.append("--tiling_window")
26 | command = [
27 | "python3", os.path.expanduser("~/patho-sam/experiments/patho-sam/evaluate_ais.py"),
28 | ] + args
29 | print(f"Running inference with {model_name} model (type: {model_type}) on {dataset} dataset...")
30 | subprocess.run(command)
31 | os.makedirs(os.path.join(output_dir, model_name, 'results', dataset, 'ais'), exist_ok=True)
32 |
33 | shutil.copy(
34 | os.path.join(
35 | output_dir, model_name, "inference", dataset, model_type, 'instance',
36 | 'results', 'instance_segmentation_with_decoder.csv'
37 | ),
38 | os.path.join(output_dir, model_name, 'results', dataset, 'ais', f'{dataset}_{model_name}_{model_type}_ais.csv')
39 | )
40 |
41 | print(f"Successfully ran inference with {model_name} model (type: {model_type}) on {dataset} dataset")
42 |
43 |
44 | def get_ais_args():
45 | parser = argparse.ArgumentParser()
46 | parser.add_argument(
47 | "-d", "--dataset", type=str, default=None, help="The dataset to infer on."
48 | )
49 | parser.add_argument(
50 | "-m", "--model", type=str, default=None, help="Provide the model type to infer with {vit_b, vit_l, vit_h}."
51 | )
52 | parser.add_argument(
53 | "-o", "--output_dir", type=str, default=None, help="Provide path where the results will be stored."
54 | )
55 | parser.add_argument(
56 | "-i", "--input_dir", type=str, default=None, help="Provide path where the dataset is located."
57 | )
58 | parser.add_argument(
59 | "-n", "--name", type=str, default=None,
60 | help="Provide the name of the model to infer with {generalist_sam, vanilla_sam, ..}."
61 | )
62 | parser.add_argument(
63 | "-c", "--checkpoint_path", type=str, default=None,
64 | help="(Optional) provide the path to the checkpoint to use for inference."
65 | )
66 |
67 |
68 | def main():
69 | args = get_ais_args()
70 | run_inference(
71 | output_dir=args.output_dir,
72 | input_dir=args.input_dir,
73 | model_type=args.model,
74 | dataset=args.dataset,
75 | model_name=args.name,
76 | checkpoint_path=args.checkpoint_path,
77 | )
78 |
79 |
80 | if __name__ == "__main__":
81 | main()
82 |
--------------------------------------------------------------------------------
/experiments/patho-sam/run_amg.py:
--------------------------------------------------------------------------------
1 | import os
2 | import shutil
3 | import argparse
4 | import subprocess
5 |
6 | from util import TILING_WINDOW_DS, PADDING_DS
7 |
8 |
9 | def run_inference(input_dir, output_dir, model_type, dataset, model_name, checkpoint_path):
10 | output_path = os.path.join(output_dir, model_name, "inference", dataset, model_type, "amg")
11 | os.makedirs(output_path, exist_ok=True)
12 | if os.path.exists(
13 | os.path.join(output_dir, model_name, 'results', dataset, 'amg', f'{dataset}_{model_name}_{model_type}_amg.csv')
14 | ):
15 | print(f"Inference with {model_name} model (type: {model_type}) on {dataset} dataset already done")
16 | return
17 |
18 | if dataset not in PADDING_DS:
19 | input_path = os.path.join(input_dir, "original_data", dataset, "eval_split")
20 | else:
21 | input_path = os.path.join(input_dir, "vit_data", dataset, "eval_split")
22 |
23 | args = [
24 | "-m", f"{model_type}",
25 | "-c", f"{checkpoint_path}",
26 | "--experiment_folder", f"{output_path}",
27 | "-i", f"{input_path}",
28 | ]
29 |
30 | if dataset in TILING_WINDOW_DS:
31 | args.append("--tiling_window")
32 |
33 | command = [
34 | "python3", os.path.expanduser("~/patho-sam/experiments/patho-sam/evaluate_amg.py"),
35 | ] + args
36 | print(f"Running inference with {model_name} model (type: {model_type}) on {dataset} dataset...")
37 | subprocess.run(command)
38 | os.makedirs(os.path.join(output_dir, model_name, 'results', dataset, 'amg'), exist_ok=True)
39 |
40 | shutil.copy(
41 | os.path.join(output_dir, model_name, "inference", dataset, model_type, 'amg', 'results', 'amg.csv'),
42 | os.path.join(output_dir, model_name, 'results', dataset, 'amg', f'{dataset}_{model_name}_{model_type}_amg.csv')
43 | )
44 | print(f"Successfully ran amg inference with {model_name} model (type: {model_type}) on {dataset} dataset")
45 |
46 |
47 | def get_amg_args():
48 | parser = argparse.ArgumentParser()
49 | parser.add_argument(
50 | "-d", "--dataset", type=str, default=None, help="The dataset to infer on."
51 | )
52 | parser.add_argument(
53 | "-m", "--model", type=str, default=None, help="Provide the model type to infer with {vit_b, vit_l, vit_h}."
54 | )
55 | parser.add_argument(
56 | "-o", "--output_dir", type=str, default=None, help="Provide path where the results will be stored."
57 | )
58 | parser.add_argument(
59 | "-i", "--input_dir", type=str, default=None, help="Provide path where the dataset is located."
60 | )
61 | parser.add_argument(
62 | "-n", "--name", type=str, default=None,
63 | help="Provide the name of the model to infer with {generalist_sam, pannuke_sam, vanilla_sam, ..}."
64 | )
65 | parser.add_argument(
66 | "-c", "--checkpoint_path", type=str, default=None,
67 | help="(Optional) provide the path to the checkpoint to use for inference."
68 | )
69 |
70 |
71 | def main():
72 | args = get_amg_args()
73 | run_inference(
74 | output_dir=args.output_dir,
75 | input_dir=args.input_dir,
76 | model_type=args.model,
77 | dataset=args.dataset,
78 | model_name=args.name,
79 | checkpoint_path=args.checkpoint_path,
80 | )
81 |
82 |
83 | if __name__ == "__main__":
84 | main()
85 |
--------------------------------------------------------------------------------
/experiments/patho-sam/run_iter_boxes.py:
--------------------------------------------------------------------------------
1 | import os
2 | import shutil
3 | import argparse
4 | import subprocess
5 |
6 |
7 | def run_inference(input_dir, output_dir, model_type, dataset, model_name, checkpoint_path, use_masks):
8 | output_path = os.path.join(output_dir, model_name, "inference", dataset, model_type, "boxes")
9 | os.makedirs(output_path, exist_ok=True)
10 | if os.path.exists(
11 | os.path.join(
12 | output_dir, model_name, 'results', dataset, 'boxes',
13 | f'{dataset}_{model_name}_{model_type}_boxes.csv'
14 | )
15 | ):
16 | print(f"Inference with {model_name} model on {dataset} dataset already done")
17 | return
18 |
19 | input_path = os.path.join(input_dir, dataset, "loaded_testset", "eval_split")
20 | args = [
21 | "-m", f"{model_type}",
22 | "-c", f"{checkpoint_path}",
23 | "--experiment_folder", f"{output_path}",
24 | "-i", f"{input_path}",
25 | "--box",
26 | ]
27 |
28 | if use_masks:
29 | args.append("--use_masks")
30 |
31 | command = [
32 | "python3", os.path.expanduser("~/patho-sam/experiments/patho-sam/evaluate_iterative_prompting.py"),
33 | ] + args
34 |
35 | print(f"Running inference with {model_name} model (type: {model_type}) on {dataset} dataset...")
36 | subprocess.run(command)
37 | shutil.rmtree(os.path.join(output_path, "embeddings"))
38 | os.makedirs(os.path.join(output_dir, model_name, 'results', dataset, 'boxes'), exist_ok=True)
39 | shutil.copy(
40 | os.path.join(
41 | output_dir, model_name, "inference", dataset, model_type, 'boxes', 'results',
42 | 'iterative_prompting_without_mask', 'iterative_prompts_start_box.csv'
43 | ),
44 | os.path.join(
45 | output_dir, model_name, 'results', dataset, 'boxes', f'{dataset}_{model_name}_{model_type}_boxes.csv'
46 | )
47 | )
48 | print(f"Successfully ran inference with {model_name} model (type: {model_type}) on {dataset} dataset")
49 |
50 |
51 | def get_iterative_boxes_args():
52 | parser = argparse.ArgumentParser()
53 | parser.add_argument(
54 | "-d", "--dataset", type=str, default=None, help="The dataset to infer on."
55 | )
56 | parser.add_argument(
57 | "-m", "--model", type=str, default=None, help="Provide the model type to infer with {vit_b, vit_l, vit_h}."
58 | )
59 | parser.add_argument(
60 | "-o", "--output_dir", type=str, default=None, help="Provide path where the results will be stored."
61 | )
62 | parser.add_argument(
63 | "-i", "--input_dir", type=str, default=None, help="Provide path where the dataset is located."
64 | )
65 | parser.add_argument(
66 | "-n", "--name", type=str, default=None,
67 | help="Provide the name of the model to infer with {generalist_sam, pannuke_sam, vanilla_sam, ..}."
68 | )
69 | parser.add_argument(
70 | "-c", "--checkpoint_path", type=str, default=None,
71 | help="(Optional) provide the path to the checkpoint to use for inference."
72 | )
73 | parser.add_argument(
74 | "--masks_off", action="store_false", help="To disable the usage of logit masks for iterative prompting."
75 | )
76 |
77 |
78 | def main():
79 | args = get_iterative_boxes_args()
80 | run_inference(
81 | output_dir=args.output_dir,
82 | input_dir=args.input_dir,
83 | model_type=args.model,
84 | dataset=args.dataset,
85 | model_name=args.name,
86 | checkpoint_path=args.checkpoint_path,
87 | use_masks=args.masks_off
88 | )
89 |
90 |
91 | if __name__ == "__main__":
92 | main()
93 |
--------------------------------------------------------------------------------
/experiments/patho-sam/run_iter_points.py:
--------------------------------------------------------------------------------
1 | import os
2 | import shutil
3 | import argparse
4 | import subprocess
5 |
6 |
7 | def run_inference(input_dir, output_dir, model_type, dataset, model_name, checkpoint_path, use_masks):
8 | output_path = os.path.join(output_dir, model_name, "inference", dataset, model_type, "points")
9 | os.makedirs(output_path, exist_ok=True)
10 | if os.path.exists(
11 | os.path.join(
12 | output_dir, model_name, 'results', dataset, 'points',
13 | f'{dataset}_{model_name}_{model_type}_points.csv'
14 | )
15 | ):
16 | print(f"Inference with {model_name} model (type: {model_type}) on {dataset} dataset already done")
17 | return
18 |
19 | input_path = os.path.join(input_dir, dataset, "loaded_testset", "eval_split")
20 | args = [
21 | "-m", f"{model_type}",
22 | "-c", f"{checkpoint_path}",
23 | "--experiment_folder", f"{output_path}",
24 | "-i", f"{input_path}",
25 | ]
26 |
27 | if use_masks:
28 | args.append("--use_masks")
29 |
30 | command = [
31 | "python3", "/user/titus.griebel/u12649/patho-sam/experiments/patho-sam/evaluate_iterative_prompting.py",
32 | ] + args
33 |
34 | print(f"Running inference with {model_name} model (type: {model_type}) on {dataset} dataset...")
35 | subprocess.run(command)
36 | embedding_path = os.path.join(output_path, "embeddings")
37 | shutil.rmtree(embedding_path)
38 | os.makedirs(os.path.join(output_dir, model_name, 'results', dataset, 'points'), exist_ok=True)
39 | shutil.copy(
40 | os.path.join(
41 | output_dir, model_name, "inference", dataset, model_type, 'points', 'results',
42 | 'iterative_prompting_without_mask', 'iterative_prompts_start_point.csv'
43 | ),
44 | os.path.join(
45 | output_dir, model_name, 'results', dataset, 'points',
46 | f'{dataset}_{model_name}_{model_type}_points.csv'
47 | )
48 | )
49 |
50 | print(
51 | f"Successfully ran iterative points inference with {model_name} "
52 | "model (type: {model_type}) on {dataset} dataset"
53 | )
54 |
55 |
56 | def get_iterative_points_args():
57 | parser = argparse.ArgumentParser()
58 | parser.add_argument(
59 | "-d", "--dataset", type=str, default=None, help="The dataset to infer on."
60 | )
61 | parser.add_argument(
62 | "-m", "--model", type=str, default=None, help="Provide the model type to infer with {vit_b, vit_l, vit_h}."
63 | )
64 | parser.add_argument(
65 | "-o", "--output_dir", type=str, default=None, help="Provide path where the results will be stored."
66 | )
67 | parser.add_argument(
68 | "-i", "--input_dir", type=str, default=None, help="Provide path where the dataset is located."
69 | )
70 | parser.add_argument(
71 | "-n", "--name", type=str, default=None,
72 | help="Provide the name of the model to infer with {generalist_sam, pannuke_sam, vanilla_sam, ..}."
73 | )
74 | parser.add_argument(
75 | "-c", "--checkpoint_path", type=str, default=None,
76 | help="(Optional) provide the path to the checkpoint to use for inference."
77 | )
78 | parser.add_argument(
79 | "--masks_off", action="store_false", help="To disable the usage of logit masks for iterative prompting."
80 | )
81 |
82 |
83 | def main():
84 | args = get_iterative_points_args()
85 | run_inference(
86 | output_dir=args.output_dir,
87 | input_dir=args.input_dir,
88 | model_type=args.model,
89 | dataset=args.dataset,
90 | model_name=args.name,
91 | checkpoint_path=args.checkpoint_path,
92 | use_masks=args.masks_off
93 | )
94 |
95 |
96 | if __name__ == "__main__":
97 | main()
98 |
--------------------------------------------------------------------------------
/experiments/patho-sam/util.py:
--------------------------------------------------------------------------------
1 | import os
2 | import argparse
3 | from glob import glob
4 | from natsort import natsorted
5 |
6 | from micro_sam.util import get_sam_model
7 |
8 |
9 | ROOT = "/scratch/projects/nim00007/sam/data/"
10 | EXPERIMENT_ROOT = "/scratch/projects/nim00007/sam/experiments/new_models"
11 |
12 | VANILLA_MODELS = {
13 | "vit_t": "/scratch-grete/projects/nim00007/sam/models/vanilla/vit_t_mobile_sam.pth",
14 | "vit_b": "/scratch-grete/projects/nim00007/sam/models/vanilla/sam_vit_b_01ec64.pth",
15 | "vit_l": "/scratch-grete/projects/nim00007/sam/models/vanilla/sam_vit_l_0b3195.pth",
16 | "vit_h": "/scratch-grete/projects/nim00007/sam/models/vanilla/sam_vit_h_4b8939.pth",
17 | "vit_b_lm": None,
18 | "vit_b_histopathology": None,
19 | }
20 |
21 |
22 | SAM_TYPES = ["vit_b", "vit_l", "vit_h"]
23 |
24 | MODEL_NAMES = ["lm_sam", "vanilla_sam", "generalist_sam", "pannuke_sam", "nuclick_sam", "glas_sam"]
25 |
26 | DATASETS = [
27 | "consep",
28 | "cpm15",
29 | "cpm17",
30 | "cryonuseg",
31 | "glas",
32 | "lizard",
33 | "lynsec_he",
34 | "lynsec_ihc",
35 | "monuseg",
36 | "nuclick",
37 | "nuinsseg",
38 | "pannuke",
39 | "puma",
40 | "srsanet",
41 | "tnbc",
42 | ]
43 |
44 | TILING_WINDOW_DS = [
45 | "cpm15",
46 | "consep",
47 | "lizard",
48 | "monuseg",
49 | "puma",
50 | ]
51 |
52 | PADDING_DS = [
53 | "pannuke",
54 | "srsanet",
55 | "nuclick",
56 | ]
57 |
58 |
59 | def get_dataset_paths(dataset_name, split_choice):
60 | file_search_specs = "*"
61 | is_explicit_split = True
62 |
63 | # if the datasets have different modalities/species, let's make use of it
64 | split_names = dataset_name.split("/")
65 | if len(split_names) > 1:
66 | assert len(split_names) <= 2
67 | dataset_name = [split_names[0], "slices", split_names[1]]
68 | else:
69 | dataset_name = [*split_names, "slices"]
70 |
71 | # if there is an explicit val/test split made, let's look at them
72 | if is_explicit_split:
73 | dataset_name.append(split_choice)
74 |
75 | raw_dir = os.path.join(ROOT, *dataset_name, "raw", file_search_specs)
76 | labels_dir = os.path.join(ROOT, *dataset_name, "labels", file_search_specs)
77 |
78 | return raw_dir, labels_dir
79 |
80 |
81 | def get_model(model_type, ckpt):
82 | if ckpt is None:
83 | ckpt = VANILLA_MODELS[model_type]
84 | predictor = get_sam_model(model_type=model_type, checkpoint_path=ckpt)
85 | return predictor
86 |
87 |
88 | def get_pred_paths(prediction_folder):
89 | pred_paths = sorted(glob(os.path.join(prediction_folder, "*")))
90 | return pred_paths
91 |
92 |
93 | def get_default_arguments():
94 | parser = argparse.ArgumentParser()
95 | parser.add_argument(
96 | "-m", "--model", type=str, required=True, help="Provide the model type to initialize the predictor"
97 | )
98 | parser.add_argument("-c", "--checkpoint", type=none_or_str, default=None) # expects best.pt
99 | parser.add_argument("-e", "--experiment_folder", type=str, required=True) # empty directory for saving the output
100 | parser.add_argument(
101 | "-i", "--input_path", type=str, required=True, default=None,
102 | help="Requires path to a directory containing 'test_images', 'test_labels', 'val_images' and 'val_labels' directories that contain the data", # noqa
103 | )
104 | parser.add_argument(
105 | "--tiling_window", action="store_true", help="To use tiling window for inputs larger than 512 x 512"
106 | )
107 | parser.add_argument(
108 | "--box", action="store_true", help="If passed, starts with first prompt as box"
109 | )
110 | parser.add_argument(
111 | "--use_masks", action="store_true", help="To use logits masks for iterative prompting."
112 | )
113 | args = parser.parse_args()
114 | return args
115 |
116 |
117 | def dataloading_args():
118 | parser = argparse.ArgumentParser()
119 | parser.add_argument("-p", "--path", type=str, default=None)
120 | parser.add_argument("-d", "--datasets", type=str, default=None)
121 | parser.add_argument("--patch_shape", type=tuple, default=(512, 512))
122 |
123 | args = parser.parse_args()
124 | return args
125 |
126 |
127 | def none_or_str(value):
128 | if value == "None":
129 | return None
130 | return value
131 |
132 |
133 | def get_val_paths(input_path):
134 | val_image_paths = natsorted(glob(os.path.join(input_path, "val_images/*")))
135 | val_label_paths = natsorted(glob(os.path.join(input_path, "val_labels/*")))
136 | return val_image_paths, val_label_paths
137 |
138 |
139 | def get_test_paths(input_path):
140 | test_image_paths = natsorted(glob(os.path.join(input_path, "test_images/*")))
141 | test_label_paths = natsorted(glob(os.path.join(input_path, "test_labels/*")))
142 | return test_image_paths, test_label_paths
143 |
144 |
145 | def get_inference_args():
146 | parser = argparse.ArgumentParser()
147 | parser.add_argument(
148 | "-d", "--dataset", type=str, default=None,
149 | help="The dataset to infer on. If None, all datasets will be chosen."
150 | )
151 | parser.add_argument(
152 | "-m", "--model", type=str, default=None,
153 | help="Provide the model type to infer with {vit_b, vit_l, vit_h}."
154 | )
155 | parser.add_argument(
156 | "--use_masks", action="store_true", help="To use logits masks for iterative prompting."
157 | )
158 | parser.add_argument(
159 | "--tiling_window", action="store_true",
160 | help="To use tiling window for inputs larger than 512 x 512")
161 | parser.add_argument(
162 | "-n", "--name", type=str, default=None,
163 | help="Provide the name of the model to infer with {generalist_sam, vanilla_sam, ..}."
164 | )
165 | args = parser.parse_args()
166 | return args
167 |
--------------------------------------------------------------------------------
/experiments/semantic_segmentation/README.md:
--------------------------------------------------------------------------------
1 | # Scripts for Training and Evaluating Semantic Segmentation using PathoSAM
2 |
3 | This folder contains scripts for training a semantic segmentation model, starting from `patho-sam` generalist models. Here is a brief mention of the folders:
4 | - `benchmarking`: Contains scripts to run BioMedParse for semantic segmentation in histopathology images and scripts to evaluate other benchmark methods for the task.
5 | - `generalists`: Contains scripts to run the training and evaluation of semantic segmentation (in all combinations) from `patho-sam` generalist models.
6 | - `results`: Contains scripts for plotting the results for semantic segmentation.
7 |
--------------------------------------------------------------------------------
/experiments/semantic_segmentation/benchmark_methods/evaluate_other_methods.py:
--------------------------------------------------------------------------------
1 | import os
2 | from glob import glob
3 | from tqdm import tqdm
4 | from natsort import natsorted
5 |
6 | import numpy as np
7 | import pandas as pd
8 |
9 | from tukra.io import read_image
10 |
11 | from patho_sam.evaluation.evaluation import semantic_segmentation_quality, extract_class_weights_for_pannuke
12 |
13 |
14 | ROOT = "/mnt/vast-nhr/projects/cidas/cca/experiments/patho_sam/semantic/external"
15 |
16 | CLASS_IDS = [1, 2, 3, 4, 5]
17 |
18 |
19 | def evaluate_benchmark_methods(per_class_weights):
20 | # Get the original images first.
21 | image_paths = natsorted(glob(os.path.join(ROOT, "semantic_split", "test_images", "*.tiff")))
22 | gt_paths = natsorted(glob(os.path.join(ROOT, "semantic_split", "test_labels", "*.tiff")))
23 |
24 | assert image_paths and len(image_paths) == len(gt_paths)
25 |
26 | cellvit_256_20_scores, cellvit_256_40_scores, cellvit_sam_20_scores, cellvit_sam_40_scores = [], [], [], []
27 | hovernet_scores = []
28 | hovernext_1_scores, hovernext_2_scores = [], []
29 | for image_path, gt_path in tqdm(zip(image_paths, gt_paths), total=len(image_paths)):
30 | # Load the input image and corresponding labels.
31 | gt = read_image(gt_path)
32 |
33 | # If the inputs do not have any semantic labels, we do not evaluate them!
34 | if len(np.unique(gt)) == 1:
35 | continue
36 |
37 | # Get the filename
38 | fname = os.path.basename(image_path)
39 |
40 | # Get predictions and scores per experiment.
41 | # 1. cellvit results (for all models).
42 | cellvit_256_20 = read_image(os.path.join(ROOT, "cellvit", "256-x20", fname))
43 | cellvit_256_40 = read_image(os.path.join(ROOT, "cellvit", "256-x40", fname))
44 | cellvit_sam_20 = read_image(os.path.join(ROOT, "cellvit", "SAM-H-x20", fname))
45 | cellvit_sam_40 = read_image(os.path.join(ROOT, "cellvit", "SAM-H-x40", fname))
46 |
47 | cellvit_256_20_scores.append(
48 | semantic_segmentation_quality(ground_truth=gt, segmentation=cellvit_256_20, class_ids=CLASS_IDS)
49 | )
50 | cellvit_256_40_scores.append(
51 | semantic_segmentation_quality(ground_truth=gt, segmentation=cellvit_256_40, class_ids=CLASS_IDS)
52 | )
53 | cellvit_sam_20_scores.append(
54 | semantic_segmentation_quality(ground_truth=gt, segmentation=cellvit_sam_20, class_ids=CLASS_IDS)
55 | )
56 | cellvit_sam_40_scores.append(
57 | semantic_segmentation_quality(ground_truth=gt, segmentation=cellvit_sam_40, class_ids=CLASS_IDS)
58 | )
59 |
60 | # 2. hovernet results.
61 | hovernet = read_image(os.path.join(ROOT, "hovernet", "pannuke", fname))
62 |
63 | hovernet_scores.append(
64 | semantic_segmentation_quality(ground_truth=gt, segmentation=hovernet, class_ids=CLASS_IDS)
65 | )
66 |
67 | # 3. hovernext results.
68 | hovernext_1 = read_image(os.path.join(ROOT, "hovernext", "pannuke_convnextv2_tiny_1", fname))
69 | hovernext_2 = read_image(os.path.join(ROOT, "hovernext", "pannuke_convnextv2_tiny_2", fname))
70 |
71 | hovernext_1_scores.append(
72 | semantic_segmentation_quality(ground_truth=gt, segmentation=hovernext_1, class_ids=CLASS_IDS)
73 | )
74 | hovernext_2_scores.append(
75 | semantic_segmentation_quality(ground_truth=gt, segmentation=hovernext_2, class_ids=CLASS_IDS)
76 | )
77 |
78 | def _get_average_results(sq_per_image, fname):
79 | msq_neoplastic_cells = np.nanmean([sq[0] for sq in sq_per_image])
80 | msq_inflammatory = np.nanmean([sq[1] for sq in sq_per_image])
81 | msq_connective = np.nanmean([sq[2] for sq in sq_per_image])
82 | msq_dead = np.nanmean([sq[3] for sq in sq_per_image])
83 | msq_epithelial = np.nanmean([sq[4] for sq in sq_per_image])
84 |
85 | all_msq = [msq_neoplastic_cells, msq_inflammatory, msq_connective, msq_dead, msq_epithelial]
86 | weighted_mean_msq = [msq * weight for msq, weight in zip(all_msq, per_class_weights)]
87 |
88 | results = {
89 | "neoplastic_cells": msq_neoplastic_cells,
90 | "inflammatory_cells": msq_inflammatory,
91 | "connective_cells": msq_connective,
92 | "dead_cells": msq_dead,
93 | "epithelial_cells": msq_epithelial,
94 | "weighted_mean": np.sum(weighted_mean_msq),
95 | "absolute_mean": np.mean(all_msq)
96 | }
97 | results = pd.DataFrame.from_dict([results])
98 | results.to_csv(fname)
99 | print(results)
100 |
101 | # Get average results per method.
102 | _get_average_results(cellvit_256_20_scores, "cellvit_256_20_semantic.csv")
103 | _get_average_results(cellvit_256_40_scores, "cellvit_256_40_semantic.csv")
104 | _get_average_results(cellvit_sam_20_scores, "cellvit_sam_20_semantic.csv")
105 | _get_average_results(cellvit_sam_40_scores, "cellvit_sam_40_semantic.csv")
106 | _get_average_results(hovernet_scores, "hovernet_semantic.csv")
107 | _get_average_results(hovernext_1_scores, "hovernext_1_semantic.csv")
108 | _get_average_results(hovernext_2_scores, "hovernext_2_semantic.csv")
109 |
110 |
111 | def main():
112 | # Get per class weights.
113 | fpath = os.path.join(*ROOT.rsplit("/")[:-2], "data", "pannuke", "pannuke_fold_3.h5")
114 | fpath = "/" + fpath
115 | per_class_weights = extract_class_weights_for_pannuke(fpath=fpath)
116 |
117 | # Run evaluation for external benchmark methods.
118 | evaluate_benchmark_methods(per_class_weights)
119 |
120 |
121 | if __name__ == "__main__":
122 | main()
123 |
--------------------------------------------------------------------------------
/experiments/semantic_segmentation/benchmark_methods/run_biomedparse.py:
--------------------------------------------------------------------------------
1 | import os
2 | from glob import glob
3 | from tqdm import tqdm
4 | from natsort import natsorted
5 |
6 | import numpy as np
7 | import pandas as pd
8 | import imageio.v3 as imageio
9 |
10 | from tukra.io import read_image
11 | from tukra.inference import get_biomedparse
12 |
13 | from patho_sam.evaluation.evaluation import semantic_segmentation_quality, extract_class_weights_for_pannuke
14 |
15 |
16 | ROOT = "/mnt/vast-nhr/projects/cidas/cca/experiments/patho_sam/semantic/external"
17 |
18 | MAPPING = {
19 | "neoplastic cells": 1,
20 | "inflammatory cells": 2,
21 | "connective tissue cells": 3,
22 | "dead cells": 4, # NOTE: dead cells are not a semantic class involved in biomedparse.
23 | "epithelial cells": 5,
24 | }
25 |
26 |
27 | def evaluate_biomedparse_for_histopathology(dataset):
28 |
29 | # Other stuff for biomedparse
30 | modality = "Pathology" # choice of modality to determine the semantic targets.
31 | model = get_biomedparse.get_biomedparse_model() # get the biomedparse model.
32 |
33 | # Get per class weights.
34 | fpath = os.path.join(*ROOT.rsplit("/")[:-2], "data", "pannuke", "pannuke_fold_3.h5")
35 | fpath = "/" + fpath
36 | per_class_weights = extract_class_weights_for_pannuke(fpath=fpath)
37 |
38 | # Get the inputs and corresponding labels.
39 | image_paths = natsorted(glob(os.path.join(ROOT, dataset, "test_images", "*.tiff")))
40 | gt_paths = natsorted(glob(os.path.join(ROOT, dataset, "test_labels", "*.tiff")))
41 |
42 | assert image_paths and len(image_paths) == len(gt_paths)
43 |
44 | # Get results directory
45 | result_dir = os.path.join(ROOT, "biomedparse", dataset)
46 | os.makedirs(result_dir, exist_ok=True)
47 |
48 | sq_per_image = []
49 | for image_path, gt_path in tqdm(zip(image_paths, gt_paths), total=len(image_paths)):
50 | # Get the input image and corresponding semantic labels.
51 | image = read_image(image_path).astype("float32")
52 | gt = read_image(gt_path)
53 |
54 | # If the inputs do not have any semantic labels, we do not evaluate them!
55 | if len(np.unique(gt)) == 1:
56 | continue
57 |
58 | # Run inference per image.
59 | prediction = get_biomedparse.run_biomedparse_automatic_inference(
60 | input_path=image, modality_type=modality, model=model, verbose=False,
61 | )
62 |
63 | semantic_seg = np.zeros_like(gt, dtype="uint8")
64 | if prediction is not None:
65 | prompts = list(prediction.keys()) # Extracting detected classes.
66 | segmentations = list(prediction.values()) # Extracting the segmentations.
67 |
68 | # Map all predicted labels.
69 | for prompt, per_seg in zip(prompts, segmentations):
70 | semantic_seg[per_seg > 0] = MAPPING[prompt]
71 |
72 | # Evaluate scores.
73 | sq_score = semantic_segmentation_quality(
74 | ground_truth=gt, segmentation=semantic_seg, class_ids=list(MAPPING.values()),
75 | )
76 | sq_per_image.append(sq_score)
77 |
78 | # Store the semantic segmentation results to avoid running them all the time.
79 | imageio.imwrite(os.path.join(result_dir, os.path.basename(image_path)), semantic_seg, compression="zlib")
80 |
81 | def _get_average_results(sq_per_image, fname):
82 | msq_neoplastic_cells = np.nanmean([sq[0] for sq in sq_per_image])
83 | msq_inflammatory = np.nanmean([sq[1] for sq in sq_per_image])
84 | msq_connective = np.nanmean([sq[2] for sq in sq_per_image])
85 | msq_dead = np.nanmean([sq[3] for sq in sq_per_image]) # This class is absent. So this would obv. be zero.
86 | msq_epithelial = np.nanmean([sq[4] for sq in sq_per_image])
87 |
88 | all_msq = [msq_neoplastic_cells, msq_inflammatory, msq_connective, msq_dead, msq_epithelial]
89 | weighted_mean_msq = [msq * weight for msq, weight in zip(all_msq, per_class_weights)]
90 |
91 | results = {
92 | "neoplastic_cells": msq_neoplastic_cells,
93 | "inflammatory_cells": msq_inflammatory,
94 | "connective_cells": msq_connective,
95 | "dead_cells": msq_dead,
96 | "epithelial_cells": msq_epithelial,
97 | "weighted_mean": np.sum(weighted_mean_msq),
98 | "absolute_mean": np.mean(all_msq)
99 | }
100 | results = pd.DataFrame.from_dict([results])
101 | results.to_csv(fname)
102 | print(results)
103 |
104 | # Get average results for biomedparse.
105 | _get_average_results(sq_per_image, "biomedparse_semantic.csv")
106 |
107 |
108 | def main():
109 | # Run automatic (semantic) segmentation inference using biomedparse
110 | evaluate_biomedparse_for_histopathology("pannuke")
111 | evaluate_biomedparse_for_histopathology("puma")
112 |
113 |
114 | if __name__ == "__main__":
115 | main()
116 |
--------------------------------------------------------------------------------
/experiments/semantic_segmentation/generalists/evaluate_pannuke.py:
--------------------------------------------------------------------------------
1 | import os
2 | from glob import glob
3 | from tqdm import tqdm
4 | from natsort import natsorted
5 |
6 | import numpy as np
7 | import pandas as pd
8 |
9 | import torch
10 |
11 | from tukra.io import read_image
12 |
13 | from micro_sam.util import get_sam_model
14 | from micro_sam.instance_segmentation import get_unetr
15 |
16 | from patho_sam.evaluation.evaluation import semantic_segmentation_quality, extract_class_weights_for_pannuke
17 |
18 |
19 | ROOT = "/mnt/vast-nhr/projects/cidas/cca/experiments/patho_sam/semantic/external"
20 |
21 |
22 | def evaluate_pannuke_semantic_segmentation(args):
23 | # Stuff needed for inference
24 | model_type = args.model_type
25 | checkpoint_path = args.checkpoint_path
26 | num_classes = 6 # available classes are [0, 1, 2, 3, 4, 5]
27 | device = "cuda" if torch.cuda.is_available() else "cpu"
28 |
29 | # Get per class weights.
30 | fpath = os.path.join(*ROOT.rsplit("/")[:-2], "data", "pannuke", "pannuke_fold_3.h5")
31 | fpath = "/" + fpath
32 | per_class_weights = extract_class_weights_for_pannuke(fpath=fpath)
33 |
34 | # Get the inputs and corresponding labels.
35 | image_paths = natsorted(glob(os.path.join(ROOT, "semantic_split", "test_images", "*.tiff")))
36 | gt_paths = natsorted(glob(os.path.join(ROOT, "semantic_split", "test_labels", "*.tiff")))
37 |
38 | assert len(image_paths) == len(gt_paths) and image_paths
39 |
40 | # Get the SAM model
41 | predictor = get_sam_model(model_type=model_type, device=device)
42 |
43 | # Get the UNETR model for semantic segmentation pipeline
44 | unetr = get_unetr(
45 | image_encoder=predictor.model.image_encoder, out_channels=num_classes, device=device,
46 | )
47 |
48 | # Load the model weights
49 | model_state = torch.load(checkpoint_path, map_location="cpu")["model_state"]
50 | unetr.load_state_dict(model_state)
51 | unetr.to(device)
52 | unetr.eval()
53 |
54 | sq_per_image = []
55 | with torch.no_grad():
56 | for image_path, gt_path in tqdm(zip(image_paths, gt_paths), total=len(image_paths)):
57 | # Read input image and corresponding labels.
58 | image = read_image(image_path)
59 | gt = read_image(gt_path)
60 |
61 | # If the inputs do not have any semantic labels, we do not evaluate them!
62 | if len(np.unique(gt)) == 1:
63 | continue
64 |
65 | # Pad the input image to fit the trained image shape.
66 | # NOTE: We pad it to top-left.
67 | image = np.pad(array=image, pad_width=((0, 256), (0, 256), (0, 0)), mode='constant')
68 |
69 | # Run inference
70 | tensor_image = image.transpose(2, 0, 1)
71 | tensor_image = torch.from_numpy(tensor_image[None]).to(device, torch.float32)
72 | outputs = unetr(tensor_image)
73 |
74 | # Perform argmax to get per class outputs.
75 | masks = torch.argmax(outputs, dim=1)
76 | masks = masks.detach().cpu().numpy().squeeze()
77 |
78 | # Unpad the images back to match the original shape.
79 | masks = masks[:256, :256]
80 |
81 | # Calcuate the score.
82 | sq_per_image.append(
83 | semantic_segmentation_quality(gt, masks, class_ids=[1, 2, 3, 4, 5])
84 | )
85 |
86 | def _get_average_results(sq_per_image, fname):
87 | msq_neoplastic_cells = np.nanmean([sq[0] for sq in sq_per_image])
88 | msq_inflammatory = np.nanmean([sq[1] for sq in sq_per_image])
89 | msq_connective = np.nanmean([sq[2] for sq in sq_per_image])
90 | msq_dead = np.nanmean([sq[3] for sq in sq_per_image])
91 | msq_epithelial = np.nanmean([sq[4] for sq in sq_per_image])
92 |
93 | all_msq = [msq_neoplastic_cells, msq_inflammatory, msq_connective, msq_dead, msq_epithelial]
94 | weighted_mean_msq = [msq * weight for msq, weight in zip(all_msq, per_class_weights)]
95 |
96 | results = {
97 | "neoplastic_cells": msq_neoplastic_cells,
98 | "inflammatory_cells": msq_inflammatory,
99 | "connective_cells": msq_connective,
100 | "dead_cells": msq_dead,
101 | "epithelial_cells": msq_epithelial,
102 | "weighted_mean": np.sum(weighted_mean_msq),
103 | "absolute_mean": np.mean(all_msq)
104 | }
105 | results = pd.DataFrame.from_dict([results])
106 | results.to_csv(fname)
107 | print(results)
108 |
109 | # Get average results per method.
110 | fname = checkpoint_path.rsplit("/")[-2] # Fetches the name of the style of training for semantic segmentation.
111 | _get_average_results(sq_per_image, f"pathosam_{fname}.csv")
112 |
113 |
114 | def main(args):
115 | evaluate_pannuke_semantic_segmentation(args)
116 |
117 |
118 | if __name__ == "__main__":
119 | import argparse
120 | parser = argparse.ArgumentParser()
121 | parser.add_argument("-i", "--input_path", default="/mnt/vast-nhr/projects/cidas/cca/test/data", type=str)
122 | parser.add_argument("-m", "--model_type", default="vit_b", type=str)
123 | parser.add_argument("-c", "--checkpoint_path", required=True)
124 | parser.add_argument("--view", action="store_true")
125 | args = parser.parse_args()
126 | main(args)
127 |
--------------------------------------------------------------------------------
/experiments/semantic_segmentation/generalists/submit_training.py:
--------------------------------------------------------------------------------
1 | import os
2 | import shutil
3 | import subprocess
4 | import itertools
5 | from datetime import datetime
6 |
7 |
8 | def submit_batch_script(script_name, decoder_only, decoder_from_pretrained, save_root, dry):
9 | batch_script = """#!/bin/bash
10 | #SBATCH -c 16
11 | #SBATCH --mem 64G
12 | #SBATCH -p grete-h100:shared
13 | #SBATCH -t 2-00:00:00
14 | #SBATCH -G H100:1
15 | #SBATCH -A gzz0001
16 | #SBATCH --job-name=patho-sam
17 |
18 | source ~/.bashrc
19 | micromamba activate super
20 | """
21 | # Prepare the python scripts
22 | python_script = "python train_pannuke.py "
23 | python_script += f"-s {save_root} "
24 |
25 | if decoder_only:
26 | python_script += "--decoder_only "
27 |
28 | if decoder_from_pretrained:
29 | python_script += "--decoder_from_pretrained "
30 |
31 | # Add the python script to the bash script
32 | batch_script += python_script
33 | with open(script_name, "w") as f:
34 | f.write(batch_script)
35 |
36 | if not dry:
37 | cmd = ["sbatch", script_name]
38 | subprocess.run(cmd)
39 |
40 |
41 | def get_batch_script_names(tmp_folder):
42 | tmp_folder = os.path.expanduser(tmp_folder)
43 | os.makedirs(tmp_folder, exist_ok=True)
44 |
45 | script_name = "patho-sam-semantic-segmentation"
46 |
47 | dt = datetime.now().strftime("%Y_%m_%d_%H_%M_%S_%f")
48 | tmp_name = script_name + dt
49 | batch_script = os.path.join(tmp_folder, f"{tmp_name}.sh")
50 |
51 | return batch_script
52 |
53 |
54 | def main(args):
55 | tmp_folder = "./gpu_jobs"
56 | if os.path.exists(tmp_folder):
57 | shutil.rmtree(tmp_folder)
58 |
59 | for decoder_only, decoder_from_pretrained in itertools.product(
60 | [True, False], [True, False]
61 | ):
62 | submit_batch_script(
63 | script_name=get_batch_script_names(tmp_folder),
64 | decoder_only=decoder_only,
65 | decoder_from_pretrained=decoder_from_pretrained,
66 | save_root=args.save_root,
67 | dry=args.dry,
68 | )
69 |
70 |
71 | if __name__ == "__main__":
72 | import argparse
73 | parser = argparse.ArgumentParser()
74 | parser.add_argument("--dry", action="store_true")
75 | parser.add_argument(
76 | "-s", "--save_root", type=str, default="/mnt/vast-nhr/projects/cidas/cca/experiments/patho_sam/semantic/models"
77 | )
78 | args = parser.parse_args()
79 | main(args)
80 |
--------------------------------------------------------------------------------
/experiments/semantic_segmentation/results/get_qualitative_plots.py:
--------------------------------------------------------------------------------
1 | import os
2 | from pathlib import Path
3 |
4 | import numpy as np
5 | from skimage.segmentation import find_boundaries
6 |
7 | from tukra.io import read_image
8 |
9 |
10 | ROOT = "/mnt/vast-nhr/projects/cidas/cca/experiments/patho_sam/semantic/external"
11 |
12 |
13 | def get_qualitative_plots():
14 | # Get the inputs and corresponding labels
15 | fnames = [
16 | '1445.tiff', '1145.tiff', '1391.tiff', '2087.tiff',
17 | '2382.tiff', '0335.tiff', '0594.tiff', '2386.tiff',
18 | '2316.tiff', '2381.tiff', "1446.tiff", "1407.tiff"
19 | ]
20 | image_paths = [os.path.join(ROOT, "semantic_split", "test_images", fname) for fname in fnames]
21 | gt_paths = [os.path.join(ROOT, "semantic_split", "test_labels", fname) for fname in fnames]
22 |
23 | for image_path, gt_path in zip(image_paths, gt_paths):
24 | this_fname = os.path.basename(image_path)
25 |
26 | image = read_image(image_path)
27 | gt = read_image(gt_path)
28 |
29 | # Get an overlay for the input image
30 | mask_overlay = np.zeros_like(image)
31 | # Individual colors per class
32 | mask_overlay[find_boundaries(gt == 1)] = [255, 0, 0]
33 | mask_overlay[find_boundaries(gt == 2)] = [0, 255, 0]
34 | mask_overlay[find_boundaries(gt == 3)] = [0, 0, 255]
35 | mask_overlay[find_boundaries(gt == 4)] = [255, 255, 0]
36 | mask_overlay[find_boundaries(gt == 5)] = [0, 255, 255]
37 |
38 | # Map overlay over the image
39 | alpha = 0.5
40 | overlay = alpha * image + (1.0 - alpha) * mask_overlay
41 | overlay = overlay.astype("uint8")
42 |
43 | def _get_expected_labels(fpath):
44 | label = read_image(fpath)
45 |
46 | # Individual colors per class
47 | label_overlay = np.zeros_like(image)
48 | label_overlay[label == 1] = [255, 0, 0]
49 | label_overlay[label == 2] = [0, 255, 0]
50 | label_overlay[label == 3] = [0, 0, 255]
51 | label_overlay[label == 4] = [255, 255, 0]
52 | label_overlay[label == 5] = [0, 255, 255]
53 |
54 | return label_overlay.astype("uint8")
55 |
56 | import matplotlib.pyplot as plt
57 | fig, ax = plt.subplots(1, 6, figsize=(30, 20))
58 | ax[0].imshow(overlay)
59 | ax[0].axis("off")
60 |
61 | ax[1].imshow(_get_expected_labels(os.path.join(ROOT, "hovernet", "pannuke", this_fname)))
62 | ax[1].axis("off")
63 |
64 | ax[2].imshow(_get_expected_labels(os.path.join(ROOT, "hovernext", "pannuke_convnextv2_tiny_1", this_fname)))
65 | ax[2].axis("off")
66 |
67 | ax[3].imshow(_get_expected_labels(os.path.join(ROOT, "cellvit", "SAM-H-x40", this_fname)))
68 | ax[3].axis("off")
69 |
70 | ax[4].imshow(_get_expected_labels(f"biomedparse_{this_fname}")) # I cache predictions on-the-fly atm.
71 | ax[4].axis("off")
72 |
73 | ax[5].imshow(_get_expected_labels(f"pathosam_{this_fname}")) # I cache predictions on-the-fly atm.
74 | ax[5].axis("off")
75 |
76 | plt.subplots_adjust(wspace=0.05, hspace=0)
77 | plt.savefig(("./" + Path(this_fname).stem + ".svg"), bbox_inches="tight")
78 | plt.close()
79 |
80 |
81 | def main():
82 | get_qualitative_plots()
83 |
84 |
85 | if __name__ == "__main__":
86 | main()
87 |
--------------------------------------------------------------------------------
/experiments/training/README.md:
--------------------------------------------------------------------------------
1 | # Finetuning Segment Anything Model for Histopathlogy
2 |
3 | This folder contains scripts for finetuning a generalist model for nuclei instance segmentation in histopathology H&E stained images. Here is a brief mention of the folders:
4 | - `generalists`: Contains scripts for training our generalist `patho-sam` models, on a large histopathology dataset for nuclei (automatic and interactive) instance segmentation, inspired by `micro-sam`.
5 | - `specialists`: Contains scripts for training our specialist `patho-sam` models, on PanNuke (for nuclei), GlaS (for glands), NuClick (for lymphocytes) and PUMA (for neutrophils) for task-/data-specific exploration.
6 |
7 | > NOTE: The scripts should run out-of-the-box and download most datasets automatically. For some datasets, automatic downloads are not supported. See the respective `get_generalist_datasets.py` or `get_specialist_dataset.py` for details about their downloads!
8 |
--------------------------------------------------------------------------------
/experiments/training/generalists/get_generalist_datasets.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import torch
4 | import torch.utils.data as data_util
5 |
6 | import torch_em
7 | from torch_em.data import ConcatDataset, MinInstanceSampler, datasets
8 | from torch_em.transform.label import PerObjectDistanceTransform
9 |
10 | from patho_sam.training import histopathology_identity
11 |
12 |
13 | def _get_train_val_split(ds, val_fraction=0.2):
14 | generator = torch.Generator().manual_seed(42)
15 | train_ds, val_ds = data_util.random_split(ds, [1 - val_fraction, val_fraction], generator=generator)
16 | return train_ds, val_ds
17 |
18 |
19 | def get_concat_hp_datasets(path, patch_shape, split_choice):
20 | # Important stuff for dataloaders.
21 | label_dtype = torch.float32
22 | sampler = MinInstanceSampler(min_num_instances=4, min_size=10)
23 |
24 | # Expected raw and label transforms.
25 | raw_transform = histopathology_identity
26 | label_transform = PerObjectDistanceTransform(
27 | distances=True, boundary_distances=True, directed_distances=False, foreground=True, instances=True, min_size=10,
28 | )
29 |
30 | # Datasets used for training: CPM15, CPM17, Lizard, MoNuSeg, PanNuke, PUMA, TNBC
31 | cpm15_ds = datasets.get_cpm_dataset(
32 | path=os.path.join(path, "cpm15"), patch_shape=patch_shape, sampler=sampler,
33 | label_dtype=label_dtype, raw_transform=raw_transform, data_choice="cpm15",
34 | split=split_choice, label_transform=label_transform, n_samples=50, # NOTE: oversampling the data.
35 | )
36 |
37 | lizard_ds = datasets.get_lizard_dataset(
38 | path=os.path.join(path, "lizard"), patch_shape=patch_shape, download=True, sampler=sampler,
39 | label_dtype=label_dtype, split=split_choice, label_transform=label_transform, raw_transform=raw_transform,
40 | )
41 |
42 | puma_ds = datasets.get_puma_dataset(
43 | path=os.path.join(path, "puma"), patch_shape=patch_shape, download=True, sampler=sampler, split=split_choice,
44 | label_transform=label_transform, raw_transform=raw_transform, label_dtype=label_dtype,
45 | )
46 |
47 | tnbc_ds = datasets.get_tnbc_dataset(
48 | path=os.path.join(path, "tnbc"), patch_shape=patch_shape, download=True, sampler=sampler,
49 | split=split_choice, label_transform=label_transform, label_dtype=label_dtype,
50 | ndim=2, raw_transform=raw_transform, n_samples=50, # NOTE: oversampling the data.
51 | )
52 |
53 | def _get_cpm17_dataset():
54 | cpm17_ds = datasets.get_cpm_dataset(
55 | path=os.path.join(path, "cpm17"), patch_shape=patch_shape, sampler=sampler,
56 | label_dtype=label_dtype, raw_transform=raw_transform, data_choice="cpm17",
57 | split="train", label_transform=label_transform, n_samples=50, # NOTE: oversampling the data.
58 | )
59 | cpm17_train_ds, cpm17_val_ds = _get_train_val_split(ds=cpm17_ds)
60 | if split_choice == "train":
61 | return cpm17_train_ds
62 | else:
63 | return cpm17_val_ds
64 |
65 | def _get_monuseg_dataset():
66 | monuseg_ds = datasets.get_monuseg_dataset(
67 | path=os.path.join(path, "monuseg"), patch_shape=patch_shape, download=True, split="train",
68 | sampler=sampler, label_transform=label_transform, label_dtype=label_dtype,
69 | ndim=2, raw_transform=raw_transform, n_samples=50, # NOTE: oversampling the data.
70 | )
71 | monuseg_train_ds, monuseg_val_ds = _get_train_val_split(ds=monuseg_ds)
72 | if split_choice == "train":
73 | return monuseg_train_ds
74 | else:
75 | return monuseg_val_ds
76 |
77 | def _get_pannuke_dataset():
78 | pannuke_ds = datasets.get_pannuke_dataset(
79 | path=os.path.join(path, "pannuke"), patch_shape=(1, *patch_shape), download=True,
80 | sampler=sampler, ndim=2, folds=["fold_1", "fold_2"], label_dtype=label_dtype,
81 | label_transform=label_transform, raw_transform=raw_transform,
82 | )
83 | pannuke_train_ds, pannuke_val_ds = _get_train_val_split(ds=pannuke_ds)
84 | if split_choice == "train":
85 | return pannuke_train_ds
86 | else:
87 | return pannuke_val_ds
88 |
89 | _datasets = [
90 | cpm15_ds, lizard_ds, puma_ds, tnbc_ds, _get_cpm17_dataset(), _get_monuseg_dataset(), _get_pannuke_dataset()
91 | ]
92 |
93 | return ConcatDataset(*_datasets)
94 |
95 |
96 | def get_generalist_hp_loaders(patch_shape, data_path):
97 | """This returns the concatenated histopathology datasets implemented in `torch_em`:
98 | https://github.com/constantinpape/torch-em/tree/main/torch_em/data/datasets/histopathology
99 | It will automatically download all the datasets
100 |
101 | NOTE: To remove / replace the datasets with another dataset, you need to add the datasets (for train and val splits)
102 | in `get_concat_hp_datasets`. The labels have to be in a label mask instance segmentation format.
103 | i.e. the tensors (inputs & masks) should be of same spatial shape, with each object in the mask having it's own ID.
104 | IMPORTANT: the ID 0 is reserved for background, and the IDs must be consecutive.
105 | """
106 | # Get the datasets
107 | generalist_train_dataset = get_concat_hp_datasets(path=data_path, patch_shape=patch_shape, split_choice="train")
108 | generalist_val_dataset = get_concat_hp_datasets(path=data_path, patch_shape=patch_shape, split_choice="val")
109 |
110 | # Get the dataloaders
111 | train_loader = torch_em.get_data_loader(generalist_train_dataset, batch_size=2, shuffle=True, num_workers=16)
112 | val_loader = torch_em.get_data_loader(generalist_val_dataset, batch_size=1, shuffle=True, num_workers=16)
113 |
114 | return train_loader, val_loader
115 |
--------------------------------------------------------------------------------
/experiments/training/generalists/run_all_finetuning.py:
--------------------------------------------------------------------------------
1 | import os
2 | import shutil
3 | import subprocess
4 | from datetime import datetime
5 |
6 |
7 | N_OBJECTS = {"vit_b": 40, "vit_l": 30, "vit_h": 25}
8 |
9 |
10 | def write_batch_script(out_path, _name, model_type, save_root, dry):
11 | "Writing scripts for different patho-sam finetunings."
12 | batch_script = """#!/bin/bash
13 | #SBATCH -t 14-00:00:00
14 | #SBATCH --mem 128G
15 | #SBATCH --nodes=1
16 | #SBATCH --ntasks=1
17 | #SBATCH -p grete:shared
18 | #SBATCH -G A100:1
19 | #SBATCH -A gzz0001
20 | #SBATCH -c 32
21 | #SBATCH --qos=14d
22 | #SBATCH --constraint=80gb
23 | #SBATCH --job-name=patho-sam-generalist
24 |
25 | source ~/.bashrc
26 | micromamba activate sam \n"""
27 |
28 | # python script
29 | python_script = f"python {_name}.py "
30 | python_script += f"-s {save_root} " # save root folder
31 | python_script += f"-m {model_type} " # name of the model configuration
32 | python_script += f"--n_objects {N_OBJECTS[model_type]} " # choice of the number of objects
33 |
34 | # let's add the python script to the bash script
35 | batch_script += python_script
36 |
37 | _op = out_path[:-3] + f"_{os.path.split(_name)[-1]}.sh"
38 | with open(_op, "w") as f:
39 | f.write(batch_script)
40 |
41 | cmd = ["sbatch", _op]
42 | if not dry:
43 | subprocess.run(cmd)
44 |
45 |
46 | def get_batch_script_names(tmp_folder):
47 | tmp_folder = os.path.expanduser(tmp_folder)
48 | os.makedirs(tmp_folder, exist_ok=True)
49 |
50 | script_name = "patho-sam-finetuning"
51 |
52 | dt = datetime.now().strftime("%Y_%m_%d_%H_%M_%S_%f")
53 | tmp_name = script_name + dt
54 | batch_script = os.path.join(tmp_folder, f"{tmp_name}.sh")
55 |
56 | return batch_script
57 |
58 |
59 | def submit_slurm(args):
60 | "Submit python script that needs gpus with given inputs on a slurm node."
61 | tmp_folder = "./gpu_jobs"
62 | if os.path.exists(tmp_folder):
63 | shutil.rmtree(tmp_folder)
64 |
65 | if args.model_type is None:
66 | models = list(N_OBJECTS.keys())
67 | else:
68 | models = [args.model_type]
69 |
70 | script_name = "train_generalist_histopathology"
71 | print(f"Running for {script_name}")
72 | for model_type in models:
73 | write_batch_script(
74 | out_path=get_batch_script_names(tmp_folder),
75 | _name=script_name,
76 | model_type=model_type,
77 | save_root=args.save_root,
78 | dry=args.dry,
79 | )
80 |
81 |
82 | def main(args):
83 | submit_slurm(args)
84 |
85 |
86 | if __name__ == "__main__":
87 | import argparse
88 | parser = argparse.ArgumentParser()
89 | parser.add_argument(
90 | "-s", "--save_root", type=str, default="/mnt/vast-nhr/projects/cidas/cca/experiments/patho_sam/models"
91 | )
92 | parser.add_argument("-m", "--model_type", type=str, default=None)
93 | parser.add_argument("--dry", action="store_true")
94 | args = parser.parse_args()
95 | main(args)
96 |
--------------------------------------------------------------------------------
/experiments/training/generalists/train_generalist_histopathology.py:
--------------------------------------------------------------------------------
1 | import os
2 | import argparse
3 |
4 | import torch
5 |
6 | import micro_sam.training as sam_training
7 | from micro_sam.util import export_custom_sam_model
8 |
9 | from get_generalist_datasets import get_generalist_hp_loaders
10 |
11 |
12 | def finetune_generalist(args):
13 | """Example code for finetuning SAM on histopathology datasets."""
14 | # override this (below) if you have some more complex set-up and need to specify the exact GPU.
15 | device = "cuda" if torch.cuda.is_available() else "cpu"
16 |
17 | # training settings:
18 | model_type = args.model_type
19 | checkpoint_path = None # override this to start training from a custom checkpoint.
20 | patch_shape = (512, 512) # the patch shape for training.
21 | n_objects_per_batch = args.n_objects # the number of objects per batch that will be sampled.
22 | checkpoint_name = f"{args.model_type}/patho_sam"
23 |
24 | # all the stuff we need for training
25 | train_loader, val_loader = get_generalist_hp_loaders(patch_shape=patch_shape, data_path=args.input_path)
26 | scheduler_kwargs = {"mode": "min", "factor": 0.9, "patience": 10}
27 |
28 | # Run training.
29 | sam_training.train_sam(
30 | name=checkpoint_name,
31 | model_type=model_type,
32 | train_loader=train_loader,
33 | val_loader=val_loader,
34 | early_stopping=None,
35 | n_objects_per_batch=n_objects_per_batch,
36 | checkpoint_path=checkpoint_path,
37 | with_segmentation_decoder=True,
38 | device=device,
39 | lr=1e-5,
40 | n_iterations=args.iterations,
41 | save_root=args.save_root,
42 | scheduler_kwargs=scheduler_kwargs,
43 | verify_n_labels_in_loader=None, # NOTE: Setting to 'None' verifies all labels in the loader(s).
44 | )
45 |
46 | if args.export_path is not None:
47 | checkpoint_path = os.path.join(
48 | "" if args.save_root is None else args.save_root, "checkpoints", checkpoint_name, "best.pt"
49 | )
50 | export_custom_sam_model(
51 | checkpoint_path=checkpoint_path,
52 | model_type=model_type,
53 | save_path=args.export_path,
54 | )
55 |
56 |
57 | def main():
58 | parser = argparse.ArgumentParser(description="Finetune Segment Anything for the Histopathology datasets.")
59 | parser.add_argument(
60 | "--input_path", "-i", default="/mnt/vast-nhr/projects/cidas/cca/experiments/patho_sam/data",
61 | help="The filepath to the datasets. If the data does not exist yet it will be downloaded.",
62 | )
63 | parser.add_argument(
64 | "--model_type", "-m", default="vit_b",
65 | help="The model type to use for fine-tuning. Either vit_t, vit_b, vit_l or vit_h.",
66 | )
67 | parser.add_argument(
68 | "--save_root", "-s", default=None,
69 | help="Where to save the checkpoint and logs. By default they will be saved where this script is run.",
70 | )
71 | parser.add_argument(
72 | "--iterations", type=int, default=int(25e4),
73 | help="For how many iterations should the model be trained? By default 250k.",
74 | )
75 | parser.add_argument(
76 | "--export_path", "-e",
77 | help="Where to export the finetuned model to. The exported model can be used in the annotation tools.",
78 | )
79 | parser.add_argument(
80 | "--n_objects", type=int, default=25, help="The number of instances (objects) per batch used for finetuning."
81 | )
82 | args = parser.parse_args()
83 | finetune_generalist(args)
84 |
85 |
86 | if __name__ == "__main__":
87 | main()
88 |
--------------------------------------------------------------------------------
/experiments/training/specialists/finetune_neutrophils.py:
--------------------------------------------------------------------------------
1 | import os
2 | import argparse
3 | from glob import glob
4 | from natsort import natsorted
5 |
6 | import torch
7 |
8 | import micro_sam.training as sam_training
9 | from micro_sam.util import export_custom_sam_model
10 |
11 | import torch_em
12 | from torch_em.data import MinInstanceSampler
13 | from torch_em.transform.label import PerObjectDistanceTransform
14 |
15 | from patho_sam.training import histopathology_identity
16 |
17 |
18 | def get_dataloaders(patch_shape, data_path):
19 | label_dtype = torch.float32
20 | sampler = MinInstanceSampler(min_num_instances=4, min_size=10)
21 |
22 | # Expected raw and label transforms.
23 | raw_transform = histopathology_identity
24 | label_transform = PerObjectDistanceTransform(
25 | distances=True,
26 | boundary_distances=True,
27 | directed_distances=False,
28 | foreground=True,
29 | instances=True,
30 | min_size=10,
31 | )
32 |
33 | train_images = natsorted(glob(os.path.join(data_path, "train", "images", "*.tiff")))
34 | train_labels = natsorted(glob(os.path.join(data_path, "train", "labels", "*.tiff")))
35 |
36 | val_images = natsorted(glob(os.path.join(data_path, "val", "images", "*.tiff")))
37 | val_labels = natsorted(glob(os.path.join(data_path, "val", "labels", "*.tiff")))
38 |
39 | train_loader = torch_em.default_segmentation_loader(
40 | raw_paths=train_images,
41 | raw_key=None,
42 | label_paths=train_labels,
43 | label_key=None,
44 | patch_shape=patch_shape,
45 | ndim=2,
46 | with_channels=True,
47 | is_seg_dataset=False,
48 | sampler=sampler,
49 | label_dtype=label_dtype,
50 | label_transform=label_transform,
51 | raw_transform=raw_transform,
52 | batch_size=1,
53 | shuffle=True,
54 | num_workers=16,
55 |
56 | )
57 | val_loader = torch_em.default_segmentation_dataset(
58 | raw_paths=val_images,
59 | raw_key=None,
60 | label_paths=val_labels,
61 | label_key=None,
62 | patch_shape=patch_shape,
63 | ndim=2,
64 | with_channels=True,
65 | is_seg_dataset=False,
66 | sampler=sampler,
67 | label_dtype=label_dtype,
68 | label_transform=label_transform,
69 | raw_transform=raw_transform,
70 | batch_size=1,
71 | shuffle=True,
72 | num_workers=16,
73 | )
74 |
75 | return train_loader, val_loader
76 |
77 |
78 | def finetune_specialist(args):
79 | """Example code for finetuning SAM on histopathology datasets."""
80 | # override this (below) if you have some more complex set-up and need to specify the exact GPU.
81 | device = "cuda" if torch.cuda.is_available() else "cpu"
82 |
83 | # training settings:
84 | model_type = args.model_type
85 | checkpoint_path = args.checkpoint # override this to start training from a custom checkpoint.
86 | patch_shape = (512, 512) # the patch shape for training.
87 | n_objects_per_batch = args.n_objects # the number of objects per batch that will be sampled.
88 | checkpoint_name = f"{args.model_type}/neutrophil_sam"
89 |
90 | # all the stuff we need for training
91 | train_loader, val_loader = get_dataloaders(patch_shape=patch_shape, data_path=args.input_path)
92 | scheduler_kwargs = {"mode": "min", "factor": 0.9, "patience": 10}
93 |
94 | # Run training.
95 | sam_training.train_sam(
96 | name=checkpoint_name,
97 | model_type=model_type,
98 | train_loader=train_loader,
99 | val_loader=val_loader,
100 | early_stopping=None,
101 | n_objects_per_batch=n_objects_per_batch,
102 | checkpoint_path=checkpoint_path,
103 | with_segmentation_decoder=True,
104 | device=device,
105 | lr=1e-5,
106 | n_iterations=args.iterations,
107 | save_root=args.save_root,
108 | scheduler_kwargs=scheduler_kwargs,
109 | verify_n_labels_in_loader=None, # NOTE: Setting to 'None' verifies all labels in the loader(s).
110 | )
111 |
112 | if args.export_path is not None:
113 | checkpoint_path = os.path.join(
114 | "" if args.save_root is None else args.save_root, "checkpoints", checkpoint_name, "best.pt"
115 | )
116 | export_custom_sam_model(
117 | checkpoint_path=checkpoint_path,
118 | model_type=model_type,
119 | save_path=args.export_path,
120 | )
121 |
122 |
123 | def main():
124 | parser = argparse.ArgumentParser(description="Finetune Segment Anything for the Histopathology datasets.")
125 | parser.add_argument(
126 | "--input_path", "-i", default="/mnt/vast-nhr/projects/cidas/cca/experiments/patho_sam/data",
127 | help="The filepath to the datasets. If the data does not exist yet it will be downloaded.",
128 | )
129 | parser.add_argument(
130 | "--model_type", "-m", default="vit_b_histopathology",
131 | help="The model type to use for fine-tuning. Either vit_t, vit_b, vit_l or vit_h.",
132 | )
133 | parser.add_argument(
134 | "--save_root", "-s", default=None,
135 | help="Where to save the checkpoint and logs. By default they will be saved where this script is run.",
136 | )
137 | parser.add_argument(
138 | "--dataset", "-d", default=None,
139 | help="Which dataset to finetune the specialist model on. Choose from datasets implemented in torch-em.",
140 | )
141 | parser.add_argument(
142 | "--checkpoint", "-c", default=None,
143 | help="Which custom checkpoint to finetune the specialist model on. If not provided, SAM will be finetuned.",
144 | )
145 | parser.add_argument(
146 | "--iterations", type=int, default=int(1e4),
147 | help="For how many iterations should the model be trained? By default 100k.",
148 | )
149 | parser.add_argument(
150 | "--export_path", "-e",
151 | help="Where to export the finetuned model to. The exported model can be used in the annotation tools.",
152 | )
153 | parser.add_argument(
154 | "--n_objects", type=int, default=25, help="The number of instances (objects) per batch used for finetuning."
155 | )
156 | args = parser.parse_args()
157 | finetune_specialist(args)
158 |
159 |
160 | if __name__ == "__main__":
161 | main()
162 |
--------------------------------------------------------------------------------
/experiments/training/specialists/finetune_specialist.py:
--------------------------------------------------------------------------------
1 | import os
2 | import argparse
3 |
4 | import torch
5 |
6 | import micro_sam.training as sam_training
7 | from micro_sam.util import export_custom_sam_model
8 |
9 | from get_specialist_dataset import get_specialist_loaders
10 |
11 |
12 | def finetune_specialist(args):
13 | """Example code for finetuning SAM on histopathology datasets."""
14 | # override this (below) if you have some more complex set-up and need to specify the exact GPU.
15 | device = "cuda" if torch.cuda.is_available() else "cpu"
16 |
17 | # training settings:
18 | model_type = args.model_type
19 | checkpoint_path = args.checkpoint # override this to start training from a custom checkpoint.
20 | patch_shape = (512, 512) # the patch shape for training.
21 | n_objects_per_batch = args.n_objects # the number of objects per batch that will be sampled.
22 | checkpoint_name = f"{args.model_type}/patho_sam"
23 |
24 | # all the stuff we need for training
25 | train_loader, val_loader = get_specialist_loaders(
26 | patch_shape=patch_shape, data_path=args.input_path, dataset=args.dataset
27 | )
28 | scheduler_kwargs = {"mode": "min", "factor": 0.9, "patience": 10}
29 |
30 | # Run training.
31 | sam_training.train_sam(
32 | name=checkpoint_name,
33 | model_type=model_type,
34 | train_loader=train_loader,
35 | val_loader=val_loader,
36 | early_stopping=None,
37 | n_objects_per_batch=n_objects_per_batch,
38 | checkpoint_path=checkpoint_path,
39 | with_segmentation_decoder=True,
40 | device=device,
41 | lr=1e-5,
42 | n_iterations=args.iterations,
43 | save_root=args.save_root,
44 | scheduler_kwargs=scheduler_kwargs,
45 | verify_n_labels_in_loader=None, # NOTE: Setting to 'None' verifies all labels in the loader(s).
46 | )
47 |
48 | if args.export_path is not None:
49 | checkpoint_path = os.path.join(
50 | "" if args.save_root is None else args.save_root, "checkpoints", checkpoint_name, "best.pt"
51 | )
52 | export_custom_sam_model(
53 | checkpoint_path=checkpoint_path,
54 | model_type=model_type,
55 | save_path=args.export_path,
56 | )
57 |
58 |
59 | def main():
60 | parser = argparse.ArgumentParser(description="Finetune Segment Anything for the Histopathology datasets.")
61 | parser.add_argument(
62 | "--input_path", "-i", default="/mnt/vast-nhr/projects/cidas/cca/experiments/patho_sam/data",
63 | help="The filepath to the datasets. If the data does not exist yet it will be downloaded.",
64 | )
65 | parser.add_argument(
66 | "--model_type", "-m", default="vit_b",
67 | help="The model type to use for fine-tuning. Either vit_t, vit_b, vit_l or vit_h.",
68 | )
69 | parser.add_argument(
70 | "--save_root", "-s", default=None,
71 | help="Where to save the checkpoint and logs. By default they will be saved where this script is run.",
72 | )
73 | parser.add_argument(
74 | "--dataset", "-d", default=None,
75 | help="Which dataset to finetune the specialist model on. Choose from datasets implemented in torch-em.",
76 | )
77 | parser.add_argument(
78 | "--checkpoint", "-c", default=None,
79 | help="Which custom checkpoint to finetune the specialist model on. If not provided, SAM will be finetuned.",
80 | )
81 | parser.add_argument(
82 | "--iterations", type=int, default=int(10e4),
83 | help="For how many iterations should the model be trained? By default 100k.",
84 | )
85 | parser.add_argument(
86 | "--export_path", "-e",
87 | help="Where to export the finetuned model to. The exported model can be used in the annotation tools.",
88 | )
89 | parser.add_argument(
90 | "--n_objects", type=int, default=25, help="The number of instances (objects) per batch used for finetuning."
91 | )
92 | args = parser.parse_args()
93 | finetune_specialist(args)
94 |
95 |
96 | if __name__ == "__main__":
97 | main()
98 |
--------------------------------------------------------------------------------
/experiments/training/specialists/get_specialist_dataset.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import torch
4 | import torch.utils.data as data_util
5 |
6 | import torch_em
7 | from torch_em.data import MinInstanceSampler, datasets
8 | from torch_em.transform.label import PerObjectDistanceTransform
9 |
10 | from patho_sam.training import histopathology_identity
11 |
12 |
13 | def _get_train_val_split(ds, val_fraction=0.2):
14 | generator = torch.Generator().manual_seed(42)
15 | train_ds, val_ds = data_util.random_split(ds, [1 - val_fraction, val_fraction], generator=generator)
16 | return train_ds, val_ds
17 |
18 |
19 | def get_specialist_dataset(path, patch_shape, split_choice, dataset):
20 | # Important stuff for dataloaders.
21 | label_dtype = torch.float32
22 | sampler = MinInstanceSampler(min_num_instances=4, min_size=10)
23 |
24 | # Expected raw and label transforms.
25 | raw_transform = histopathology_identity
26 | label_transform = PerObjectDistanceTransform(
27 | distances=True,
28 | boundary_distances=True,
29 | directed_distances=False,
30 | foreground=True,
31 | instances=True,
32 | min_size=10,
33 | )
34 |
35 | if dataset == "nuclick":
36 | ds = datasets.get_nuclick_dataset(
37 | path=os.path.join(path, dataset),
38 | patch_shape=patch_shape,
39 | download=True,
40 | sampler=MinInstanceSampler(min_num_instances=2, min_size=10),
41 | split="Train",
42 | label_dtype=label_dtype,
43 | label_transform=label_transform,
44 | raw_transform=raw_transform,
45 | )
46 | nuclick_train_ds, nuclick_val_ds = _get_train_val_split(ds)
47 | if split_choice == "train":
48 | return nuclick_train_ds
49 | else:
50 | return nuclick_val_ds
51 |
52 | elif dataset == "cryonuseg":
53 | return datasets.get_cryonuseg_dataset(
54 | path=os.path.join(path, "cryonuseg"),
55 | patch_shape=(1, *patch_shape),
56 | download=True,
57 | sampler=sampler,
58 | split=split_choice,
59 | rater="b1",
60 | label_dtype=label_dtype,
61 | label_transform=label_transform,
62 | raw_transform=raw_transform,
63 | )
64 |
65 | elif dataset == "pannuke":
66 | ds = datasets.get_pannuke_dataset(
67 | path=os.path.join(path, "pannuke"),
68 | patch_shape=(1, *patch_shape),
69 | download=True,
70 | sampler=sampler,
71 | ndim=2,
72 | folds=["fold_1", "fold_2"],
73 | label_dtype=label_dtype,
74 | label_transform=label_transform,
75 | raw_transform=raw_transform,
76 | )
77 | pannuke_train_ds, pannuke_val_ds = _get_train_val_split(ds=ds)
78 | if split_choice == "train":
79 | return pannuke_train_ds
80 | else:
81 | return pannuke_val_ds
82 |
83 | elif dataset == "glas":
84 | ds = datasets.get_glas_dataset(
85 | path=os.path.join(path, "glas"),
86 | patch_shape=patch_shape,
87 | download=True,
88 | sampler=MinInstanceSampler(min_num_instances=2),
89 | split="train",
90 | label_dtype=label_dtype,
91 | label_transform=label_transform,
92 | raw_transform=raw_transform,
93 | )
94 | glas_train_ds, glas_val_ds = _get_train_val_split(ds=ds)
95 | if split_choice == "train":
96 | return glas_train_ds
97 | else:
98 | return glas_val_ds
99 |
100 | else:
101 | raise NotImplementedError
102 |
103 |
104 | def get_specialist_loaders(patch_shape, data_path, dataset):
105 | """This returns a selected histopathology dataset implemented in `torch_em`:
106 | https://github.com/constantinpape/torch-em/tree/main/torch_em/data/datasets/histopathology
107 | It will automatically download the dataset
108 |
109 | NOTE: To remove / replace the dataset with another dataset, you need to add the datasets (for train and val splits)
110 | in `get_specialist_datasets`. The labels have to be in a label mask instance segmentation format.
111 | i.e. the tensors (inputs & masks) should be of same spatial shape, with each object in the mask having it's own ID.
112 | IMPORTANT: the ID 0 is reserved for background, and the IDs must be consecutive.
113 | """
114 | # Get the datasets
115 | specialist_train_dataset = get_specialist_dataset(
116 | path=data_path, patch_shape=patch_shape, split_choice="train", dataset=dataset
117 | )
118 | specialist_val_dataset = get_specialist_dataset(
119 | path=data_path, patch_shape=patch_shape, split_choice="val", dataset=dataset
120 | )
121 | # Get the dataloaders
122 | train_loader = torch_em.get_data_loader(specialist_train_dataset, batch_size=2, shuffle=True, num_workers=16)
123 | val_loader = torch_em.get_data_loader(specialist_val_dataset, batch_size=1, shuffle=True, num_workers=16)
124 |
125 | return train_loader, val_loader
126 |
--------------------------------------------------------------------------------
/experiments/training/specialists/run_all_finetuning.py:
--------------------------------------------------------------------------------
1 | import os
2 | import shutil
3 | import subprocess
4 | from datetime import datetime
5 |
6 |
7 | N_OBJECTS = {"vit_b": 40, "vit_l": 30, "vit_h": 25}
8 |
9 |
10 | def write_batch_script(out_path, _name, model_type, save_root, dry, dataset, checkpoint):
11 | "Writing scripts for different patho-sam specialist finetunings."
12 | batch_script = """#!/bin/bash
13 | #SBATCH -t 2-00:00:00
14 | #SBATCH --mem 64G
15 | #SBATCH --nodes=1
16 | #SBATCH --ntasks=1
17 | #SBATCH -p grete:shared
18 | #SBATCH -G A100:1
19 | #SBATCH -A nim00007
20 | #SBATCH -c 16
21 | #SBATCH --constraint=80gb
22 | #SBATCH --job-name=patho-sam-specialist
23 |
24 | source ~/.bashrc
25 | micromamba activate sam \n"""
26 |
27 | # python script
28 | python_script = f"python {_name}.py "
29 | python_script += f"-s {save_root} " # save root folder
30 | python_script += f"-d {dataset} " # dataset to finetune on
31 | python_script += f"-c {checkpoint} " # checkpoint to train on
32 | python_script += f"-m {model_type} " # name of the model configuration
33 | python_script += f"--n_objects {N_OBJECTS[model_type]} " # choice of the number of objects
34 |
35 | # let's add the python script to the bash script
36 | batch_script += python_script
37 |
38 | _op = out_path[:-3] + f"_{os.path.split(_name)[-1]}.sh"
39 | with open(_op, "w") as f:
40 | f.write(batch_script)
41 |
42 | cmd = ["sbatch", _op]
43 | if not dry:
44 | subprocess.run(cmd)
45 |
46 |
47 | def get_batch_script_names(tmp_folder):
48 | tmp_folder = os.path.expanduser(tmp_folder)
49 | os.makedirs(tmp_folder, exist_ok=True)
50 |
51 | script_name = "patho-sam-finetuning"
52 |
53 | dt = datetime.now().strftime("%Y_%m_%d_%H_%M_%S_%f")
54 | tmp_name = script_name + dt
55 | batch_script = os.path.join(tmp_folder, f"{tmp_name}.sh")
56 |
57 | return batch_script
58 |
59 |
60 | def submit_slurm(args):
61 | "Submit python script that needs gpus with given inputs on a slurm node."
62 | tmp_folder = "./gpu_jobs"
63 | if os.path.exists(tmp_folder):
64 | shutil.rmtree(tmp_folder)
65 |
66 | if args.model_type is None:
67 | models = list(N_OBJECTS.keys())
68 | else:
69 | models = [args.model_type]
70 |
71 | script_name = f"finetune_{args.dataset}_specialist"
72 | print(f"Running for {script_name}")
73 | for model_type in models:
74 | write_batch_script(
75 | out_path=get_batch_script_names(tmp_folder),
76 | _name=script_name,
77 | model_type=model_type,
78 | save_root=args.save_root,
79 | dry=args.dry,
80 | dataset=args.dataset,
81 | checkpoint=args.checkpoint
82 | )
83 |
84 |
85 | def main(args):
86 | submit_slurm(args)
87 |
88 |
89 | if __name__ == "__main__":
90 | import argparse
91 |
92 | parser = argparse.ArgumentParser()
93 | parser.add_argument("-s", "--save_root", type=str, default=None)
94 | parser.add_argument("-d", "--dataset", type=str, default=None, required=True)
95 | parser.add_argument("-c", "--checkpoint", type=str, default=None)
96 | parser.add_argument("-m", "--model_type", type=str, default=None)
97 | parser.add_argument("--dry", action="store_true")
98 | args = parser.parse_args()
99 | main(args)
100 |
--------------------------------------------------------------------------------
/patho_sam/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/computational-cell-analytics/patho-sam/482e2e336b9e5b75028c013452a70aba8107207a/patho_sam/__init__.py
--------------------------------------------------------------------------------
/patho_sam/__version__.py:
--------------------------------------------------------------------------------
1 | __version__ = "0.0.1"
2 |
--------------------------------------------------------------------------------
/patho_sam/evaluation/class_weights.py:
--------------------------------------------------------------------------------
1 | import os
2 | from glob import glob
3 | from typing import List, Union, Optional
4 |
5 | import numpy as np
6 | import pandas as pd
7 |
8 | from torch_em.util.image import load_data
9 | from torch_em.data.datasets import histopathology
10 |
11 | from ..training.util import CLASS_DICT
12 |
13 |
14 | LABEL_KEYS = {
15 | 'pannuke': {'semantic': 'labels/semantic', 'instance': 'labels/instances'},
16 | 'puma': {'semantic': 'labels/semantic/nuclei', 'instance': 'labels/instance/nuclei'},
17 | }
18 |
19 |
20 | def extract_class_weights(
21 | path: Union[os.PathLike, str], dataset: Optional[str] = None, output_path: Optional[Union[os.PathLike, str]] = None,
22 | ) -> List:
23 | """Extract class weights per semantic class.
24 |
25 | Args:
26 | path: The filepath where the input data is either located or will be downloaded automatically.
27 | dataset: The choice of dataset to extract the class weights.
28 | output_path: The output directory where you would like to store the class weights.
29 |
30 | Return:
31 | List of class weights for chosen dataset.
32 | """
33 |
34 | # Get the input filepaths
35 | _get_dataset = {
36 | "puma": lambda: histopathology.puma.get_puma_paths(path=path, split="train", download=True),
37 | "pannuke": lambda: histopathology.pannuke.get_pannuke_paths(
38 | path=path, folds=["fold_1", "fold_2"], download=True,
39 | )
40 | }
41 |
42 | # Next, get the volume paths.
43 | volume_paths = _get_dataset[dataset]()
44 |
45 | # Load the entire instance and semantic stack.
46 | if isinstance(volume_paths, str):
47 | instances = load_data(path, key=LABEL_KEYS[dataset]['instance'])
48 | semantics = load_data(path, key=LABEL_KEYS[dataset]['semantic'])
49 | else:
50 | all_semantic, all_instances = [], []
51 | for fpath in glob(os.path.join(path, "*.h5")):
52 | _instance = load_data(fpath, key=LABEL_KEYS[dataset]['instance'])
53 | _semantic = load_data(fpath, key=LABEL_KEYS[dataset]['semantic'])
54 |
55 | # Get all semantic and instance labels.
56 | all_instances.append(_instance)
57 | all_semantic.append(_semantic)
58 |
59 | instances = np.concatenate(all_instances, axis=0)
60 | semantics = np.concatenate(all_semantic, axis=0)
61 |
62 | # We need the following:
63 | # - Count the total number of instances.
64 | total_instance_counts = [
65 | len(np.unique(ilabel)[1:]) for ilabel in instances if len(np.unique(ilabel)) > 1
66 | ] # Counting all valid foreground instances only.
67 | total_instance_counts = sum(total_instance_counts)
68 |
69 | # - Count per-semantic-class instances.
70 | class_ids = CLASS_DICT[dataset].values()
71 | total_per_class_instance_counts = [
72 | [len(np.unique(np.where(slabel == cid, ilabel, 0))[1:]) for cid in class_ids]
73 | for ilabel, slabel in zip(instances, semantics) if len(np.unique(ilabel)) > 1
74 | ]
75 | assert total_instance_counts == sum([sum(t) for t in total_per_class_instance_counts])
76 |
77 | # Calculate per class mean values.
78 | total_per_class_instance_counts = [sum(x) for x in zip(*total_per_class_instance_counts)]
79 | assert total_instance_counts == sum(total_per_class_instance_counts)
80 |
81 | # Finally, let's get the weight per class. Results are saved as .csv in the output folder and output as a list
82 | per_class_weights = [t / total_instance_counts for t in total_per_class_instance_counts]
83 |
84 | # Store the class weights locally.
85 | os.makedirs(output_path, exist_ok=True)
86 | result_dict = {nt: weight for nt, weight in zip(CLASS_DICT[dataset].keys(), per_class_weights)}
87 | result_df = pd.DataFrame.from_dict(result_dict, orient='index', columns=["Class Weight"])
88 | result_df.to_csv(os.path.join(output_path, f"{dataset}_class_weights.csv"), index=True)
89 |
90 | return per_class_weights
91 |
--------------------------------------------------------------------------------
/patho_sam/evaluation/evaluation.py:
--------------------------------------------------------------------------------
1 | import os
2 | from typing import List, Union
3 |
4 | import h5py
5 | import numpy as np
6 |
7 | from elf.evaluation import dice_score
8 |
9 | CLASS_IDS = [1, 2, 3, 4, 5]
10 |
11 |
12 | def extract_class_weights_for_pannuke(fpath: Union[os.PathLike, str], class_ids: List = CLASS_IDS) -> List[float]:
13 | """Extract class weights per semantic class.
14 |
15 | Args:
16 | fpath: The filepath where the input stack for fold 3 stored for PanNuke dataset.
17 | Use `torch_em.data.datasets.histopathology.get_pannuke_paths` to get filepath for the stack.
18 | class_ids: The choice of all available class ids.
19 |
20 | Returns:
21 | List of class weights.
22 | """
23 | # Load the entire instance and semantic stack.
24 | with h5py.File(fpath, "r") as f:
25 | instances = f['labels/instances'][:]
26 | semantic = f['labels/semantic'][:]
27 |
28 | # We need the following:
29 | # - Count the total number of instances.
30 | total_instance_counts = [
31 | len(np.unique(ilabel)[1:]) for ilabel in instances if len(np.unique(ilabel)) > 1
32 | ] # Counting all valid foreground instances only.
33 | total_instance_counts = sum(total_instance_counts)
34 |
35 | # - Count per-semantic-class instances.
36 | total_per_class_instance_counts = [
37 | [len(np.unique(np.where(slabel == cid, ilabel, 0))[1:]) for cid in class_ids]
38 | for ilabel, slabel in zip(instances, semantic) if len(np.unique(ilabel)) > 1
39 | ]
40 | assert total_instance_counts == sum([sum(t) for t in total_per_class_instance_counts])
41 |
42 | # Calculate per class mean values.
43 | total_per_class_instance_counts = [sum(x) for x in zip(*total_per_class_instance_counts)]
44 | assert total_instance_counts == sum(total_per_class_instance_counts)
45 |
46 | # Finally, let's get the weight per class.
47 | per_class_weights = [t / total_instance_counts for t in total_per_class_instance_counts]
48 |
49 | return per_class_weights
50 |
51 |
52 | def semantic_segmentation_quality(
53 | ground_truth: np.ndarray, segmentation: np.ndarray, class_ids: List[int]
54 | ) -> List[float]:
55 | """Evaluation metric for the semantic segmentation task.
56 |
57 | Args:
58 | ground_truth: The ground truth with expected semantic labels.
59 | segmentation: The predicted masks with expected semantic labels.
60 | class_ids: The per-class id available for all tasks, to calculate per class semantic quality score.
61 |
62 | Returns:
63 | List of semantic quality score per class.
64 | """
65 | # First, we iterate over all classes
66 | sq_per_class = []
67 | for id in class_ids:
68 | # Get the per semantic class values.
69 | this_gt = (ground_truth == id).astype("uint32")
70 | this_seg = (segmentation == id).astype("uint32")
71 |
72 | # Check if the ground truth is empty for this semantic class. We skip calculation for this.
73 | if len(np.unique(this_gt)) == 1 and len(np.unique(this_seg) == 1):
74 | this_sq = np.nan
75 | else:
76 | this_sq = dice_score(this_seg, this_gt)
77 |
78 | sq_per_class.append(this_sq)
79 |
80 | return sq_per_class
81 |
--------------------------------------------------------------------------------
/patho_sam/io/__init__.py:
--------------------------------------------------------------------------------
1 | from .util import read_wsi
2 |
--------------------------------------------------------------------------------
/patho_sam/io/util.py:
--------------------------------------------------------------------------------
1 | import os
2 | from typing import Union, Tuple, Optional
3 |
4 | import numpy as np
5 |
6 | try:
7 | import slideio
8 | except ImportError:
9 | slideio = None
10 |
11 |
12 | def read_wsi(
13 | input_path: Union[os.PathLike, str],
14 | scale: Optional[Tuple[int, int]] = None,
15 | image_size: Optional[Tuple[int, int, int, int]] = None,
16 | ) -> np.ndarray:
17 | """Function to read whole-slide images (WSIs) in histopathology.
18 |
19 | The file formats tested are '.svs', '.scn', '.czi', '.zvi', '.ndpi',
20 | '.vsi', '.qptiff' and other gdal formats.
21 |
22 | Args:
23 | input_path: The path to the WSI.
24 | scale: Relevant for WSIs, to get the image for a desired scale. Provide the desired (H, W) combination to scale.
25 | You can choose (H, 0) or (0, W) to scale along one dimension and keep the resolution intact.
26 | image_size: Relevant for WSIs, to get a ROI crop for a desired shape.
27 |
28 | Returns:
29 | The numpy array.
30 | """
31 | if not os.path.exists(input_path):
32 | raise FileNotFoundError(input_path)
33 |
34 | assert slideio is not None, "Please install 'slideio': 'pip install slideio'."
35 | slide = slideio.open_slide(input_path) # Fetches the slide object.
36 |
37 | # Let's check with expected scale.
38 | if scale is None:
39 | scale = (0, 0) # Loads original resolution.
40 | else:
41 | if not isinstance(scale, Tuple) and len(scale) != 2:
42 | raise ValueError(
43 | "The scale parameter is expected to be a tuple of height and width dimensions, "
44 | "such that the new shape is (H', W')"
45 | )
46 |
47 | # Let's check for the expected size of the desired ROI.
48 | # NOTE: Here, we expect all values for placing an ROI precisely: (x, y, W, H)
49 | if image_size is None:
50 | image_size = (0, 0, 0, 0)
51 | else:
52 | if not isinstance(image_size, Tuple):
53 | raise ValueError(
54 | "The image size parameter is expected to be a tuple of desired target ROI crop, "
55 | "such that the new crop shape is for this ROI."
56 | )
57 |
58 | # If the user provides shapes in the usual 2d axes format, eg. (1024, 1024),
59 | # we provide them a top-left corner crop.
60 | if len(image_size) == 2:
61 | image_size = (0, 0, *image_size)
62 |
63 | assert len(scale) == 2
64 | assert len(image_size) == 4
65 |
66 | # NOTE: Each slide objects could contain one or multiple scenes,
67 | # which is coined as a continuous raster region (with the 2d image, other meta-data, etc)
68 | scene = slide.get_scene(0)
69 | input_array = scene.read_block(size=scale, rect=image_size)
70 |
71 | return input_array
72 |
--------------------------------------------------------------------------------
/patho_sam/training/__init__.py:
--------------------------------------------------------------------------------
1 | from .util import histopathology_identity, get_train_val_split
2 | from .semantic_trainer import SemanticInstanceTrainer
3 |
--------------------------------------------------------------------------------
/patho_sam/training/semantic_trainer.py:
--------------------------------------------------------------------------------
1 | from typing import Optional, List
2 |
3 | import torch
4 | import torch.nn as nn
5 |
6 | from micro_sam.training import SemanticSamTrainer
7 | from micro_sam.training.semantic_sam_trainer import CustomDiceLoss
8 |
9 |
10 | class SemanticInstanceTrainer(SemanticSamTrainer):
11 | """Modified trainer class for training the Segment Anything Model for semantic (instance) segmentation.
12 | """
13 | def __init__(
14 | self,
15 | convert_inputs,
16 | num_classes: int,
17 | dice_weight: Optional[float] = None,
18 | class_weights: Optional[List[float]] = None,
19 | **kwargs
20 | ):
21 | assert num_classes > 1
22 |
23 | if "loss" not in kwargs:
24 | kwargs["loss"] = CustomDiceLoss(num_classes=num_classes)
25 |
26 | if "metric" not in kwargs:
27 | kwargs["metric"] = CustomDiceLoss(num_classes=num_classes)
28 |
29 | super().__init__(convert_inputs=convert_inputs, num_classes=num_classes, dice_weight=dice_weight, **kwargs)
30 |
31 | self.class_weights = class_weights
32 | if self.class_weights is None:
33 | self.compute_ce_loss = nn.CrossEntropyLoss()
34 | else:
35 | weight_vals = torch.tensor(self.class_weights, dtype=torch.float32, device=self.device)
36 | self.compute_ce_loss = nn.CrossEntropyLoss(weight=weight_vals)
37 |
38 | if self.dice_weight is not None and (self.dice_weight < 0 or self.dice_weight > 1):
39 | raise ValueError("The weight factor should lie between 0 and 1.")
40 |
41 | self._kwargs = kwargs
42 |
43 | def _get_model_outputs(self, batched_inputs):
44 | """Get the predictions from the model.
45 | """
46 | inputs = torch.stack([bi["image"] for bi in batched_inputs], dim=0).to(self.device)
47 | outputs = self.model(inputs.to(self.device))
48 | return outputs
49 |
--------------------------------------------------------------------------------
/patho_sam/training/util.py:
--------------------------------------------------------------------------------
1 | from typing import Tuple, List
2 |
3 | import numpy as np
4 |
5 | import torch
6 | import torch.utils.data as data_util
7 |
8 | from torch_em.data.datasets.light_microscopy.neurips_cell_seg import to_rgb
9 |
10 |
11 | CLASS_MAP = {
12 | 'puma': {
13 | 2: 1,
14 | 3: 2, 4: 2, 5: 2, 6: 2, 7: 2,
15 | 1: 3, 8: 3,
16 | 10: 4,
17 | 9: 5,
18 | },
19 | }
20 |
21 | CLASS_DICT = {
22 | 'puma': {
23 | "nuclei_stroma": 1,
24 | "nuclei_tumor": 2,
25 | "nuclei_plasma_cell": 3,
26 | "nuclei_histiocyte": 4,
27 | "nuclei_lymphocyte": 5,
28 | "nuclei_melanophage": 6,
29 | "nuclei_neutrophil": 7,
30 | "nuclei_endothelium": 8,
31 | "nuclei_epithelium": 9,
32 | "nuclei_apoptosis": 10
33 | },
34 | 'pannuke': {
35 | "neoplastic": 1,
36 | "inflammatory": 2,
37 | "connective / soft tissue": 3,
38 | "dead cells": 4,
39 | "epithelial": 5,
40 | },
41 | }
42 |
43 |
44 | def histopathology_identity(x, ensure_rgb=True):
45 | """Identity transform.
46 | Inspired from 'micro_sam/training/util.py' -> 'identity' function.
47 |
48 | This ensures to skip data normalization when finetuning SAM.
49 | Data normalization is performed within the model to SA-1B data statistics
50 | and should thus be skipped as a preprocessing step in training.
51 | """
52 | if ensure_rgb:
53 | x = to_rgb(x)
54 |
55 | return x
56 |
57 |
58 | def get_train_val_split(
59 | ds: torch.utils.data.Dataset, val_fraction: float = 0.2, seed: int = 42,
60 | ) -> Tuple[torch.utils.data.Dataset, torch.utils.data.Dataset]:
61 | """Creates split for a dataset for a decided fraction.
62 |
63 | Args:
64 | dataset: The segmentation dataset.
65 | val_fraction: The fraction of split to decide for validation, and remanining for test.
66 | seed: Setting a seed for your storage device for reproducibility.
67 |
68 | Returns:
69 | Tuple of train and val datasets.
70 | """
71 | generator = torch.Generator().manual_seed(seed)
72 | train_ds, val_ds = data_util.random_split(ds, [1 - val_fraction, val_fraction], generator=generator)
73 | return train_ds, val_ds
74 |
75 |
76 | def remap_labels(y: np.ndarray, name: str) -> np.ndarray:
77 | """Maps the labels to overall meta classes, to match the
78 | semantic class structure of PanNuke dataset.
79 |
80 | Args:
81 | y: The original semantic label.
82 | name: The name of target dataset to remap original class ids to PanNuke class ids.
83 |
84 | Returns:
85 | The remapped labels.
86 | """
87 | if name not in CLASS_MAP:
88 | raise ValueError(f"The chosen dataset '{name}' is not supported.")
89 |
90 | # Get the class id map.
91 | mapping = CLASS_MAP[name]
92 |
93 | # Remap the labels.
94 | # NOTE: We go with this remapping to make sure that each ids are mapped to the exact values.
95 | per_id_lookup_array = np.array([mapping.get(i, 0) for i in range(max(mapping) + 1)], dtype=np.int32)
96 | y_remapped = per_id_lookup_array[y]
97 | return y_remapped
98 |
99 |
100 | def calculate_class_weights_for_loss_weighting(
101 | foreground_class_weights: List[float] = [0.4702, 0.1797, 0.2229, 0.0159, 0.1113],
102 | ) -> List[float]:
103 | """Calculates the class weights for weighting the cross entropy loss.
104 |
105 | NOTE 1: The default weights originate from weighting both the PanNuke and PUMA labels.
106 | TODO: Scripts coming soon!
107 |
108 | NOTE 2: We weigh the classes using relative integers on a scale of 1 to 10,
109 | where 1 resembles the most frequent class and 10 the least frequent class.
110 |
111 | NOTE 3: Make sure that the order of weights match the class id order.
112 |
113 | Args:
114 | foreground_class_weight: The ratio / frequency of foreground class weights.
115 |
116 | Returns:
117 | The integer weighting for each class, including the background class.
118 | """
119 | foreground_class_weights = np.array(foreground_class_weights)
120 |
121 | # Define the range for integer weighting.
122 | background_weight, max_weight = 1, 10
123 |
124 | # Normalize the class weights.
125 | min_val, max_val = np.min(foreground_class_weights), np.max(foreground_class_weights)
126 |
127 | # Invert the mapping (i.e. higher for rarer class, lower for common classes)
128 | mapped_weights = max_weight - ((foreground_class_weights - min_val) / (max_val - min_val)) * (max_weight - 1)
129 |
130 | # Make sure that the most common class has weight 1.
131 | mapped_weights[np.argmax(foreground_class_weights)] = background_weight
132 |
133 | # Round the weights and convert them to integer values.
134 | final_weights = np.round(mapped_weights).astype(int)
135 |
136 | # Add background weights in the beginning.
137 | final_weights_with_bg = [background_weight, *final_weights]
138 |
139 | return final_weights_with_bg
140 |
--------------------------------------------------------------------------------
/patho_sam/util.py:
--------------------------------------------------------------------------------
1 | import os
2 | import collections
3 | from typing import Union, Optional, OrderedDict
4 |
5 | import pooch
6 |
7 | import torch
8 |
9 | from micro_sam.util import microsam_cachedir, get_cache_directory
10 | from micro_sam.sample_data import fetch_wholeslide_histopathology_example_data
11 |
12 | from .io import read_wsi
13 |
14 |
15 | DECODER_URL = "https://owncloud.gwdg.de/index.php/s/TFbI25UZoixd1hi/download"
16 |
17 |
18 | def export_semantic_segmentation_decoder(
19 | checkpoint_path: Union[str, os.PathLike], save_path: Union[str, os.PathLike],
20 | ):
21 | """Exports the weights of the trained convolutional decoder for semantic segemntation task.
22 |
23 | Args:
24 | checkpoint_path: Filepath to the trained semantic segmentation checkpoint.
25 | save_path: Filepath where the decoder weights will be stored.
26 | """
27 | # Load the model state from finetuned checkpoint.
28 | model_state = torch.load(checkpoint_path, map_location="cpu")["model_state"]
29 |
30 | # Get the decoder state only.
31 | decoder_state = collections.OrderedDict(
32 | [(k, v) for k, v in model_state.items() if not k.startswith("encoder")]
33 | )
34 |
35 | # Store the decoder state to a desired path.
36 | torch.save(decoder_state, save_path)
37 |
38 |
39 | def get_semantic_segmentation_decoder_weights(save_path: Optional[Union[str, os.PathLike]] = None) -> OrderedDict:
40 | """Get the semantic segmentation decoder weights for initializing the decoder-only.
41 |
42 | Args:
43 | save_path: Whether to save the model checkpoints to desired path.
44 |
45 | Returns:
46 | The pretrained decoder weights.
47 | """
48 | # By default, we store decoder weights to `micro-sam` cache directory.
49 | save_directory = os.path.join(microsam_cachedir(), "models") if save_path is None else save_path
50 |
51 | # Download the model weights
52 | fname = "vit_b_histopathology_semantic_segmentation_decoder"
53 | pooch.retrieve(
54 | url=DECODER_URL,
55 | known_hash="bdd05a55c72c02abce72a7aa6885c6ec21df9c43fda9cf3c5d11ef5788de0ab0",
56 | fname=fname,
57 | path=save_directory,
58 | progressbar=True,
59 | )
60 |
61 | # Get the checkpoint path.
62 | checkpoint_path = os.path.join(save_directory, fname)
63 |
64 | # Load the decoder state.
65 | state = torch.load(checkpoint_path, map_location="cpu")
66 |
67 | return state
68 |
69 |
70 | def get_example_wsi_data():
71 | """@private"""
72 | import argparse
73 | parser = argparse.ArgumentParser(description="Download and visualize the example whole-slide image (WSI).")
74 | parser.add_argument(
75 | "-s", "--save_path", type=str, default=None,
76 | help=f"The folder to store the whole-slide image. By default, it is stored at '{get_cache_directory()}'."
77 | )
78 | parser.add_argument(
79 | "--roi", nargs="+", type=int, default=None,
80 | help="The roi shape of the whole slide image for automatic segmentation. By default, predicts on entire WSI. "
81 | "You can provide the ROI shape as: '--roi X Y W H'.",
82 | )
83 | parser.add_argument(
84 | "--view", action="store_true", help="Whether to view the WSI in napari."
85 | )
86 |
87 | args = parser.parse_args()
88 |
89 | # Get the folder to store the WSI. By default, stores it at 'micro-sam' cache directory.
90 | save_dir = os.path.join(get_cache_directory(), "sample_data") if args.save_path is None else args.save_dir
91 |
92 | # Download the example WSI.
93 | example_data = fetch_wholeslide_histopathology_example_data(save_dir)
94 |
95 | # Check the ROI and convert it to tuple, if provided by user.
96 | roi = None if args.roi is None else tuple(args.roi)
97 |
98 | if args.view:
99 | # Load the WSI image.
100 | image = read_wsi(example_data, image_size=roi)
101 |
102 | # Get multi-scales for the input image.
103 | multiscale_images = [
104 | image,
105 | read_wsi(example_data, image_size=roi, scale=(int(image.shape[0] / 2), 0)),
106 | read_wsi(example_data, image_size=roi, scale=(int(image.shape[0] / 4), 0)),
107 | read_wsi(example_data, image_size=roi, scale=(int(image.shape[0] / 8), 0)),
108 | ]
109 |
110 | import napari
111 | v = napari.Viewer()
112 | v.add_image(multiscale_images, name="Input Image")
113 | napari.run()
114 |
115 | print(f"The example WSI is stored at '{example_data}'.")
116 |
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [build-system]
2 | requires = ["setuptools>=42", "wheel"]
3 | build-backend = "setuptools.build_meta"
4 |
--------------------------------------------------------------------------------
/scripts/plotting/get_dataset_figure.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import matplotlib.pyplot as plt
4 |
5 | from torch_em.data import datasets, MinTwoInstanceSampler
6 |
7 | from micro_sam.training import identity
8 | from micro_sam.evaluation.model_comparison import _overlay_outline
9 |
10 |
11 | ROOT = "/mnt/vast-nhr/projects/cidas/cca/experiments/patho_sam/data"
12 |
13 |
14 | def for_fig_1a():
15 | sampler = MinTwoInstanceSampler()
16 |
17 | # Get a cummulative image for multiple datasets.
18 | get_loaders = {
19 | "consep": lambda: datasets.get_consep_loader(
20 | path=os.path.join(ROOT, "consep"), batch_size=1, shuffle=True, raw_transform=identity,
21 | patch_shape=(512, 512), split="test", download=True, sampler=sampler,
22 | ),
23 | # "cpm15": lambda: datasets.get_cpm_loader(
24 | # path=os.path.join(ROOT, "cpm15"), batch_size=1, shuffle=True, patch_shape=(512, 512),
25 | # data_choice="cpm15", resize_inputs=True, download=True, sampler=sampler, raw_transform=identity,
26 | # ),
27 | "cpm17": lambda: datasets.get_cpm_loader(
28 | path=os.path.join(ROOT, "cpm17"), batch_size=1, shuffle=True, patch_shape=(512, 512), split="test",
29 | data_choice="cpm17", resize_inputs=True, download=True, sampler=sampler, raw_transform=identity,
30 | ),
31 | "cryonuseg": lambda: datasets.get_cryonuseg_loader(
32 | path=os.path.join(ROOT, "cryonuseg"), batch_size=1, shuffle=True, patch_shape=(512, 512),
33 | split="test", resize_inputs=True, download=True, sampler=sampler, raw_transform=identity,
34 | ),
35 | "lizard": lambda: datasets.get_lizard_loader(
36 | path=os.path.join(ROOT, "lizard"), batch_size=1, patch_shape=(512, 512), split="test",
37 | resize_inputs=True, download=True, shuffle=True, sampler=sampler, raw_transform=identity,
38 | ),
39 | "lynsec": lambda: datasets.get_lynsec_loader(
40 | path=os.path.join(ROOT, "lynsec"), batch_size=1, patch_shape=(512, 512), shuffle=True,
41 | choice="h&e", resize_inputs=True, download=True, sampler=sampler, raw_transform=identity,
42 | ),
43 | "monuseg": lambda: datasets.get_monuseg_loader(
44 | path=os.path.join(ROOT, "monuseg"), batch_size=1, shuffle=True, patch_shape=(512, 512),
45 | resize_inputs=True, download=True, sampler=sampler, raw_transform=identity, split="test",
46 | ),
47 | "nuinsseg": lambda: datasets.get_nuinsseg_loader(
48 | path=os.path.join(ROOT, "nuinsseg"), batch_size=1, shuffle=True, patch_shape=(512, 512),
49 | download=True, sampler=sampler, raw_transform=identity, resize_inputs=True,
50 | ),
51 | # "pannuke": lambda: datasets.get_pannuke_loader(
52 | # path=os.path.join(ROOT, "pannuke"), batch_size=1, patch_shape=(512, 512), folds=["fold_3"],
53 | # download=True, shuffle=True, sampler=sampler, raw_transform=identity,
54 | # ),
55 | "puma": lambda: datasets.get_puma_loader(
56 | path=os.path.join(ROOT, "puma"), batch_size=1, patch_shape=(512, 512), split="test",
57 | download=True, sampler=sampler, raw_transform=identity, resize_inputs=True,
58 | ),
59 | "tnbc": lambda: datasets.get_tnbc_loader(
60 | path=os.path.join(ROOT, "tnbc"), batch_size=1, patch_shape=(512, 512), ndim=2, shuffle=True,
61 | split="train", resize_inputs=True, download=True, sampler=sampler, raw_transform=identity,
62 | )
63 | }
64 |
65 | fig, ax = plt.subplots(3, 3, figsize=(30, 30))
66 | ax = ax.flatten()
67 |
68 | for i, dname in enumerate(get_loaders.keys()):
69 | loader = get_loaders[dname]()
70 | counter = 0
71 | for x, y in loader:
72 | if counter > 0:
73 | break
74 | counter += 1
75 |
76 | x, y = x.squeeze().numpy(), y.squeeze().numpy()
77 |
78 | # Make channels last for RGB images.
79 | if x.shape[0] == 3:
80 | x = x.transpose(1, 2, 0)
81 |
82 | # Normalize images.
83 | from torch_em.transform.raw import normalize
84 | x = normalize(x) * 255
85 | x = x.astype(int)
86 |
87 | # Finally, plot them into one place.
88 | image = _overlay_outline(x, y, outline_dilation=1)
89 |
90 | ax[i].imshow(image, cmap="gray")
91 | ax[i].axis("off")
92 |
93 | plt.subplots_adjust(hspace=0.01, wspace=0.01)
94 | plt.savefig("./fig_1a_histopathology_dataset_images.png", bbox_inches="tight")
95 | plt.savefig("./fig_1a_histopatholgoy_dataset_images.svg", bbox_inches="tight")
96 |
97 |
98 | def main():
99 | for_fig_1a()
100 |
101 |
102 | if __name__ == "__main__":
103 | main()
104 |
--------------------------------------------------------------------------------
/scripts/test_pannuke.py:
--------------------------------------------------------------------------------
1 | import os
2 | from tqdm import tqdm
3 | from glob import glob
4 | from natsort import natsorted
5 |
6 | import h5py
7 | import numpy as np
8 | import imageio.v3 as imageio
9 | from skimage.segmentation import relabel_sequential
10 |
11 | from torch_em.data.datasets.histopathology import pannuke, monuseg
12 |
13 | from micro_sam.util import get_sam_model
14 | from micro_sam.evaluation import inference, evaluation
15 |
16 |
17 | def _pad_image(image, target_shape=(512, 512), pad_value=0):
18 | pad = [
19 | (max(t - s, 0) // 2, max(t - s, 0) - max(t - s, 0) // 2) for s, t in zip(image.shape[:2], target_shape)
20 | ]
21 |
22 | if image.ndim == 3:
23 | pad.append((0, 0))
24 |
25 | return np.pad(image, pad, mode="constant", constant_values=pad_value)
26 |
27 |
28 | def get_data_paths(input_path, dataset_name):
29 | # Set specific data folders.
30 | input_path = os.path.join(input_path, dataset_name)
31 |
32 | if dataset_name == "pannuke":
33 | # Create clone of single images in input_path directory.
34 | data_dir = os.path.join(input_path, "benchmark_2d")
35 |
36 | if not os.path.exists(data_dir):
37 | # First, we get the fold 3.
38 | fold_path = pannuke.get_pannuke_paths(path=input_path, folds=["fold_3"], download=True)
39 | fold_path = fold_path[0]
40 |
41 | # Next, simply extract the images.
42 | with h5py.File(fold_path, "r") as f:
43 | image_stack = f["images"][:].transpose(1, 2, 3, 0)
44 | label_stack = f["labels/instances"][:]
45 |
46 | # Store them one-by-one locally in an experiment folder.
47 | os.makedirs(data_dir)
48 |
49 | for i, (image, label) in tqdm(
50 | enumerate(zip(image_stack, label_stack)), total=len(image_stack), desc="Extracting images",
51 | ):
52 | # There has to be some foreground in the image to be considered for interactive segmentation.
53 | if len(np.unique(label)) == 1:
54 | continue
55 |
56 | # NOTE: I am padding the image below to match the shape of inputs on which it is trained,
57 | # i.e. (512, 512), for proper reproducibility (otherwise the results are slightly worse)
58 | image = _pad_image(image)
59 | label = _pad_image(label)
60 |
61 | imageio.imwrite(os.path.join(data_dir, f"pannuke_fold_3_{i:05}_image.tif"), image, compression="zlib")
62 | imageio.imwrite(os.path.join(data_dir, f"pannuke_fold_3_{i:05}_label.tif"), label, compression="zlib")
63 |
64 | # Well, now we have our image and label paths.
65 | image_paths = natsorted(glob(os.path.join(data_dir, "*_image.tif")))
66 | label_paths = natsorted(glob(os.path.join(data_dir, "*_label.tif")))
67 |
68 | assert len(image_paths) == len(label_paths)
69 |
70 | # HACK: for debugging purpose, I will check on first 100 images. Remove the next line to run it on all images.
71 | image_paths, label_paths = image_paths[:100], label_paths[:100]
72 |
73 | elif dataset_name == "monuseg":
74 | # Create clone of cropped images in input_path directory.
75 | data_dir = os.path.join(input_path, "benchmark_2d")
76 | os.makedirs(data_dir, exist_ok=True)
77 |
78 | curr_image_paths, curr_label_paths = monuseg.get_monuseg_paths(path=input_path, split="test", download=True)
79 |
80 | # Let's do a simple cropping to test stuff.
81 | image_paths, label_paths = [], []
82 | for curr_image_path, curr_label_path in zip(curr_image_paths, curr_label_paths):
83 | image = imageio.imread(curr_image_path)
84 | label = imageio.imread(curr_label_path).astype("uint32")
85 |
86 | # Do a simple cropping and relabel instances.
87 | image, label = image[:512, :512, :], label[:512, :512]
88 | label = relabel_sequential(label)[0]
89 |
90 | # And save the cropped image and corresponding label.
91 | image_path = os.path.join(data_dir, os.path.basename(curr_image_path))
92 | image_paths.append(image_path)
93 | imageio.imwrite(image_path, image, compression="zlib")
94 |
95 | label_path = os.path.join(data_dir, os.path.basename(curr_label_path))
96 | label_paths.append(label_path)
97 | imageio.imwrite(label_path, label, compression="zlib")
98 |
99 | else:
100 | raise ValueError
101 |
102 | return image_paths, label_paths
103 |
104 |
105 | def run_interactive_segmentation(input_path, experiment_folder, model_type, start_with_box_prompt=True):
106 |
107 | # Setup 1: PanNuke images (pad (256, 256) images up to (512, 512) to match the training patch shape)
108 | # image_paths, label_paths = get_data_paths(input_path, "pannuke") # NOTE: uncomment to run it on PanNuke
109 |
110 | # Setup 2: MoNuSeg images (since the images are larger than training patch shape, we crop them to shape (512, 512))
111 | image_paths, label_paths = get_data_paths(input_path, "monuseg") # NOTE: comment this before running other setups.
112 |
113 | # Get the Segment Anything model.
114 | predictor = get_sam_model(model_type=model_type)
115 |
116 | # Then run interactive segmentation by simulating prompts from labels.
117 | prediction_root = os.path.join(
118 | experiment_folder, ("start_with_box" if start_with_box_prompt else "start_with_point")
119 | )
120 | inference.run_inference_with_iterative_prompting(
121 | predictor=predictor,
122 | image_paths=image_paths,
123 | gt_paths=label_paths,
124 | embedding_dir=None,
125 | prediction_dir=prediction_root,
126 | start_with_box_prompt=start_with_box_prompt,
127 | )
128 |
129 | # And evaluate the results.
130 | results = evaluation.run_evaluation_for_iterative_prompting(
131 | gt_paths=label_paths,
132 | prediction_root=prediction_root,
133 | experiment_folder=experiment_folder,
134 | start_with_box_prompt=start_with_box_prompt,
135 | )
136 |
137 | print(results)
138 |
139 |
140 | def main(args):
141 | run_interactive_segmentation(
142 | input_path=args.input_path,
143 | model_type=args.model_type,
144 | experiment_folder=args.experiment_folder,
145 | start_with_box_prompt=False,
146 | )
147 |
148 |
149 | if __name__ == "__main__":
150 | import argparse
151 | parser = argparse.ArgumentParser()
152 | parser.add_argument("-i", "--input_path", default="/mnt/vast-nhr/projects/cidas/cca/data", type=str)
153 | parser.add_argument("-e", "--experiment_folder", default="./experiments", type=str)
154 | parser.add_argument("-m", "--model_type", default="vit_b_histopathology", type=str)
155 | args = parser.parse_args()
156 | main(args)
157 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 |
3 | import runpy
4 | from distutils.core import setup
5 |
6 |
7 | __version__ = runpy.run_path("patho_sam/__version__.py")["__version__"]
8 |
9 |
10 | setup(
11 | name='patho_sam',
12 | description='Segment Anything for Histopathology.',
13 | version=__version__,
14 | author='Titus Griebel, Anwai Archit, Constantin Pape',
15 | author_email='titus.griebel@stud.uni-goettingen.de, anwai.archit@uni-goettingen.de, constantin.pape@informatik.uni-goettingen.de', # noqa
16 | url='https://github.com/computational-cell-analytics/patho-sam',
17 | packages=['patho_sam'],
18 | license="MIT",
19 | entry_points={
20 | "console_scripts": [
21 | "patho_sam.example_data = patho_sam.util:get_example_wsi_data",
22 | "patho_sam.automatic_segmentation = patho_sam.automatic_segmentation:main",
23 | ]
24 | }
25 | )
26 |
--------------------------------------------------------------------------------