├── .gitmodules ├── eval ├── get_dim.py ├── README.md ├── run_hest.py ├── renormalized.py ├── extract_cls_token.py ├── configs │ └── vision │ │ └── pathology │ │ ├── offline │ │ └── classification │ │ │ ├── gleason_arvaniti.yaml │ │ │ ├── mhist.yaml │ │ │ ├── crc.yaml │ │ │ ├── breakhis.yaml │ │ │ ├── bach.yaml │ │ │ ├── bracs.yaml │ │ │ ├── patch_camelyon.yaml │ │ │ ├── pcam_10shots.yaml │ │ │ ├── camelyon16_small.yaml │ │ │ └── panda_small.yaml │ │ └── online │ │ └── segmentation │ │ ├── consep.yaml │ │ └── monusac.yaml ├── object_tools.py ├── run_eva_internal.sh ├── backbones.py ├── hest_bench_config.yaml └── kaiko.py ├── LICENSE ├── .gitignore └── README.md /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "eval/eva"] 2 | path = eval/eva 3 | url = git@github.com:kaiko-ai/eva.git 4 | [submodule "eval/HEST"] 5 | path = eval/HEST 6 | url = git@github.com:mahmoodlab/HEST.git 7 | -------------------------------------------------------------------------------- /eval/get_dim.py: -------------------------------------------------------------------------------- 1 | import sys 2 | 3 | def get_dim(name, check_concat=False): 4 | dim = 384 5 | if "h_optimus_0" in name.lower(): 6 | dim = 1536 7 | elif "virchow2" in name.lower(): 8 | dim = 1280 9 | elif "owkin_vits16_phikon" in name.lower() or "dino_vits16_phikon" in name.lower(): 10 | dim = 768 11 | elif "vitl1" in name.lower(): 12 | dim = 1024 13 | elif "vitb1" in name.lower(): 14 | dim = 768 15 | elif "vitb8" in name.lower(): 16 | dim = 768 17 | elif "vitg14" in name.lower(): 18 | dim = 1536 19 | 20 | if check_concat and "concat" in name: 21 | return dim * 2 22 | return dim 23 | 24 | if __name__ == '__main__': 25 | if len(sys.argv) < 2: 26 | print("Usage: python get_dim.py []") 27 | exit(1) 28 | print(get_dim(sys.argv[1], int(sys.argv[2]) if len(sys.argv) >= 3 else False)) 29 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2025 Kaiko 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /eval/README.md: -------------------------------------------------------------------------------- 1 | # Evaluation 2 | 3 | ### Running HEST 4 | 5 | 1. Get the submodules `git submodule update --init --recursive` 6 | 2. Install HEST `pip install -e HEST` and get the HEST-1k data according to the HEST's [manual](https://github.com/mahmoodlab/HEST/tree/568df6ce8e6cd88829866a8cfdb5d2452f72f96e?tab=readme-ov-file#downloadquery-hest-1k-1tb) 7 | 3. Install hestcore `pip install hestcore` 8 | 4. Run the benchmark: `python run_hest.py --config hest_bench_config.yaml` 9 | 10 | ### Running EVA 11 | 1. Get the submodules `git submodule update --init --recursive` 12 | 2. Copy the custom model definitions [kaiko.py](./eva/src/eva/vision/models/networks/backbones/pathology/kaiko.py) and other helpers to eva: `cp kaiko.py object_tools.py backbones.py renormalized.py ./eva/src/eva/vision/models/networks/backbones/pathology/; mv kaiko.py kaiko.py_;` 13 | 3. Install EVA: `pip install -e eva` 14 | 4. Download the data according following the eva [instructions](https://kaiko-ai.github.io/eva/main/datasets) (or add flag `DOWNLOAD_DATA=true` to the run scripts to automatically download the data inside the first eva run), also see the [eva user guide](https://kaiko-ai.github.io/eva/main/user-guide/) 15 | 5. Run benchmarks [run_eva_internal.sh](./run_eva_internal.sh), e.g.: 16 | ```bash 17 | MODEL_NAME="vitg14_Kaiko_Midnight_concat"; 18 | TASK="camelyon16_small"; 19 | MODEL_NAME=$MODEL_NAME \ 20 | IN_FEATURES=$(python get_dim.py $MODEL_NAME 1) \ 21 | OUTPUT_ROOT=/pathology_fm/mikhail/data/eva/RESULTS_patience/${TASK}/${MODEL_NAME} \ 22 | EMBEDDINGS_ROOT=/dev/shm/mikhail/eva/EMBEDDINGS/${TASK}/${MODEL_NAME} \ 23 | DATA_ROOT=/pathology_fm/mikhail/data/eva/${TASK} \ 24 | NORMALIZE_MEAN=[0.5,0.5,0.5] \ 25 | NORMALIZE_STD=[0.5,0.5,0.5] \ 26 | python -m eva predict_fit --config configs/vision/pathology/offline/classification/${TASK}.yaml 27 | ``` 28 | -------------------------------------------------------------------------------- /eval/run_hest.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | """ 3 | Run HEST benchmark on a custom ViT checkpoint. 4 | """ 5 | 6 | import argparse 7 | import os 8 | import socket 9 | import sys 10 | from pathlib import Path 11 | 12 | import torch 13 | import torchvision.transforms as T 14 | import yaml 15 | from hest.bench import benchmark 16 | from loguru import logger 17 | from torchvision.transforms import v2 18 | 19 | # External libraries 20 | from object_tools import ModelBuilder 21 | 22 | # Set the default timeout (in seconds) 23 | socket.setdefaulttimeout(50) 24 | 25 | 26 | def main(): 27 | parser = argparse.ArgumentParser( 28 | description="Run HEST benchmark on a ViT-S/14 DINOv2 model with distilled weights." 29 | ) 30 | parser.add_argument( 31 | "--config", 32 | type=str, 33 | default="./hest_bench_config.yaml", 34 | help="Path to the HEST benchmark config YAML file.", 35 | ) 36 | parser.add_argument( 37 | "--checkpoint_path", 38 | type=str, 39 | default="./vits14_distilled_from_tcga-nki_099_test.pth", 40 | help="Path to the distilled model checkpoint file.", 41 | ) 42 | parser.add_argument( 43 | "--repo-or-dir", 44 | type=str, 45 | default="facebookresearch/dinov2:main", 46 | help="mlfhub repository or local directory for the base model.", 47 | ) 48 | parser.add_argument( 49 | "--model-name", 50 | type=str, 51 | default="dinov2_vits14", 52 | help="Name of the model to load from the repository.", 53 | ) 54 | parser.add_argument( 55 | "--pretrained", 56 | action="store_true", 57 | help="Load base model with pretrained weights (default: False).", 58 | ) 59 | args = parser.parse_args() 60 | 61 | # Initialize logging 62 | logger.remove() # Remove any default handlers to reconfigure 63 | logger.add(sys.stderr, level="INFO", format="{time} | {message}") 64 | 65 | # Resolve paths 66 | config_path = Path(args.config).resolve() 67 | 68 | # Validate config exists 69 | if not config_path.is_file(): 70 | logger.error(f"Config file not found at: {config_path}") 71 | sys.exit(1) 72 | 73 | # Build transforms 74 | model_transforms = v2.Compose( 75 | [ 76 | v2.Resize(size=224), # FYI: change to 392 for the Midnight-92k/392 models 77 | # v2.CenterCrop(size=224), 78 | T.ToTensor(), 79 | T.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)), 80 | ] 81 | ) 82 | 83 | # Load model 84 | with open(config_path) as f: 85 | config = yaml.safe_load(f) 86 | 87 | custom_encoders = config["custom_encoders"] 88 | config.pop("custom_encoders") 89 | for name, model_args in custom_encoders.items(): 90 | model = ModelBuilder(**model_args).build() 91 | 92 | config["results_dir"] += "/" + name 93 | config["embed_dataroot"] += "/" + name 94 | config["weights_root"] += "/" + name 95 | model_config_path = config["results_dir"] + "/config.yaml" 96 | os.mkdir(config["results_dir"]) 97 | with open(model_config_path, "w") as yaml_file: 98 | yaml.dump(config, yaml_file) 99 | 100 | # Run benchmark 101 | try: 102 | logger.info(f"Running HEST benchmark for model: {name}") 103 | benchmark( 104 | model, 105 | model_transforms, 106 | precision=torch.float32, 107 | config=model_config_path, 108 | ) 109 | logger.info("Benchmark completed successfully.") 110 | except Exception as e: 111 | logger.exception(f"Benchmark failed: {e}") 112 | 113 | 114 | if __name__ == "__main__": 115 | main() 116 | -------------------------------------------------------------------------------- /eval/renormalized.py: -------------------------------------------------------------------------------- 1 | """Wrappers for models that require custom input normalization.""" 2 | 3 | import timm 4 | import torch 5 | import torch.nn as nn 6 | import torchvision 7 | from typing_extensions import override 8 | 9 | from object_tools import ModelBuilder 10 | 11 | 12 | class Renormalize: 13 | """Changes the normalization of the images. 14 | 15 | E.g., Renormalize(old={'mean': [0.5, 0.5, 0.5], 'std': [0.5, 0.5, 0.5]}, 16 | new={'mean': [0.1, 0.1, 0.1], 'std': [0.1, 0.1, 0.1]}) 17 | will renormalize the data from 0.5 to 0.1 mean and std. 18 | 19 | Args: 20 | old: The normalization that's been applied to the data. 21 | e.g., old={'mean': [0.5, 0.5, 0.5], 'std': [0.5, 0.5, 0.5]}. 22 | new: New normalization parameters. 23 | e.g., new={'mean': [0.1, 0.1, 0.1], 'std': [0.1, 0.1, 0.1]}. 24 | """ 25 | 26 | def __init__(self, old: dict[str, list[float]], new: dict[str, list[float]]) -> None: 27 | """Initializes the class. 28 | 29 | Args: 30 | old: A dict with keys `mean` and `std` representing the old normalization. 31 | new: A dict with keys `mean` and `std` representing the new normalization. 32 | """ 33 | super().__init__() 34 | 35 | mean = torch.Tensor(old["mean"]) 36 | new_mean = torch.Tensor(new["mean"]) 37 | 38 | std = torch.Tensor(old["std"]) 39 | new_std = torch.Tensor(new["std"]) 40 | 41 | # x -> y = (x-m)/s -> (y - (-m/s))/(1/s) 42 | # (y - ((m2-m)/s))/(s2/s) 43 | # denormalize is the same as the normalization with mean=-1 and sigma=2 44 | self._renormalize = torchvision.transforms.Normalize( 45 | (new_mean - mean) / std, new_std / std, inplace=False 46 | ) 47 | 48 | def __call__(self, tensor: torch.Tensor) -> torch.Tensor: 49 | return self._renormalize(tensor) 50 | 51 | 52 | class RenormalizingModel(nn.Module): 53 | """Wrapper class for models with custom normalization.""" 54 | 55 | def __init__( # pylint: disable=dangerous-default-value 56 | self, 57 | model: nn.Module, 58 | new_normalization: dict[str, list[float]], 59 | data_normalization: dict[str, list[float]] = { 60 | "mean": [0.5, 0.5, 0.5], 61 | "std": [0.5, 0.5, 0.5], 62 | }, 63 | ) -> None: 64 | """Initializes the model. 65 | 66 | Args: 67 | data_normalization: The normalization that has already been applied to the input data. 68 | new_normalization: The desired normalization. 69 | """ 70 | super().__init__() 71 | 72 | self._renormalize = Renormalize(old=data_normalization, new=new_normalization) 73 | 74 | if isinstance(model, dict) and {"path", "arguments"}.issubset(model): 75 | if model["arguments"].get("act_layer") == "torch.nn.SiLU": 76 | model["arguments"]["act_layer"] = torch.nn.SiLU 77 | if model["arguments"].get("mlp_layer") == "timm.layers.SwiGLUPacked": 78 | model["arguments"]["mlp_layer"] = timm.layers.SwiGLUPacked 79 | self._model = ModelBuilder(**model).build() 80 | else: 81 | self._model = model 82 | 83 | @override 84 | def forward(self, tensor: torch.Tensor) -> torch.Tensor: 85 | """Forward pass through the model.""" 86 | 87 | return self._model(self._renormalize(tensor)) 88 | 89 | @override 90 | def get_intermediate_layers( 91 | self, 92 | x: torch.Tensor, 93 | n: int | tuple[int, ...] = 1, # Layers or n last layers to take 94 | reshape: bool = False, 95 | return_class_token: bool = False, 96 | norm=True, 97 | ) -> tuple[torch.Tensor | tuple[torch.Tensor]]: 98 | """Returns the intermediate layers of the model.""" 99 | return self._model.get_intermediate_layers( 100 | self._renormalize(x), 101 | n=n, 102 | reshape=reshape, 103 | return_class_token=return_class_token, 104 | norm=norm, 105 | ) 106 | -------------------------------------------------------------------------------- /eval/extract_cls_token.py: -------------------------------------------------------------------------------- 1 | """Defines the CLS token extractors.""" 2 | 3 | import math 4 | 5 | import torch 6 | from transformers import modeling_outputs 7 | from typing_extensions import override 8 | 9 | 10 | class ExtractCLSToken: 11 | """Extracts the CLS token from a ViT model output.""" 12 | def __call__( 13 | self, tensor: torch.Tensor | modeling_outputs.BaseModelOutputWithPooling 14 | ) -> torch.Tensor: 15 | """Call method for the transformation. 16 | 17 | Args: 18 | tensor: The tensor representing the model output. 19 | """ 20 | if isinstance(tensor, torch.Tensor): 21 | return tensor[:, 0, :] 22 | if isinstance(tensor, modeling_outputs.BaseModelOutputWithPooling): 23 | return tensor.last_hidden_state[:, 0, :] 24 | raise ValueError(f"Unsupported type {type(tensor)}") 25 | 26 | 27 | class ExtractConcatToken: 28 | """Extracts the CLS with Mean Patch tokens from a ViT model output.""" 29 | 30 | def __init__(self, num_reg_tokens: int = 0) -> None: 31 | self.num_reg_tokens = num_reg_tokens 32 | 33 | @override 34 | def __call__( 35 | self, tensor: torch.Tensor | modeling_outputs.BaseModelOutputWithPooling 36 | ) -> torch.Tensor: 37 | """Call method for the transformation. 38 | 39 | Args: 40 | tensor: The tensor representing the model output. 41 | """ 42 | if isinstance(tensor, torch.Tensor): 43 | return torch.cat( 44 | [tensor[:, 0, :], tensor[:, 1 + self.num_reg_tokens :, :].mean(1)], dim=-1 45 | ) 46 | if isinstance(tensor, modeling_outputs.BaseModelOutputWithPooling): 47 | return torch.cat( 48 | [ 49 | tensor.last_hidden_state[:, 0, :], 50 | tensor.last_hidden_state[:, 1 + self.num_reg_tokens :, :].mean(1), 51 | ], 52 | dim=-1, 53 | ) 54 | raise ValueError(f"Unsupported type {type(tensor)}") 55 | 56 | 57 | class ExtractPatchFeatures: 58 | """Extracts the patch features from a ViT model output.""" 59 | 60 | def __init__( 61 | self, 62 | has_cls_token: bool = True, 63 | num_reg_tokens: int = 0, 64 | ignore_remaining_dims: bool = False, 65 | ) -> None: 66 | """Initializes the transformation. 67 | 68 | Args: 69 | has_cls_token: If set to `True`, the model output is expected to have 70 | a classification token. 71 | num_reg_tokens: The number of register tokens in the model output. 72 | ignore_remaining_dims: If set to `True`, ignore the remaining dimensions 73 | of the patch grid if it is not a square number. 74 | """ 75 | self._has_cls_token = has_cls_token 76 | self._num_reg_tokens = num_reg_tokens 77 | self._ignore_remaining_dims = ignore_remaining_dims 78 | 79 | def __call__( 80 | self, tensor: torch.Tensor | modeling_outputs.BaseModelOutputWithPooling 81 | ) -> list[torch.Tensor]: 82 | """Call method for the transformation. 83 | 84 | Args: 85 | tensor: The raw embeddings of the model. 86 | 87 | Returns: 88 | A tensor (batch_size, hidden_size, n_patches_height, n_patches_width) 89 | representing the model output. 90 | """ 91 | num_skip = int(self._has_cls_token) + self._num_reg_tokens 92 | if isinstance(tensor, modeling_outputs.BaseModelOutputWithPooling): 93 | features = tensor.last_hidden_state[:, num_skip:, :].permute(0, 2, 1) 94 | else: 95 | features = tensor[:, num_skip:, :].permute(0, 2, 1) 96 | 97 | batch_size, hidden_size, patch_grid = features.shape 98 | height = width = int(math.sqrt(patch_grid)) 99 | if height * width != patch_grid: 100 | if self._ignore_remaining_dims: 101 | features = features[:, :, -height * width :] 102 | else: 103 | raise ValueError(f"Patch grid size must be a square number {patch_grid}.") 104 | patch_embeddings = features.view(batch_size, hidden_size, height, width) 105 | 106 | return [patch_embeddings] 107 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # UV 98 | # Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | #uv.lock 102 | 103 | # poetry 104 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 105 | # This is especially recommended for binary packages to ensure reproducibility, and is more 106 | # commonly ignored for libraries. 107 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 108 | #poetry.lock 109 | 110 | # pdm 111 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 112 | #pdm.lock 113 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 114 | # in version control. 115 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control 116 | .pdm.toml 117 | .pdm-python 118 | .pdm-build/ 119 | 120 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 121 | __pypackages__/ 122 | 123 | # Celery stuff 124 | celerybeat-schedule 125 | celerybeat.pid 126 | 127 | # SageMath parsed files 128 | *.sage.py 129 | 130 | # Environments 131 | .env 132 | .venv 133 | env/ 134 | venv/ 135 | ENV/ 136 | env.bak/ 137 | venv.bak/ 138 | 139 | # Spyder project settings 140 | .spyderproject 141 | .spyproject 142 | 143 | # Rope project settings 144 | .ropeproject 145 | 146 | # mkdocs documentation 147 | /site 148 | 149 | # mypy 150 | .mypy_cache/ 151 | .dmypy.json 152 | dmypy.json 153 | 154 | # Pyre type checker 155 | .pyre/ 156 | 157 | # pytype static type analyzer 158 | .pytype/ 159 | 160 | # Cython debug symbols 161 | cython_debug/ 162 | 163 | # PyCharm 164 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 165 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 166 | # and can be added to the global gitignore or merged into this file. For a more nuclear 167 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 168 | #.idea/ 169 | 170 | # Ruff stuff: 171 | .ruff_cache/ 172 | 173 | # PyPI configuration file 174 | .pypirc 175 | -------------------------------------------------------------------------------- /eval/configs/vision/pathology/offline/classification/gleason_arvaniti.yaml: -------------------------------------------------------------------------------- 1 | --- 2 | trainer: 3 | class_path: eva.Trainer 4 | init_args: 5 | n_runs: &N_RUNS ${oc.env:N_RUNS, 5} 6 | default_root_dir: &OUTPUT_ROOT ${oc.env:OUTPUT_ROOT, logs/${oc.env:MODEL_NAME, dino_vits16}/offline/gleason_arvaniti} 7 | max_steps: &MAX_STEPS ${oc.env:MAX_STEPS, 12500} 8 | checkpoint_type: ${oc.env:CHECKPOINT_TYPE, best} 9 | callbacks: 10 | - class_path: eva.callbacks.ConfigurationLogger 11 | - class_path: lightning.pytorch.callbacks.TQDMProgressBar 12 | init_args: 13 | refresh_rate: ${oc.env:TQDM_REFRESH_RATE, 1} 14 | - class_path: lightning.pytorch.callbacks.LearningRateMonitor 15 | init_args: 16 | logging_interval: epoch 17 | - class_path: lightning.pytorch.callbacks.ModelCheckpoint 18 | init_args: 19 | filename: best 20 | save_last: true 21 | save_top_k: 1 22 | monitor: &MONITOR_METRIC ${oc.env:MONITOR_METRIC, val/MulticlassAccuracy} 23 | mode: &MONITOR_METRIC_MODE ${oc.env:MONITOR_METRIC_MODE, max} 24 | - class_path: lightning.pytorch.callbacks.EarlyStopping 25 | init_args: 26 | min_delta: 0 27 | patience: ${oc.env:PATIENCE, 21} 28 | monitor: *MONITOR_METRIC 29 | mode: *MONITOR_METRIC_MODE 30 | - class_path: eva.callbacks.ClassificationEmbeddingsWriter 31 | init_args: 32 | output_dir: &DATASET_EMBEDDINGS_ROOT ${oc.env:EMBEDDINGS_ROOT, ./data/embeddings}/${oc.env:MODEL_NAME, dino_vits16}/gleason_arvaniti 33 | dataloader_idx_map: 34 | 0: train 35 | 1: val 36 | backbone: 37 | class_path: eva.vision.models.ModelFromRegistry 38 | init_args: 39 | model_name: ${oc.env:MODEL_NAME, universal/vit_small_patch16_224_dino} 40 | model_extra_kwargs: ${oc.env:MODEL_EXTRA_KWARGS, null} 41 | overwrite: false 42 | logger: 43 | - class_path: lightning.pytorch.loggers.TensorBoardLogger 44 | init_args: 45 | save_dir: *OUTPUT_ROOT 46 | name: "" 47 | model: 48 | class_path: eva.HeadModule 49 | init_args: 50 | head: 51 | class_path: torch.nn.Linear 52 | init_args: 53 | in_features: ${oc.env:IN_FEATURES, 384} 54 | out_features: &NUM_CLASSES 4 55 | criterion: torch.nn.CrossEntropyLoss 56 | optimizer: 57 | class_path: torch.optim.AdamW 58 | init_args: 59 | lr: ${oc.env:LR_VALUE, 0.0003} 60 | lr_scheduler: 61 | class_path: torch.optim.lr_scheduler.CosineAnnealingLR 62 | init_args: 63 | T_max: *MAX_STEPS 64 | eta_min: 0.0 65 | metrics: 66 | common: 67 | - class_path: eva.metrics.AverageLoss 68 | - class_path: eva.metrics.MulticlassClassificationMetrics 69 | init_args: 70 | num_classes: *NUM_CLASSES 71 | data: 72 | class_path: eva.DataModule 73 | init_args: 74 | datasets: 75 | train: 76 | class_path: eva.datasets.EmbeddingsClassificationDataset 77 | init_args: &DATASET_ARGS 78 | root: *DATASET_EMBEDDINGS_ROOT 79 | manifest_file: manifest.csv 80 | split: train 81 | val: 82 | class_path: eva.datasets.EmbeddingsClassificationDataset 83 | init_args: 84 | <<: *DATASET_ARGS 85 | split: val 86 | predict: 87 | - class_path: eva.vision.datasets.GleasonArvaniti 88 | init_args: &PREDICT_DATASET_ARGS 89 | root: ${oc.env:DATA_ROOT, ./data/arvaniti_gleason_patches} 90 | split: train 91 | transforms: 92 | class_path: eva.vision.data.transforms.common.ResizeAndCrop 93 | init_args: 94 | size: ${oc.env:RESIZE_DIM, 224} 95 | mean: ${oc.env:NORMALIZE_MEAN, [0.485, 0.456, 0.406]} 96 | std: ${oc.env:NORMALIZE_STD, [0.229, 0.224, 0.225]} 97 | - class_path: eva.vision.datasets.GleasonArvaniti 98 | init_args: 99 | <<: *PREDICT_DATASET_ARGS 100 | split: val 101 | dataloaders: 102 | train: 103 | batch_size: &BATCH_SIZE ${oc.env:BATCH_SIZE, 256} 104 | num_workers: &N_DATA_WORKERS ${oc.env:N_DATA_WORKERS, 4} 105 | shuffle: true 106 | val: 107 | batch_size: *BATCH_SIZE 108 | num_workers: *N_DATA_WORKERS 109 | predict: 110 | batch_size: &PREDICT_BATCH_SIZE ${oc.env:PREDICT_BATCH_SIZE, 64} 111 | num_workers: *N_DATA_WORKERS 112 | -------------------------------------------------------------------------------- /eval/configs/vision/pathology/offline/classification/mhist.yaml: -------------------------------------------------------------------------------- 1 | --- 2 | trainer: 3 | class_path: eva.Trainer 4 | init_args: 5 | n_runs: &N_RUNS ${oc.env:N_RUNS, 5} 6 | default_root_dir: &OUTPUT_ROOT ${oc.env:OUTPUT_ROOT, logs/${oc.env:MODEL_NAME, dino_vits16}/offline/mhist} 7 | max_steps: &MAX_STEPS ${oc.env:MAX_STEPS, 12500} 8 | checkpoint_type: ${oc.env:CHECKPOINT_TYPE, best} 9 | callbacks: 10 | - class_path: eva.callbacks.ConfigurationLogger 11 | - class_path: lightning.pytorch.callbacks.TQDMProgressBar 12 | init_args: 13 | refresh_rate: ${oc.env:TQDM_REFRESH_RATE, 1} 14 | - class_path: lightning.pytorch.callbacks.LearningRateMonitor 15 | init_args: 16 | logging_interval: epoch 17 | - class_path: lightning.pytorch.callbacks.ModelCheckpoint 18 | init_args: 19 | filename: best 20 | save_last: true 21 | save_top_k: 1 22 | monitor: &MONITOR_METRIC ${oc.env:MONITOR_METRIC, val/BinaryBalancedAccuracy} 23 | mode: &MONITOR_METRIC_MODE ${oc.env:MONITOR_METRIC_MODE, max} 24 | - class_path: lightning.pytorch.callbacks.EarlyStopping 25 | init_args: 26 | min_delta: 0 27 | patience: ${oc.env:PATIENCE, 70} 28 | monitor: *MONITOR_METRIC 29 | mode: *MONITOR_METRIC_MODE 30 | - class_path: eva.callbacks.ClassificationEmbeddingsWriter 31 | init_args: 32 | output_dir: &DATASET_EMBEDDINGS_ROOT ${oc.env:EMBEDDINGS_ROOT, ./data/embeddings}/${oc.env:MODEL_NAME, dino_vits16}/mhist 33 | dataloader_idx_map: 34 | 0: train 35 | 1: test 36 | backbone: 37 | class_path: eva.vision.models.ModelFromRegistry 38 | init_args: 39 | model_name: ${oc.env:MODEL_NAME, universal/vit_small_patch16_224_dino} 40 | model_extra_kwargs: ${oc.env:MODEL_EXTRA_KWARGS, null} 41 | overwrite: false 42 | logger: 43 | - class_path: lightning.pytorch.loggers.TensorBoardLogger 44 | init_args: 45 | save_dir: *OUTPUT_ROOT 46 | name: "" 47 | model: 48 | class_path: eva.HeadModule 49 | init_args: 50 | head: 51 | class_path: torch.nn.Linear 52 | init_args: 53 | in_features: ${oc.env:IN_FEATURES, 384} 54 | out_features: 1 55 | criterion: torch.nn.BCEWithLogitsLoss 56 | optimizer: 57 | class_path: torch.optim.AdamW 58 | init_args: 59 | lr: ${oc.env:LR_VALUE, 0.0003} 60 | lr_scheduler: 61 | class_path: torch.optim.lr_scheduler.CosineAnnealingLR 62 | init_args: 63 | T_max: *MAX_STEPS 64 | eta_min: 0.0 65 | metrics: 66 | common: 67 | - class_path: eva.metrics.AverageLoss 68 | - class_path: eva.metrics.BinaryClassificationMetrics 69 | data: 70 | class_path: eva.DataModule 71 | init_args: 72 | datasets: 73 | train: 74 | class_path: eva.datasets.EmbeddingsClassificationDataset 75 | init_args: &DATASET_ARGS 76 | root: *DATASET_EMBEDDINGS_ROOT 77 | manifest_file: manifest.csv 78 | split: train 79 | target_transforms: 80 | class_path: torchvision.transforms.v2.ToDtype 81 | init_args: 82 | dtype: torch.float32 83 | val: 84 | class_path: eva.datasets.EmbeddingsClassificationDataset 85 | init_args: 86 | <<: *DATASET_ARGS 87 | split: test 88 | predict: 89 | - class_path: eva.vision.datasets.MHIST 90 | init_args: &PREDICT_DATASET_ARGS 91 | root: ${oc.env:DATA_ROOT, ./data/mhist} 92 | split: train 93 | transforms: 94 | class_path: eva.vision.data.transforms.common.ResizeAndCrop 95 | init_args: 96 | size: ${oc.env:RESIZE_DIM, 224} 97 | mean: ${oc.env:NORMALIZE_MEAN, [0.485, 0.456, 0.406]} 98 | std: ${oc.env:NORMALIZE_STD, [0.229, 0.224, 0.225]} 99 | - class_path: eva.vision.datasets.MHIST 100 | init_args: 101 | <<: *PREDICT_DATASET_ARGS 102 | split: test 103 | dataloaders: 104 | train: 105 | batch_size: &BATCH_SIZE ${oc.env:BATCH_SIZE, 256} 106 | num_workers: &N_DATA_WORKERS ${oc.env:N_DATA_WORKERS, 4} 107 | shuffle: true 108 | val: 109 | batch_size: *BATCH_SIZE 110 | num_workers: *N_DATA_WORKERS 111 | predict: 112 | batch_size: &PREDICT_BATCH_SIZE ${oc.env:PREDICT_BATCH_SIZE, 64} 113 | num_workers: *N_DATA_WORKERS 114 | -------------------------------------------------------------------------------- /eval/configs/vision/pathology/offline/classification/crc.yaml: -------------------------------------------------------------------------------- 1 | --- 2 | trainer: 3 | class_path: eva.Trainer 4 | init_args: 5 | n_runs: &N_RUNS ${oc.env:N_RUNS, 5} 6 | default_root_dir: &OUTPUT_ROOT ${oc.env:OUTPUT_ROOT, logs/${oc.env:MODEL_NAME, dino_vits16}/offline/crc} 7 | max_steps: &MAX_STEPS ${oc.env:MAX_STEPS, 12500} 8 | checkpoint_type: ${oc.env:CHECKPOINT_TYPE, best} 9 | callbacks: 10 | - class_path: eva.callbacks.ConfigurationLogger 11 | - class_path: lightning.pytorch.callbacks.TQDMProgressBar 12 | init_args: 13 | refresh_rate: ${oc.env:TQDM_REFRESH_RATE, 1} 14 | - class_path: lightning.pytorch.callbacks.LearningRateMonitor 15 | init_args: 16 | logging_interval: epoch 17 | - class_path: lightning.pytorch.callbacks.ModelCheckpoint 18 | init_args: 19 | filename: best 20 | save_last: true 21 | save_top_k: 1 22 | monitor: &MONITOR_METRIC ${oc.env:MONITOR_METRIC, val/MulticlassAccuracy} 23 | mode: &MONITOR_METRIC_MODE ${oc.env:MONITOR_METRIC_MODE, max} 24 | - class_path: lightning.pytorch.callbacks.EarlyStopping 25 | init_args: 26 | min_delta: 0 27 | patience: ${oc.env:PATIENCE, 24} 28 | monitor: *MONITOR_METRIC 29 | mode: *MONITOR_METRIC_MODE 30 | - class_path: eva.callbacks.ClassificationEmbeddingsWriter 31 | init_args: 32 | output_dir: &DATASET_EMBEDDINGS_ROOT ${oc.env:EMBEDDINGS_ROOT, ./data/embeddings}/${oc.env:MODEL_NAME, dino_vits16}/crc 33 | dataloader_idx_map: 34 | 0: train 35 | 1: val 36 | backbone: 37 | class_path: eva.vision.models.ModelFromRegistry 38 | init_args: 39 | model_name: ${oc.env:MODEL_NAME, universal/vit_small_patch16_224_dino} 40 | model_extra_kwargs: ${oc.env:MODEL_EXTRA_KWARGS, null} 41 | overwrite: false 42 | logger: 43 | - class_path: lightning.pytorch.loggers.TensorBoardLogger 44 | init_args: 45 | save_dir: *OUTPUT_ROOT 46 | name: "" 47 | model: 48 | class_path: eva.HeadModule 49 | init_args: 50 | head: 51 | class_path: torch.nn.Linear 52 | init_args: 53 | in_features: ${oc.env:IN_FEATURES, 384} 54 | out_features: &NUM_CLASSES 9 55 | criterion: torch.nn.CrossEntropyLoss 56 | optimizer: 57 | class_path: torch.optim.AdamW 58 | init_args: 59 | lr: ${oc.env:LR_VALUE, 0.0003} 60 | lr_scheduler: 61 | class_path: torch.optim.lr_scheduler.CosineAnnealingLR 62 | init_args: 63 | T_max: *MAX_STEPS 64 | eta_min: 0.0 65 | metrics: 66 | common: 67 | - class_path: eva.metrics.AverageLoss 68 | - class_path: eva.metrics.MulticlassClassificationMetrics 69 | init_args: 70 | num_classes: *NUM_CLASSES 71 | data: 72 | class_path: eva.DataModule 73 | init_args: 74 | datasets: 75 | train: 76 | class_path: eva.datasets.EmbeddingsClassificationDataset 77 | init_args: &DATASET_ARGS 78 | root: *DATASET_EMBEDDINGS_ROOT 79 | manifest_file: manifest.csv 80 | split: train 81 | val: 82 | class_path: eva.datasets.EmbeddingsClassificationDataset 83 | init_args: 84 | <<: *DATASET_ARGS 85 | split: val 86 | predict: 87 | - class_path: eva.vision.datasets.CRC 88 | init_args: &PREDICT_DATASET_ARGS 89 | root: ${oc.env:DATA_ROOT, ./data/crc} 90 | split: train 91 | download: ${oc.env:DOWNLOAD_DATA, false} 92 | # Set `download: true` to download the dataset from https://zenodo.org/records/1214456 93 | # The CRC dataset is distributed under the following license: "CC BY 4.0 LEGAL CODE" 94 | # (see: https://creativecommons.org/licenses/by/4.0/legalcode) 95 | transforms: 96 | class_path: eva.vision.data.transforms.common.ResizeAndCrop 97 | init_args: 98 | size: ${oc.env:RESIZE_DIM, 224} 99 | mean: ${oc.env:NORMALIZE_MEAN, [0.485, 0.456, 0.406]} 100 | std: ${oc.env:NORMALIZE_STD, [0.229, 0.224, 0.225]} 101 | - class_path: eva.vision.datasets.CRC 102 | init_args: 103 | <<: *PREDICT_DATASET_ARGS 104 | split: val 105 | dataloaders: 106 | train: 107 | batch_size: &BATCH_SIZE ${oc.env:BATCH_SIZE, 256} 108 | num_workers: &N_DATA_WORKERS ${oc.env:N_DATA_WORKERS, 4} 109 | shuffle: true 110 | val: 111 | batch_size: *BATCH_SIZE 112 | num_workers: *N_DATA_WORKERS 113 | predict: 114 | batch_size: &PREDICT_BATCH_SIZE ${oc.env:PREDICT_BATCH_SIZE, 64} 115 | num_workers: *N_DATA_WORKERS 116 | -------------------------------------------------------------------------------- /eval/configs/vision/pathology/offline/classification/breakhis.yaml: -------------------------------------------------------------------------------- 1 | --- 2 | trainer: 3 | class_path: eva.Trainer 4 | init_args: 5 | n_runs: &N_RUNS ${oc.env:N_RUNS, 5} 6 | default_root_dir: &OUTPUT_ROOT ${oc.env:OUTPUT_ROOT, logs/${oc.env:MODEL_NAME, dino_vits16}/offline/breakhis} 7 | max_steps: &MAX_STEPS ${oc.env:MAX_STEPS, 12500} 8 | checkpoint_type: ${oc.env:CHECKPOINT_TYPE, best} 9 | callbacks: 10 | - class_path: eva.callbacks.ConfigurationLogger 11 | - class_path: lightning.pytorch.callbacks.TQDMProgressBar 12 | init_args: 13 | refresh_rate: ${oc.env:TQDM_REFRESH_RATE, 1} 14 | - class_path: lightning.pytorch.callbacks.LearningRateMonitor 15 | init_args: 16 | logging_interval: epoch 17 | - class_path: lightning.pytorch.callbacks.ModelCheckpoint 18 | init_args: 19 | filename: best 20 | save_last: true 21 | save_top_k: 1 22 | monitor: &MONITOR_METRIC ${oc.env:MONITOR_METRIC, val/MulticlassAccuracy} 23 | mode: &MONITOR_METRIC_MODE ${oc.env:MONITOR_METRIC_MODE, max} 24 | - class_path: lightning.pytorch.callbacks.EarlyStopping 25 | init_args: 26 | min_delta: 0 27 | patience: ${oc.env:PATIENCE, 105} 28 | monitor: *MONITOR_METRIC 29 | mode: *MONITOR_METRIC_MODE 30 | - class_path: eva.callbacks.ClassificationEmbeddingsWriter 31 | init_args: 32 | output_dir: &DATASET_EMBEDDINGS_ROOT ${oc.env:EMBEDDINGS_ROOT, ./data/embeddings}/${oc.env:MODEL_NAME, dino_vits16}/breakhis 33 | dataloader_idx_map: 34 | 0: train 35 | 1: val 36 | backbone: 37 | class_path: eva.vision.models.ModelFromRegistry 38 | init_args: 39 | model_name: ${oc.env:MODEL_NAME, universal/vit_small_patch16_224_dino} 40 | model_extra_kwargs: ${oc.env:MODEL_EXTRA_KWARGS, null} 41 | overwrite: false 42 | logger: 43 | - class_path: lightning.pytorch.loggers.TensorBoardLogger 44 | init_args: 45 | save_dir: *OUTPUT_ROOT 46 | name: "" 47 | model: 48 | class_path: eva.HeadModule 49 | init_args: 50 | head: 51 | class_path: torch.nn.Linear 52 | init_args: 53 | in_features: ${oc.env:IN_FEATURES, 384} 54 | out_features: &NUM_CLASSES 4 55 | criterion: torch.nn.CrossEntropyLoss 56 | optimizer: 57 | class_path: torch.optim.AdamW 58 | init_args: 59 | lr: ${oc.env:LR_VALUE, 0.0003} 60 | lr_scheduler: 61 | class_path: torch.optim.lr_scheduler.CosineAnnealingLR 62 | init_args: 63 | T_max: *MAX_STEPS 64 | eta_min: 0.0 65 | metrics: 66 | common: 67 | - class_path: eva.metrics.AverageLoss 68 | - class_path: eva.metrics.MulticlassClassificationMetrics 69 | init_args: 70 | num_classes: *NUM_CLASSES 71 | data: 72 | class_path: eva.DataModule 73 | init_args: 74 | datasets: 75 | train: 76 | class_path: eva.datasets.EmbeddingsClassificationDataset 77 | init_args: &DATASET_ARGS 78 | root: *DATASET_EMBEDDINGS_ROOT 79 | manifest_file: manifest.csv 80 | split: train 81 | val: 82 | class_path: eva.datasets.EmbeddingsClassificationDataset 83 | init_args: 84 | <<: *DATASET_ARGS 85 | split: val 86 | predict: 87 | - class_path: eva.vision.datasets.BreaKHis 88 | init_args: &PREDICT_DATASET_ARGS 89 | root: ${oc.env:DATA_ROOT, ./data/breakhis} 90 | split: train 91 | download: ${oc.env:DOWNLOAD_DATA, false} 92 | # Set `download: true` to download the dataset from https://zenodo.org/records/1214456 93 | # The BreaKHis dataset is distributed under the following license: "CC BY 4.0" 94 | # (see: https://creativecommons.org/licenses/by/4.0/) 95 | transforms: 96 | class_path: eva.vision.data.transforms.common.ResizeAndCrop 97 | init_args: 98 | size: ${oc.env:RESIZE_DIM, 224} 99 | mean: ${oc.env:NORMALIZE_MEAN, [0.485, 0.456, 0.406]} 100 | std: ${oc.env:NORMALIZE_STD, [0.229, 0.224, 0.225]} 101 | - class_path: eva.vision.datasets.BreaKHis 102 | init_args: 103 | <<: *PREDICT_DATASET_ARGS 104 | split: val 105 | dataloaders: 106 | train: 107 | batch_size: &BATCH_SIZE ${oc.env:BATCH_SIZE, 256} 108 | num_workers: &N_DATA_WORKERS ${oc.env:N_DATA_WORKERS, 4} 109 | shuffle: true 110 | val: 111 | batch_size: *BATCH_SIZE 112 | num_workers: *N_DATA_WORKERS 113 | predict: 114 | batch_size: &PREDICT_BATCH_SIZE ${oc.env:PREDICT_BATCH_SIZE, 64} 115 | num_workers: *N_DATA_WORKERS 116 | -------------------------------------------------------------------------------- /eval/configs/vision/pathology/online/segmentation/consep.yaml: -------------------------------------------------------------------------------- 1 | --- 2 | trainer: 3 | class_path: eva.Trainer 4 | init_args: 5 | n_runs: &N_RUNS ${oc.env:N_RUNS, 5} 6 | default_root_dir: &OUTPUT_ROOT ${oc.env:OUTPUT_ROOT, logs/${oc.env:MODEL_NAME, vit_small_patch16_224_dino}/consep} 7 | max_steps: &MAX_STEPS ${oc.env:MAX_STEPS, 2000} 8 | log_every_n_steps: 6 9 | checkpoint_type: ${oc.env:CHECKPOINT_TYPE, best} 10 | callbacks: 11 | - class_path: eva.callbacks.ConfigurationLogger 12 | - class_path: lightning.pytorch.callbacks.TQDMProgressBar 13 | init_args: 14 | refresh_rate: ${oc.env:TQDM_REFRESH_RATE, 1} 15 | - class_path: eva.vision.callbacks.SemanticSegmentationLogger 16 | init_args: 17 | log_every_n_epochs: 1 18 | mean: &NORMALIZE_MEAN ${oc.env:NORMALIZE_MEAN, [0.485, 0.456, 0.406]} 19 | std: &NORMALIZE_STD ${oc.env:NORMALIZE_STD, [0.229, 0.224, 0.225]} 20 | - class_path: lightning.pytorch.callbacks.ModelCheckpoint 21 | init_args: 22 | filename: best 23 | save_last: true 24 | save_top_k: 1 25 | monitor: &MONITOR_METRIC ${oc.env:MONITOR_METRIC, 'val/MonaiDiceScore'} 26 | mode: &MONITOR_METRIC_MODE ${oc.env:MONITOR_METRIC_MODE, max} 27 | - class_path: lightning.pytorch.callbacks.EarlyStopping 28 | init_args: 29 | min_delta: 0 30 | patience: ${oc.env:PATIENCE, 34} 31 | monitor: *MONITOR_METRIC 32 | mode: *MONITOR_METRIC_MODE 33 | logger: 34 | - class_path: lightning.pytorch.loggers.TensorBoardLogger 35 | init_args: 36 | save_dir: *OUTPUT_ROOT 37 | name: "" 38 | model: 39 | class_path: eva.vision.models.modules.SemanticSegmentationModule 40 | init_args: 41 | encoder: 42 | class_path: eva.vision.models.ModelFromRegistry 43 | init_args: 44 | model_name: ${oc.env:MODEL_NAME, universal/vit_small_patch16_224_dino} 45 | model_kwargs: 46 | out_indices: ${oc.env:OUT_INDICES, 1} 47 | model_extra_kwargs: ${oc.env:MODEL_EXTRA_KWARGS, null} 48 | decoder: 49 | class_path: eva.vision.models.networks.decoders.segmentation.ConvDecoderWithImage 50 | init_args: 51 | in_features: ${oc.env:IN_FEATURES, 384} 52 | num_classes: &NUM_CLASSES 5 53 | criterion: 54 | class_path: eva.vision.losses.DiceLoss 55 | init_args: 56 | softmax: true 57 | batch: true 58 | lr_multiplier_encoder: 0.0 59 | optimizer: 60 | class_path: torch.optim.AdamW 61 | init_args: 62 | lr: ${oc.env:LR_VALUE, 0.002} 63 | lr_scheduler: 64 | class_path: torch.optim.lr_scheduler.PolynomialLR 65 | init_args: 66 | total_iters: *MAX_STEPS 67 | power: 0.9 68 | postprocess: 69 | predictions_transforms: 70 | - class_path: torch.argmax 71 | init_args: 72 | dim: 1 73 | metrics: 74 | common: 75 | - class_path: eva.metrics.AverageLoss 76 | evaluation: 77 | - class_path: eva.vision.metrics.defaults.MulticlassSegmentationMetrics 78 | init_args: 79 | num_classes: *NUM_CLASSES 80 | - class_path: torchmetrics.ClasswiseWrapper 81 | init_args: 82 | metric: 83 | class_path: eva.vision.metrics.MonaiDiceScore 84 | init_args: 85 | include_background: true 86 | num_classes: *NUM_CLASSES 87 | reduction: none 88 | labels: 89 | - background 90 | - other 91 | - inflammatory 92 | - epithelial 93 | - spindle-shaped 94 | data: 95 | class_path: eva.DataModule 96 | init_args: 97 | datasets: 98 | train: 99 | class_path: eva.vision.datasets.CoNSeP 100 | init_args: &DATASET_ARGS 101 | root: ${oc.env:DATA_ROOT, ./data/consep} 102 | split: train 103 | sampler: eva.vision.data.wsi.patching.samplers.GridSampler 104 | width: 250 105 | height: 250 106 | target_mpp: 0.25 107 | transforms: 108 | class_path: eva.vision.data.transforms.common.ResizeAndCrop 109 | init_args: 110 | size: ${oc.env:RESIZE_DIM, 224} 111 | mean: *NORMALIZE_MEAN 112 | std: *NORMALIZE_STD 113 | val: 114 | class_path: eva.vision.datasets.CoNSeP 115 | init_args: 116 | <<: *DATASET_ARGS 117 | split: val 118 | dataloaders: 119 | train: 120 | batch_size: &BATCH_SIZE ${oc.env:BATCH_SIZE, 64} 121 | num_workers: &N_DATA_WORKERS ${oc.env:N_DATA_WORKERS, 4} 122 | shuffle: true 123 | val: 124 | batch_size: *BATCH_SIZE 125 | num_workers: *N_DATA_WORKERS 126 | -------------------------------------------------------------------------------- /eval/configs/vision/pathology/offline/classification/bach.yaml: -------------------------------------------------------------------------------- 1 | --- 2 | trainer: 3 | class_path: eva.Trainer 4 | init_args: 5 | n_runs: &N_RUNS ${oc.env:N_RUNS, 5} 6 | default_root_dir: &OUTPUT_ROOT ${oc.env:OUTPUT_ROOT, logs/${oc.env:MODEL_NAME, dino_vits16}/offline/bach} 7 | max_steps: &MAX_STEPS ${oc.env:MAX_STEPS, 12500} 8 | checkpoint_type: ${oc.env:CHECKPOINT_TYPE, best} 9 | callbacks: 10 | - class_path: eva.callbacks.ConfigurationLogger 11 | - class_path: lightning.pytorch.callbacks.TQDMProgressBar 12 | init_args: 13 | refresh_rate: ${oc.env:TQDM_REFRESH_RATE, 1} 14 | - class_path: lightning.pytorch.callbacks.LearningRateMonitor 15 | init_args: 16 | logging_interval: epoch 17 | - class_path: lightning.pytorch.callbacks.ModelCheckpoint 18 | init_args: 19 | filename: best 20 | save_last: true 21 | save_top_k: 1 22 | monitor: &MONITOR_METRIC ${oc.env:MONITOR_METRIC, val/MulticlassAccuracy} 23 | mode: &MONITOR_METRIC_MODE ${oc.env:MONITOR_METRIC_MODE, max} 24 | - class_path: lightning.pytorch.callbacks.EarlyStopping 25 | init_args: 26 | min_delta: 0 27 | patience: ${oc.env:PATIENCE, 400} 28 | monitor: *MONITOR_METRIC 29 | mode: *MONITOR_METRIC_MODE 30 | - class_path: eva.callbacks.ClassificationEmbeddingsWriter 31 | init_args: 32 | output_dir: &DATASET_EMBEDDINGS_ROOT ${oc.env:EMBEDDINGS_ROOT, ./data/embeddings}/${oc.env:MODEL_NAME, dino_vits16}/bach 33 | dataloader_idx_map: 34 | 0: train 35 | 1: val 36 | backbone: 37 | class_path: eva.vision.models.ModelFromRegistry 38 | init_args: 39 | model_name: ${oc.env:MODEL_NAME, universal/vit_small_patch16_224_dino} 40 | model_extra_kwargs: ${oc.env:MODEL_EXTRA_KWARGS, null} 41 | overwrite: false 42 | logger: 43 | - class_path: lightning.pytorch.loggers.TensorBoardLogger 44 | init_args: 45 | save_dir: *OUTPUT_ROOT 46 | name: "" 47 | model: 48 | class_path: eva.HeadModule 49 | init_args: 50 | head: 51 | class_path: torch.nn.Linear 52 | init_args: 53 | in_features: ${oc.env:IN_FEATURES, 384} 54 | out_features: &NUM_CLASSES 4 55 | criterion: torch.nn.CrossEntropyLoss 56 | optimizer: 57 | class_path: torch.optim.AdamW 58 | init_args: 59 | lr: ${oc.env:LR_VALUE, 0.0003} 60 | lr_scheduler: 61 | class_path: torch.optim.lr_scheduler.CosineAnnealingLR 62 | init_args: 63 | T_max: *MAX_STEPS 64 | eta_min: 0.0 65 | metrics: 66 | common: 67 | - class_path: eva.metrics.AverageLoss 68 | - class_path: eva.metrics.MulticlassClassificationMetrics 69 | init_args: 70 | num_classes: *NUM_CLASSES 71 | data: 72 | class_path: eva.DataModule 73 | init_args: 74 | datasets: 75 | train: 76 | class_path: eva.datasets.EmbeddingsClassificationDataset 77 | init_args: &DATASET_ARGS 78 | root: *DATASET_EMBEDDINGS_ROOT 79 | manifest_file: manifest.csv 80 | split: train 81 | val: 82 | class_path: eva.datasets.EmbeddingsClassificationDataset 83 | init_args: 84 | <<: *DATASET_ARGS 85 | split: val 86 | predict: 87 | - class_path: eva.vision.datasets.BACH 88 | init_args: &PREDICT_DATASET_ARGS 89 | root: ${oc.env:DATA_ROOT, ./data/bach} 90 | split: train 91 | download: ${oc.env:DOWNLOAD_DATA, false} 92 | # Set `download: true` to download the dataset from https://zenodo.org/records/3632035 93 | # The BACH dataset is distributed under the following license 94 | # Attribution-NonCommercial-NoDerivs 4.0 International license 95 | # (see: https://creativecommons.org/licenses/by-nc-nd/4.0/legalcode) 96 | transforms: 97 | class_path: eva.vision.data.transforms.common.ResizeAndCrop 98 | init_args: 99 | size: ${oc.env:RESIZE_DIM, 224} 100 | mean: ${oc.env:NORMALIZE_MEAN, [0.485, 0.456, 0.406]} 101 | std: ${oc.env:NORMALIZE_STD, [0.229, 0.224, 0.225]} 102 | - class_path: eva.vision.datasets.BACH 103 | init_args: 104 | <<: *PREDICT_DATASET_ARGS 105 | split: val 106 | dataloaders: 107 | train: 108 | batch_size: &BATCH_SIZE ${oc.env:BATCH_SIZE, 256} 109 | num_workers: &N_DATA_WORKERS ${oc.env:N_DATA_WORKERS, 4} 110 | shuffle: true 111 | val: 112 | batch_size: *BATCH_SIZE 113 | num_workers: *N_DATA_WORKERS 114 | predict: 115 | batch_size: &PREDICT_BATCH_SIZE ${oc.env:PREDICT_BATCH_SIZE, 64} 116 | num_workers: *N_DATA_WORKERS 117 | -------------------------------------------------------------------------------- /eval/configs/vision/pathology/offline/classification/bracs.yaml: -------------------------------------------------------------------------------- 1 | --- 2 | trainer: 3 | class_path: eva.Trainer 4 | init_args: 5 | n_runs: &N_RUNS ${oc.env:N_RUNS, 5} 6 | default_root_dir: &OUTPUT_ROOT ${oc.env:OUTPUT_ROOT, logs/${oc.env:MODEL_NAME, dino_vits16}/offline/bracs} 7 | max_steps: &MAX_STEPS ${oc.env:MAX_STEPS, 12500} 8 | checkpoint_type: ${oc.env:CHECKPOINT_TYPE, best} 9 | callbacks: 10 | - class_path: eva.callbacks.ConfigurationLogger 11 | - class_path: lightning.pytorch.callbacks.TQDMProgressBar 12 | init_args: 13 | refresh_rate: ${oc.env:TQDM_REFRESH_RATE, 1} 14 | - class_path: lightning.pytorch.callbacks.LearningRateMonitor 15 | init_args: 16 | logging_interval: epoch 17 | - class_path: lightning.pytorch.callbacks.ModelCheckpoint 18 | init_args: 19 | filename: best 20 | save_last: true 21 | save_top_k: 1 22 | monitor: &MONITOR_METRIC ${oc.env:MONITOR_METRIC, val/MulticlassAccuracy} 23 | mode: &MONITOR_METRIC_MODE ${oc.env:MONITOR_METRIC_MODE, max} 24 | - class_path: lightning.pytorch.callbacks.EarlyStopping 25 | init_args: 26 | min_delta: 0 27 | patience: ${oc.env:PATIENCE, 74} 28 | monitor: *MONITOR_METRIC 29 | mode: *MONITOR_METRIC_MODE 30 | - class_path: eva.callbacks.ClassificationEmbeddingsWriter 31 | init_args: 32 | output_dir: &DATASET_EMBEDDINGS_ROOT ${oc.env:EMBEDDINGS_ROOT, ./data/embeddings}/${oc.env:MODEL_NAME, dino_vits16}/bracs 33 | dataloader_idx_map: 34 | 0: train 35 | 1: val 36 | 2: test 37 | backbone: 38 | class_path: eva.vision.models.ModelFromRegistry 39 | init_args: 40 | model_name: ${oc.env:MODEL_NAME, universal/vit_small_patch16_224_dino} 41 | model_extra_kwargs: ${oc.env:MODEL_EXTRA_KWARGS, null} 42 | overwrite: false 43 | logger: 44 | - class_path: lightning.pytorch.loggers.TensorBoardLogger 45 | init_args: 46 | save_dir: *OUTPUT_ROOT 47 | name: "" 48 | model: 49 | class_path: eva.HeadModule 50 | init_args: 51 | head: 52 | class_path: torch.nn.Linear 53 | init_args: 54 | in_features: ${oc.env:IN_FEATURES, 384} 55 | out_features: &NUM_CLASSES 7 56 | criterion: torch.nn.CrossEntropyLoss 57 | optimizer: 58 | class_path: torch.optim.AdamW 59 | init_args: 60 | lr: ${oc.env:LR_VALUE, 0.0003} 61 | lr_scheduler: 62 | class_path: torch.optim.lr_scheduler.CosineAnnealingLR 63 | init_args: 64 | T_max: *MAX_STEPS 65 | eta_min: 0.0 66 | metrics: 67 | common: 68 | - class_path: eva.metrics.AverageLoss 69 | - class_path: eva.metrics.MulticlassClassificationMetrics 70 | init_args: 71 | num_classes: *NUM_CLASSES 72 | data: 73 | class_path: eva.DataModule 74 | init_args: 75 | datasets: 76 | train: 77 | class_path: eva.datasets.EmbeddingsClassificationDataset 78 | init_args: &DATASET_ARGS 79 | root: *DATASET_EMBEDDINGS_ROOT 80 | manifest_file: manifest.csv 81 | split: train 82 | val: 83 | class_path: eva.datasets.EmbeddingsClassificationDataset 84 | init_args: 85 | <<: *DATASET_ARGS 86 | split: val 87 | test: 88 | class_path: eva.datasets.EmbeddingsClassificationDataset 89 | init_args: 90 | <<: *DATASET_ARGS 91 | split: test 92 | predict: 93 | - class_path: eva.vision.datasets.BRACS 94 | init_args: &PREDICT_DATASET_ARGS 95 | root: ${oc.env:DATA_ROOT, ./data/bracs} 96 | split: train 97 | transforms: 98 | class_path: eva.vision.data.transforms.common.ResizeAndCrop 99 | init_args: 100 | size: ${oc.env:RESIZE_DIM, 224} 101 | mean: ${oc.env:NORMALIZE_MEAN, [0.485, 0.456, 0.406]} 102 | std: ${oc.env:NORMALIZE_STD, [0.229, 0.224, 0.225]} 103 | - class_path: eva.vision.datasets.BRACS 104 | init_args: 105 | <<: *PREDICT_DATASET_ARGS 106 | split: val 107 | - class_path: eva.vision.datasets.BRACS 108 | init_args: 109 | <<: *PREDICT_DATASET_ARGS 110 | split: test 111 | dataloaders: 112 | train: 113 | batch_size: &BATCH_SIZE ${oc.env:BATCH_SIZE, 256} 114 | num_workers: &N_DATA_WORKERS ${oc.env:N_DATA_WORKERS, 4} 115 | shuffle: true 116 | val: 117 | batch_size: *BATCH_SIZE 118 | num_workers: *N_DATA_WORKERS 119 | test: 120 | batch_size: *BATCH_SIZE 121 | num_workers: *N_DATA_WORKERS 122 | predict: 123 | batch_size: &PREDICT_BATCH_SIZE ${oc.env:PREDICT_BATCH_SIZE, 64} 124 | num_workers: *N_DATA_WORKERS 125 | -------------------------------------------------------------------------------- /eval/configs/vision/pathology/offline/classification/patch_camelyon.yaml: -------------------------------------------------------------------------------- 1 | --- 2 | trainer: 3 | class_path: eva.Trainer 4 | init_args: 5 | n_runs: &N_RUNS ${oc.env:N_RUNS, 5} 6 | default_root_dir: &OUTPUT_ROOT ${oc.env:OUTPUT_ROOT, logs/${oc.env:MODEL_NAME, dino_vits16}/offline/patch_camelyon} 7 | max_steps: &MAX_STEPS ${oc.env:MAX_STEPS, 12500} 8 | checkpoint_type: ${oc.env:CHECKPOINT_TYPE, best} 9 | callbacks: 10 | - class_path: eva.callbacks.ConfigurationLogger 11 | - class_path: lightning.pytorch.callbacks.TQDMProgressBar 12 | init_args: 13 | refresh_rate: ${oc.env:TQDM_REFRESH_RATE, 1} 14 | - class_path: lightning.pytorch.callbacks.LearningRateMonitor 15 | init_args: 16 | logging_interval: epoch 17 | - class_path: lightning.pytorch.callbacks.ModelCheckpoint 18 | init_args: 19 | filename: best 20 | save_last: true 21 | save_top_k: 1 22 | monitor: &MONITOR_METRIC ${oc.env:MONITOR_METRIC, val/BinaryBalancedAccuracy} 23 | mode: &MONITOR_METRIC_MODE ${oc.env:MONITOR_METRIC_MODE, max} 24 | - class_path: lightning.pytorch.callbacks.EarlyStopping 25 | init_args: 26 | min_delta: 0 27 | patience: ${oc.env:PATIENCE, 9} 28 | monitor: *MONITOR_METRIC 29 | mode: *MONITOR_METRIC_MODE 30 | - class_path: eva.callbacks.ClassificationEmbeddingsWriter 31 | init_args: 32 | output_dir: &DATASET_EMBEDDINGS_ROOT ${oc.env:EMBEDDINGS_ROOT, ./data/embeddings}/${oc.env:MODEL_NAME, dino_vits16}/patch_camelyon 33 | dataloader_idx_map: 34 | 0: train 35 | 1: val 36 | 2: test 37 | backbone: 38 | class_path: eva.vision.models.ModelFromRegistry 39 | init_args: 40 | model_name: ${oc.env:MODEL_NAME, universal/vit_small_patch16_224_dino} 41 | model_extra_kwargs: ${oc.env:MODEL_EXTRA_KWARGS, null} 42 | overwrite: false 43 | logger: 44 | - class_path: lightning.pytorch.loggers.TensorBoardLogger 45 | init_args: 46 | save_dir: *OUTPUT_ROOT 47 | name: "" 48 | model: 49 | class_path: eva.HeadModule 50 | init_args: 51 | head: 52 | class_path: torch.nn.Linear 53 | init_args: 54 | in_features: ${oc.env:IN_FEATURES, 384} 55 | out_features: 1 56 | criterion: torch.nn.BCEWithLogitsLoss 57 | optimizer: 58 | class_path: torch.optim.AdamW 59 | init_args: 60 | lr: ${oc.env:LR_VALUE, 0.0003} 61 | lr_scheduler: 62 | class_path: torch.optim.lr_scheduler.CosineAnnealingLR 63 | init_args: 64 | T_max: *MAX_STEPS 65 | eta_min: 0.0 66 | metrics: 67 | common: 68 | - class_path: eva.metrics.AverageLoss 69 | - class_path: eva.metrics.BinaryClassificationMetrics 70 | data: 71 | class_path: eva.DataModule 72 | init_args: 73 | datasets: 74 | train: 75 | class_path: eva.datasets.EmbeddingsClassificationDataset 76 | init_args: &DATASET_ARGS 77 | root: *DATASET_EMBEDDINGS_ROOT 78 | manifest_file: manifest.csv 79 | split: train 80 | target_transforms: 81 | class_path: torchvision.transforms.v2.ToDtype 82 | init_args: 83 | dtype: torch.float32 84 | val: 85 | class_path: eva.datasets.EmbeddingsClassificationDataset 86 | init_args: 87 | <<: *DATASET_ARGS 88 | split: val 89 | test: 90 | class_path: eva.datasets.EmbeddingsClassificationDataset 91 | init_args: 92 | <<: *DATASET_ARGS 93 | split: test 94 | predict: 95 | - class_path: eva.vision.datasets.PatchCamelyon 96 | init_args: &PREDICT_DATASET_ARGS 97 | root: ${oc.env:DATA_ROOT, ./data/patch_camelyon} 98 | split: train 99 | download: ${oc.env:DOWNLOAD_DATA, false} 100 | # Set `download: true` to download the dataset from https://zenodo.org/records/1494286 101 | # The PatchCamelyon dataset is distributed under the following license: 102 | # "Creative Commons Zero v1.0 Universal" 103 | # (see: https://choosealicense.com/licenses/cc0-1.0/) 104 | transforms: 105 | class_path: eva.vision.data.transforms.common.ResizeAndCrop 106 | init_args: 107 | size: ${oc.env:RESIZE_DIM, 224} 108 | mean: ${oc.env:NORMALIZE_MEAN, [0.485, 0.456, 0.406]} 109 | std: ${oc.env:NORMALIZE_STD, [0.229, 0.224, 0.225]} 110 | - class_path: eva.vision.datasets.PatchCamelyon 111 | init_args: 112 | <<: *PREDICT_DATASET_ARGS 113 | split: val 114 | - class_path: eva.vision.datasets.PatchCamelyon 115 | init_args: 116 | <<: *PREDICT_DATASET_ARGS 117 | split: test 118 | dataloaders: 119 | train: 120 | batch_size: &BATCH_SIZE ${oc.env:BATCH_SIZE, 256} 121 | num_workers: &N_DATA_WORKERS ${oc.env:N_DATA_WORKERS, 4} 122 | shuffle: true 123 | val: 124 | batch_size: 4096 125 | num_workers: *N_DATA_WORKERS 126 | test: 127 | batch_size: 4096 128 | num_workers: *N_DATA_WORKERS 129 | predict: 130 | batch_size: &PREDICT_BATCH_SIZE ${oc.env:PREDICT_BATCH_SIZE, 512} 131 | num_workers: *N_DATA_WORKERS 132 | -------------------------------------------------------------------------------- /eval/configs/vision/pathology/offline/classification/pcam_10shots.yaml: -------------------------------------------------------------------------------- 1 | --- 2 | trainer: 3 | class_path: eva.Trainer 4 | init_args: 5 | n_runs: &N_RUNS ${oc.env:N_RUNS, 50} 6 | default_root_dir: &OUTPUT_ROOT ${oc.env:OUTPUT_ROOT, logs/${oc.env:MODEL_NAME, dino_vits16}/offline/patch_camelyon} 7 | max_steps: &MAX_STEPS ${oc.env:MAX_STEPS, 12500} 8 | num_sanity_val_steps: 0 9 | check_val_every_n_epoch: 10 10 | checkpoint_type: ${oc.env:CHECKPOINT_TYPE, best} 11 | callbacks: 12 | - class_path: eva.callbacks.ConfigurationLogger 13 | - class_path: lightning.pytorch.callbacks.TQDMProgressBar 14 | init_args: 15 | refresh_rate: ${oc.env:TQDM_REFRESH_RATE, 1} 16 | - class_path: lightning.pytorch.callbacks.LearningRateMonitor 17 | init_args: 18 | logging_interval: epoch 19 | - class_path: lightning.pytorch.callbacks.ModelCheckpoint 20 | init_args: 21 | filename: best 22 | save_last: true 23 | save_top_k: 1 24 | monitor: &MONITOR_METRIC ${oc.env:MONITOR_METRIC, val/BinaryBalancedAccuracy} 25 | mode: &MONITOR_METRIC_MODE ${oc.env:MONITOR_METRIC_MODE, max} 26 | - class_path: lightning.pytorch.callbacks.EarlyStopping 27 | init_args: 28 | min_delta: 0 29 | patience: ${oc.env:PATIENCE, 9} 30 | monitor: *MONITOR_METRIC 31 | mode: *MONITOR_METRIC_MODE 32 | - class_path: eva.callbacks.ClassificationEmbeddingsWriter 33 | init_args: 34 | output_dir: &DATASET_EMBEDDINGS_ROOT ${oc.env:EMBEDDINGS_ROOT, ./data/embeddings}/${oc.env:MODEL_NAME, dino_vits16}/patch_camelyon 35 | dataloader_idx_map: 36 | 0: train 37 | 1: val 38 | 2: test 39 | backbone: 40 | class_path: eva.vision.models.ModelFromRegistry 41 | init_args: 42 | model_name: ${oc.env:MODEL_NAME, universal/vit_small_patch16_224_dino} 43 | model_extra_kwargs: ${oc.env:MODEL_EXTRA_KWARGS, null} 44 | overwrite: false 45 | logger: 46 | - class_path: lightning.pytorch.loggers.TensorBoardLogger 47 | init_args: 48 | save_dir: *OUTPUT_ROOT 49 | name: "" 50 | model: 51 | class_path: eva.HeadModule 52 | init_args: 53 | head: 54 | class_path: torch.nn.Linear 55 | init_args: 56 | in_features: ${oc.env:IN_FEATURES, 384} 57 | out_features: 1 58 | criterion: torch.nn.BCEWithLogitsLoss 59 | optimizer: 60 | class_path: torch.optim.AdamW 61 | init_args: 62 | lr: ${oc.env:LR_VALUE, 0.0003} 63 | lr_scheduler: 64 | class_path: torch.optim.lr_scheduler.CosineAnnealingLR 65 | init_args: 66 | T_max: *MAX_STEPS 67 | eta_min: 0.0 68 | metrics: 69 | common: 70 | - class_path: eva.metrics.AverageLoss 71 | - class_path: eva.metrics.BinaryClassificationMetrics 72 | data: 73 | class_path: eva.DataModule 74 | init_args: 75 | datasets: 76 | train: 77 | class_path: eva.datasets.EmbeddingsClassificationDataset 78 | init_args: &DATASET_ARGS 79 | root: *DATASET_EMBEDDINGS_ROOT 80 | manifest_file: manifest.csv 81 | split: train 82 | target_transforms: 83 | class_path: torchvision.transforms.v2.ToDtype 84 | init_args: 85 | dtype: torch.float32 86 | val: 87 | class_path: eva.datasets.EmbeddingsClassificationDataset 88 | init_args: 89 | <<: *DATASET_ARGS 90 | split: val 91 | test: 92 | class_path: eva.datasets.EmbeddingsClassificationDataset 93 | init_args: 94 | <<: *DATASET_ARGS 95 | split: test 96 | predict: 97 | - class_path: eva.vision.datasets.PatchCamelyon 98 | init_args: &PREDICT_DATASET_ARGS 99 | root: ${oc.env:DATA_ROOT, ./data/patch_camelyon} 100 | split: train 101 | download: ${oc.env:DOWNLOAD_DATA, false} 102 | # Set `download: true` to download the dataset from https://zenodo.org/records/1494286 103 | # The PatchCamelyon dataset is distributed under the following license: 104 | # "Creative Commons Zero v1.0 Universal" 105 | # (see: https://choosealicense.com/licenses/cc0-1.0/) 106 | transforms: 107 | class_path: eva.vision.data.transforms.common.ResizeAndCrop 108 | init_args: 109 | size: ${oc.env:RESIZE_DIM, 224} 110 | mean: ${oc.env:NORMALIZE_MEAN, [0.485, 0.456, 0.406]} 111 | std: ${oc.env:NORMALIZE_STD, [0.229, 0.224, 0.225]} 112 | - class_path: eva.vision.datasets.PatchCamelyon 113 | init_args: 114 | <<: *PREDICT_DATASET_ARGS 115 | split: val 116 | - class_path: eva.vision.datasets.PatchCamelyon 117 | init_args: 118 | <<: *PREDICT_DATASET_ARGS 119 | split: test 120 | dataloaders: 121 | train: 122 | batch_size: &BATCH_SIZE ${oc.env:BATCH_SIZE, 256} 123 | num_workers: &N_DATA_WORKERS ${oc.env:N_DATA_WORKERS, 4} 124 | shuffle: false 125 | val: 126 | batch_size: 4096 127 | num_workers: *N_DATA_WORKERS 128 | test: 129 | batch_size: 4096 130 | num_workers: *N_DATA_WORKERS 131 | predict: 132 | batch_size: &PREDICT_BATCH_SIZE ${oc.env:PREDICT_BATCH_SIZE, 512} 133 | num_workers: *N_DATA_WORKERS 134 | samplers: 135 | train: 136 | class_path: eva.core.data.samplers.classification.BalancedSampler 137 | init_args: 138 | num_samples: 10 139 | -------------------------------------------------------------------------------- /eval/object_tools.py: -------------------------------------------------------------------------------- 1 | """Python object related utilities and helpers.""" 2 | 3 | import copy 4 | import dataclasses 5 | import importlib 6 | from typing import Any, Union 7 | 8 | import torch 9 | from jsonargparse import ArgumentParser 10 | from lightning.fabric.utilities.cloud_io import _load as pl_load 11 | from lightning.fabric.utilities.cloud_io import get_filesystem 12 | from loguru import logger 13 | 14 | SERIALIZABLE_TYPES = Union[int, float, str, bool, None] # noqa: UP007 15 | SERIALIZABLE_TYPES_COMPOSITE = Union[ # noqa: UP007 16 | SERIALIZABLE_TYPES, 17 | list[Union[SERIALIZABLE_TYPES, "SERIALIZABLE_TYPES_COMPOSITE"]], 18 | dict[str, Union[SERIALIZABLE_TYPES, "SERIALIZABLE_TYPES_COMPOSITE"]], 19 | ] 20 | 21 | 22 | @dataclasses.dataclass 23 | class ObjectBuilder: 24 | """Helper dataclass which allows to initialize objects on command.""" 25 | 26 | path: str 27 | """The object path (class or function).""" 28 | 29 | arguments: dict[str, SERIALIZABLE_TYPES_COMPOSITE] | None = None 30 | """The initialization arguments of the object.""" 31 | 32 | def build(self) -> Any: 33 | """Initializes and returns the defined object.""" 34 | return _build_object_from_path( 35 | self.path, 36 | get_anyobject_jsonargparse(copy.deepcopy(self.arguments)) if self.arguments else None, 37 | ) 38 | 39 | def to_dict(self) -> dict[str, Any]: 40 | """Converts the object builder to a dictionary.""" 41 | return dataclasses.asdict(self) 42 | 43 | 44 | @dataclasses.dataclass 45 | class ModelBuilder(ObjectBuilder): 46 | """Helper dataclass for model initialization.""" 47 | 48 | ckpt_path: str | None = None 49 | ckpt_submodule: str | None = None 50 | 51 | def build(self) -> Any: 52 | """Initializes and returns the defined model.""" 53 | model = super().build() 54 | if self.ckpt_path is not None: 55 | model = load_model_checkpoint(model, self.ckpt_path, ckpt_submodule=self.ckpt_submodule) 56 | return model 57 | 58 | 59 | def _build_object_from_path(path: str, arguments: dict[str, Any] | None) -> Any: 60 | """Initializes and build an object from path. 61 | 62 | Args: 63 | path: The path to the object (class or function). 64 | arguments: The initialization arguments. Defaults to `None`. 65 | 66 | Returns: 67 | The path object. 68 | """ 69 | module_name, class_name = path.rsplit(".", 1) 70 | try: 71 | _module = importlib.import_module(module_name) 72 | try: 73 | _object = getattr(_module, class_name)(**arguments or {}) 74 | except AttributeError as err: 75 | raise AttributeError( 76 | f"Class `{class_name}` in `{module_name}` does not exist." 77 | ) from err 78 | except ImportError as err: 79 | raise ImportError(f"Module `{module_name}` does not exist.") from err 80 | return _object 81 | 82 | 83 | def get_anyobject_jsonargparse(conf_dict: dict[str, Any], expected_type=Any) -> Any: 84 | """Use jsonargparse to parse arbitrary object.""" 85 | parser = ArgumentParser() 86 | parser.add_argument("arg", type=expected_type) 87 | anyobject = parser.parse_object({"arg": conf_dict}) 88 | return parser.instantiate_classes(anyobject).arg 89 | 90 | 91 | def load_model_checkpoint( 92 | model: torch.nn.Module, 93 | checkpoint_path: str, 94 | strict: bool = False, 95 | ckpt_submodule: str | None = None, 96 | ) -> torch.nn.Module: 97 | """Initializes the model with the weights. 98 | 99 | Args: 100 | model: model to initialize. 101 | checkpoint_path: the path to the checkpoint. 102 | strict: if `True`, it loads the weights only if the dictionary matches the architecture 103 | exactly. if `False`, it loads the weights even if the weights of some layers 104 | are missing. 105 | ckpt_submodule: the submodule of the checkpoint for loading into the model. If `None`, load 106 | the entire checkpoint. Default: `None`. 107 | 108 | Returns: 109 | the model initialized with the checkpoint. 110 | """ 111 | logger.info(f"Loading {model.__class__.__name__} from checkpoint {checkpoint_path}") 112 | fs = get_filesystem(checkpoint_path, anon=False) 113 | with fs.open(checkpoint_path, "rb") as f: 114 | checkpoint = pl_load(f, map_location="cpu") # type: ignore[arg-type] 115 | if "state_dict" in checkpoint: 116 | checkpoint = checkpoint["state_dict"] 117 | if ckpt_submodule is not None: 118 | key = ckpt_submodule if ckpt_submodule.endswith(".") else ckpt_submodule + "." 119 | checkpoint = { 120 | m.removeprefix(key): w for m, w in checkpoint.items() if m.startswith(key) 121 | } 122 | out = model.load_state_dict(checkpoint, strict=strict) 123 | missing, unexpected = out.missing_keys, out.unexpected_keys 124 | keys = model.state_dict().keys() 125 | if len(missing): 126 | logger.warning( 127 | f"{len(missing)}/{len(keys)} modules are missing in the checkpoint and will not be " 128 | f"initialized: {missing}" 129 | ) 130 | if len(unexpected): 131 | logger.warning( 132 | f"The checkpoint also contains {len(unexpected)} modules ignored by the model: " 133 | f"{unexpected}" 134 | ) 135 | logger.info( 136 | f"Loaded {len(set(keys) - set(missing))}/{len(keys)} modules for " 137 | f"{model.__class__.__name__} from checkpoint {checkpoint_path}" 138 | ) 139 | return model 140 | -------------------------------------------------------------------------------- /eval/configs/vision/pathology/offline/classification/camelyon16_small.yaml: -------------------------------------------------------------------------------- 1 | --- 2 | trainer: 3 | class_path: eva.Trainer 4 | init_args: 5 | n_runs: &N_RUNS ${oc.env:N_RUNS, 5} 6 | default_root_dir: &OUTPUT_ROOT ${oc.env:OUTPUT_ROOT, logs/${oc.env:MODEL_NAME, dino_vits16}/offline/camelyon16} 7 | max_epochs: &MAX_EPOCHS ${oc.env:MAX_EPOCHS, 100} 8 | checkpoint_type: ${oc.env:CHECKPOINT_TYPE, best} 9 | callbacks: 10 | - class_path: eva.callbacks.ConfigurationLogger 11 | - class_path: lightning.pytorch.callbacks.TQDMProgressBar 12 | init_args: 13 | refresh_rate: ${oc.env:TQDM_REFRESH_RATE, 1} 14 | - class_path: lightning.pytorch.callbacks.LearningRateMonitor 15 | init_args: 16 | logging_interval: epoch 17 | - class_path: lightning.pytorch.callbacks.ModelCheckpoint 18 | init_args: 19 | filename: best 20 | save_last: true 21 | save_top_k: 1 22 | monitor: &MONITOR_METRIC ${oc.env:MONITOR_METRIC, val/BinaryBalancedAccuracy} 23 | mode: &MONITOR_METRIC_MODE ${oc.env:MONITOR_METRIC_MODE, max} 24 | - class_path: lightning.pytorch.callbacks.EarlyStopping 25 | init_args: 26 | min_delta: 0 27 | patience: ${oc.env:PATIENCE, 10} 28 | monitor: *MONITOR_METRIC 29 | mode: *MONITOR_METRIC_MODE 30 | - class_path: eva.callbacks.ClassificationEmbeddingsWriter 31 | init_args: 32 | output_dir: &DATASET_EMBEDDINGS_ROOT ${oc.env:EMBEDDINGS_ROOT, ./data/embeddings/${oc.env:MODEL_NAME, dino_vits16}/camelyon16} 33 | save_every_n: 10_000 34 | dataloader_idx_map: 35 | 0: train 36 | 1: val 37 | 2: test 38 | metadata_keys: ["wsi_id"] 39 | backbone: 40 | class_path: eva.vision.models.ModelFromRegistry 41 | init_args: 42 | model_name: ${oc.env:MODEL_NAME, universal/} 43 | model_extra_kwargs: ${oc.env:MODEL_EXTRA_KWARGS, null} 44 | overwrite: false 45 | logger: 46 | - class_path: lightning.pytorch.loggers.TensorBoardLogger 47 | init_args: 48 | save_dir: *OUTPUT_ROOT 49 | name: "" 50 | model: 51 | class_path: eva.HeadModule 52 | init_args: 53 | head: 54 | class_path: eva.vision.models.networks.ABMIL 55 | init_args: 56 | input_size: ${oc.env:IN_FEATURES, 384} 57 | output_size: &NUM_CLASSES 1 58 | projected_input_size: 128 59 | criterion: torch.nn.BCEWithLogitsLoss 60 | optimizer: 61 | class_path: torch.optim.AdamW 62 | init_args: 63 | lr: ${oc.env:LR_VALUE, 0.001} 64 | betas: [0.9, 0.999] 65 | lr_scheduler: 66 | class_path: torch.optim.lr_scheduler.CosineAnnealingLR 67 | init_args: 68 | T_max: *MAX_EPOCHS 69 | eta_min: 0.0 70 | metrics: 71 | common: 72 | - class_path: eva.metrics.AverageLoss 73 | - class_path: eva.metrics.BinaryClassificationMetrics 74 | data: 75 | class_path: eva.DataModule 76 | init_args: 77 | datasets: 78 | train: 79 | class_path: eva.datasets.MultiEmbeddingsClassificationDataset 80 | init_args: &DATASET_ARGS 81 | root: *DATASET_EMBEDDINGS_ROOT 82 | manifest_file: manifest.csv 83 | split: train 84 | embeddings_transforms: 85 | class_path: eva.core.data.transforms.Pad2DTensor 86 | init_args: 87 | pad_size: &N_PATCHES ${oc.env:N_PATCHES, 1000} 88 | target_transforms: 89 | class_path: eva.core.data.transforms.dtype.ArrayToFloatTensor 90 | val: 91 | class_path: eva.datasets.MultiEmbeddingsClassificationDataset 92 | init_args: 93 | <<: *DATASET_ARGS 94 | split: val 95 | test: 96 | class_path: eva.datasets.MultiEmbeddingsClassificationDataset 97 | init_args: 98 | <<: *DATASET_ARGS 99 | split: test 100 | predict: 101 | - class_path: eva.vision.datasets.Camelyon16 102 | init_args: &PREDICT_DATASET_ARGS 103 | root: ${oc.env:DATA_ROOT, ./data/camelyon16} 104 | sampler: 105 | class_path: eva.vision.data.wsi.patching.samplers.ForegroundGridSampler 106 | init_args: 107 | max_samples: *N_PATCHES 108 | width: 224 109 | height: 224 110 | target_mpp: 0.25 111 | split: train 112 | coords_path: ${data.init_args.datasets.train.init_args.root}/coords_${.split}.csv 113 | image_transforms: 114 | class_path: eva.vision.data.transforms.common.ResizeAndCrop 115 | init_args: 116 | size: ${oc.env:RESIZE_DIM, 224} 117 | mean: ${oc.env:NORMALIZE_MEAN, [0.485, 0.456, 0.406]} 118 | std: ${oc.env:NORMALIZE_STD, [0.229, 0.224, 0.225]} 119 | - class_path: eva.vision.datasets.Camelyon16 120 | init_args: 121 | <<: *PREDICT_DATASET_ARGS 122 | split: val 123 | - class_path: eva.vision.datasets.Camelyon16 124 | init_args: 125 | <<: *PREDICT_DATASET_ARGS 126 | split: test 127 | dataloaders: 128 | train: 129 | batch_size: &BATCH_SIZE ${oc.env:BATCH_SIZE, 32} 130 | num_workers: &N_DATA_WORKERS ${oc.env:N_DATA_WORKERS, 4} 131 | shuffle: true 132 | val: 133 | batch_size: *BATCH_SIZE 134 | num_workers: *N_DATA_WORKERS 135 | test: 136 | batch_size: *BATCH_SIZE 137 | num_workers: *N_DATA_WORKERS 138 | predict: 139 | batch_size: &PREDICT_BATCH_SIZE ${oc.env:PREDICT_BATCH_SIZE, 64} 140 | num_workers: *N_DATA_WORKERS 141 | -------------------------------------------------------------------------------- /eval/configs/vision/pathology/offline/classification/panda_small.yaml: -------------------------------------------------------------------------------- 1 | --- 2 | trainer: 3 | class_path: eva.Trainer 4 | init_args: 5 | n_runs: &N_RUNS ${oc.env:N_RUNS, 20} 6 | default_root_dir: &OUTPUT_ROOT ${oc.env:OUTPUT_ROOT, logs/${oc.env:MODEL_NAME, dino_vits16}/offline/panda} 7 | max_epochs: &MAX_EPOCHS ${oc.env:MAX_EPOCHS, 49} 8 | num_sanity_val_steps: 0 9 | checkpoint_type: ${oc.env:CHECKPOINT_TYPE, best} 10 | callbacks: 11 | - class_path: eva.callbacks.ConfigurationLogger 12 | - class_path: lightning.pytorch.callbacks.TQDMProgressBar 13 | init_args: 14 | refresh_rate: ${oc.env:TQDM_REFRESH_RATE, 1} 15 | - class_path: lightning.pytorch.callbacks.LearningRateMonitor 16 | init_args: 17 | logging_interval: epoch 18 | - class_path: lightning.pytorch.callbacks.ModelCheckpoint 19 | init_args: 20 | filename: best 21 | save_last: true 22 | save_top_k: 1 23 | monitor: &MONITOR_METRIC ${oc.env:MONITOR_METRIC, val/MulticlassAccuracy} 24 | mode: &MONITOR_METRIC_MODE ${oc.env:MONITOR_METRIC_MODE, max} 25 | - class_path: lightning.pytorch.callbacks.EarlyStopping 26 | init_args: 27 | min_delta: 0 28 | patience: ${oc.env:PATIENCE, 8} 29 | monitor: *MONITOR_METRIC 30 | mode: *MONITOR_METRIC_MODE 31 | - class_path: eva.callbacks.ClassificationEmbeddingsWriter 32 | init_args: 33 | output_dir: &DATASET_EMBEDDINGS_ROOT ${oc.env:EMBEDDINGS_ROOT, ./data/embeddings/${oc.env:MODEL_NAME, dino_vits16}/panda} 34 | dataloader_idx_map: 35 | 0: train 36 | 1: val 37 | 2: test 38 | metadata_keys: ["wsi_id"] 39 | backbone: 40 | class_path: eva.vision.models.ModelFromRegistry 41 | init_args: 42 | model_name: ${oc.env:MODEL_NAME, universal/vit_small_patch16_224_dino} 43 | model_extra_kwargs: ${oc.env:MODEL_EXTRA_KWARGS, null} 44 | overwrite: false 45 | logger: 46 | - class_path: lightning.pytorch.loggers.TensorBoardLogger 47 | init_args: 48 | save_dir: *OUTPUT_ROOT 49 | name: "" 50 | model: 51 | class_path: eva.HeadModule 52 | init_args: 53 | head: 54 | class_path: eva.vision.models.networks.ABMIL 55 | init_args: 56 | input_size: ${oc.env:IN_FEATURES, 384} 57 | output_size: &NUM_CLASSES 6 58 | projected_input_size: 128 59 | criterion: torch.nn.CrossEntropyLoss 60 | optimizer: 61 | class_path: torch.optim.AdamW 62 | init_args: 63 | lr: ${oc.env:LR_VALUE, 0.001} 64 | betas: [0.9, 0.999] 65 | lr_scheduler: 66 | class_path: torch.optim.lr_scheduler.CosineAnnealingLR 67 | init_args: 68 | T_max: *MAX_EPOCHS 69 | eta_min: 0.0 70 | metrics: 71 | common: 72 | - class_path: eva.metrics.AverageLoss 73 | - class_path: eva.metrics.MulticlassClassificationMetrics 74 | init_args: 75 | num_classes: *NUM_CLASSES 76 | data: 77 | class_path: eva.DataModule 78 | init_args: 79 | datasets: 80 | train: 81 | class_path: eva.datasets.MultiEmbeddingsClassificationDataset 82 | init_args: &DATASET_ARGS 83 | root: *DATASET_EMBEDDINGS_ROOT 84 | manifest_file: manifest.csv 85 | split: train 86 | embeddings_transforms: 87 | class_path: eva.core.data.transforms.Pad2DTensor 88 | init_args: 89 | pad_size: &N_PATCHES ${oc.env:N_PATCHES, 200} 90 | val: 91 | class_path: eva.datasets.MultiEmbeddingsClassificationDataset 92 | init_args: 93 | <<: *DATASET_ARGS 94 | split: val 95 | test: 96 | class_path: eva.datasets.MultiEmbeddingsClassificationDataset 97 | init_args: 98 | <<: *DATASET_ARGS 99 | split: test 100 | predict: 101 | - class_path: eva.vision.datasets.PANDASmall 102 | init_args: &PREDICT_DATASET_ARGS 103 | root: ${oc.env:DATA_ROOT, ./data/panda/prostate-cancer-grade-assessment} 104 | sampler: 105 | class_path: eva.vision.data.wsi.patching.samplers.ForegroundGridSampler 106 | init_args: 107 | max_samples: *N_PATCHES 108 | width: 448 109 | height: 448 110 | target_mpp: 0.25 # the original mpp is 0.486, so we extract 224x224@0.5 without downsampling 111 | split: train 112 | coords_path: ${data.init_args.datasets.train.init_args.root}/coords_${.split}.csv 113 | image_transforms: 114 | class_path: eva.vision.data.transforms.common.ResizeAndCrop 115 | init_args: 116 | size: ${oc.env:RESIZE_DIM, 224} 117 | mean: ${oc.env:NORMALIZE_MEAN, [0.485, 0.456, 0.406]} 118 | std: ${oc.env:NORMALIZE_STD, [0.229, 0.224, 0.225]} 119 | - class_path: eva.vision.datasets.PANDASmall 120 | init_args: 121 | <<: *PREDICT_DATASET_ARGS 122 | split: val 123 | - class_path: eva.vision.datasets.PANDASmall 124 | init_args: 125 | <<: *PREDICT_DATASET_ARGS 126 | split: test 127 | dataloaders: 128 | train: 129 | batch_size: &BATCH_SIZE ${oc.env:BATCH_SIZE, 32} 130 | num_workers: &N_DATA_WORKERS ${oc.env:N_DATA_WORKERS, 4} 131 | shuffle: true 132 | val: 133 | batch_size: 4096 134 | num_workers: *N_DATA_WORKERS 135 | test: 136 | batch_size: 4096 137 | num_workers: *N_DATA_WORKERS 138 | predict: 139 | batch_size: &PREDICT_BATCH_SIZE ${oc.env:PREDICT_BATCH_SIZE, 128} 140 | num_workers: *N_DATA_WORKERS 141 | -------------------------------------------------------------------------------- /eval/run_eva_internal.sh: -------------------------------------------------------------------------------- 1 | # Main benchmarks 2 | for MODEL_NAME in {vitg14_Kaiko_Midnight_concat,}; do 3 | for TASK in {pcam_10shots,camelyon16_small,panda_small,}; do 4 | # this is supposed to be run in a single-GPU job 5 | bash -c "\ 6 | rm -rf /dev/shm/mikhail/eva/EMBEDDINGS/${TASK}/${MODEL_NAME}; \ 7 | MODEL_NAME=$MODEL_NAME \ 8 | IN_FEATURES=$(python get_dim.py $MODEL_NAME 1) \ 9 | OUTPUT_ROOT=/pathology_fm/mikhail/data/eva/RESULTS_patience/${TASK}/${MODEL_NAME} \ 10 | EMBEDDINGS_ROOT=/dev/shm/mikhail/eva/EMBEDDINGS/${TASK}/${MODEL_NAME} \ 11 | DATA_ROOT=/pathology_fm/mikhail/data/eva/${TASK} \ 12 | NORMALIZE_MEAN=[0.5,0.5,0.5] \ 13 | NORMALIZE_STD=[0.5,0.5,0.5] \ 14 | python -m eva predict_fit --config configs/vision/pathology/offline/classification/${TASK}.yaml; \ 15 | rm -rf /dev/shm/mikhail/eva/EMBEDDINGS/${TASK}/${MODEL_NAME}"; 16 | done 17 | 18 | for TASK in {bach,crc,mhist,patch_camelyon,breakhis,bracs,gleason_arvaniti}; do 19 | bash -c "\ 20 | rm -rf /dev/shm/mikhail/eva/EMBEDDINGS/${TASK}/${MODEL_NAME}; \ 21 | MODEL_NAME=$MODEL_NAME \ 22 | IN_FEATURES=$(python get_dim.py $MODEL_NAME 1) \ 23 | PATIENCE=12500 \ 24 | OUTPUT_ROOT=/pathology_fm/mikhail/data/eva/RESULTS_patience/${TASK}/${MODEL_NAME} \ 25 | EMBEDDINGS_ROOT=/dev/shm/mikhail/eva/EMBEDDINGS/${TASK}/${MODEL_NAME} \ 26 | DATA_ROOT=/pathology_fm/mikhail/data/eva/${TASK} \ 27 | NORMALIZE_MEAN=[0.5,0.5,0.5] \ 28 | NORMALIZE_STD=[0.5,0.5,0.5] \ 29 | python -m eva predict_fit --config configs/vision/pathology/offline/classification/${TASK}.yaml; \ 30 | rm -rf /dev/shm/mikhail/eva/EMBEDDINGS/${TASK}/${MODEL_NAME}"; 31 | done 32 | 33 | for TASK in {consep,monusac}; do 34 | bash -c "\ 35 | MODEL_NAME=$MODEL_NAME \ 36 | IN_FEATURES=$(python get_dim.py $MODEL_NAME) \ 37 | PATIENCE=12500 \ 38 | OUTPUT_ROOT=/pathology_fm/mikhail/data/eva/RESULTS_patience/${TASK}/${MODEL_NAME} \ 39 | DATA_ROOT=/pathology_fm/mikhail/data/eva/${TASK} \ 40 | NORMALIZE_MEAN=[0.5,0.5,0.5] \ 41 | NORMALIZE_STD=[0.5,0.5,0.5] \ 42 | python -m eva fit --config configs/vision/pathology/online/segmentation/${TASK}.yaml"; 43 | done 44 | done 45 | 46 | 47 | # High-resolution benchmarks 48 | for MODEL_NAME in {DINOv2_vitg14_nki-tcga_post_100_aspect_epoch_059_bicubic_concat_resize392,}; do 49 | for TASK in {pcam_10shots,camelyon16_small,panda_small,}; do 50 | bash -c "\ 51 | rm -rf /dev/shm/mikhail/eva/EMBEDDINGS/${TASK}/${MODEL_NAME}; \ 52 | MODEL_NAME=$MODEL_NAME \ 53 | IN_FEATURES=$(python get_dim.py $MODEL_NAME 1) \ 54 | RESIZE_DIM=392 \ 55 | OUTPUT_ROOT=/pathology_fm/mikhail/data/eva/RESULTS_patience/${TASK}/${MODEL_NAME} \ 56 | EMBEDDINGS_ROOT=/dev/shm/mikhail/eva/EMBEDDINGS/${TASK}/${MODEL_NAME} \ 57 | DATA_ROOT=/pathology_fm/mikhail/data/eva/${TASK} \ 58 | NORMALIZE_MEAN=[0.5,0.5,0.5] \ 59 | NORMALIZE_STD=[0.5,0.5,0.5] \ 60 | python -m eva predict_fit --config configs/vision/pathology/offline/classification/${TASK}.yaml; \ 61 | rm -rf /dev/shm/mikhail/eva/EMBEDDINGS/${TASK}/${MODEL_NAME}"; 62 | done 63 | 64 | for TASK in {bach,crc,mhist,patch_camelyon,breakhis,bracs,gleason_arvaniti}; do 65 | bash -c "\ 66 | rm -rf /dev/shm/mikhail/eva/EMBEDDINGS/${TASK}/${MODEL_NAME}; \ 67 | MODEL_NAME=$MODEL_NAME \ 68 | IN_FEATURES=$(python get_dim.py $MODEL_NAME 1) \ 69 | RESIZE_DIM=392 \ 70 | PATIENCE=12500 \ 71 | OUTPUT_ROOT=/pathology_fm/mikhail/data/eva/RESULTS_patience/${TASK}/${MODEL_NAME} \ 72 | EMBEDDINGS_ROOT=/dev/shm/mikhail/eva/EMBEDDINGS/${TASK}/${MODEL_NAME} \ 73 | DATA_ROOT=/pathology_fm/mikhail/data/eva/${TASK} \ 74 | NORMALIZE_MEAN=[0.5,0.5,0.5] \ 75 | NORMALIZE_STD=[0.5,0.5,0.5] \ 76 | python -m eva predict_fit --config configs/vision/pathology/offline/classification/${TASK}.yaml; \ 77 | rm -rf /dev/shm/mikhail/eva/EMBEDDINGS/${TASK}/${MODEL_NAME}"; 78 | done 79 | 80 | for TASK in {consep,monusac}; do 81 | bash -c "\ 82 | MODEL_NAME=$MODEL_NAME \ 83 | IN_FEATURES=$(python get_dim.py $MODEL_NAME) \ 84 | RESIZE_DIM=392 \ 85 | PATIENCE=12500 \ 86 | OUTPUT_ROOT=/pathology_fm/mikhail/data/eva/RESULTS_patience/${TASK}/${MODEL_NAME} \ 87 | DATA_ROOT=/pathology_fm/mikhail/data/eva/${TASK} \ 88 | NORMALIZE_MEAN=[0.5,0.5,0.5] \ 89 | NORMALIZE_STD=[0.5,0.5,0.5] \ 90 | python -m eva fit --config configs/vision/pathology/online/segmentation/${TASK}.yaml"; 91 | done 92 | done 93 | 94 | 95 | # Benchmarks for ablation studies 96 | for MODEL_NAME in {DINOv2_vitb14_four,}; do 97 | for TASK in {bach,crc,mhist,patch_camelyon,camelyon16_small,panda_small,breakhis,bracs,gleason_arvaniti}; do 98 | bash -c "\ 99 | rm -rf /pathology_fm/mikhail/data/eva/EMBEDDINGS/${TASK}/${MODEL_NAME}; \ 100 | MODEL_NAME=$MODEL_NAME \ 101 | IN_FEATURES=$(python get_dim.py $MODEL_NAME 1) \ 102 | OUTPUT_ROOT=/pathology_fm/mikhail/data/eva/RESULTS/${TASK}/${MODEL_NAME} \ 103 | EMBEDDINGS_ROOT=/dev/shm/mikhail/eva/EMBEDDINGS/${TASK}/${MODEL_NAME} \ 104 | DATA_ROOT=/pathology_fm/mikhail/data/eva/${TASK} \ 105 | NORMALIZE_MEAN=[0.5,0.5,0.5] \ 106 | NORMALIZE_STD=[0.5,0.5,0.5] \ 107 | python -m eva predict_fit --config configs/vision/pathology/offline/classification/${TASK}.yaml; \ 108 | rm -r /dev/shm/mikhail/eva/EMBEDDINGS/${TASK}/${MODEL_NAME}"; 109 | done 110 | 111 | for TASK in {consep,monusac}; do 112 | bash -c "\ 113 | MODEL_NAME=$MODEL_NAME \ 114 | IN_FEATURES=$(python get_dim.py $MODEL_NAME) \ 115 | OUTPUT_ROOT=/pathology_fm/mikhail/data/eva/RESULTS/${TASK}/${MODEL_NAME} \ 116 | DATA_ROOT=/pathology_fm/mikhail/data/eva/${TASK} \ 117 | NORMALIZE_MEAN=[0.5,0.5,0.5] \ 118 | NORMALIZE_STD=[0.5,0.5,0.5] \ 119 | python -m eva fit --config configs/vision/pathology/online/segmentation/${TASK}.yaml"; 120 | done 121 | done 122 | -------------------------------------------------------------------------------- /eval/configs/vision/pathology/online/segmentation/monusac.yaml: -------------------------------------------------------------------------------- 1 | --- 2 | trainer: 3 | class_path: eva.Trainer 4 | init_args: 5 | n_runs: &N_RUNS ${oc.env:N_RUNS, 20} 6 | default_root_dir: &OUTPUT_ROOT ${oc.env:OUTPUT_ROOT, logs/${oc.env:MODEL_NAME, vit_small_patch16_224_dino}/monusac} 7 | max_steps: &MAX_STEPS ${oc.env:MAX_STEPS, 2000} 8 | log_every_n_steps: 6 9 | checkpoint_type: ${oc.env:CHECKPOINT_TYPE, best} 10 | callbacks: 11 | - class_path: eva.callbacks.ConfigurationLogger 12 | - class_path: lightning.pytorch.callbacks.TQDMProgressBar 13 | init_args: 14 | refresh_rate: ${oc.env:TQDM_REFRESH_RATE, 1} 15 | - class_path: eva.vision.callbacks.SemanticSegmentationLogger 16 | init_args: 17 | log_every_n_epochs: 1 18 | mean: &NORMALIZE_MEAN ${oc.env:NORMALIZE_MEAN, [0.485, 0.456, 0.406]} 19 | std: &NORMALIZE_STD ${oc.env:NORMALIZE_STD, [0.229, 0.224, 0.225]} 20 | - class_path: lightning.pytorch.callbacks.ModelCheckpoint 21 | init_args: 22 | filename: best 23 | save_last: true 24 | save_top_k: 1 25 | monitor: &MONITOR_METRIC ${oc.env:MONITOR_METRIC, 'val/MonaiDiceScore'} 26 | mode: &MONITOR_METRIC_MODE ${oc.env:MONITOR_METRIC_MODE, max} 27 | - class_path: lightning.pytorch.callbacks.EarlyStopping 28 | init_args: 29 | min_delta: 0 30 | patience: ${oc.env:PATIENCE, 50} 31 | monitor: *MONITOR_METRIC 32 | mode: *MONITOR_METRIC_MODE 33 | logger: 34 | - class_path: lightning.pytorch.loggers.TensorBoardLogger 35 | init_args: 36 | save_dir: *OUTPUT_ROOT 37 | name: "" 38 | model: 39 | class_path: eva.vision.models.modules.SemanticSegmentationModule 40 | init_args: 41 | encoder: 42 | class_path: eva.vision.models.ModelFromRegistry 43 | init_args: 44 | model_name: ${oc.env:MODEL_NAME, universal/vit_small_patch16_224_dino} 45 | model_kwargs: 46 | out_indices: ${oc.env:OUT_INDICES, 1} 47 | model_extra_kwargs: ${oc.env:MODEL_EXTRA_KWARGS, null} 48 | decoder: 49 | class_path: eva.vision.models.networks.decoders.segmentation.ConvDecoderWithImage 50 | init_args: 51 | in_features: ${oc.env:IN_FEATURES, 384} 52 | num_classes: &NUM_CLASSES 5 53 | criterion: 54 | class_path: eva.vision.losses.DiceLoss 55 | init_args: 56 | softmax: true 57 | batch: true 58 | ignore_index: &IGNORE_INDEX 5 59 | lr_multiplier_encoder: 0.0 60 | optimizer: 61 | class_path: torch.optim.AdamW 62 | init_args: 63 | lr: ${oc.env:LR_VALUE, 0.002} 64 | lr_scheduler: 65 | class_path: torch.optim.lr_scheduler.PolynomialLR 66 | init_args: 67 | total_iters: *MAX_STEPS 68 | power: 0.9 69 | postprocess: 70 | predictions_transforms: 71 | - class_path: torch.argmax 72 | init_args: 73 | dim: 1 74 | metrics: 75 | common: 76 | - class_path: eva.metrics.AverageLoss 77 | evaluation: 78 | - class_path: eva.vision.metrics.defaults.MulticlassSegmentationMetrics 79 | init_args: 80 | num_classes: 6 81 | ignore_index: *IGNORE_INDEX 82 | - class_path: torchmetrics.ClasswiseWrapper 83 | init_args: 84 | metric: 85 | class_path: eva.vision.metrics.MonaiDiceScore 86 | init_args: 87 | include_background: true 88 | num_classes: 6 89 | reduction: none 90 | ignore_index: *IGNORE_INDEX 91 | labels: 92 | - background 93 | - epithelial 94 | - lymphocyte 95 | - neutrophil 96 | - macrophage 97 | - ambiguous 98 | data: 99 | class_path: eva.DataModule 100 | init_args: 101 | datasets: 102 | train: 103 | class_path: eva.vision.datasets.MoNuSAC 104 | init_args: &DATASET_ARGS 105 | root: ${oc.env:DATA_ROOT, ./data/monusac} 106 | split: train 107 | download: ${oc.env:DOWNLOAD_DATA, false} 108 | # Set `download: true` to download the dataset from https://monusac-2020.grand-challenge.org/Data/ 109 | # The MoNuSAC dataset is distributed under the following license: 110 | # "Creative Commons Attribution-NonCommercial-ShareAlike 4.0 International" 111 | # (see: https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode) 112 | transforms: 113 | class_path: torchvision.transforms.v2.Compose 114 | init_args: 115 | transforms: 116 | - class_path: torchvision.transforms.v2.RandomResizedCrop 117 | init_args: 118 | size: ${oc.env:RESIZE_DIM, 224} 119 | - class_path: torchvision.transforms.v2.ToDtype 120 | init_args: 121 | dtype: torch.float32 122 | scale: true 123 | - class_path: torchvision.transforms.v2.Normalize 124 | init_args: 125 | mean: *NORMALIZE_MEAN 126 | std: *NORMALIZE_STD 127 | val: 128 | class_path: eva.vision.datasets.MoNuSAC 129 | init_args: 130 | <<: *DATASET_ARGS 131 | split: test 132 | transforms: 133 | class_path: eva.vision.data.transforms.common.ResizeAndCrop 134 | init_args: 135 | size: ${oc.env:RESIZE_DIM, 224} 136 | mean: *NORMALIZE_MEAN 137 | std: *NORMALIZE_STD 138 | dataloaders: 139 | train: 140 | batch_size: &BATCH_SIZE ${oc.env:BATCH_SIZE, 64} 141 | num_workers: &N_DATA_WORKERS ${oc.env:N_DATA_WORKERS, 4} 142 | shuffle: true 143 | val: 144 | batch_size: *BATCH_SIZE 145 | num_workers: *N_DATA_WORKERS 146 | -------------------------------------------------------------------------------- /eval/backbones.py: -------------------------------------------------------------------------------- 1 | """Wrappers for custom models.""" 2 | 3 | import timm 4 | import torch 5 | import torch.nn as nn 6 | import torchvision 7 | from typing_extensions import override 8 | 9 | from object_tools import load_model_checkpoint 10 | 11 | 12 | from typing import Any, Callable 13 | 14 | import torch 15 | import torch.nn as nn 16 | from transformers import AutoConfig, AutoModel 17 | from typing_extensions import override 18 | 19 | from extract_cls_token import ExtractCLSToken 20 | 21 | 22 | class HuggingFaceModel(nn.Module): 23 | """Wrapper class for loading HuggingFace `transformers` models.""" 24 | 25 | def __init__( 26 | self, 27 | model_name_or_path: str, 28 | output_transform: Callable = ExtractCLSToken(), 29 | with_config: bool = True, 30 | **kwargs: Any, 31 | ) -> None: 32 | """Initializes the model. 33 | 34 | Args: 35 | model_name_or_path: The model name or path to load the model from. 36 | This can be a local path or a model name from the `HuggingFace` 37 | model hub. 38 | output_transform: The transform to apply to the output tensor produced by the model. 39 | """ 40 | super().__init__() 41 | 42 | self._output_transform = output_transform 43 | 44 | config = AutoConfig.from_pretrained(model_name_or_path) if with_config else None 45 | self._model = AutoModel.from_pretrained(model_name_or_path, config=config, **kwargs) 46 | 47 | @override 48 | def forward(self, tensor: torch.Tensor) -> torch.Tensor: 49 | """Forward pass through the model.""" 50 | 51 | tensor = self._model(tensor) 52 | return self._output_transform(tensor) 53 | 54 | 55 | class TimmModel(nn.Module): 56 | def __init__( 57 | self, 58 | model_name: str, 59 | concat_mean_patch_tokens: bool = False, 60 | **kwargs, 61 | ): 62 | super().__init__() 63 | if kwargs.get("mlp_layer") == "timm.layers.SwiGLUPacked": 64 | kwargs["mlp_layer"] = timm.layers.SwiGLUPacked 65 | if kwargs.get("act_layer") == "torch.nn.SiLU": 66 | kwargs["act_layer"] = torch.nn.SiLU 67 | self.model = timm.create_model(model_name, **kwargs) 68 | self.concat_mean_patch_tokens = concat_mean_patch_tokens 69 | self.out_indices = kwargs.get("out_indices") 70 | 71 | @override 72 | def forward(self, tensor: torch.Tensor) -> torch.Tensor: 73 | """Forward pass through the model.""" 74 | 75 | if self.out_indices is not None: 76 | return self.model(tensor) 77 | 78 | output = self.model.forward_features(tensor) 79 | 80 | class_token = output[:, 0] 81 | patch_tokens = output[:, self.model.num_prefix_tokens :] # skip cls token and registers 82 | 83 | if self.concat_mean_patch_tokens: 84 | # concatenate class token and average pool of patch tokens 85 | embedding = torch.cat([class_token, patch_tokens.mean(1)], dim=-1) 86 | else: 87 | embedding = class_token 88 | return embedding 89 | 90 | 91 | # Based on https://huggingface.co/paige-ai/Virchow2 92 | class Virchow2(TimmModel): 93 | def __init__(self, concat_mean_patch_tokens: bool = True, **kwargs): 94 | super().__init__( 95 | model_name="hf-hub:paige-ai/Virchow2", 96 | pretrained=True, 97 | mlp_layer=timm.layers.SwiGLUPacked, 98 | act_layer=torch.nn.SiLU, 99 | concat_mean_patch_tokens=concat_mean_patch_tokens, 100 | **kwargs, 101 | ) 102 | if self.out_indices is None: 103 | assert self.model.num_prefix_tokens == 5 104 | 105 | 106 | class Kaiko(nn.Module): 107 | """A wrapper constructing custom embeddings for the standard ViT models. 108 | The final embedding is a concatenation of the class token with the average of the patch tokens. 109 | """ 110 | 111 | def __init__( 112 | self, 113 | repo_or_dir: str, 114 | model: str, 115 | pretrained: bool | None = None, 116 | ckpt_path: str | None = None, 117 | ckpt_submodule: str | None = None, 118 | concat_mean_patch_tokens: bool = False, 119 | resize: int | None = None, 120 | mode: str = "bilinear", 121 | antialias: bool = True, 122 | **kwargs, 123 | ): 124 | super().__init__() 125 | 126 | self.model = torch.hub.load( 127 | repo_or_dir=repo_or_dir, 128 | model=model, 129 | **({"pretrained": pretrained} if pretrained is not None else {}), 130 | **kwargs, 131 | ) 132 | 133 | if ckpt_path is not None: 134 | load_model_checkpoint(self.model, ckpt_path, ckpt_submodule=ckpt_submodule) 135 | 136 | self.concat_mean_patch_tokens = concat_mean_patch_tokens 137 | self.resize = resize 138 | self.mode = mode 139 | self.antialias = antialias 140 | self.out_indices = kwargs.get("out_indices") 141 | 142 | @override 143 | def forward(self, tensor: torch.Tensor) -> torch.Tensor: 144 | """Forward pass through the model.""" 145 | if self.resize and tensor.numel() > 0: 146 | tensor = torchvision.transforms.functional.center_crop(tensor, min(tensor.shape[-2:])) 147 | tensor = nn.functional.interpolate( 148 | tensor, size=(self.resize, self.resize), mode=self.mode, antialias=self.antialias 149 | ) 150 | 151 | if self.out_indices is not None: 152 | return self.model(tensor) 153 | 154 | out = self.model.forward_features(tensor) 155 | 156 | if isinstance(out, torch.Tensor): 157 | class_token = out[:, 0] 158 | patch_tokens = out[:, 1:] 159 | else: 160 | class_token = out["x_norm_clstoken"] 161 | patch_tokens = out["x_norm_patchtokens"] 162 | 163 | if self.concat_mean_patch_tokens: 164 | # concatenate class token and average pool of patch tokens 165 | embedding = torch.cat([class_token, patch_tokens.mean(1)], dim=-1) 166 | else: 167 | embedding = class_token 168 | return embedding 169 | 170 | def get_intermediate_layers( 171 | self, 172 | x: torch.Tensor, 173 | n: int | tuple[int, ...] = 1, # Layers or n last layers to take 174 | reshape: bool = False, 175 | return_class_token: bool = False, 176 | norm=True, 177 | ) -> tuple[torch.Tensor | tuple[torch.Tensor]]: 178 | """Returns the intermediate layers of the model.""" 179 | return self.model.get_intermediate_layers( 180 | x, n=n, reshape=reshape, return_class_token=return_class_token, norm=norm 181 | ) 182 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Kaiko midnight 2 | Midnight - Training State-of-the-Art Pathology Foundation Models with Orders of Magnitude Less Data 3 | 4 | This repository contains supplementary data for the paper [_Training state-of-the-art pathology foundation models with orders of magnitude less data_](https://arxiv.org/abs/2504.05186v1). Our approach achieves competitive performance compared to leading pathology foundation models (FMs), despite being trained on significantly fewer whole slide images (WSIs). 5 | 6 | ```bibtex 7 | @InProceedings{KDK_Training_MICCAI2025, 8 | title={Training state-of-the-art pathology foundation models with orders of magnitude less data}, 9 | author={Mikhail Karasikov and Joost van Doorn and Nicolas Känzig and Melis Erdal Cesur and Hugo Mark Horlings and Robert Berke and Fei Tang and Sebastian Otálora}, 10 | booktitle = {Medical Image Computing and Computer Assisted Intervention -- MICCAI 2025}, 11 | year = {2025}, 12 | publisher = {Springer Nature Switzerland}, 13 | volume = {LNCS 15967}, 14 | month = {October}, 15 | pages = {573--583}, 16 | doi={10.1007/978-3-032-04984-1_55}, 17 | } 18 | ``` 19 | 20 | ## Overview 21 | 22 | We propose a refined self-supervised training framework based on DINOv2 with modifications that optimize model performance specifically for computational pathology. Our main contributions include: 23 | 24 | - Three novel pathology FMs trained with significantly reduced data (up to 100x fewer WSIs). 25 | - Introduction of high-resolution post-training to enhance embedding quality. 26 | 27 | ## Model Highlights 28 | 29 | - **Midnight-12k**: Trained exclusively on the publicly available TCGA dataset (12k WSIs). 30 | - **Midnight-92k**: Trained on TCGA and an additional proprietary dataset (NKI-80k). 31 | - **Midnight-92k/392**: Our top-performing model fine-tuned with high-resolution post-training. 32 | 33 | ## Training Datasets 34 | 35 | | Dataset | WSIs | Source | Comment | 36 | |---------|------|---------------|------------| 37 | | TCGA | 12k | Public | FFPE only | 38 | | NKI-80k | 80k | Proprietary | 10,141 patients, 31 organs | 39 | 40 | ## Training Components 41 | 42 | - **DINOv2**: Self-supervised training with [DINOv2](https://github.com/facebookresearch/dinov2). 43 | - **[KDE regularizer](https://proceedings.mlr.press/v119/wang20k/wang20k.pdf)**: Replaced KoLeo in DINOv2 to ensure embedding diversity and training stability. 44 | - **[Online patching](https://arxiv.org/pdf/2404.15217)**: Efficient real-time extraction of informative tiles. 45 | - **Color augmentation ([HED](https://arxiv.org/pdf/1902.06543))**: Robustness to stain variations. 46 | - **Tile [filtering](https://arxiv.org/html/2408.00738v3#S5)**: Removal of low-informative tissue regions. 47 | 48 | ## Evaluation 49 | 50 | We comprehensively evaluated the models using two sets of open-source benchmarks: 51 | 52 | - [eva](https://github.com/kaiko-ai/eva): For both tile (classification, segmentation) and slide-level tasks. 53 | - [HEST](https://github.com/mahmoodlab/HEST): For gene expression prediction tasks (regression). 54 | 55 | Our best model **Midnight-92k/392** consistently outperforms or matches leading models like Virchow2 and UNI-2. 56 | 57 | ## Results Summary 58 | 59 | | Model | AVG. | PCam 10 shots | BACH | BRACS | BreaKHis | CRC | Gleason | MHIST | PCam | Cam16 (small) | Panda (small) | CoNSeP | MoNuSAC | HEST | 60 | |-------|------|---------------|------|-------|----------|-----|---------|-------|------|---------------|---------------|--------|---------|------| 61 | | **[Midnight-92k/392](#usage)** | **0.778** | **0.900** | **0.904** | **0.646** | 0.802 | 0.966 | **0.807** | 0.828 | **0.951** | 0.868 | 0.651 | **0.662** | **0.708** | 0.415 | 62 | | [UNI-2](https://huggingface.co/MahmoodLab/UNI2-h) | **0.776** | **0.885** | **0.924** | **0.651** | **0.863** | **0.970** | 0.777 | 0.829 | **0.951** | **0.873** | **0.666** | 0.626 | 0.644 | **0.431** | 63 | | **[Midnight-92k](#usage)** | **0.767** | **0.882** | 0.889 | 0.615 | 0.793 | **0.967** | **0.823** | 0.831 | 0.948 | **0.872** | 0.643 | 0.629 | 0.656 | **0.425** | 64 | | [Virchow2](https://huggingface.co/paige-ai/Virchow2) | 0.766 | 0.835 | 0.890 | 0.633 | 0.818 | 0.966 | **0.791** | **0.865** | 0.938 | 0.860 | 0.646 | 0.640 | 0.674 | 0.403 | 65 | | **[Midnight-12k](#usage)** | 0.763 | 0.803 | **0.907** | 0.639 | 0.840 | **0.967** | 0.790 | 0.815 | 0.931 | **0.869** | 0.656 | 0.625 | 0.664 | 0.412 | 66 | | [Kaiko-B8](https://github.com/kaiko-ai/towards_large_pathology_fms) | 0.757 | 0.799 | 0.876 | 0.641 | **0.842** | 0.960 | 0.761 | 0.830 | 0.920 | 0.836 | 0.650 | **0.644** | 0.686 | 0.391 | 67 | | [H-Optimus-0](https://huggingface.co/bioptimus/H-optimus-0) | 0.755 | 0.831 | 0.752 | 0.620 | 0.813 | 0.962 | 0.769 | **0.850** | 0.943 | 0.847 | **0.672** | **0.644** | **0.687** | **0.425** | 68 | | [Prov_GigaPath](https://github.com/prov-gigapath/prov-gigapath) | 0.752 | 0.853 | 0.794 | 0.626 | **0.846** | 0.959 | 0.727 | 0.831 | 0.944 | 0.812 | 0.657 | 0.628 | **0.688** | 0.405 | 69 | | [Hibou-L](https://huggingface.co/histai/hibou-L) | 0.751 | 0.825 | 0.792 | **0.643** | 0.767 | 0.954 | 0.766 | **0.850** | **0.949** | 0.852 | 0.654 | **0.646** | 0.668 | 0.397 | 70 | | [UNI](https://huggingface.co/MahmoodLab/UNI) | 0.749 | 0.833 | 0.797 | 0.613 | 0.808 | 0.954 | 0.759 | 0.841 | 0.937 | 0.854 | **0.662** | 0.627 | 0.662 | 0.391 | 71 | | [Phikon](https://huggingface.co/owkin/phikon) | 0.724 | 0.826 | 0.744 | 0.579 | 0.715 | 0.946 | 0.743 | 0.824 | 0.919 | 0.822 | 0.648 | 0.624 | 0.644 | 0.377 | 72 | | [Phikon-v2](https://huggingface.co/owkin/phikon-v2) | 0.718 | 0.756 | 0.737 | 0.607 | 0.725 | 0.953 | 0.753 | 0.796 | 0.900 | 0.807 | 0.634 | 0.626 | 0.645 | 0.391 | 73 | | [Lunit](https://github.com/lunit-io/benchmark-ssl-pathology) | 0.714 | 0.763 | 0.785 | 0.627 | 0.759 | 0.943 | 0.758 | 0.785 | 0.905 | 0.759 | 0.604 | 0.600 | 0.630 | 0.362 | 74 | | [vitg14 (nat. img.)](https://github.com/facebookresearch/dinov2) | 0.674 | 0.721 | 0.724 | 0.578 | 0.783 | 0.943 | 0.740 | **0.855** | 0.881 | 0.500 | 0.509 | 0.565 | 0.614 | 0.351 | 75 | | [vitg14 (initial)](https://github.com/facebookresearch/dinov2) | 0.493 | 0.652 | 0.474 | 0.413 | 0.425 | 0.754 | 0.459 | 0.578 | 0.763 | 0.526 | 0.304 | 0.462 | 0.432 | 0.166 | 76 | 77 | ## Model Weights 78 | - **Midnight-12k**: Publicly available at https://huggingface.co/kaiko-ai/midnight. 79 | - **Midnight-92k** & **Midnight-92k/392**: Trained on proprietary data and, hence, subject to restricted access. 80 | 81 | 82 | ## Usage 83 | 84 | **Midnight-12k** is publicly available at [https://huggingface.co/kaiko-ai/midnight](https://huggingface.co/kaiko-ai/midnight). 85 | 86 | Our models are trained on 224x224 images normalized with a mean of (0.5, 0.5, 0.5) and a standard deviation of (0.5, 0.5, 0.5). Please ensure you apply these exact normalization parameters when preparing your datasets for embedding extraction. 87 | 88 | ```python 89 | from transformers import AutoImageProcessor, AutoModel 90 | from PIL import Image 91 | import requests 92 | from torchvision.transforms import v2 93 | 94 | url = 'https://upload.wikimedia.org/wikipedia/commons/8/80/Breast_DCIS_histopathology_%281%29.jpg' 95 | image = Image.open(requests.get(url, stream=True).raw) 96 | 97 | transform = v2.Compose( 98 | [ 99 | v2.Resize(224), 100 | v2.CenterCrop(224), 101 | v2.ToTensor(), 102 | v2.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)), 103 | ] 104 | ) 105 | model = AutoModel.from_pretrained('kaiko-ai/midnight') 106 | ``` 107 | 108 | ### Extract embeddings for classification 109 | For segmentation tasks, the model output corresponds to 16x16 patch tokens (derived from 224/14=16). 110 | ```python 111 | import torch 112 | 113 | def extract_classification_embedding(tensor): 114 | cls_embedding, patch_embeddings = tensor[:, 0, :], tensor[:, 1:, :] 115 | return torch.cat([cls_embedding, patch_embeddings.mean(1)], dim=-1) 116 | 117 | batch = transform(image).unsqueeze(dim=0) 118 | embedding = extract_classification_embedding(model(batch).last_hidden_state) 119 | print(f"Embedding shape: {embedding[0].shape}") 120 | ``` 121 | 122 | ### Extract embeddings for segmentation 123 | 124 | ```python 125 | import math 126 | import torch 127 | 128 | def extract_segmentation_embedding(tensor): 129 | features = tensor[:, 1:, :].permute(0, 2, 1) 130 | batch_size, hidden_size, patch_grid = features.shape 131 | height = width = int(math.sqrt(patch_grid)) 132 | return features.view(batch_size, hidden_size, height, width) 133 | 134 | batch = transform(image).unsqueeze(dim=0) 135 | embedding = extract_segmentation_embedding(model(batch).last_hidden_state) 136 | print(f"Embedding shape: {embedding[0].shape}") 137 | ``` 138 | 139 | ### Use via Trident 140 | 141 | Midnight-12k is now supported in the [Trident toolkit](https://github.com/mahmoodlab/TRIDENT), see the documentation for more details. 142 | 143 |
144 | 145 |
146 | 147 |
148 | -------------------------------------------------------------------------------- /eval/hest_bench_config.yaml: -------------------------------------------------------------------------------- 1 | # directory containing the data for each task 2 | bench_data_root: '/Users/mike/pathology_fm/hest_eval/bench_data' 3 | 4 | # directory where benchmark results will be dumped 5 | results_dir: '/Users/mike/pathology_fm/hest_eval/ST_pred_results' 6 | 7 | # directory where the vision embeddings will be dumped 8 | embed_dataroot: '/Users/mike/pathology_fm/hest_eval/ST_data_emb' 9 | 10 | # directory to the model weights root 11 | weights_root: '/Users/mike/pathology_fm/hest_eval/fm_v1' 12 | 13 | # inference parameters 14 | batch_size: 128 15 | num_workers: 4 16 | 17 | # encoders to benchmark 18 | encoders: [ 19 | #"plip", 20 | #"uni_v1", # uncomment after requesting the weights 21 | #"resnet50", 22 | #"ctranspath", 23 | #"phikon", 24 | #"remedis", # uncomment after requesting the weights 25 | #"conch_v1", # uncomment after requesting the weights 26 | #"gigapath", # uncomment after requesting the weights 27 | #"virchow", # uncomment after requesting the weights 28 | #"virchow2", # uncomment after requesting the weights 29 | #"hoptimus0", 30 | ] 31 | 32 | # datasets contained in `bench_data_root` to benchmark 33 | datasets: [ 34 | "IDC", 35 | "PRAD", 36 | "PAAD", 37 | "SKCM", 38 | "COAD", 39 | "READ", 40 | "CCRCC", 41 | "HCC", 42 | "LUNG", 43 | "LYMPH_IDC", 44 | ] 45 | 46 | dimreduce: "PCA" 47 | 48 | custom_encoders: 49 | # DINOv2_vitg14_nki-tcga_post_100_aspect_epoch_059_resize392: 50 | # path: backbones.Kaiko 51 | # arguments: 52 | # repo_or_dir: facebookresearch/dinov2:main 53 | # model: dinov2_vitg14 54 | # pretrained: false 55 | # concat_mean_patch_tokens: false 56 | # resize: 392 57 | # ckpt_path: /mnt/vast01/shared/experimental/pathology_fm/mikhail/runs/E48D4DAD-C48E-49E5-B4D9-AB679A594090/lightning_logs/version_0/checkpoints/epoch_059-step_30000.ckpt 58 | # ckpt_submodule: teacher.backbone 59 | # mode: bicubic 60 | 61 | # DINOv2_vitg14_nki-tcga_post_100_aspect_epoch_059_resize392_concat: 62 | # path: backbones.Kaiko 63 | # arguments: 64 | # repo_or_dir: facebookresearch/dinov2:main 65 | # model: dinov2_vitg14 66 | # pretrained: false 67 | # concat_mean_patch_tokens: true 68 | # resize: 392 69 | # ckpt_path: /mnt/vast01/shared/experimental/pathology_fm/mikhail/runs/E48D4DAD-C48E-49E5-B4D9-AB679A594090/lightning_logs/version_0/checkpoints/epoch_059-step_30000.ckpt 70 | # ckpt_submodule: teacher.backbone 71 | # mode: bicubic 72 | 73 | # DINOv2_vitg14_tcga_post-three_300_aspect_epoch_029_resize392: 74 | # path: backbones.Kaiko 75 | # arguments: 76 | # repo_or_dir: facebookresearch/dinov2:main 77 | # model: dinov2_vitg14 78 | # pretrained: false 79 | # concat_mean_patch_tokens: false 80 | # resize: 392 81 | # ckpt_path: /mnt/vast01/shared/experimental/pathology_fm/mikhail/runs/5B2123C9-D2CB-49A7-AD0A-A5F29BCD070E/lightning_logs/version_0/checkpoints/epoch_029-step_15000.ckpt 82 | # ckpt_submodule: teacher.backbone 83 | # mode: bicubic 84 | 85 | # DINOv2_vitg14_tcga_post-three_300_aspect_epoch_029_resize392_concat: 86 | # path: backbones.Kaiko 87 | # arguments: 88 | # repo_or_dir: facebookresearch/dinov2:main 89 | # model: dinov2_vitg14 90 | # pretrained: false 91 | # concat_mean_patch_tokens: true 92 | # resize: 392 93 | # ckpt_path: /mnt/vast01/shared/experimental/pathology_fm/mikhail/runs/5B2123C9-D2CB-49A7-AD0A-A5F29BCD070E/lightning_logs/version_0/checkpoints/epoch_029-step_15000.ckpt 94 | # ckpt_submodule: teacher.backbone 95 | # mode: bicubic 96 | 97 | # DINOv2_vitg14_from_imagenet_tcga_100M_epoch_294: 98 | # ckpt_path: /mnt/vast01/shared/experimental/pathology_fm/mikhail/runs/1FDA5ADA-8357-4F5E-8A5C-8758F7751471/teacher.backbone/epoch_294-step_491765.pth 99 | # path: torch.hub.load 100 | # arguments: 101 | # repo_or_dir: facebookresearch/dinov2:main 102 | # model: dinov2_vitg14 103 | # pretrained: false 104 | 105 | # DINOv2_vitg14_from_imagenet_tcga_100M_epoch_294_concat: 106 | # path: backbones.Kaiko 107 | # arguments: 108 | # repo_or_dir: facebookresearch/dinov2:main 109 | # model: dinov2_vitg14 110 | # pretrained: false 111 | # concat_mean_patch_tokens: true 112 | # ckpt_path: /mnt/vast01/shared/experimental/pathology_fm/mikhail/runs/1FDA5ADA-8357-4F5E-8A5C-8758F7751471/teacher.backbone/epoch_294-step_491765.pth 113 | 114 | # vitL16_UNI: 115 | # path: renormalized.RenormalizingModel 116 | # arguments: 117 | # new_normalization: 118 | # mean: [0.485, 0.456, 0.406] 119 | # std: [0.229, 0.224, 0.225] 120 | # model: 121 | # path: backbones.TimmModel 122 | # arguments: 123 | # concat_mean_patch_tokens: false 124 | # model_name: hf-hub:MahmoodLab/uni 125 | # init_values: 1.0e-5 126 | # pretrained: true 127 | # dynamic_img_size: true 128 | # num_classes: 0 129 | 130 | # vitL16_UNI_resize512: 131 | # path: renormalized.RenormalizingModel 132 | # arguments: 133 | # new_normalization: 134 | # mean: [0.485, 0.456, 0.406] 135 | # std: [0.229, 0.224, 0.225] 136 | # model: 137 | # path: backbones.TimmModel 138 | # arguments: 139 | # concat_mean_patch_tokens: false 140 | # model_name: hf-hub:MahmoodLab/uni 141 | # init_values: 1.0e-5 142 | # pretrained: true 143 | # dynamic_img_size: true 144 | # num_classes: 0 145 | 146 | # vitL16_UNI_resize512_concat: 147 | # path: renormalized.RenormalizingModel 148 | # arguments: 149 | # new_normalization: 150 | # mean: [0.485, 0.456, 0.406] 151 | # std: [0.229, 0.224, 0.225] 152 | # model: 153 | # path: backbones.TimmModel 154 | # arguments: 155 | # concat_mean_patch_tokens: true 156 | # model_name: hf-hub:MahmoodLab/uni 157 | # init_values: 1.0e-5 158 | # pretrained: true 159 | # dynamic_img_size: true 160 | # num_classes: 0 161 | 162 | # vitL16_UNI_concat: 163 | # path: renormalized.RenormalizingModel 164 | # arguments: 165 | # new_normalization: 166 | # mean: [0.485, 0.456, 0.406] 167 | # std: [0.229, 0.224, 0.225] 168 | # model: 169 | # path: backbones.TimmModel 170 | # arguments: 171 | # concat_mean_patch_tokens: true 172 | # model_name: hf-hub:MahmoodLab/uni 173 | # init_values: 1.0e-5 174 | # pretrained: true 175 | # dynamic_img_size: true 176 | # num_classes: 0 177 | 178 | # virchow2: 179 | # path: renormalized.RenormalizingModel 180 | # arguments: 181 | # new_normalization: 182 | # mean: [0.485, 0.456, 0.406] 183 | # std: [0.229, 0.224, 0.225] 184 | # model: 185 | # path: backbones.Virchow2 186 | # arguments: {concat_mean_patch_tokens: false} 187 | 188 | # virchow2_concat: 189 | # path: renormalized.RenormalizingModel 190 | # arguments: 191 | # new_normalization: 192 | # mean: [0.485, 0.456, 0.406] 193 | # std: [0.229, 0.224, 0.225] 194 | # model: 195 | # path: backbones.Virchow2 196 | # arguments: {concat_mean_patch_tokens: true} 197 | 198 | # Bioptimus_h_optimus_0: 199 | # path: renormalized.RenormalizingModel 200 | # arguments: 201 | # new_normalization: 202 | # mean: [0.707223, 0.578729, 0.703617] 203 | # std: [0.211883, 0.230117, 0.177517] 204 | # model: 205 | # path: timm.create_model 206 | # arguments: 207 | # model_name: hf-hub:bioptimus/H-optimus-0 208 | # init_values: 1.0e-5 209 | # pretrained: true 210 | # dynamic_img_size: true 211 | # num_classes: 0 212 | 213 | # Bioptimus_h_optimus_0_concat: 214 | # path: renormalized.RenormalizingModel 215 | # arguments: 216 | # new_normalization: 217 | # mean: [0.707223, 0.578729, 0.703617] 218 | # std: [0.211883, 0.230117, 0.177517] 219 | # model: 220 | # path: backbones.TimmModel 221 | # arguments: 222 | # concat_mean_patch_tokens: true 223 | # model_name: hf-hub:bioptimus/H-optimus-0 224 | # init_values: 1.0e-5 225 | # pretrained: true 226 | # dynamic_img_size: true 227 | # num_classes: 0 228 | 229 | # vit_giant_patch14_224_UNI_resize392: 230 | # path: renormalized.RenormalizingModel 231 | # arguments: 232 | # new_normalization: 233 | # mean: [0.485, 0.456, 0.406] 234 | # std: [0.229, 0.224, 0.225] 235 | # model: 236 | # path: timm.create_model 237 | # arguments: 238 | # model_name: hf-hub:MahmoodLab/UNI2-h 239 | # pretrained: True 240 | # img_size: 224 241 | # patch_size: 14 242 | # depth: 24 243 | # num_heads: 24 244 | # init_values: 1.0e-5 245 | # embed_dim: 1536 246 | # mlp_ratio: 5.33334 # 2.66667*2 247 | # num_classes: 0 248 | # no_embed_class: True 249 | # mlp_layer: timm.layers.SwiGLUPacked 250 | # act_layer: torch.nn.SiLU 251 | # reg_tokens: 8 252 | # dynamic_img_size: True 253 | 254 | # vit_giant_patch14_224_UNI_resize392_concat: 255 | # path: renormalized.RenormalizingModel 256 | # arguments: 257 | # new_normalization: 258 | # mean: [0.485, 0.456, 0.406] 259 | # std: [0.229, 0.224, 0.225] 260 | # model: 261 | # path: backbones.TimmModel 262 | # arguments: 263 | # concat_mean_patch_tokens: true 264 | # model_name: hf-hub:MahmoodLab/UNI2-h 265 | # pretrained: True 266 | # img_size: 224 267 | # patch_size: 14 268 | # depth: 24 269 | # num_heads: 24 270 | # init_values: 1.0e-5 271 | # embed_dim: 1536 272 | # mlp_ratio: 5.33334 # 2.66667*2 273 | # num_classes: 0 274 | # no_embed_class: True 275 | # mlp_layer: timm.layers.SwiGLUPacked 276 | # act_layer: torch.nn.SiLU 277 | # reg_tokens: 8 278 | # dynamic_img_size: True 279 | 280 | # vit_giant_patch14_224_UNI_concat: 281 | # path: renormalized.RenormalizingModel 282 | # arguments: 283 | # new_normalization: 284 | # mean: [0.485, 0.456, 0.406] 285 | # std: [0.229, 0.224, 0.225] 286 | # model: 287 | # path: backbones.TimmModel 288 | # arguments: 289 | # concat_mean_patch_tokens: true 290 | # model_name: hf-hub:MahmoodLab/UNI2-h 291 | # pretrained: True 292 | # img_size: 224 293 | # patch_size: 14 294 | # depth: 24 295 | # num_heads: 24 296 | # init_values: 1.0e-5 297 | # embed_dim: 1536 298 | # mlp_ratio: 5.33334 # 2.66667*2 299 | # num_classes: 0 300 | # no_embed_class: True 301 | # mlp_layer: timm.layers.SwiGLUPacked 302 | # act_layer: torch.nn.SiLU 303 | # reg_tokens: 8 304 | # dynamic_img_size: True 305 | 306 | # vitL14_histai_hibou_l: 307 | # path: renormalized.RenormalizingModel 308 | # arguments: 309 | # new_normalization: 310 | # mean: [0.7068,0.5755,0.722] 311 | # std: [0.195,0.2316,0.1816] 312 | # model: 313 | # path: backbones.HuggingFaceModel 314 | # arguments: 315 | # model_name_or_path: histai/hibou-L 316 | # trust_remote_code: true 317 | # with_config: false 318 | 319 | # vitL14_histai_hibou_l_concat: 320 | # path: renormalized.RenormalizingModel 321 | # arguments: 322 | # new_normalization: 323 | # mean: [0.7068,0.5755,0.722] 324 | # std: [0.195,0.2316,0.1816] 325 | # model: 326 | # path: backbones.HuggingFaceModel 327 | # arguments: 328 | # model_name_or_path: histai/hibou-L 329 | # trust_remote_code: true 330 | # with_config: false 331 | # output_transform: 332 | # class_path: extract_cls_token.ExtractConcatToken 333 | # init_args: 334 | # num_reg_tokens: 4 335 | 336 | # vitg14_Prov_GigaPath: 337 | # path: renormalized.RenormalizingModel 338 | # arguments: 339 | # new_normalization: 340 | # mean: [0.485, 0.456, 0.406] 341 | # std: [0.229, 0.224, 0.225] 342 | # model: 343 | # path: backbones.TimmModel 344 | # arguments: 345 | # concat_mean_patch_tokens: false 346 | # model_name: hf_hub:prov-gigapath/prov-gigapath 347 | # pretrained: true 348 | # dynamic_img_size: true 349 | # num_classes: 0 350 | 351 | # vitg14_Prov_GigaPath_concat: 352 | # path: renormalized.RenormalizingModel 353 | # arguments: 354 | # new_normalization: 355 | # mean: [0.485, 0.456, 0.406] 356 | # std: [0.229, 0.224, 0.225] 357 | # model: 358 | # path: backbones.TimmModel 359 | # arguments: 360 | # concat_mean_patch_tokens: true 361 | # model_name: hf_hub:prov-gigapath/prov-gigapath 362 | # pretrained: true 363 | # dynamic_img_size: true 364 | # num_classes: 0 365 | 366 | # dino_vitL16_phikon2: 367 | # path: renormalized.RenormalizingModel 368 | # arguments: 369 | # new_normalization: 370 | # mean: [0.485, 0.456, 0.406] 371 | # std: [0.229, 0.224, 0.225] 372 | # model: 373 | # path: backbones.HuggingFaceModel 374 | # arguments: 375 | # model_name_or_path: owkin/phikon-v2 376 | # output_transform: 377 | # class_path: extract_cls_token.ExtractCLSToken 378 | 379 | # dino_vits16_phikon: 380 | # path: renormalized.RenormalizingModel 381 | # arguments: 382 | # new_normalization: 383 | # mean: [0.485, 0.456, 0.406] 384 | # std: [0.229, 0.224, 0.225] 385 | # model: 386 | # path: backbones.HuggingFaceModel 387 | # arguments: 388 | # model_name_or_path: owkin/phikon 389 | # output_transform: 390 | # class_path: extract_cls_token.ExtractCLSToken 391 | 392 | # dino_vitL16_phikon2_concat: 393 | # path: renormalized.RenormalizingModel 394 | # arguments: 395 | # new_normalization: 396 | # mean: [0.485, 0.456, 0.406] 397 | # std: [0.229, 0.224, 0.225] 398 | # model: 399 | # path: backbones.HuggingFaceModel 400 | # arguments: 401 | # model_name_or_path: owkin/phikon-v2 402 | # output_transform: 403 | # class_path: extract_cls_token.ExtractConcatToken 404 | 405 | # dino_vits16_phikon_concat: 406 | # path: renormalized.RenormalizingModel 407 | # arguments: 408 | # new_normalization: 409 | # mean: [0.485, 0.456, 0.406] 410 | # std: [0.229, 0.224, 0.225] 411 | # model: 412 | # path: backbones.HuggingFaceModel 413 | # arguments: 414 | # model_name_or_path: owkin/phikon 415 | # output_transform: 416 | # class_path: extract_cls_token.ExtractConcatToken 417 | 418 | # KAIKO-vitB8: 419 | # path: backbones.Kaiko 420 | # arguments: 421 | # repo_or_dir: kaiko-ai/towards_large_pathology_fms 422 | # model: vitb8 423 | # concat_mean_patch_tokens: false 424 | # trust_repo: true 425 | # dynamic_img_size: true 426 | # out_indices: null 427 | 428 | # KAIKO-vitB8_concat: 429 | # path: backbones.Kaiko 430 | # arguments: 431 | # repo_or_dir: kaiko-ai/towards_large_pathology_fms 432 | # model: vitb8 433 | # concat_mean_patch_tokens: true 434 | # trust_repo: true 435 | # dynamic_img_size: true 436 | # out_indices: null 437 | 438 | # kaikofm: 439 | # path: torch.hub.load 440 | # arguments: 441 | # repo_or_dir: facebookresearch/dinov2:main 442 | # model: dinov2_vitg14 443 | # pretrained: false 444 | # concat_mean_patch_tokens: false 445 | # ckpt_path: /mnt/vast01/shared/experimental/pathology_fm/mikhail/runs/FM-1738EDAF-99E8-48E1-B1F8-498B280E098F/teacher.backbone/epoch_274-step_458425.pth 446 | 447 | # DINOv2_vitg14_imagenet: 448 | # path: torch.hub.load 449 | # arguments: 450 | # repo_or_dir: facebookresearch/dinov2:main 451 | # model: dinov2_vitg14 452 | # pretrained: true 453 | 454 | # DINOv2_vitg14_imagenet_concat: 455 | # path: backbones.Kaiko 456 | # arguments: 457 | # repo_or_dir: facebookresearch/dinov2:main 458 | # model: dinov2_vitg14 459 | # pretrained: true 460 | # concat_mean_patch_tokens: true 461 | 462 | # vitg14: 463 | # path: torch.hub.load 464 | # arguments: 465 | # repo_or_dir: facebookresearch/dinov2:main 466 | # model: dinov2_vitg14 467 | # pretrained: false 468 | 469 | # vitg14_concat: 470 | # path: backbones.Kaiko 471 | # arguments: 472 | # repo_or_dir: facebookresearch/dinov2:main 473 | # model: dinov2_vitg14 474 | # pretrained: false 475 | # concat_mean_patch_tokens: true 476 | 477 | # dino_vits16_lunit_renorm: 478 | # path: renormalized.RenormalizingModel 479 | # arguments: 480 | # new_normalization: 481 | # mean: [0.70322989, 0.53606487, 0.66096631] 482 | # std: [0.21716536, 0.26081574, 0.20723464] 483 | # model: 484 | # path: timm.create_model 485 | # arguments: 486 | # model_name: hf-hub:1aurent/vit_small_patch16_224.lunit_dino 487 | # pretrained: true 488 | # dynamic_img_size: true 489 | # num_classes: 0 490 | 491 | # vits16_Lunit_renorm_concat: 492 | # path: renormalized.RenormalizingModel 493 | # arguments: 494 | # new_normalization: 495 | # mean: [0.70322989, 0.53606487, 0.66096631] 496 | # std: [0.21716536, 0.26081574, 0.20723464] 497 | # model: 498 | # path: backbones.TimmModel 499 | # arguments: 500 | # concat_mean_patch_tokens: true 501 | # model_name: hf-hub:1aurent/vit_small_patch16_224.lunit_dino 502 | # pretrained: true 503 | # dynamic_img_size: true 504 | # num_classes: 0 505 | 506 | # DINOv2_vitg14_from_imagenet_tcga_epoch_244: 507 | # # ckpt_path: az://experimental@stkaikodtpprdlab.blob.core.windows.net/pathology_fm/runs/nebul/vitg14_tcga_epoch_244-step_408415.pth 508 | # ckpt_path: /mnt/vast01/shared/outputs/mikhail/runs/FM-18A8DB4F-B29F-474E-BFA1-C5E8ABD39986/teacher.backbone/epoch_244-step_408415.pth 509 | # path: torch.hub.load 510 | # arguments: 511 | # repo_or_dir: facebookresearch/dinov2:main 512 | # model: dinov2_vitg14 513 | # pretrained: false 514 | 515 | # DINOv2_vitg14_from_imagenet_tcga_epoch_244_concat: 516 | # path: backbones.Kaiko 517 | # arguments: 518 | # repo_or_dir: facebookresearch/dinov2:main 519 | # model: dinov2_vitg14 520 | # pretrained: false 521 | # concat_mean_patch_tokens: true 522 | # # ckpt_path: az://experimental@stkaikodtpprdlab.blob.core.windows.net/pathology_fm/runs/nebul/vitg14_tcga_epoch_244-step_408415.pth 523 | # ckpt_path: /mnt/vast01/shared/outputs/mikhail/runs/FM-18A8DB4F-B29F-474E-BFA1-C5E8ABD39986/teacher.backbone/epoch_244-step_408415.pth 524 | 525 | # DINOv2_vitg14_from_imagenet_tcga-nki_epoch_274: 526 | # # ckpt_path: az://experimental@stkaikodtpprdlab.blob.core.windows.net/pathology_fm/runs/nebul/vitg14_tcga-nki_epoch_274-step_458425.pth 527 | # # ckpt_path: /Users/mike/Downloads/vitg14_tcga-nki_epoch_274-step_458425.pth 528 | # ckpt_path: /mnt/vast01/shared/experimental/pathology_fm/mikhail/runs/FM-1738EDAF-99E8-48E1-B1F8-498B280E098F/teacher.backbone/epoch_274-step_458425.pth 529 | # path: torch.hub.load 530 | # arguments: 531 | # repo_or_dir: facebookresearch/dinov2:main 532 | # model: dinov2_vitg14 533 | # pretrained: false 534 | 535 | # DINOv2_vitg14_from_imagenet_tcga-nki_epoch_274_concat: 536 | # path: backbones.Kaiko 537 | # arguments: 538 | # repo_or_dir: facebookresearch/dinov2:main 539 | # model: dinov2_vitg14 540 | # pretrained: false 541 | # concat_mean_patch_tokens: true 542 | # # ckpt_path: az://experimental@stkaikodtpprdlab.blob.core.windows.net/pathology_fm/runs/nebul/vitg14_tcga-nki_epoch_274-step_458425.pth 543 | # ckpt_path: /mnt/vast01/shared/experimental/pathology_fm/mikhail/runs/FM-1738EDAF-99E8-48E1-B1F8-498B280E098F/teacher.backbone/epoch_274-step_458425.pth 544 | 545 | # DINOv2_vits14_distilled_from_tcga-nki_099_concat: 546 | # path: backbones.Kaiko 547 | # arguments: 548 | # repo_or_dir: facebookresearch/dinov2:main 549 | # model: dinov2_vits14 550 | # pretrained: false 551 | # concat_mean_patch_tokens: true 552 | # ckpt_path: /mnt/vast01/shared/experimental/pathology_fm/mikhail/runs/FM-0ABEE04B-0294-43EB-B682-18138CFD37F5/distilled.backbone/epoch_099-step_125000.pth 553 | 554 | # DINOv2_vits14_distilled_from_tcga_099_no_rgzr_concat: 555 | # path: backbones.Kaiko 556 | # arguments: 557 | # repo_or_dir: facebookresearch/dinov2:main 558 | # model: dinov2_vits14 559 | # pretrained: false 560 | # concat_mean_patch_tokens: true 561 | # ckpt_path: /mnt/vast01/shared/experimental/pathology_fm/mikhail/runs/2386C448-CDA2-4C23-A1B0-9A2CC6CBF0F5/lightning_logs/version_0/checkpoints/distilled.backbone/epoch_099-step_125000.pth 562 | 563 | # vitg14_Kaiko_Midnight_concat: 564 | # path: renormalized.RenormalizingModel 565 | # arguments: 566 | # new_normalization: 567 | # mean: [0.5, 0.5, 0.5] 568 | # std: [0.5, 0.5, 0.5] 569 | # model: 570 | # path: backbones.TimmModel 571 | # arguments: 572 | # concat_mean_patch_tokens: true 573 | # model_name: hf-hub:kaiko-ai/midnight 574 | # pretrained: true 575 | # num_classes: 0 576 | 577 | vitg14_Kaiko_Midnight_concat: 578 | path: backbones.HuggingFaceModel 579 | arguments: 580 | model_name_or_path: kaiko-ai/midnight 581 | output_transform: 582 | class_path: extract_cls_token.ExtractConcatToken 583 | -------------------------------------------------------------------------------- /eval/kaiko.py: -------------------------------------------------------------------------------- 1 | """Internal Pathology FMs from kaiko.ai.""" 2 | 3 | from typing import Tuple 4 | 5 | import torch 6 | import yaml 7 | from eva.core.models import wrappers 8 | from eva.vision.models.networks.backbones.registry import register_model 9 | from torch import nn 10 | from typing_extensions import override 11 | 12 | from object_tools import ModelBuilder 13 | 14 | 15 | @register_model("pathology/kaiko_vits16") 16 | def kaiko_vits16( 17 | dynamic_img_size: bool = True, out_indices: int | Tuple[int, ...] | None = None 18 | ) -> nn.Module: 19 | """Initializes the ViTS-16 pathology FM by kaiko.ai. 20 | 21 | Args: 22 | dynamic_img_size: Support different input image sizes by allowing to change 23 | the grid size (interpolate abs and/or ROPE pos) in the forward pass. 24 | out_indices: Whether and which multi-level patch embeddings to return. 25 | 26 | Returns: 27 | The model instance. 28 | """ 29 | return torch.hub.load( # type: ignore 30 | repo_or_dir="kaiko-ai/towards_large_pathology_fms", 31 | model="vits16", 32 | trust_repo=True, 33 | dynamic_img_size=dynamic_img_size, 34 | out_indices=out_indices, 35 | ) 36 | 37 | 38 | @register_model("pathology/kaiko_vits8") 39 | def kaiko_vits8( 40 | dynamic_img_size: bool = True, out_indices: int | Tuple[int, ...] | None = None 41 | ) -> nn.Module: 42 | """Initializes the ViTS-8 pathology FM by kaiko.ai. 43 | 44 | Args: 45 | dynamic_img_size: Support different input image sizes by allowing to change 46 | the grid size (interpolate abs and/or ROPE pos) in the forward pass. 47 | out_indices: Whether and which multi-level patch embeddings to return. 48 | 49 | Returns: 50 | The model instance. 51 | """ 52 | return torch.hub.load( # type: ignore 53 | repo_or_dir="kaiko-ai/towards_large_pathology_fms", 54 | model="vits8", 55 | trust_repo=True, 56 | dynamic_img_size=dynamic_img_size, 57 | out_indices=out_indices, 58 | ) 59 | 60 | 61 | @register_model("pathology/kaiko_vitb16") 62 | def kaiko_vitb16( 63 | dynamic_img_size: bool = True, out_indices: int | Tuple[int, ...] | None = None 64 | ) -> nn.Module: 65 | """Initializes the ViTB-16 pathology FM by kaiko.ai. 66 | 67 | Args: 68 | dynamic_img_size: Support different input image sizes by allowing to change 69 | the grid size (interpolate abs and/or ROPE pos) in the forward pass. 70 | out_indices: Whether and which multi-level patch embeddings to return. 71 | 72 | Returns: 73 | The model instance. 74 | """ 75 | return torch.hub.load( # type: ignore 76 | repo_or_dir="kaiko-ai/towards_large_pathology_fms", 77 | model="vitb16", 78 | trust_repo=True, 79 | dynamic_img_size=dynamic_img_size, 80 | out_indices=out_indices, 81 | ) 82 | 83 | 84 | @register_model("pathology/kaiko_vitb8") 85 | def kaiko_vitb8( 86 | dynamic_img_size: bool = True, out_indices: int | Tuple[int, ...] | None = None 87 | ) -> nn.Module: 88 | """Initializes the ViTB-8 pathology FM by kaiko.ai. 89 | 90 | Args: 91 | dynamic_img_size: Support different input image sizes by allowing to change 92 | the grid size (interpolate abs and/or ROPE pos) in the forward pass. 93 | out_indices: Whether and which multi-level patch embeddings to return. 94 | 95 | Returns: 96 | The model instance. 97 | """ 98 | return torch.hub.load( # type: ignore 99 | repo_or_dir="kaiko-ai/towards_large_pathology_fms", 100 | model="vitb8", 101 | trust_repo=True, 102 | dynamic_img_size=dynamic_img_size, 103 | out_indices=out_indices, 104 | ) 105 | 106 | 107 | @register_model("pathology/kaiko_vitl14") 108 | def kaiko_vitl14( 109 | dynamic_img_size: bool = True, out_indices: int | Tuple[int, ...] | None = None 110 | ) -> nn.Module: 111 | """Initializes the ViTL-14 pathology FM by kaiko.ai. 112 | 113 | Args: 114 | dynamic_img_size: Support different input image sizes by allowing to change 115 | the grid size (interpolate abs and/or ROPE pos) in the forward pass. 116 | out_indices: Whether and which multi-level patch embeddings to return. 117 | 118 | Returns: 119 | The model instance. 120 | """ 121 | return torch.hub.load( # type: ignore 122 | repo_or_dir="kaiko-ai/towards_large_pathology_fms", 123 | model="vitl14", 124 | trust_repo=True, 125 | dynamic_img_size=dynamic_img_size, 126 | out_indices=out_indices, 127 | ) 128 | 129 | 130 | class KaikoModel(wrappers.BaseModel): 131 | """Model wrapper for `torch.hub` models.""" 132 | 133 | def __init__( 134 | self, 135 | model_yaml_str: str, 136 | out_indices: int | Tuple[int, ...] | None = None, 137 | norm: bool = True, 138 | ) -> None: 139 | """Initializes the encoder. 140 | 141 | Args: 142 | model_yaml_str: Model config in yaml str. 143 | out_indices: Returns last n blocks if `int`, all if `None`, select 144 | matching indices if sequence. 145 | """ 146 | super().__init__() 147 | 148 | self._model_yaml_str = model_yaml_str 149 | self._out_indices = out_indices 150 | self._norm = norm 151 | self.load_model() 152 | 153 | @override 154 | def load_model(self) -> None: 155 | """Builds and loads the torch.hub model.""" 156 | self._model: nn.Module = ModelBuilder(**yaml.safe_load(self._model_yaml_str)).build() 157 | 158 | @override 159 | def model_forward(self, tensor: torch.Tensor) -> torch.Tensor | list[torch.Tensor]: 160 | if self._out_indices is None: 161 | return self._model(tensor) 162 | 163 | return list( 164 | self._model.get_intermediate_layers( 165 | tensor, 166 | self._out_indices, 167 | reshape=True, 168 | return_class_token=False, 169 | norm=self._norm, 170 | ) 171 | ) 172 | 173 | 174 | @register_model("DINOv2_vitb14_tcga") 175 | def model( 176 | dynamic_img_size: bool = True, out_indices: int | Tuple[int, ...] | None = None 177 | ) -> nn.Module: 178 | model_str = """ 179 | ckpt_path: /mnt/vast01/shared/experimental/pathology_fm/mikhail/runs/1A83778A-859E-4DE4-9EA1-073A263E5BDD/lightning_logs/version_0/checkpoints/epoch_099-step_166700.ckpt 180 | ckpt_submodule: teacher.backbone 181 | path: torch.hub.load 182 | arguments: 183 | repo_or_dir: facebookresearch/dinov2:main 184 | model: dinov2_vitb14 185 | pretrained: false 186 | """ 187 | return KaikoModel(model_str, out_indices) 188 | 189 | 190 | @register_model("DINOv2_vitb14_tcga_SRA-0.2-apply0.5") 191 | def model( 192 | dynamic_img_size: bool = True, out_indices: int | Tuple[int, ...] | None = None 193 | ) -> nn.Module: 194 | model_str = """ 195 | ckpt_path: /mnt/vast01/shared/experimental/pathology_fm/mikhail/runs/200B3AB5-4439-42C0-BA95-A6314013E75C/lightning_logs/version_0/checkpoints/epoch_099-step_166700.ckpt 196 | ckpt_submodule: teacher.backbone 197 | path: torch.hub.load 198 | arguments: 199 | repo_or_dir: facebookresearch/dinov2:main 200 | model: dinov2_vitb14 201 | pretrained: false 202 | """ 203 | return KaikoModel(model_str, out_indices) 204 | 205 | 206 | @register_model("DINOv2_vitb14_tcga_SRA") 207 | def model( 208 | dynamic_img_size: bool = True, out_indices: int | Tuple[int, ...] | None = None 209 | ) -> nn.Module: 210 | model_str = """ 211 | ckpt_path: /mnt/vast01/shared/experimental/pathology_fm/mikhail/runs/3749129B-B887-4584-99D8-29ECEE3A5A75/lightning_logs/version_0/checkpoints/epoch_099-step_166700.ckpt 212 | ckpt_submodule: teacher.backbone 213 | path: torch.hub.load 214 | arguments: 215 | repo_or_dir: facebookresearch/dinov2:main 216 | model: dinov2_vitb14 217 | pretrained: false 218 | """ 219 | return KaikoModel(model_str, out_indices) 220 | 221 | 222 | @register_model("DINOv2_vitb14_tcga+NKI") 223 | def model( 224 | dynamic_img_size: bool = True, out_indices: int | Tuple[int, ...] | None = None 225 | ) -> nn.Module: 226 | model_str = """ 227 | ckpt_path: /mnt/vast01/shared/experimental/pathology_fm/mikhail/runs/116FF48B-12AE-4212-B33E-9DEF855D42D0/lightning_logs/version_0/checkpoints/epoch_099-step_166700.ckpt 228 | ckpt_submodule: teacher.backbone 229 | path: torch.hub.load 230 | arguments: 231 | repo_or_dir: facebookresearch/dinov2:main 232 | model: dinov2_vitb14 233 | pretrained: false 234 | """ 235 | return KaikoModel(model_str, out_indices) 236 | 237 | 238 | @register_model("DINOv2_vitb14_tcga+NKI_300-long") 239 | def model( 240 | dynamic_img_size: bool = True, out_indices: int | Tuple[int, ...] | None = None 241 | ) -> nn.Module: 242 | model_str = """ 243 | ckpt_path: /mnt/vast01/shared/experimental/pathology_fm/mikhail/runs/D3D61EE3-1027-4722-B9D3-F2FA3EB53939/lightning_logs/version_0/checkpoints/epoch_299-step_500100.ckpt 244 | ckpt_submodule: teacher.backbone 245 | path: torch.hub.load 246 | arguments: 247 | repo_or_dir: facebookresearch/dinov2:main 248 | model: dinov2_vitb14 249 | pretrained: false 250 | """ 251 | return KaikoModel(model_str, out_indices) 252 | 253 | 254 | @register_model("DINOv2_vitb14_tcga+NKI_300-long_momentum-0.9985") 255 | def model( 256 | dynamic_img_size: bool = True, out_indices: int | Tuple[int, ...] | None = None 257 | ) -> nn.Module: 258 | model_str = """ 259 | ckpt_path: /mnt/vast01/shared/experimental/pathology_fm/mikhail/runs/9C0585FF-CEA7-4829-AAFE-01A4EAD0780E/lightning_logs/version_0/checkpoints/epoch_299-step_500100.ckpt 260 | ckpt_submodule: teacher.backbone 261 | path: torch.hub.load 262 | arguments: 263 | repo_or_dir: facebookresearch/dinov2:main 264 | model: dinov2_vitb14 265 | pretrained: false 266 | """ 267 | return KaikoModel(model_str, out_indices) 268 | 269 | 270 | @register_model("DINOv2_vitb14_tcga+CPTAC") 271 | def model( 272 | dynamic_img_size: bool = True, out_indices: int | Tuple[int, ...] | None = None 273 | ) -> nn.Module: 274 | model_str = """ 275 | ckpt_path: /mnt/vast01/shared/experimental/pathology_fm/mikhail/runs/F787A20E-3EF5-47B2-9F3F-FAC8365B5DB8/lightning_logs/version_0/checkpoints/epoch_099-step_166700.ckpt 276 | ckpt_submodule: teacher.backbone 277 | path: torch.hub.load 278 | arguments: 279 | repo_or_dir: facebookresearch/dinov2:main 280 | model: dinov2_vitb14 281 | pretrained: false 282 | """ 283 | return KaikoModel(model_str, out_indices) 284 | 285 | 286 | @register_model("DINOv2_vitb14_tcga+GTEx") 287 | def model( 288 | dynamic_img_size: bool = True, out_indices: int | Tuple[int, ...] | None = None 289 | ) -> nn.Module: 290 | model_str = """ 291 | ckpt_path: /mnt/vast01/shared/experimental/pathology_fm/mikhail/runs/E15DA9D1-AA5B-452E-BFFA-EAF5AD03C618/lightning_logs/version_0/checkpoints/epoch_099-step_166700.ckpt 292 | ckpt_submodule: teacher.backbone 293 | path: torch.hub.load 294 | arguments: 295 | repo_or_dir: facebookresearch/dinov2:main 296 | model: dinov2_vitb14 297 | pretrained: false 298 | """ 299 | return KaikoModel(model_str, out_indices) 300 | 301 | 302 | @register_model("DINOv2_vitb14_tcga_RandStainNA") 303 | def model( 304 | dynamic_img_size: bool = True, out_indices: int | Tuple[int, ...] | None = None 305 | ) -> nn.Module: 306 | model_str = """ 307 | ckpt_path: /mnt/vast01/shared/experimental/pathology_fm/mikhail/runs/B429E466-B181-4B0D-9456-26DF4C40B819/lightning_logs/version_0/checkpoints/epoch_099-step_166700.ckpt 308 | ckpt_submodule: teacher.backbone 309 | path: torch.hub.load 310 | arguments: 311 | repo_or_dir: facebookresearch/dinov2:main 312 | model: dinov2_vitb14 313 | pretrained: false 314 | """ 315 | return KaikoModel(model_str, out_indices) 316 | 317 | 318 | @register_model("DINOv2_vitb14_tcga_noHED_noHSV") 319 | def model( 320 | dynamic_img_size: bool = True, out_indices: int | Tuple[int, ...] | None = None 321 | ) -> nn.Module: 322 | model_str = """ 323 | ckpt_path: /mnt/vast01/shared/experimental/pathology_fm/mikhail/runs/693EE61B-DA87-4725-8338-B145BEF9D495/lightning_logs/version_0/checkpoints/epoch_099-step_166700.ckpt 324 | ckpt_submodule: teacher.backbone 325 | path: torch.hub.load 326 | arguments: 327 | repo_or_dir: facebookresearch/dinov2:main 328 | model: dinov2_vitb14 329 | pretrained: false 330 | """ 331 | return KaikoModel(model_str, out_indices) 332 | 333 | 334 | @register_model("DINOv2_vitb14_four") 335 | def DINOv2_vitb14_four( 336 | dynamic_img_size: bool = True, out_indices: int | Tuple[int, ...] | None = None 337 | ) -> nn.Module: 338 | model_str = """ 339 | ckpt_path: /mnt/vast01/shared/experimental/pathology_fm/mikhail/runs/FM-A9A4C593-4CC0-4844-8896-5D15E9A2052E/epoch_099-step_166700.ckpt 340 | ckpt_submodule: teacher.backbone 341 | path: torch.hub.load 342 | arguments: 343 | repo_or_dir: facebookresearch/dinov2:main 344 | model: dinov2_vitb14 345 | pretrained: false 346 | """ 347 | return KaikoModel(model_str, out_indices) 348 | 349 | 350 | @register_model("DINOv2_vitb14_four_restarts") 351 | def DINOv2_vitb14_four( 352 | dynamic_img_size: bool = True, out_indices: int | Tuple[int, ...] | None = None 353 | ) -> nn.Module: 354 | model_str = """ 355 | ckpt_path: /mnt/vast01/shared/experimental/pathology_fm/mikhail/runs/551D1520-BAD2-4DE7-91F5-F0FFA0EC8925/lightning_logs/version_0/checkpoints/epoch_099-step_166700.ckpt 356 | ckpt_submodule: teacher.backbone 357 | path: torch.hub.load 358 | arguments: 359 | repo_or_dir: facebookresearch/dinov2:main 360 | model: dinov2_vitb14 361 | pretrained: false 362 | """ 363 | return KaikoModel(model_str, out_indices) 364 | 365 | 366 | @register_model("DINOv2_vitb14_tcga_KoLeo_noHED") 367 | def DINOv2_vitb14_four( 368 | dynamic_img_size: bool = True, out_indices: int | Tuple[int, ...] | None = None 369 | ) -> nn.Module: 370 | model_str = """ 371 | ckpt_path: /mnt/vast01/shared/experimental/pathology_fm/mikhail/runs/76953D51-8F71-4201-9F96-7567ACB038BE/lightning_logs/version_0/checkpoints/epoch_099-step_166700.ckpt 372 | ckpt_submodule: teacher.backbone 373 | path: torch.hub.load 374 | arguments: 375 | repo_or_dir: facebookresearch/dinov2:main 376 | model: dinov2_vitb14 377 | pretrained: false 378 | """ 379 | return KaikoModel(model_str, out_indices) 380 | 381 | 382 | @register_model("DINOv2_vitb14_tcga_noHED") 383 | def DINOv2_vitb14_four( 384 | dynamic_img_size: bool = True, out_indices: int | Tuple[int, ...] | None = None 385 | ) -> nn.Module: 386 | model_str = """ 387 | ckpt_path: /mnt/vast01/shared/experimental/pathology_fm/mikhail/runs/99C75885-14C8-4595-B19F-B588896B90B7/lightning_logs/version_0/checkpoints/epoch_099-step_166700.ckpt 388 | ckpt_submodule: teacher.backbone 389 | path: torch.hub.load 390 | arguments: 391 | repo_or_dir: facebookresearch/dinov2:main 392 | model: dinov2_vitb14 393 | pretrained: false 394 | """ 395 | return KaikoModel(model_str, out_indices) 396 | 397 | 398 | @register_model("DINOv2_vitb14_four_KoLeo_noHED_centering") 399 | def model( 400 | dynamic_img_size: bool = True, out_indices: int | Tuple[int, ...] | None = None 401 | ) -> nn.Module: 402 | model_str = """ 403 | ckpt_path: /mnt/vast01/shared/experimental/pathology_fm/mikhail/runs/5E226107-07EE-4612-A492-72DA18FD7C0E/lightning_logs/version_0/checkpoints/epoch_099-step_166700.ckpt 404 | ckpt_submodule: teacher.backbone 405 | path: torch.hub.load 406 | arguments: 407 | repo_or_dir: facebookresearch/dinov2:main 408 | model: dinov2_vitb14 409 | pretrained: false 410 | """ 411 | return KaikoModel(model_str, out_indices) 412 | 413 | 414 | @register_model("DINOv2_vitb14_four_noHED") 415 | def model( 416 | dynamic_img_size: bool = True, out_indices: int | Tuple[int, ...] | None = None 417 | ) -> nn.Module: 418 | model_str = """ 419 | ckpt_path: /mnt/vast01/shared/experimental/pathology_fm/mikhail/runs/0CB3BB8D-E9C3-4BCE-8ABD-6AA1AF2F6979/lightning_logs/version_0/checkpoints/epoch_099-step_166700.ckpt 420 | ckpt_submodule: teacher.backbone 421 | path: torch.hub.load 422 | arguments: 423 | repo_or_dir: facebookresearch/dinov2:main 424 | model: dinov2_vitb14 425 | pretrained: false 426 | """ 427 | return KaikoModel(model_str, out_indices) 428 | 429 | 430 | @register_model("DINOv2_vitb14_four_KoLeo") 431 | def model( 432 | dynamic_img_size: bool = True, out_indices: int | Tuple[int, ...] | None = None 433 | ) -> nn.Module: 434 | model_str = """ 435 | ckpt_path: /mnt/vast01/shared/experimental/pathology_fm/mikhail/runs/A050D239-0686-4704-84EA-4BEDA0C55030/lightning_logs/version_0/checkpoints/epoch_099-step_166700.ckpt 436 | ckpt_submodule: teacher.backbone 437 | path: torch.hub.load 438 | arguments: 439 | repo_or_dir: facebookresearch/dinov2:main 440 | model: dinov2_vitb14 441 | pretrained: false 442 | """ 443 | return KaikoModel(model_str, out_indices) 444 | 445 | 446 | @register_model("DINOv2_vitb14_four_KoLeo_noHED") 447 | def model( 448 | dynamic_img_size: bool = True, out_indices: int | Tuple[int, ...] | None = None 449 | ) -> nn.Module: 450 | model_str = """ 451 | ckpt_path: /mnt/vast01/shared/experimental/pathology_fm/mikhail/runs/B399989A-4309-4676-B457-0AFF9486D5CC/lightning_logs/version_0/checkpoints/epoch_099-step_166700.ckpt 452 | ckpt_submodule: teacher.backbone 453 | path: torch.hub.load 454 | arguments: 455 | repo_or_dir: facebookresearch/dinov2:main 456 | model: dinov2_vitb14 457 | pretrained: false 458 | """ 459 | return KaikoModel(model_str, out_indices) 460 | 461 | 462 | @register_model("DINOv2_vitb14_four-TCGA") 463 | def model( 464 | dynamic_img_size: bool = True, out_indices: int | Tuple[int, ...] | None = None 465 | ) -> nn.Module: 466 | model_str = """ 467 | ckpt_path: /mnt/vast01/shared/experimental/pathology_fm/mikhail/runs/D5ABCEE4-D075-498C-84F5-6B3E28864174/lightning_logs/version_0/checkpoints/epoch_099-step_166700.ckpt 468 | ckpt_submodule: teacher.backbone 469 | path: torch.hub.load 470 | arguments: 471 | repo_or_dir: facebookresearch/dinov2:main 472 | model: dinov2_vitb14 473 | pretrained: false 474 | """ 475 | return KaikoModel(model_str, out_indices) 476 | 477 | 478 | @register_model("DINOv2_vitb14_four-GTEx_hsv0.45") 479 | def model( 480 | dynamic_img_size: bool = True, out_indices: int | Tuple[int, ...] | None = None 481 | ) -> nn.Module: 482 | model_str = """ 483 | ckpt_path: /mnt/vast01/shared/experimental/pathology_fm/mikhail/runs/00F52B7E-1D47-443A-ADB5-B481E5513E23/lightning_logs/version_0/checkpoints/epoch_099-step_166700.ckpt 484 | ckpt_submodule: teacher.backbone 485 | path: torch.hub.load 486 | arguments: 487 | repo_or_dir: facebookresearch/dinov2:main 488 | model: dinov2_vitb14 489 | pretrained: false 490 | """ 491 | return KaikoModel(model_str, out_indices) 492 | 493 | 494 | @register_model("DINOv2_vitb14_four-GTEx_hsv0.6") 495 | def model( 496 | dynamic_img_size: bool = True, out_indices: int | Tuple[int, ...] | None = None 497 | ) -> nn.Module: 498 | model_str = """ 499 | ckpt_path: /mnt/vast01/shared/experimental/pathology_fm/mikhail/runs/59CBA6A7-B2AD-4D9B-8C26-356A0C05BBF7/lightning_logs/version_0/checkpoints/epoch_099-step_166700.ckpt 500 | ckpt_submodule: teacher.backbone 501 | path: torch.hub.load 502 | arguments: 503 | repo_or_dir: facebookresearch/dinov2:main 504 | model: dinov2_vitb14 505 | pretrained: false 506 | """ 507 | return KaikoModel(model_str, out_indices) 508 | 509 | 510 | @register_model("DINOv2_vitb14_four_hsv0.6") 511 | def DINOv2_vitb14_four_hsv_06( 512 | dynamic_img_size: bool = True, out_indices: int | Tuple[int, ...] | None = None 513 | ) -> nn.Module: 514 | model_str = """ 515 | ckpt_path: /mnt/vast01/shared/experimental/pathology_fm/mikhail/runs/FM-7C006271-6762-4D2E-B831-6447B37B1BDD/lightning_logs/version_0/checkpoints/epoch_099-step_166700.ckpt 516 | ckpt_submodule: teacher.backbone 517 | path: torch.hub.load 518 | arguments: 519 | repo_or_dir: facebookresearch/dinov2:main 520 | model: dinov2_vitb14 521 | pretrained: false 522 | """ 523 | return KaikoModel(model_str, out_indices) 524 | 525 | 526 | @register_model("DINOv2_vitg14_from_imagenet_tcga-nki_epoch_274") 527 | def DINOv2_vitg14_from_imagenet_tcga_nki_epoch_274( 528 | dynamic_img_size: bool = True, out_indices: int | Tuple[int, ...] | None = None 529 | ) -> nn.Module: 530 | model_str = """ 531 | path: backbones.Kaiko 532 | arguments: 533 | repo_or_dir: facebookresearch/dinov2:main 534 | model: dinov2_vitg14 535 | pretrained: false 536 | concat_mean_patch_tokens: false 537 | # ckpt_path: az://experimental@stkaikodtpprdlab.blob.core.windows.net/pathology_fm/runs/nebul/vitg14_tcga-nki_epoch_274-step_458425.pth 538 | ckpt_path: /mnt/vast01/shared/experimental/pathology_fm/mikhail/runs/FM-1738EDAF-99E8-48E1-B1F8-498B280E098F/teacher.backbone/epoch_274-step_458425.pth 539 | """ 540 | return KaikoModel(model_str, out_indices) 541 | 542 | 543 | @register_model("DINOv2_vitg14_from_imagenet_tcga-nki_epoch_274_concat") 544 | def DINOv2_vitg14_from_imagenet_tcga_nki_epoch_274_concat( 545 | dynamic_img_size: bool = True, out_indices: int | Tuple[int, ...] | None = None 546 | ) -> nn.Module: 547 | model_str = """ 548 | path: backbones.Kaiko 549 | arguments: 550 | repo_or_dir: facebookresearch/dinov2:main 551 | model: dinov2_vitg14 552 | pretrained: false 553 | concat_mean_patch_tokens: true 554 | # ckpt_path: az://experimental@stkaikodtpprdlab.blob.core.windows.net/pathology_fm/runs/nebul/vitg14_tcga-nki_epoch_274-step_458425.pth 555 | ckpt_path: /mnt/vast01/shared/experimental/pathology_fm/mikhail/runs/FM-1738EDAF-99E8-48E1-B1F8-498B280E098F/teacher.backbone/epoch_274-step_458425.pth 556 | """ 557 | return KaikoModel(model_str, out_indices) 558 | 559 | 560 | @register_model("vitB14") 561 | def model( 562 | dynamic_img_size: bool = True, out_indices: int | Tuple[int, ...] | None = None 563 | ) -> nn.Module: 564 | model_str = """ 565 | path: torch.hub.load 566 | arguments: 567 | repo_or_dir: facebookresearch/dinov2:main 568 | model: dinov2_vitb14 569 | pretrained: false 570 | """ 571 | return KaikoModel(model_str, out_indices) 572 | 573 | 574 | @register_model("DINOv2_vitB14_imagenet") 575 | def model( 576 | dynamic_img_size: bool = True, out_indices: int | Tuple[int, ...] | None = None 577 | ) -> nn.Module: 578 | model_str = """ 579 | path: torch.hub.load 580 | arguments: 581 | repo_or_dir: facebookresearch/dinov2:main 582 | model: dinov2_vitb14 583 | pretrained: true 584 | """ 585 | return KaikoModel(model_str, out_indices) 586 | 587 | 588 | @register_model("DINOv2_vitg14_imagenet") 589 | def DINOv2_vitg14_imagenet( 590 | dynamic_img_size: bool = True, out_indices: int | Tuple[int, ...] | None = None 591 | ) -> nn.Module: 592 | model_str = """ 593 | path: torch.hub.load 594 | arguments: 595 | repo_or_dir: facebookresearch/dinov2:main 596 | model: dinov2_vitg14 597 | pretrained: true 598 | """ 599 | return KaikoModel(model_str, out_indices) 600 | 601 | 602 | @register_model("DINOv2_vitg14_imagenet_concat") 603 | def vitg14( 604 | dynamic_img_size: bool = True, out_indices: int | Tuple[int, ...] | None = None 605 | ) -> nn.Module: 606 | model_str = """ 607 | path: backbones.Kaiko 608 | arguments: 609 | repo_or_dir: facebookresearch/dinov2:main 610 | model: dinov2_vitg14 611 | pretrained: true 612 | concat_mean_patch_tokens: true 613 | """ 614 | return KaikoModel(model_str, out_indices) 615 | 616 | 617 | @register_model("vitg14") 618 | def vitg14( 619 | dynamic_img_size: bool = True, out_indices: int | Tuple[int, ...] | None = None 620 | ) -> nn.Module: 621 | model_str = """ 622 | path: torch.hub.load 623 | arguments: 624 | repo_or_dir: facebookresearch/dinov2:main 625 | model: dinov2_vitg14 626 | pretrained: false 627 | """ 628 | return KaikoModel(model_str, out_indices) 629 | 630 | 631 | @register_model("vitg14_concat") 632 | def vitg14( 633 | dynamic_img_size: bool = True, out_indices: int | Tuple[int, ...] | None = None 634 | ) -> nn.Module: 635 | model_str = """ 636 | path: backbones.Kaiko 637 | arguments: 638 | repo_or_dir: facebookresearch/dinov2:main 639 | model: dinov2_vitg14 640 | pretrained: false 641 | concat_mean_patch_tokens: true 642 | """ 643 | return KaikoModel(model_str, out_indices) 644 | 645 | 646 | @register_model("vitg14_init_1e-5") 647 | def vitg14_init( 648 | dynamic_img_size: bool = True, out_indices: int | Tuple[int, ...] | None = None 649 | ) -> nn.Module: 650 | model_str = """ 651 | path: torch.hub.load 652 | arguments: 653 | repo_or_dir: facebookresearch/dinov2:main 654 | model: dinov2_vitg14 655 | pretrained: false 656 | init_values: 1.0e-5 657 | """ 658 | return KaikoModel(model_str, out_indices) 659 | 660 | 661 | @register_model("vitg14_init_0") 662 | def vitg14_init_0( 663 | dynamic_img_size: bool = True, out_indices: int | Tuple[int, ...] | None = None 664 | ) -> nn.Module: 665 | model_str = """ 666 | path: torch.hub.load 667 | arguments: 668 | repo_or_dir: facebookresearch/dinov2:main 669 | model: dinov2_vitg14 670 | pretrained: false 671 | init_values: 0.0 672 | """ 673 | return KaikoModel(model_str, out_indices) 674 | 675 | 676 | @register_model("DINOv2_vitg14_from_imagenet_tcga_100M_epoch_294") 677 | def model( 678 | dynamic_img_size: bool = True, out_indices: int | Tuple[int, ...] | None = None 679 | ) -> nn.Module: 680 | model_str = """ 681 | ckpt_path: /mnt/vast01/shared/experimental/pathology_fm/mikhail/runs/1FDA5ADA-8357-4F5E-8A5C-8758F7751471/teacher.backbone/epoch_294-step_491765.pth 682 | path: torch.hub.load 683 | arguments: 684 | repo_or_dir: facebookresearch/dinov2:main 685 | model: dinov2_vitg14 686 | pretrained: false 687 | """ 688 | return KaikoModel(model_str, out_indices) 689 | 690 | 691 | @register_model("DINOv2_vitg14_from_imagenet_tcga_100M_epoch_294_concat") 692 | def model( 693 | dynamic_img_size: bool = True, out_indices: int | Tuple[int, ...] | None = None 694 | ) -> nn.Module: 695 | model_str = """ 696 | path: backbones.Kaiko 697 | arguments: 698 | repo_or_dir: facebookresearch/dinov2:main 699 | model: dinov2_vitg14 700 | pretrained: false 701 | concat_mean_patch_tokens: true 702 | ckpt_path: /mnt/vast01/shared/experimental/pathology_fm/mikhail/runs/1FDA5ADA-8357-4F5E-8A5C-8758F7751471/teacher.backbone/epoch_294-step_491765.pth 703 | """ 704 | return KaikoModel(model_str, out_indices) 705 | 706 | 707 | @register_model("DINOv2_vitg14_from_imagenet_tcga_epoch_244") 708 | def model( 709 | dynamic_img_size: bool = True, out_indices: int | Tuple[int, ...] | None = None 710 | ) -> nn.Module: 711 | model_str = """ 712 | # ckpt_path: az://experimental@stkaikodtpprdlab.blob.core.windows.net/pathology_fm/runs/nebul/vitg14_tcga_epoch_244-step_408415.pth 713 | ckpt_path: /mnt/vast01/shared/outputs/mikhail/runs/FM-18A8DB4F-B29F-474E-BFA1-C5E8ABD39986/teacher.backbone/epoch_244-step_408415.pth 714 | path: torch.hub.load 715 | arguments: 716 | repo_or_dir: facebookresearch/dinov2:main 717 | model: dinov2_vitg14 718 | pretrained: false 719 | """ 720 | return KaikoModel(model_str, out_indices) 721 | 722 | 723 | @register_model("DINOv2_vitg14_from_imagenet_tcga_epoch_244_concat") 724 | def model( 725 | dynamic_img_size: bool = True, out_indices: int | Tuple[int, ...] | None = None 726 | ) -> nn.Module: 727 | model_str = """ 728 | path: backbones.Kaiko 729 | arguments: 730 | repo_or_dir: facebookresearch/dinov2:main 731 | model: dinov2_vitg14 732 | pretrained: false 733 | concat_mean_patch_tokens: true 734 | # ckpt_path: az://experimental@stkaikodtpprdlab.blob.core.windows.net/pathology_fm/runs/nebul/vitg14_tcga_epoch_244-step_408415.pth 735 | ckpt_path: /mnt/vast01/shared/outputs/mikhail/runs/FM-18A8DB4F-B29F-474E-BFA1-C5E8ABD39986/teacher.backbone/epoch_244-step_408415.pth 736 | """ 737 | return KaikoModel(model_str, out_indices) 738 | 739 | 740 | @register_model("DINOv2_vitb14_four_hsv0.4") 741 | def model( 742 | dynamic_img_size: bool = True, out_indices: int | Tuple[int, ...] | None = None 743 | ) -> nn.Module: 744 | model_str = """ 745 | ckpt_path: /mnt/vast01/shared/experimental/pathology_fm/mikhail/runs/FM-9BD5863B-43F0-4C65-A005-9BA47E5BDFEE/lightning_logs/version_0/checkpoints/epoch_099-step_166700.ckpt 746 | ckpt_submodule: teacher.backbone 747 | path: torch.hub.load 748 | arguments: 749 | repo_or_dir: facebookresearch/dinov2:main 750 | model: dinov2_vitb14 751 | pretrained: false 752 | """ 753 | return KaikoModel(model_str, out_indices) 754 | 755 | 756 | @register_model("DINOv2_vitb14_four+TCGAFrozen") 757 | def model( 758 | dynamic_img_size: bool = True, out_indices: int | Tuple[int, ...] | None = None 759 | ) -> nn.Module: 760 | model_str = """ 761 | ckpt_path: /mnt/vast01/shared/experimental/pathology_fm/mikhail/runs/FM-7988DC68-9306-4ADF-ACE4-73AB5FB3F257/lightning_logs/version_0/checkpoints/epoch_099-step_166700.ckpt 762 | ckpt_submodule: teacher.backbone 763 | path: torch.hub.load 764 | arguments: 765 | repo_or_dir: facebookresearch/dinov2:main 766 | model: dinov2_vitb14 767 | pretrained: false 768 | """ 769 | return KaikoModel(model_str, out_indices) 770 | 771 | 772 | @register_model("DINOv2_vitb14_four-GTEx") 773 | def model( 774 | dynamic_img_size: bool = True, out_indices: int | Tuple[int, ...] | None = None 775 | ) -> nn.Module: 776 | model_str = """ 777 | ckpt_path: /mnt/vast01/shared/experimental/pathology_fm/mikhail/runs/FM-E21FA649-6231-4BC1-9E40-7DE1FE6DEBF6/lightning_logs/version_0/checkpoints/epoch_099-step_166700.ckpt 778 | ckpt_submodule: teacher.backbone 779 | path: torch.hub.load 780 | arguments: 781 | repo_or_dir: facebookresearch/dinov2:main 782 | model: dinov2_vitb14 783 | pretrained: false 784 | """ 785 | return KaikoModel(model_str, out_indices) 786 | 787 | 788 | @register_model("DINOv2_vitb14_four-NKI") 789 | def model( 790 | dynamic_img_size: bool = True, out_indices: int | Tuple[int, ...] | None = None 791 | ) -> nn.Module: 792 | model_str = """ 793 | ckpt_path: /mnt/vast01/shared/experimental/pathology_fm/mikhail/runs/FM-45BAF471-5E99-4139-BCDB-552E4642EA0B/lightning_logs/version_0/checkpoints/epoch_099-step_166700.ckpt 794 | ckpt_submodule: teacher.backbone 795 | path: torch.hub.load 796 | arguments: 797 | repo_or_dir: facebookresearch/dinov2:main 798 | model: dinov2_vitb14 799 | pretrained: false 800 | """ 801 | return KaikoModel(model_str, out_indices) 802 | 803 | 804 | @register_model("DINOv2_vitb14_four-CPTAC") 805 | def model( 806 | dynamic_img_size: bool = True, out_indices: int | Tuple[int, ...] | None = None 807 | ) -> nn.Module: 808 | model_str = """ 809 | ckpt_path: /mnt/vast01/shared/experimental/pathology_fm/mikhail/runs/FM-B54A1C09-EBE8-4927-AF11-9DE3ECE2DB1B/lightning_logs/version_0/checkpoints/epoch_099-step_166700.ckpt 810 | ckpt_submodule: teacher.backbone 811 | path: torch.hub.load 812 | arguments: 813 | repo_or_dir: facebookresearch/dinov2:main 814 | model: dinov2_vitb14 815 | pretrained: false 816 | """ 817 | return KaikoModel(model_str, out_indices) 818 | 819 | 820 | @register_model("DINOv2_vitb14_four_noHSV") 821 | def model( 822 | dynamic_img_size: bool = True, out_indices: int | Tuple[int, ...] | None = None 823 | ) -> nn.Module: 824 | model_str = """ 825 | ckpt_path: /mnt/vast01/shared/experimental/pathology_fm/mikhail/runs/FM-A24D6478-3D04-4D09-B61A-3BBF97CCC503/lightning_logs/version_0/checkpoints/epoch_099-step_166700.ckpt 826 | ckpt_submodule: teacher.backbone 827 | path: torch.hub.load 828 | arguments: 829 | repo_or_dir: facebookresearch/dinov2:main 830 | model: dinov2_vitb14 831 | pretrained: false 832 | """ 833 | return KaikoModel(model_str, out_indices) 834 | 835 | 836 | @register_model("DINOv2_vitb14_four_hsv0.2") 837 | def model( 838 | dynamic_img_size: bool = True, out_indices: int | Tuple[int, ...] | None = None 839 | ) -> nn.Module: 840 | model_str = """ 841 | ckpt_path: /mnt/vast01/shared/experimental/pathology_fm/mikhail/runs/FM-EE5084B1-88CC-472F-B080-ADBE19FA34E9/lightning_logs/version_0/checkpoints/epoch_099-step_166700.ckpt 842 | ckpt_submodule: teacher.backbone 843 | path: torch.hub.load 844 | arguments: 845 | repo_or_dir: facebookresearch/dinov2:main 846 | model: dinov2_vitb14 847 | pretrained: false 848 | """ 849 | return KaikoModel(model_str, out_indices) 850 | 851 | 852 | @register_model("DINOv2_vitb14_four_hsv0.6-20") 853 | def model( 854 | dynamic_img_size: bool = True, out_indices: int | Tuple[int, ...] | None = None 855 | ) -> nn.Module: 856 | model_str = """ 857 | ckpt_path: /mnt/vast01/shared/experimental/pathology_fm/mikhail/runs/FM-2D897A1E-F43F-4E66-920B-A686B59C0C35/lightning_logs/version_0/checkpoints/epoch_099-step_166700.ckpt 858 | ckpt_submodule: teacher.backbone 859 | path: torch.hub.load 860 | arguments: 861 | repo_or_dir: facebookresearch/dinov2:main 862 | model: dinov2_vitb14 863 | pretrained: false 864 | """ 865 | return KaikoModel(model_str, out_indices) 866 | 867 | 868 | @register_model("DINOv2_vitb14_four_hsv0.6-10") 869 | def model( 870 | dynamic_img_size: bool = True, out_indices: int | Tuple[int, ...] | None = None 871 | ) -> nn.Module: 872 | model_str = """ 873 | ckpt_path: /mnt/vast01/shared/experimental/pathology_fm/mikhail/runs/FM-1B3C1C64-E641-4333-B129-A5984BBDBA84/lightning_logs/version_0/checkpoints/epoch_099-step_166700.ckpt 874 | ckpt_submodule: teacher.backbone 875 | path: torch.hub.load 876 | arguments: 877 | repo_or_dir: facebookresearch/dinov2:main 878 | model: dinov2_vitb14 879 | pretrained: false 880 | """ 881 | return KaikoModel(model_str, out_indices) 882 | 883 | 884 | @register_model("DINOv2_vitb14_four_hsv0.4+TCGAfrozen") 885 | def model( 886 | dynamic_img_size: bool = True, out_indices: int | Tuple[int, ...] | None = None 887 | ) -> nn.Module: 888 | model_str = """ 889 | ckpt_path: /mnt/vast01/shared/experimental/pathology_fm/mikhail/runs/6F8BB4BA-BCC4-47DF-B11D-65E0E2E95156/lightning_logs/version_0/checkpoints/epoch_099-step_166700.ckpt 890 | ckpt_submodule: teacher.backbone 891 | path: torch.hub.load 892 | arguments: 893 | repo_or_dir: facebookresearch/dinov2:main 894 | model: dinov2_vitb14 895 | pretrained: false 896 | """ 897 | return KaikoModel(model_str, out_indices) 898 | 899 | 900 | @register_model("DINOv2_vitb14_four_mf0.2_hsv0.4") 901 | def model( 902 | dynamic_img_size: bool = True, out_indices: int | Tuple[int, ...] | None = None 903 | ) -> nn.Module: 904 | model_str = """ 905 | ckpt_path: /mnt/vast01/shared/experimental/pathology_fm/mikhail/runs/7C85A1BD-BA78-4BEE-A7EB-2DC3F2F1309A/lightning_logs/version_0/checkpoints/epoch_099-step_166700.ckpt 906 | ckpt_submodule: teacher.backbone 907 | path: torch.hub.load 908 | arguments: 909 | repo_or_dir: facebookresearch/dinov2:main 910 | model: dinov2_vitb14 911 | pretrained: false 912 | """ 913 | return KaikoModel(model_str, out_indices) 914 | 915 | 916 | @register_model("DINOv2_vitb14_four-GTEx_mf0.2_hsv0.4") 917 | def model( 918 | dynamic_img_size: bool = True, out_indices: int | Tuple[int, ...] | None = None 919 | ) -> nn.Module: 920 | model_str = """ 921 | ckpt_path: /mnt/vast01/shared/experimental/pathology_fm/mikhail/runs/D5B72462-D6B2-4E20-B6AC-42B756642639/lightning_logs/version_0/checkpoints/epoch_099-step_166700.ckpt 922 | ckpt_submodule: teacher.backbone 923 | path: torch.hub.load 924 | arguments: 925 | repo_or_dir: facebookresearch/dinov2:main 926 | model: dinov2_vitb14 927 | pretrained: false 928 | """ 929 | return KaikoModel(model_str, out_indices) 930 | 931 | 932 | @register_model("DINOv2_vitb14_four-GTEx+TCGAFrozen_mf0.2_hsv0.4") 933 | def model( 934 | dynamic_img_size: bool = True, out_indices: int | Tuple[int, ...] | None = None 935 | ) -> nn.Module: 936 | model_str = """ 937 | ckpt_path: /mnt/vast01/shared/experimental/pathology_fm/mikhail/runs/A45721C7-F3C2-46DC-9B41-1B271A8D7DCF/lightning_logs/version_0/checkpoints/epoch_099-step_166700.ckpt 938 | ckpt_submodule: teacher.backbone 939 | path: torch.hub.load 940 | arguments: 941 | repo_or_dir: facebookresearch/dinov2:main 942 | model: dinov2_vitb14 943 | pretrained: false 944 | """ 945 | return KaikoModel(model_str, out_indices) 946 | 947 | 948 | @register_model("DINOv2_vitg14_nki-tcga_post_100_aspect_epoch_059_resize392") 949 | def model( 950 | dynamic_img_size: bool = True, out_indices: int | Tuple[int, ...] | None = None 951 | ) -> nn.Module: 952 | model_str = """ 953 | path: backbones.Kaiko 954 | arguments: 955 | repo_or_dir: facebookresearch/dinov2:main 956 | model: dinov2_vitg14 957 | pretrained: false 958 | concat_mean_patch_tokens: false 959 | resize: 392 960 | ckpt_path: /mnt/vast01/shared/experimental/pathology_fm/mikhail/runs/E48D4DAD-C48E-49E5-B4D9-AB679A594090/lightning_logs/version_0/checkpoints/epoch_059-step_30000.ckpt 961 | ckpt_submodule: teacher.backbone 962 | mode: bicubic 963 | """ 964 | return KaikoModel(model_str, out_indices) 965 | 966 | 967 | @register_model("DINOv2_vitg14_nki-tcga_post_100_aspect_epoch_059_bicubic_concat_resize392") 968 | def model( 969 | dynamic_img_size: bool = True, out_indices: int | Tuple[int, ...] | None = None 970 | ) -> nn.Module: 971 | model_str = """ 972 | path: backbones.Kaiko 973 | arguments: 974 | repo_or_dir: facebookresearch/dinov2:main 975 | model: dinov2_vitg14 976 | pretrained: false 977 | concat_mean_patch_tokens: true 978 | resize: 392 979 | ckpt_path: /mnt/vast01/shared/experimental/pathology_fm/mikhail/runs/E48D4DAD-C48E-49E5-B4D9-AB679A594090/lightning_logs/version_0/checkpoints/epoch_059-step_30000.ckpt 980 | ckpt_submodule: teacher.backbone 981 | mode: bicubic 982 | """ 983 | return KaikoModel(model_str, out_indices) 984 | 985 | 986 | @register_model("KAIKO-vitB8") 987 | def model( 988 | dynamic_img_size: bool = True, out_indices: int | Tuple[int, ...] | None = None 989 | ) -> nn.Module: 990 | return torch.hub.load( # type: ignore 991 | repo_or_dir="kaiko-ai/towards_large_pathology_fms", 992 | model="vitb8", 993 | trust_repo=True, 994 | dynamic_img_size=dynamic_img_size, 995 | out_indices=out_indices, 996 | ) 997 | 998 | 999 | @register_model("KAIKO-vitB8_concat") 1000 | def model( 1001 | dynamic_img_size: bool = True, out_indices: int | Tuple[int, ...] | None = None 1002 | ) -> nn.Module: 1003 | model_str = f""" 1004 | path: backbones.Kaiko 1005 | arguments: 1006 | repo_or_dir: kaiko-ai/towards_large_pathology_fms 1007 | model: vitb8 1008 | concat_mean_patch_tokens: true 1009 | trust_repo: true 1010 | dynamic_img_size: {dynamic_img_size} 1011 | out_indices: {"null" if out_indices is None else out_indices} 1012 | """ 1013 | return KaikoModel(model_str, None) 1014 | 1015 | 1016 | @register_model("dino_vitL16_phikon2") 1017 | def model( 1018 | dynamic_img_size: bool = True, out_indices: int | Tuple[int, ...] | None = None 1019 | ) -> nn.Module: 1020 | model_str = f""" 1021 | path: renormalized.RenormalizingModel 1022 | arguments: 1023 | new_normalization: 1024 | mean: [0.485, 0.456, 0.406] 1025 | std: [0.229, 0.224, 0.225] 1026 | model: 1027 | path: backbones.HuggingFaceModel 1028 | arguments: 1029 | model_name_or_path: owkin/phikon-v2 1030 | output_transform: 1031 | class_path: extract_cls_token.{"ExtractCLSToken" if out_indices is None else "ExtractPatchFeatures"} 1032 | """ 1033 | return KaikoModel(model_str, None) 1034 | 1035 | 1036 | @register_model("dino_vits16_phikon") 1037 | def model( 1038 | dynamic_img_size: bool = True, out_indices: int | Tuple[int, ...] | None = None 1039 | ) -> nn.Module: 1040 | model_str = f""" 1041 | path: renormalized.RenormalizingModel 1042 | arguments: 1043 | new_normalization: 1044 | mean: [0.485, 0.456, 0.406] 1045 | std: [0.229, 0.224, 0.225] 1046 | model: 1047 | path: backbones.HuggingFaceModel 1048 | arguments: 1049 | model_name_or_path: owkin/phikon 1050 | output_transform: 1051 | class_path: extract_cls_token.{"ExtractCLSToken" if out_indices is None else "ExtractPatchFeatures"} 1052 | """ 1053 | return KaikoModel(model_str, None) 1054 | 1055 | 1056 | @register_model("dino_vitL16_phikon2_concat") 1057 | def model( 1058 | dynamic_img_size: bool = True, out_indices: int | Tuple[int, ...] | None = None 1059 | ) -> nn.Module: 1060 | model_str = f""" 1061 | path: renormalized.RenormalizingModel 1062 | arguments: 1063 | new_normalization: 1064 | mean: [0.485, 0.456, 0.406] 1065 | std: [0.229, 0.224, 0.225] 1066 | model: 1067 | path: backbones.HuggingFaceModel 1068 | arguments: 1069 | model_name_or_path: owkin/phikon-v2 1070 | output_transform: 1071 | class_path: extract_cls_token.{"ExtractConcatToken" if out_indices is None else "ExtractPatchFeatures"} 1072 | """ 1073 | return KaikoModel(model_str, None) 1074 | 1075 | 1076 | @register_model("vitg14_Kaiko_Midnight_concat") 1077 | def model( 1078 | dynamic_img_size: bool = True, out_indices: int | Tuple[int, ...] | None = None 1079 | ) -> nn.Module: 1080 | model_str = f""" 1081 | path: backbones.HuggingFaceModel 1082 | arguments: 1083 | model_name_or_path: kaiko-ai/midnight 1084 | output_transform: 1085 | class_path: extract_cls_token.{"ExtractConcatToken" if out_indices is None else "ExtractPatchFeatures"} 1086 | """ 1087 | return KaikoModel(model_str, None) 1088 | 1089 | 1090 | @register_model("dino_vits16_phikon_concat") 1091 | def model( 1092 | dynamic_img_size: bool = True, out_indices: int | Tuple[int, ...] | None = None 1093 | ) -> nn.Module: 1094 | model_str = f""" 1095 | path: renormalized.RenormalizingModel 1096 | arguments: 1097 | new_normalization: 1098 | mean: [0.485, 0.456, 0.406] 1099 | std: [0.229, 0.224, 0.225] 1100 | model: 1101 | path: backbones.HuggingFaceModel 1102 | arguments: 1103 | model_name_or_path: owkin/phikon 1104 | output_transform: 1105 | class_path: extract_cls_token.{"ExtractConcatToken" if out_indices is None else "ExtractPatchFeatures"} 1106 | """ 1107 | return KaikoModel(model_str, None) 1108 | 1109 | 1110 | @register_model("vitL14_histai_hibou_l") 1111 | def model( 1112 | dynamic_img_size: bool = True, out_indices: int | Tuple[int, ...] | None = None 1113 | ) -> nn.Module: 1114 | model_str = f""" 1115 | path: renormalized.RenormalizingModel 1116 | arguments: 1117 | new_normalization: 1118 | mean: [0.7068,0.5755,0.722] 1119 | std: [0.195,0.2316,0.1816] 1120 | model: 1121 | path: backbones.HuggingFaceModel 1122 | arguments: 1123 | model_name_or_path: histai/hibou-L 1124 | trust_remote_code: true 1125 | with_config: false 1126 | output_transform: 1127 | class_path: extract_cls_token.{"ExtractCLSToken" if out_indices is None else "ExtractPatchFeatures"} 1128 | init_args: {'{}' if out_indices is None else '{num_reg_tokens: 4}'} 1129 | """ 1130 | return KaikoModel(model_str, None) 1131 | 1132 | 1133 | @register_model("vitL14_histai_hibou_l_concat") 1134 | def model( 1135 | dynamic_img_size: bool = True, out_indices: int | Tuple[int, ...] | None = None 1136 | ) -> nn.Module: 1137 | model_str = f""" 1138 | path: renormalized.RenormalizingModel 1139 | arguments: 1140 | new_normalization: 1141 | mean: [0.7068,0.5755,0.722] 1142 | std: [0.195,0.2316,0.1816] 1143 | model: 1144 | path: backbones.HuggingFaceModel 1145 | arguments: 1146 | model_name_or_path: histai/hibou-L 1147 | trust_remote_code: true 1148 | with_config: false 1149 | output_transform: 1150 | class_path: extract_cls_token.{"ExtractConcatToken" if out_indices is None else "ExtractPatchFeatures"} 1151 | init_args: 1152 | num_reg_tokens: 4 1153 | """ 1154 | return KaikoModel(model_str, None) 1155 | 1156 | 1157 | @register_model("vitg14_Prov_GigaPath") 1158 | def model( 1159 | dynamic_img_size: bool = True, out_indices: int | Tuple[int, ...] | None = None 1160 | ) -> nn.Module: 1161 | model_str = f""" 1162 | path: renormalized.RenormalizingModel 1163 | arguments: 1164 | new_normalization: 1165 | mean: [0.485, 0.456, 0.406] 1166 | std: [0.229, 0.224, 0.225] 1167 | model: 1168 | path: backbones.TimmModel 1169 | arguments: 1170 | concat_mean_patch_tokens: false 1171 | out_indices: {"null" if out_indices is None else out_indices} 1172 | features_only: {out_indices is not None} 1173 | model_name: hf_hub:prov-gigapath/prov-gigapath 1174 | pretrained: true 1175 | dynamic_img_size: {dynamic_img_size} 1176 | """ 1177 | return KaikoModel(model_str, None) 1178 | 1179 | 1180 | @register_model("vitg14_Prov_GigaPath_concat") 1181 | def model( 1182 | dynamic_img_size: bool = True, out_indices: int | Tuple[int, ...] | None = None 1183 | ) -> nn.Module: 1184 | model_str = f""" 1185 | path: renormalized.RenormalizingModel 1186 | arguments: 1187 | new_normalization: 1188 | mean: [0.485, 0.456, 0.406] 1189 | std: [0.229, 0.224, 0.225] 1190 | model: 1191 | path: backbones.TimmModel 1192 | arguments: 1193 | concat_mean_patch_tokens: true 1194 | out_indices: {"null" if out_indices is None else out_indices} 1195 | features_only: {out_indices is not None} 1196 | model_name: hf_hub:prov-gigapath/prov-gigapath 1197 | pretrained: true 1198 | dynamic_img_size: {dynamic_img_size} 1199 | """ 1200 | return KaikoModel(model_str, None) 1201 | 1202 | 1203 | @register_model("vits16_Lunit_renorm") 1204 | def model( 1205 | dynamic_img_size: bool = True, out_indices: int | Tuple[int, ...] | None = None 1206 | ) -> nn.Module: 1207 | model_str = f""" 1208 | path: renormalized.RenormalizingModel 1209 | arguments: 1210 | new_normalization: 1211 | mean: [0.70322989, 0.53606487, 0.66096631] 1212 | std: [0.21716536, 0.26081574, 0.20723464] 1213 | model: 1214 | path: backbones.TimmModel 1215 | arguments: 1216 | concat_mean_patch_tokens: false 1217 | out_indices: {"null" if out_indices is None else out_indices} 1218 | features_only: {out_indices is not None} 1219 | model_name: hf-hub:1aurent/vit_small_patch16_224.lunit_dino 1220 | pretrained: true 1221 | dynamic_img_size: true 1222 | num_classes: 0 1223 | """ 1224 | return KaikoModel(model_str, None) 1225 | 1226 | 1227 | @register_model("vits16_Lunit_renorm_concat") 1228 | def model( 1229 | dynamic_img_size: bool = True, out_indices: int | Tuple[int, ...] | None = None 1230 | ) -> nn.Module: 1231 | model_str = f""" 1232 | path: renormalized.RenormalizingModel 1233 | arguments: 1234 | new_normalization: 1235 | mean: [0.70322989, 0.53606487, 0.66096631] 1236 | std: [0.21716536, 0.26081574, 0.20723464] 1237 | model: 1238 | path: backbones.TimmModel 1239 | arguments: 1240 | concat_mean_patch_tokens: true 1241 | out_indices: {"null" if out_indices is None else out_indices} 1242 | features_only: {out_indices is not None} 1243 | model_name: hf-hub:1aurent/vit_small_patch16_224.lunit_dino 1244 | pretrained: true 1245 | dynamic_img_size: true 1246 | num_classes: 0 1247 | """ 1248 | return KaikoModel(model_str, None) 1249 | 1250 | 1251 | @register_model("vitL16_UNI") 1252 | def model( 1253 | dynamic_img_size: bool = True, out_indices: int | Tuple[int, ...] | None = None 1254 | ) -> nn.Module: 1255 | model_str = f""" 1256 | path: renormalized.RenormalizingModel 1257 | arguments: 1258 | new_normalization: 1259 | mean: [0.485, 0.456, 0.406] 1260 | std: [0.229, 0.224, 0.225] 1261 | model: 1262 | path: backbones.TimmModel 1263 | arguments: 1264 | concat_mean_patch_tokens: false 1265 | out_indices: {"null" if out_indices is None else out_indices} 1266 | features_only: {out_indices is not None} 1267 | model_name: hf-hub:MahmoodLab/uni 1268 | init_values: 1.0e-5 1269 | pretrained: true 1270 | dynamic_img_size: true 1271 | num_classes: 0 1272 | """ 1273 | return KaikoModel(model_str, None) 1274 | 1275 | 1276 | @register_model("vitL16_UNI_resize512") 1277 | def model( 1278 | dynamic_img_size: bool = True, out_indices: int | Tuple[int, ...] | None = None 1279 | ) -> nn.Module: 1280 | model_str = f""" 1281 | path: renormalized.RenormalizingModel 1282 | arguments: 1283 | new_normalization: 1284 | mean: [0.485, 0.456, 0.406] 1285 | std: [0.229, 0.224, 0.225] 1286 | model: 1287 | path: backbones.TimmModel 1288 | arguments: 1289 | concat_mean_patch_tokens: false 1290 | out_indices: {"null" if out_indices is None else out_indices} 1291 | features_only: {out_indices is not None} 1292 | model_name: hf-hub:MahmoodLab/uni 1293 | init_values: 1.0e-5 1294 | pretrained: true 1295 | dynamic_img_size: true 1296 | num_classes: 0 1297 | """ 1298 | return KaikoModel(model_str, None) 1299 | 1300 | 1301 | @register_model("vitL16_UNI_concat") 1302 | def model( 1303 | dynamic_img_size: bool = True, out_indices: int | Tuple[int, ...] | None = None 1304 | ) -> nn.Module: 1305 | model_str = f""" 1306 | path: renormalized.RenormalizingModel 1307 | arguments: 1308 | new_normalization: 1309 | mean: [0.485, 0.456, 0.406] 1310 | std: [0.229, 0.224, 0.225] 1311 | model: 1312 | path: backbones.TimmModel 1313 | arguments: 1314 | concat_mean_patch_tokens: true 1315 | out_indices: {"null" if out_indices is None else out_indices} 1316 | features_only: {out_indices is not None} 1317 | model_name: hf-hub:MahmoodLab/uni 1318 | init_values: 1.0e-5 1319 | pretrained: true 1320 | dynamic_img_size: true 1321 | num_classes: 0 1322 | """ 1323 | return KaikoModel(model_str, None) 1324 | 1325 | 1326 | @register_model("vitL16_UNI_concat_resize512") 1327 | def model( 1328 | dynamic_img_size: bool = True, out_indices: int | Tuple[int, ...] | None = None 1329 | ) -> nn.Module: 1330 | model_str = f""" 1331 | path: renormalized.RenormalizingModel 1332 | arguments: 1333 | new_normalization: 1334 | mean: [0.485, 0.456, 0.406] 1335 | std: [0.229, 0.224, 0.225] 1336 | model: 1337 | path: backbones.TimmModel 1338 | arguments: 1339 | concat_mean_patch_tokens: true 1340 | out_indices: {"null" if out_indices is None else out_indices} 1341 | features_only: {out_indices is not None} 1342 | model_name: hf-hub:MahmoodLab/uni 1343 | init_values: 1.0e-5 1344 | pretrained: true 1345 | dynamic_img_size: true 1346 | num_classes: 0 1347 | """ 1348 | return KaikoModel(model_str, None) 1349 | 1350 | 1351 | @register_model("vitg14_224_UNI2") 1352 | def model( 1353 | dynamic_img_size: bool = True, out_indices: int | Tuple[int, ...] | None = None 1354 | ) -> nn.Module: 1355 | model_str = f""" 1356 | path: renormalized.RenormalizingModel 1357 | arguments: 1358 | new_normalization: 1359 | mean: [0.485, 0.456, 0.406] 1360 | std: [0.229, 0.224, 0.225] 1361 | model: 1362 | path: backbones.TimmModel 1363 | arguments: 1364 | concat_mean_patch_tokens: false 1365 | out_indices: {"null" if out_indices is None else out_indices} 1366 | features_only: {out_indices is not None} 1367 | model_name: hf-hub:MahmoodLab/UNI2-h 1368 | pretrained: True 1369 | img_size: 224 1370 | patch_size: 14 1371 | depth: 24 1372 | num_heads: 24 1373 | init_values: 1.0e-5 1374 | embed_dim: 1536 1375 | mlp_ratio: 5.33334 # 2.66667*2 1376 | num_classes: 0 1377 | no_embed_class: True 1378 | mlp_layer: timm.layers.SwiGLUPacked 1379 | act_layer: torch.nn.SiLU 1380 | reg_tokens: 8 1381 | dynamic_img_size: True 1382 | """ 1383 | return KaikoModel(model_str, None) 1384 | 1385 | 1386 | @register_model("vitg14_224_UNI2_resize392") 1387 | def model( 1388 | dynamic_img_size: bool = True, out_indices: int | Tuple[int, ...] | None = None 1389 | ) -> nn.Module: 1390 | model_str = f""" 1391 | path: renormalized.RenormalizingModel 1392 | arguments: 1393 | new_normalization: 1394 | mean: [0.485, 0.456, 0.406] 1395 | std: [0.229, 0.224, 0.225] 1396 | model: 1397 | path: backbones.TimmModel 1398 | arguments: 1399 | concat_mean_patch_tokens: false 1400 | out_indices: {"null" if out_indices is None else out_indices} 1401 | features_only: {out_indices is not None} 1402 | model_name: hf-hub:MahmoodLab/UNI2-h 1403 | pretrained: True 1404 | img_size: 224 1405 | patch_size: 14 1406 | depth: 24 1407 | num_heads: 24 1408 | init_values: 1.0e-5 1409 | embed_dim: 1536 1410 | mlp_ratio: 5.33334 # 2.66667*2 1411 | num_classes: 0 1412 | no_embed_class: True 1413 | mlp_layer: timm.layers.SwiGLUPacked 1414 | act_layer: torch.nn.SiLU 1415 | reg_tokens: 8 1416 | dynamic_img_size: True 1417 | """ 1418 | return KaikoModel(model_str, None) 1419 | 1420 | 1421 | @register_model("vitg14_224_UNI2_concat") 1422 | def model( 1423 | dynamic_img_size: bool = True, out_indices: int | Tuple[int, ...] | None = None 1424 | ) -> nn.Module: 1425 | model_str = f""" 1426 | path: renormalized.RenormalizingModel 1427 | arguments: 1428 | new_normalization: 1429 | mean: [0.485, 0.456, 0.406] 1430 | std: [0.229, 0.224, 0.225] 1431 | model: 1432 | path: backbones.TimmModel 1433 | arguments: 1434 | concat_mean_patch_tokens: true 1435 | out_indices: {"null" if out_indices is None else out_indices} 1436 | features_only: {out_indices is not None} 1437 | model_name: hf-hub:MahmoodLab/UNI2-h 1438 | pretrained: True 1439 | img_size: 224 1440 | patch_size: 14 1441 | depth: 24 1442 | num_heads: 24 1443 | init_values: 1.0e-5 1444 | embed_dim: 1536 1445 | mlp_ratio: 5.33334 # 2.66667*2 1446 | num_classes: 0 1447 | no_embed_class: True 1448 | mlp_layer: timm.layers.SwiGLUPacked 1449 | act_layer: torch.nn.SiLU 1450 | reg_tokens: 8 1451 | dynamic_img_size: True 1452 | """ 1453 | return KaikoModel(model_str, None) 1454 | 1455 | 1456 | @register_model("vitg14_224_UNI2_concat_resize392") 1457 | def model( 1458 | dynamic_img_size: bool = True, out_indices: int | Tuple[int, ...] | None = None 1459 | ) -> nn.Module: 1460 | model_str = f""" 1461 | path: renormalized.RenormalizingModel 1462 | arguments: 1463 | new_normalization: 1464 | mean: [0.485, 0.456, 0.406] 1465 | std: [0.229, 0.224, 0.225] 1466 | model: 1467 | path: backbones.TimmModel 1468 | arguments: 1469 | concat_mean_patch_tokens: true 1470 | out_indices: {"null" if out_indices is None else out_indices} 1471 | features_only: {out_indices is not None} 1472 | model_name: hf-hub:MahmoodLab/UNI2-h 1473 | pretrained: True 1474 | img_size: 224 1475 | patch_size: 14 1476 | depth: 24 1477 | num_heads: 24 1478 | init_values: 1.0e-5 1479 | embed_dim: 1536 1480 | mlp_ratio: 5.33334 # 2.66667*2 1481 | num_classes: 0 1482 | no_embed_class: True 1483 | mlp_layer: timm.layers.SwiGLUPacked 1484 | act_layer: torch.nn.SiLU 1485 | reg_tokens: 8 1486 | dynamic_img_size: True 1487 | """ 1488 | return KaikoModel(model_str, None) 1489 | 1490 | 1491 | @register_model("vitH14_Virchow2") 1492 | def model( 1493 | dynamic_img_size: bool = True, out_indices: int | Tuple[int, ...] | None = None 1494 | ) -> nn.Module: 1495 | model_str = f""" 1496 | path: renormalized.RenormalizingModel 1497 | arguments: 1498 | new_normalization: 1499 | mean: [0.485, 0.456, 0.406] 1500 | std: [0.229, 0.224, 0.225] 1501 | model: 1502 | path: backbones.Virchow2 1503 | arguments: 1504 | concat_mean_patch_tokens: false 1505 | out_indices: {"null" if out_indices is None else out_indices} 1506 | features_only: {out_indices is not None} 1507 | """ 1508 | return KaikoModel(model_str, None) 1509 | 1510 | 1511 | @register_model("vitH14_Virchow2_concat") 1512 | def model( 1513 | dynamic_img_size: bool = True, out_indices: int | Tuple[int, ...] | None = None 1514 | ) -> nn.Module: 1515 | model_str = f""" 1516 | path: renormalized.RenormalizingModel 1517 | arguments: 1518 | new_normalization: 1519 | mean: [0.485, 0.456, 0.406] 1520 | std: [0.229, 0.224, 0.225] 1521 | model: 1522 | path: backbones.Virchow2 1523 | arguments: 1524 | concat_mean_patch_tokens: true 1525 | out_indices: {"null" if out_indices is None else out_indices} 1526 | features_only: {out_indices is not None} 1527 | """ 1528 | return KaikoModel(model_str, None) 1529 | 1530 | 1531 | @register_model("Bioptimus_h_optimus_0") 1532 | def model( 1533 | dynamic_img_size: bool = True, out_indices: int | Tuple[int, ...] | None = None 1534 | ) -> nn.Module: 1535 | model_str = f""" 1536 | path: renormalized.RenormalizingModel 1537 | arguments: 1538 | new_normalization: 1539 | mean: [0.707223, 0.578729, 0.703617] 1540 | std: [0.211883, 0.230117, 0.177517] 1541 | model: 1542 | path: backbones.TimmModel 1543 | arguments: 1544 | concat_mean_patch_tokens: false 1545 | model_name: hf-hub:bioptimus/H-optimus-0 1546 | init_values: 1.0e-5 1547 | pretrained: true 1548 | dynamic_img_size: true 1549 | num_classes: 0 1550 | out_indices: {"null" if out_indices is None else out_indices} 1551 | features_only: {out_indices is not None} 1552 | """ 1553 | return KaikoModel(model_str, None) 1554 | 1555 | 1556 | @register_model("Bioptimus_h_optimus_0_concat") 1557 | def model( 1558 | dynamic_img_size: bool = True, out_indices: int | Tuple[int, ...] | None = None 1559 | ) -> nn.Module: 1560 | model_str = f""" 1561 | path: renormalized.RenormalizingModel 1562 | arguments: 1563 | new_normalization: 1564 | mean: [0.707223, 0.578729, 0.703617] 1565 | std: [0.211883, 0.230117, 0.177517] 1566 | model: 1567 | path: backbones.TimmModel 1568 | arguments: 1569 | concat_mean_patch_tokens: true 1570 | model_name: hf-hub:bioptimus/H-optimus-0 1571 | init_values: 1.0e-5 1572 | pretrained: true 1573 | dynamic_img_size: true 1574 | num_classes: 0 1575 | out_indices: {"null" if out_indices is None else out_indices} 1576 | features_only: {out_indices is not None} 1577 | """ 1578 | return KaikoModel(model_str, None) 1579 | --------------------------------------------------------------------------------