├── .gitignore ├── LICENSE.txt ├── README.md ├── base_ml ├── base_cli.py ├── base_early_stopping.py ├── base_experiment.py ├── base_loss.py ├── base_optim.py ├── base_trainer.py ├── base_utils.py ├── base_validator.py ├── optim_factory.py └── unireplknet_layer_decay_optimizer_constructor.py ├── cell_segmentation ├── __init__.py ├── datasets │ ├── base_cell.py │ ├── cell_graph_datamodel.py │ ├── conic.py │ ├── consep.py │ ├── dataset_coordinator.py │ ├── monuseg.py │ ├── pannuke.py │ ├── prepare_monuseg.py │ └── prepare_pannuke.py ├── experiments │ ├── __init__.py │ └── experiment_cellvit_pannuke.py ├── inference │ ├── __init__.py │ ├── cell_detection.py │ ├── cell_detection_256.py │ ├── cell_detection_mp.py │ ├── inference_cellvit_experiment_monuseg.py │ └── inference_cellvit_experiment_pannuke.py ├── run_cellvit.py ├── trainer │ ├── __init__.py │ └── trainer_cellvit.py └── utils │ ├── __init__.py │ ├── metrics.py │ ├── post_proc_cellvit.py │ ├── template_geojson.py │ └── tools.py ├── config.yaml ├── datamodel ├── __init__.py ├── graph_datamodel.py └── wsi_datamodel.py ├── docs ├── datasets │ └── PanNuke │ │ ├── dataset_config.yaml │ │ ├── fold0 │ │ ├── cell_count.csv │ │ └── types.csv │ │ ├── fold1 │ │ ├── cell_count.csv │ │ └── types.csv │ │ ├── fold2 │ │ ├── cell_count.csv │ │ └── types.csv │ │ └── weight_config.yaml ├── model.png └── readmes │ ├── cell_segmentation.md │ ├── monuseg.md │ └── pannuke.md ├── models ├── __init__.py ├── segmentation │ ├── __init__.py │ └── cell_segmentation │ │ ├── __init__.py │ │ ├── __pycache__ │ │ ├── UNet_v2.py │ │ ├── __init__.cpython-310.pyc │ │ ├── __init__.cpython-38.pyc │ │ ├── cellvit.cpython-310.pyc │ │ ├── cellvit.cpython-38.pyc │ │ ├── cellvit_shared.cpython-38.pyc │ │ ├── cellvit_unirepLKnet.cpython-310.pyc │ │ ├── cellvit_unirepLKnet.cpython-38.pyc │ │ ├── replknet.cpython-38.pyc │ │ └── utils.cpython-38.pyc │ │ ├── cellvit.py │ │ ├── cellvit_unirepLKnet.py │ │ ├── replknet.py │ │ └── utils.py └── utils │ ├── __init__.py │ ├── attention.py │ ├── dense.py │ ├── residual.py │ ├── tf_utils.py │ └── tools.py ├── preprocessing ├── encoding │ └── datasets │ │ ├── __init__.py │ │ └── patched_wsi_inference.py └── patch_extraction │ ├── main_extraction.py │ ├── scripts │ └── macenko.py │ └── src │ └── cli.py ├── requirements.txt └── utils ├── __init__.py ├── file_handling.py ├── logger.py └── tools.py /.gitignore: -------------------------------------------------------------------------------- 1 | **/__pycache__ 2 | /.vscode 3 | test.py -------------------------------------------------------------------------------- /LICENSE.txt: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Hust Vision Lab 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |
2 |

LKCell🔬

3 |

Efficient Cell Nuclei Instance Segmentation with Large Convolution Kernels

4 | 5 | 6 | 7 | [![Paper](https://img.shields.io/badge/cs.CV-2407.18054-b31b1b?logo=arxiv&logoColor=red)](https://arxiv.org/abs/2407.18054) 8 | [![Python 3.9.7](https://img.shields.io/badge/python-3.9.7-blue.svg)](https://www.python.org/downloads/release/python-360/) 9 | [![license](https://img.shields.io/badge/license-MIT-orange)](LICENSE) 10 | [![authors](https://img.shields.io/badge/by-hustvl-green)](https://github.com/hustvl) 11 | [![HuggingFace](https://img.shields.io/badge/%F0%9F%A4%97%20Space-Demo-yellow)](https://huggingface.co/spaces/xiazhi/LKCell) 12 | 13 | [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/lkcell-efficient-cell-nuclei-instance/panoptic-segmentation-on-pannuke)](https://paperswithcode.com/sota/panoptic-segmentation-on-pannuke?p=lkcell-efficient-cell-nuclei-instance) 14 | 15 | [Ziwei Cui](https://github.com/ziwei-cui) 1*, [Jingfeng Yao](https://github.com/JingfengYao) 1*, [Lunbin Zeng](https://github.com/xiazhi1) 1, [Juan Yang]() 2, [Wenyu Liu](http://eic.hust.edu.cn/professor/liuwenyu) 1, [Xinggang Wang](https://xwcv.github.io/) 1,📧 16 | 17 | 1 School of Electronic Information and Communications, Huazhong University of Science and Technology \ 18 | 2 Department of Cardiology, Huanggang Central Hospital 19 | 20 | (\* equal contribution, 📧 corresponding author) 21 | 22 | [Key Features](#key-features) • [Installation](#installation) • [Usage](#usage) • [Training](#training) • [Inference](#inference) • [Citation](#Citation) 23 | 24 | 25 | 26 | --- 27 |
28 | 29 | 30 | ## Key Features 31 | 32 | **Click and try LKCell on our [🤗 Hugging Face Space](https://huggingface.co/spaces/xiazhi/LKCell)!** 33 | 34 | This repository contains the code implementation of LKCell, a deep learning-based method for automated instance segmentation of cell nuclei in digitized tissue samples. LKCell utilizes an architecture based on large convolutional kernels and achieves state-of-the-art performance on the [PanNuke](https://warwick.ac.uk/fac/cross_fac/tia/data/pannuke) dataset, a challenging nuclei instance segmentation benchmark. 35 | 36 | 37 |

38 | 39 |

40 | 41 | 42 | 43 | 44 | 45 | 46 | ## Installation 47 | 48 | ``` 49 | git clone https://github.com/hustvl/LKCell.git 50 | conda create -n lkcell 51 | conda activate lkcell 52 | pip install -r requirements.txt 53 | ``` 54 | 55 | Note: (1) preferred torch version is 2.0; (2) If you find problem in installing `depthwise-conv2d-implicit-gemm==0.0.0` , please follow the instruction in [here](https://github.com/AILab-CVC/UniRepLKNet#use-an-efficient-large-kernel-convolution-with-pytorch). 56 | 57 | 65 | 66 | 67 | 68 | ## Project Structure 69 | 70 | We are currently using the following folder structure: 71 | 72 | ```bash 73 | ├── base_ml # Basic Machine Learning Code: Trainer, Experiment, ... 74 | ├── cell_segmentation # Cell Segmentation training and inference files 75 | │ ├── datasets # Datasets (PyTorch) 76 | │ ├── experiments # Specific Experiment Code for different experiments 77 | │ ├── inference # Inference code for experiment statistics and plots 78 | │ ├── trainer # Trainer functions to train networks 79 | │ ├── utils # Utils code 80 | │ └── run_cellvit.py # Run file to start an experiment 81 | ├── config # Python configuration file for global Python settings 82 | ├── docs # Documentation files (in addition to this main README.md 83 | ├── models # Machine Learning Models (PyTorch implementations) 84 | │ └── segmentation # LKCell Code 85 | ├── datamodel # Code of dataclass :Graph Data , WSI object , ... 86 | ├── preprocessing # Code of preprocessing : Encoding , Patch Extraction , ... 87 | ``` 88 | 89 | 90 | ## Training 91 | 92 | ### Dataset preparation 93 | We use a customized dataset structure for the PanNuke and the MoNuSeg dataset. 94 | The dataset structures are explained in [pannuke.md](docs/readmes/pannuke.md) and [monuseg.md](docs/readmes/monuseg.md) documentation files. 95 | We also provide preparation scripts in the [`cell_segmentation/datasets/`](cell_segmentation/datasets/) folder. 96 | 97 | ### Training script 98 | The CLI for a ML-experiment to train the LKCell-Network is as follows (here the [```run_cellvit.py```](cell_segmentation/run_cellvit.py) script is used): 99 | ```bash 100 | usage: run_cellvit.py [-h] --config CONFIG [--gpu GPU] [--sweep | --agent AGENT | --checkpoint CHECKPOINT] 101 | Start an experiment with given configuration file. 102 | 103 | python ./cell_segmentation/run_cellvit.py --config ./config.yaml 104 | ``` 105 | 106 | The important file is the configuration file, in which all paths are set, the model configuration is given and the hyperparameters or sweeps are defined. 107 | 108 | 109 | 110 | **Pre-trained UnirepLKNet models** for training initialization can be downloaded from Google Drive: [UnirepLKNet-Models](https://drive.google.com/drive/folders/1pqjCBZIv4WwEsE5raUPz5AUM7I-UPtMJ). 111 | 112 | 113 | ### Evaluation 114 | In our paper, we did not (!) use early stopping, but rather train all models for 100 to eliminate selection bias but have the largest possible database for training. Therefore, evaluation neeeds to be performed with the `latest_checkpoint.pth` model and not the best early stopping model. 115 | We provide a script to create evaluation results: [`inference_cellvit_experiment.py`](cell_segmentation/inference/inference_cellvit_experiment.py) for PanNuke and [`inference_cellvit_monuseg.py`](cell_segmentation/inference/inference_cellvit_monuseg.py) for MoNuSeg. 116 | 117 | ### Inference 118 | 119 | Model checkpoints can be downloaded here, You can choose to download from Google Drive or HuggingFace : 120 | 121 | - Google Drive 122 | - [LKCell-L](https://drive.google.com/drive/folders/1r4vCwcyHgLtMJkr2rhFLox6SDldB2p7F?usp=drive_link) 🚀 123 | - [LKCell-B](https://drive.google.com/drive/folders/1i7SrHloSsGZSbesDZ9hBxbOG4RnaPhQU?usp=drive_link) 124 | - HuggingFace 125 | - [LKCell-L](https://huggingface.co/xiazhi/LKCell-L) 🚀 126 | - [LKCell-B](https://huggingface.co/xiazhi/LKCell-B) 127 | 128 | 129 | You can click [🤗 Hugging Face Space](https://huggingface.co/spaces/xiazhi/LKCell) to quickly perform model inference. 130 | 131 | 132 | ## Acknowledgement 133 | 134 | This project is built upon [CellViT](https://github.com/TIO-IKIM/CellViT) and [UniRepLKNet](https://github.com/AILab-CVC/UniRepLKNet). Thanks for these awesome repos! 135 | 136 | ## Citation 137 | ```latex 138 | @misc{cui2024lkcellefficientcellnuclei, 139 | title={LKCell: Efficient Cell Nuclei Instance Segmentation with Large Convolution Kernels}, 140 | author={Ziwei Cui and Jingfeng Yao and Lunbin Zeng and Juan Yang and Wenyu Liu and Xinggang Wang}, 141 | year={2024}, 142 | eprint={2407.18054}, 143 | archivePrefix={arXiv}, 144 | primaryClass={eess.IV}, 145 | url={https://arxiv.org/abs/2407.18054}, 146 | } 147 | ``` 148 | -------------------------------------------------------------------------------- /base_ml/base_cli.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Base CLI to parse Arguments 3 | 4 | import argparse 5 | import logging 6 | from abc import ABC, abstractmethod 7 | from typing import Tuple, Union 8 | 9 | import yaml 10 | from pydantic import BaseModel 11 | 12 | 13 | class ABCParser(ABC): 14 | """Blueprint for Argument Parser""" 15 | 16 | @abstractmethod 17 | def __init__(self) -> None: 18 | pass 19 | 20 | @abstractmethod 21 | def get_config(self) -> Tuple[Union[BaseModel, dict], logging.Logger]: 22 | """Load configuration and create a logger 23 | 24 | Returns: 25 | Tuple[PreProcessingConfig, logging.Logger]: Configuration and Logger 26 | """ 27 | pass 28 | 29 | @abstractmethod 30 | def store_config(self) -> None: 31 | """Store the config file in the logging directory to keep track of the configuration.""" 32 | pass 33 | 34 | 35 | class ExperimentBaseParser: 36 | """Configuration Parser for Machine Learning Experiments""" 37 | 38 | def __init__(self) -> None: 39 | parser = argparse.ArgumentParser( 40 | formatter_class=argparse.ArgumentDefaultsHelpFormatter, 41 | description="Start an experiment with given configuration file.", 42 | ) 43 | requiredNamed = parser.add_argument_group("required named arguments") 44 | requiredNamed.add_argument( 45 | "--config", type=str, help="Path to a config file", required=True 46 | ) 47 | parser.add_argument("--gpu", type=int, help="Cuda-GPU ID") 48 | group = parser.add_mutually_exclusive_group(required=False) 49 | group.add_argument( 50 | "--sweep", 51 | action="store_true", 52 | help="Starting a sweep. For this the configuration file must be structured according to WandB sweeping. " 53 | "Compare https://docs.wandb.ai/guides/sweeps and https://community.wandb.ai/t/nested-sweep-configuration/3369/3 " 54 | "for further information. This parameter cannot be set in the config file!", 55 | ) 56 | group.add_argument( 57 | "--agent", 58 | type=str, 59 | help="Add a new agent to the sweep. " 60 | "Please pass the sweep ID as argument in the way entity/project/sweep_id, e.g., user1/test_project/v4hwbijh. " 61 | "The agent configuration can be found in the WandB dashboard for the running sweep in the sweep overview tab " 62 | "under launch agent. Just paste the entity/project/sweep_id given there. The provided config file must be a sweep config file." 63 | "This parameter cannot be set in the config file!", 64 | ) 65 | group.add_argument( 66 | "--checkpoint", 67 | type=str, 68 | help="Path to a PyTorch checkpoint file. " 69 | "The file is loaded and continued to train with the provided settings. " 70 | "If this is passed, no sweeps are possible. " 71 | "This parameter cannot be set in the config file!", 72 | ) 73 | 74 | self.parser = parser 75 | 76 | def parse_arguments(self) -> Tuple[Union[BaseModel, dict]]: 77 | """Parse the arguments from CLI and load yaml config 78 | 79 | Returns: 80 | Tuple[Union[BaseModel, dict]]: Parsed arguments 81 | """ 82 | # parse the arguments 83 | opt = self.parser.parse_args() 84 | with open(opt.config, "r") as config_file: 85 | yaml_config = yaml.safe_load(config_file) 86 | yaml_config_dict = dict(yaml_config) 87 | 88 | opt_dict = vars(opt) 89 | # check for gpu to overwrite with cli argument 90 | if "gpu" in opt_dict: 91 | if opt_dict["gpu"] is not None: 92 | yaml_config_dict["gpu"] = opt_dict["gpu"] 93 | 94 | # check if either training, sweep, checkpoint or start agent should be called 95 | # first step: remove such keys from the config file 96 | if "run_sweep" in yaml_config_dict: 97 | yaml_config_dict.pop("run_sweep") 98 | if "agent" in yaml_config_dict: 99 | yaml_config_dict.pop("agent") 100 | if "checkpoint" in yaml_config_dict: 101 | yaml_config_dict.pop("checkpoint") 102 | 103 | # select one of the options 104 | if "sweep" in opt_dict and opt_dict["sweep"] is True: 105 | yaml_config_dict["run_sweep"] = True 106 | else: 107 | yaml_config_dict["run_sweep"] = False 108 | if "agent" in opt_dict: 109 | yaml_config_dict["agent"] = opt_dict["agent"] 110 | if "checkpoint" in opt_dict: 111 | if opt_dict["checkpoint"] is not None: 112 | yaml_config_dict["checkpoint"] = opt_dict["checkpoint"] 113 | 114 | self.config = yaml_config_dict #将yaml_config_dict赋给self.config 115 | 116 | return self.config 117 | -------------------------------------------------------------------------------- /base_ml/base_early_stopping.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Base Machine Learning Experiment 3 | 4 | import logging 5 | 6 | logger = logging.getLogger("__main__") 7 | logger.addHandler(logging.NullHandler()) 8 | 9 | import wandb 10 | 11 | 12 | class EarlyStopping: 13 | """Early Stopping Class 14 | 15 | Args: 16 | patience (int): Patience to wait before stopping 17 | strategy (str, optional): Optimization strategy. 18 | Please select 'minimize' or 'maximize' for strategy. Defaults to "minimize". 19 | """ 20 | 21 | def __init__(self, patience: int, strategy: str = "minimize"): 22 | assert strategy.lower() in [ 23 | "minimize", 24 | "maximize", 25 | ], "Please select 'minimize' or 'maximize' for strategy" 26 | 27 | self.patience = patience 28 | self.counter = 0 29 | self.strategy = strategy.lower() 30 | self.best_metric = None 31 | self.best_epoch = None 32 | self.early_stop = False 33 | 34 | logger.info( 35 | f"Using early stopping with a range of {self.patience} and {self.strategy} strategy" 36 | ) 37 | 38 | def __call__(self, metric: float, epoch: int) -> bool: 39 | """Early stopping update call 40 | 41 | Args: 42 | metric (float): Metric for early stopping 43 | epoch (int): Current epoch 44 | 45 | Returns: 46 | bool: Returns true if the model is performing better than the current best model, 47 | otherwise false 48 | """ 49 | if self.best_metric is None: 50 | self.best_metric = metric 51 | self.best_epoch = epoch 52 | return True 53 | else: 54 | if self.strategy == "minimize": 55 | if self.best_metric >= metric: 56 | self.best_metric = metric 57 | self.best_epoch = epoch 58 | self.counter = 0 59 | wandb.run.summary["Best-Epoch"] = epoch 60 | wandb.run.summary["Best-Metric"] = metric 61 | return True 62 | else: 63 | self.counter += 1 64 | if self.counter >= self.patience: 65 | self.early_stop = True 66 | return False 67 | elif self.strategy == "maximize": 68 | if self.best_metric <= metric: 69 | self.best_metric = metric 70 | self.best_epoch = epoch 71 | self.counter = 0 72 | wandb.run.summary["Best-Epoch"] = epoch 73 | wandb.run.summary["Best-Metric"] = metric 74 | return True 75 | else: 76 | self.counter += 1 77 | if self.counter >= self.patience: 78 | self.early_stop = True 79 | return False 80 | -------------------------------------------------------------------------------- /base_ml/base_optim.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Wrappping all available PyTorch Optimizer 3 | 4 | 5 | from torch.optim import ( 6 | ASGD, 7 | LBFGS, 8 | SGD, 9 | Adadelta, 10 | Adagrad, 11 | Adam, 12 | Adamax, 13 | AdamW, 14 | RAdam, 15 | RMSprop, 16 | Rprop, 17 | SparseAdam, 18 | ) 19 | 20 | OPTI_DICT = { 21 | "Adadelta": Adadelta, 22 | "Adagrad": Adagrad, 23 | "Adam": Adam, 24 | "AdamW": AdamW, 25 | "SparseAdam": SparseAdam, 26 | "Adamax": Adamax, 27 | "ASGD": ASGD, 28 | "LBFGS": LBFGS, 29 | "RAdam": RAdam, 30 | "RMSprop": RMSprop, 31 | "Rprop": Rprop, 32 | "SGD": SGD, 33 | } 34 | -------------------------------------------------------------------------------- /base_ml/base_trainer.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Base Trainer Class 3 | 4 | 5 | import logging 6 | from abc import abstractmethod 7 | from typing import Tuple, Union 8 | 9 | import torch 10 | import torch.nn as nn 11 | import wandb 12 | from base_ml.base_early_stopping import EarlyStopping 13 | from pathlib import Path 14 | from torch.nn.modules.loss import _Loss 15 | from torch.optim import Optimizer 16 | from torch.optim.lr_scheduler import _LRScheduler 17 | from torch.utils.data import DataLoader 18 | from utils.tools import flatten_dict 19 | 20 | 21 | class BaseTrainer: 22 | """ 23 | Base class for all trainers with important ML components 24 | 25 | Args: 26 | model (nn.Module): Model that should be trained 27 | loss_fn (_Loss): Loss function 28 | optimizer (Optimizer): Optimizer 29 | scheduler (_LRScheduler): Learning rate scheduler 30 | device (str): Cuda device to use, e.g., cuda:0. 31 | logger (logging.Logger): Logger module 32 | logdir (Union[Path, str]): Logging directory 33 | experiment_config (dict): Configuration of this experiment 34 | early_stopping (EarlyStopping, optional): Early Stopping Class. Defaults to None. 35 | accum_iter (int, optional): Accumulation steps for gradient accumulation. 36 | Provide a number greater than 1 for activating gradient accumulation. Defaults to 1. 37 | mixed_precision (bool, optional): If mixed-precision should be used. Defaults to False. 38 | log_images (bool, optional): If images should be logged to WandB. Defaults to False. 39 | """ 40 | 41 | def __init__( 42 | self, 43 | model: nn.Module, 44 | loss_fn: _Loss, 45 | optimizer: Optimizer, 46 | scheduler: _LRScheduler, 47 | device: str, 48 | logger: logging.Logger, 49 | logdir: Union[Path, str], 50 | experiment_config: dict, 51 | early_stopping: EarlyStopping = None, 52 | accum_iter: int = 1, 53 | mixed_precision: bool = False, 54 | log_images: bool = False, 55 | #model_ema: bool = True, 56 | ) -> None: 57 | self.model = model 58 | 59 | self.loss_fn = loss_fn 60 | self.optimizer = optimizer 61 | self.scheduler = scheduler 62 | self.device = device 63 | self.logger = logger 64 | self.logdir = Path(logdir) 65 | self.early_stopping = early_stopping 66 | self.accum_iter = accum_iter 67 | self.start_epoch = 0 68 | self.experiment_config = experiment_config 69 | self.log_images = log_images 70 | self.mixed_precision = mixed_precision 71 | if self.mixed_precision: 72 | self.scaler = torch.cuda.amp.GradScaler(enabled=True) 73 | else: 74 | self.scaler = None 75 | 76 | @abstractmethod 77 | def train_epoch( 78 | self, epoch: int, train_loader: DataLoader, **kwargs 79 | ) -> Tuple[dict, dict]: 80 | """Training logic for a training epoch 81 | 82 | Args: 83 | epoch (int): Current epoch number 84 | train_loader (DataLoader): Train dataloader 85 | 86 | Raises: 87 | NotImplementedError: Needs to be implemented 88 | 89 | Returns: 90 | Tuple[dict, dict]: wandb logging dictionaries 91 | * Scalar metrics 92 | * Image metrics 93 | """ 94 | raise NotImplementedError 95 | 96 | @abstractmethod 97 | def validation_epoch( 98 | self, epoch: int, val_dataloader: DataLoader 99 | ) -> Tuple[dict, dict, float]: 100 | """Training logic for an validation epoch 101 | 102 | Args: 103 | epoch (int): Current epoch number 104 | val_dataloader (DataLoader): Validation dataloader 105 | 106 | Raises: 107 | NotImplementedError: Needs to be implemented 108 | 109 | Returns: 110 | Tuple[dict, dict, float]: wandb logging dictionaries and early_stopping_metric 111 | * Scalar metrics 112 | * Image metrics 113 | * Early Stopping metric as float 114 | """ 115 | raise NotImplementedError 116 | 117 | @abstractmethod 118 | def train_step(self, batch: object, batch_idx: int, num_batches: int): 119 | """Training logic for one training batch 120 | 121 | Args: 122 | batch (object): A training batch 123 | batch_idx (int): Current batch index 124 | num_batches (int): Maximum number of batches 125 | 126 | Raises: 127 | NotImplementedError: Needs to be implemented 128 | """ 129 | 130 | raise NotImplementedError 131 | 132 | @abstractmethod 133 | def validation_step(self, batch, batch_idx: int): 134 | """Training logic for one validation batch 135 | 136 | Args: 137 | batch (object): A training batch 138 | batch_idx (int): Current batch index 139 | 140 | Raises: 141 | NotImplementedError: Needs to be implemented 142 | """ 143 | 144 | def fit( 145 | self, 146 | epochs: int, 147 | train_dataloader: DataLoader, 148 | val_dataloader: DataLoader, 149 | metric_init: dict = None, 150 | eval_every: int = 1, 151 | **kwargs, 152 | ): 153 | """Fitting function to start training and validation of the trainer 154 | 155 | Args: 156 | epochs (int): Number of epochs the network should be training 157 | train_dataloader (DataLoader): Dataloader with training data 158 | val_dataloader (DataLoader): Dataloader with validation data 159 | metric_init (dict, optional): Initialization dictionary with scalar metrics that should be initialized for startup. 160 | This is just import for logging with wandb if you want to have the plots properly scaled. 161 | The data in the the metric dictionary is used as values for epoch 0 (before training has startetd). 162 | If not provided, step 0 (epoch 0) is not logged. Should have the same scalar keys as training and validation epochs report. 163 | For more information, you should have a look into the train_epoch and val_epoch methods where the wandb logging dicts are assembled. 164 | Defaults to None. 165 | eval_every (int, optional): How often the network should be evaluated (after how many epochs). Defaults to 1. 166 | **kwargs 167 | """ 168 | 169 | self.logger.info(f"Starting training, total number of epochs: {epochs}") 170 | if metric_init is not None and self.start_epoch == 0: 171 | wandb.log(metric_init, step=0) 172 | for epoch in range(self.start_epoch, epochs): 173 | # training epoch 174 | #train_sampler.set_epoch(epoch) # for distributed training 175 | self.logger.info(f"Epoch: {epoch+1}/{epochs}") 176 | train_scalar_metrics, train_image_metrics = self.train_epoch( 177 | epoch, train_dataloader, **kwargs 178 | ) 179 | wandb.log(train_scalar_metrics, step=epoch + 1) 180 | if self.log_images: 181 | wandb.log(train_image_metrics, step=epoch + 1) 182 | if epoch >=95 and ((epoch + 1)) % eval_every == 0: 183 | # validation epoch 184 | ( 185 | val_scalar_metrics, 186 | val_image_metrics, 187 | early_stopping_metric, 188 | ) = self.validation_epoch(epoch, val_dataloader) 189 | wandb.log(val_scalar_metrics, step=epoch + 1) 190 | if self.log_images: 191 | wandb.log(val_image_metrics, step=epoch + 1) 192 | 193 | #self.save_checkpoint(epoch, f"checkpoint_{epoch}.pth") 194 | 195 | # log learning rate 196 | curr_lr = self.optimizer.param_groups[0]["lr"] 197 | wandb.log( 198 | { 199 | "Learning-Rate/Learning-Rate": curr_lr, 200 | }, 201 | step=epoch + 1, 202 | ) 203 | if epoch >=95 and ((epoch + 1)) % eval_every == 0: 204 | # early stopping 205 | if self.early_stopping is not None: 206 | best_model = self.early_stopping(early_stopping_metric, epoch) 207 | if best_model: 208 | self.logger.info("New best model - save checkpoint") 209 | self.save_checkpoint(epoch, "model_best.pth") 210 | elif self.early_stopping.early_stop: 211 | self.logger.info("Performing early stopping!") 212 | break 213 | self.save_checkpoint(epoch, "latest_checkpoint.pth") 214 | 215 | # scheduling 216 | if type(self.scheduler) == torch.optim.lr_scheduler.ReduceLROnPlateau: 217 | self.scheduler.step(float(val_scalar_metrics["Loss/Validation"])) 218 | else: 219 | self.scheduler.step() 220 | new_lr = self.optimizer.param_groups[0]["lr"] 221 | self.logger.debug(f"Old lr: {curr_lr:.6f} - New lr: {new_lr:.6f}") 222 | 223 | def save_checkpoint(self, epoch: int, checkpoint_name: str): 224 | if self.early_stopping is None: 225 | best_metric = None 226 | best_epoch = None 227 | else: 228 | best_metric = self.early_stopping.best_metric 229 | best_epoch = self.early_stopping.best_epoch 230 | 231 | arch = type(self.model).__name__ 232 | state = { 233 | "arch": arch, 234 | "epoch": epoch, 235 | "model_state_dict": self.model.state_dict(), 236 | "optimizer_state_dict": self.optimizer.state_dict(), 237 | "scheduler_state_dict": self.scheduler.state_dict(), 238 | "best_metric": best_metric, 239 | "best_epoch": best_epoch, 240 | "config": flatten_dict(wandb.config), 241 | "wandb_id": wandb.run.id, 242 | "logdir": str(self.logdir.resolve()), 243 | "run_name": str(Path(self.logdir).name), 244 | "scaler_state_dict": self.scaler.state_dict() 245 | if self.scaler is not None 246 | else None, 247 | } 248 | 249 | checkpoint_dir = self.logdir / "checkpoints" 250 | checkpoint_dir.mkdir(exist_ok=True, parents=True) 251 | 252 | filename = str(checkpoint_dir / checkpoint_name) 253 | torch.save(state, filename) 254 | 255 | def resume_checkpoint(self, checkpoint): 256 | self.logger.info("Loading checkpoint") 257 | self.logger.info("Loading Model") 258 | self.model.load_state_dict(checkpoint["model_state_dict"]) 259 | self.logger.info("Loading Optimizer state dict") 260 | self.optimizer.load_state_dict(checkpoint["optimizer_state_dict"]) 261 | self.scheduler.load_state_dict(checkpoint["scheduler_state_dict"]) 262 | 263 | if self.early_stopping is not None: 264 | self.early_stopping.best_metric = checkpoint["best_metric"] 265 | self.early_stopping.best_epoch = checkpoint["best_epoch"] 266 | if self.scaler is not None: 267 | self.scaler.load_state_dict(checkpoint["scaler_state_dict"]) 268 | 269 | self.logger.info(f"Checkpoint epoch: {int(checkpoint['epoch'])}") 270 | self.start_epoch = int(checkpoint["epoch"]) 271 | self.logger.info(f"Next epoch is: {self.start_epoch + 1}") 272 | -------------------------------------------------------------------------------- /base_ml/base_utils.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import torch 3 | import torch.nn.functional as F 4 | 5 | __all__ = ["filter2D", "gaussian", "gaussian_kernel2d", "sobel_hv"] 6 | 7 | 8 | def filter2D(input_tensor: torch.Tensor, kernel: torch.Tensor) -> torch.Tensor: 9 | """Convolves a given kernel on input tensor without losing dimensional shape. 10 | 11 | Parameters 12 | ---------- 13 | input_tensor : torch.Tensor 14 | Input image/tensor. 15 | kernel : torch.Tensor 16 | Convolution kernel/window. 17 | 18 | Returns 19 | ------- 20 | torch.Tensor: 21 | The convolved tensor of same shape as the input. 22 | """ 23 | (_, channel, _, _) = input_tensor.size() 24 | 25 | # "SAME" padding to avoid losing height and width 26 | pad = [ 27 | kernel.size(2) // 2, 28 | kernel.size(2) // 2, 29 | kernel.size(3) // 2, 30 | kernel.size(3) // 2, 31 | ] 32 | pad_tensor = F.pad(input_tensor, pad, "replicate") 33 | 34 | out = F.conv2d(pad_tensor, kernel, groups=channel) 35 | return out 36 | 37 | 38 | def gaussian( 39 | window_size: int, sigma: float, device: torch.device = None 40 | ) -> torch.Tensor: 41 | """Create a gaussian 1D tensor. 42 | 43 | Parameters 44 | ---------- 45 | window_size : int 46 | Number of elements for the output tensor. 47 | sigma : float 48 | Std of the gaussian distribution. 49 | device : torch.device 50 | Device for the tensor. 51 | 52 | Returns 53 | ------- 54 | torch.Tensor: 55 | A gaussian 1D tensor. Shape: (window_size, ). 56 | """ 57 | x = torch.arange(window_size, device=device).float() - window_size // 2 58 | if window_size % 2 == 0: 59 | x = x + 0.5 60 | 61 | gauss = torch.exp((-x.pow(2.0) / float(2 * sigma**2))) 62 | 63 | return gauss / gauss.sum() 64 | 65 | 66 | def gaussian_kernel2d( 67 | window_size: int, sigma: float, n_channels: int = 1, device: torch.device = None 68 | ) -> torch.Tensor: 69 | """Create 2D window_size**2 sized kernel a gaussial kernel. 70 | 71 | Parameters 72 | ---------- 73 | window_size : int 74 | Number of rows and columns for the output tensor. 75 | sigma : float 76 | Std of the gaussian distribution. 77 | n_channel : int 78 | Number of channels in the image that will be convolved with 79 | this kernel. 80 | device : torch.device 81 | Device for the kernel. 82 | 83 | Returns: 84 | ----------- 85 | torch.Tensor: 86 | A tensor of shape (1, 1, window_size, window_size) 87 | """ 88 | kernel_x = gaussian(window_size, sigma, device=device) 89 | kernel_y = gaussian(window_size, sigma, device=device) 90 | 91 | kernel_2d = torch.matmul(kernel_x.unsqueeze(-1), kernel_y.unsqueeze(-1).t()) 92 | kernel_2d = kernel_2d.expand(n_channels, 1, window_size, window_size) 93 | 94 | return kernel_2d 95 | 96 | 97 | def sobel_hv(window_size: int = 5, device: torch.device = None): 98 | """Create a kernel that is used to compute 1st order derivatives. 99 | 100 | Parameters 101 | ---------- 102 | window_size : int 103 | Size of the convolution kernel. 104 | device : torch.device: 105 | Device for the kernel. 106 | 107 | Returns 108 | ------- 109 | torch.Tensor: 110 | the computed 1st order derivatives of the input tensor. 111 | Shape (B, 2, H, W) 112 | 113 | Raises 114 | ------ 115 | ValueError: 116 | If `window_size` is not an odd number. 117 | """ 118 | if not window_size % 2 == 1: 119 | raise ValueError(f"window_size must be odd. Got: {window_size}") 120 | 121 | # Generate the sobel kernels 122 | range_h = torch.arange( 123 | -window_size // 2 + 1, window_size // 2 + 1, dtype=torch.float32, device=device 124 | ) 125 | range_v = torch.arange( 126 | -window_size // 2 + 1, window_size // 2 + 1, dtype=torch.float32, device=device 127 | ) 128 | h, v = torch.meshgrid(range_h, range_v) 129 | 130 | kernel_h = h / (h * h + v * v + 1e-6) 131 | kernel_h = kernel_h.unsqueeze(0).unsqueeze(0) 132 | 133 | kernel_v = v / (h * h + v * v + 1e-6) 134 | kernel_v = kernel_v.unsqueeze(0).unsqueeze(0) 135 | 136 | return torch.cat([kernel_h, kernel_v], dim=0) 137 | -------------------------------------------------------------------------------- /base_ml/base_validator.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Validators 3 | 4 | 5 | from schema import Schema, Or 6 | 7 | sweep_schema = Schema( 8 | { 9 | "method": Or("grid", "random", "bayes"), 10 | "name": str, 11 | "metric": {"name": str, "goal": Or("maximize", "minimize")}, 12 | "run_cap": int, 13 | }, 14 | ignore_extra_keys=True, 15 | ) 16 | -------------------------------------------------------------------------------- /base_ml/optim_factory.py: -------------------------------------------------------------------------------- 1 | # Based on RepLKNet, ConvNeXt, timm, DINO and DeiT code bases 2 | # https://github.com/DingXiaoH/RepLKNet-pytorch 3 | # https://github.com/facebookresearch/ConvNeXt 4 | # https://github.com/rwightman/pytorch-image-models/tree/master/timm 5 | # https://github.com/facebookresearch/deit/ 6 | # https://github.com/facebookresearch/dino 7 | # --------------------------------------------------------' 8 | import torch 9 | from torch import optim as optim 10 | 11 | from timm.optim.adafactor import Adafactor 12 | from timm.optim.adahessian import Adahessian 13 | from timm.optim.adamp import AdamP 14 | from timm.optim.lookahead import Lookahead 15 | from timm.optim.nadam import Nadam 16 | from timm.optim.radam import RAdam 17 | from timm.optim.rmsprop_tf import RMSpropTF 18 | from timm.optim.sgdp import SGDP 19 | 20 | import json 21 | 22 | try: 23 | from apex.optimizers import FusedNovoGrad, FusedAdam, FusedLAMB, FusedSGD 24 | has_apex = True 25 | except ImportError: 26 | has_apex = False 27 | 28 | 29 | def get_num_layer_for_convnext(var_name): 30 | """ 31 | Divide [3, 3, 27, 3] layers into 12 groups; each group is three 32 | consecutive blocks, including possible neighboring downsample layers; 33 | adapted from https://github.com/microsoft/unilm/blob/master/beit/optim_factory.py 34 | """ 35 | num_max_layer = 12 36 | if var_name.startswith("downsample_layers"): 37 | stage_id = int(var_name.split('.')[1]) 38 | if stage_id == 0: 39 | layer_id = 0 40 | elif stage_id == 1 or stage_id == 2: 41 | layer_id = stage_id + 1 42 | elif stage_id == 3: 43 | layer_id = 12 44 | return layer_id 45 | 46 | elif var_name.startswith("stages"): 47 | stage_id = int(var_name.split('.')[1]) 48 | block_id = int(var_name.split('.')[2]) 49 | if stage_id == 0 or stage_id == 1: 50 | layer_id = stage_id + 1 51 | elif stage_id == 2: 52 | layer_id = 3 + block_id // 3 53 | elif stage_id == 3: 54 | layer_id = 12 55 | return layer_id 56 | else: 57 | return num_max_layer + 1 58 | 59 | class LayerDecayValueAssigner(object): 60 | def __init__(self, values): 61 | self.values = values 62 | 63 | def get_scale(self, layer_id): 64 | return self.values[layer_id] 65 | 66 | def get_layer_id(self, var_name): 67 | return get_num_layer_for_convnext(var_name) 68 | 69 | 70 | def get_parameter_groups(model, weight_decay=1e-5, skip_list=(), get_num_layer=None, get_layer_scale=None): 71 | parameter_group_names = {} 72 | parameter_group_vars = {} 73 | 74 | for name, param in model.named_parameters(): 75 | if not param.requires_grad: 76 | continue # frozen weights 77 | if len(param.shape) == 1 or name.endswith(".bias") or name in skip_list: 78 | group_name = "no_decay" 79 | this_weight_decay = 0. 80 | else: 81 | group_name = "decay" 82 | this_weight_decay = weight_decay 83 | if get_num_layer is not None: 84 | layer_id = get_num_layer(name) 85 | group_name = "layer_%d_%s" % (layer_id, group_name) 86 | else: 87 | layer_id = None 88 | 89 | if group_name not in parameter_group_names: 90 | if get_layer_scale is not None: 91 | scale = get_layer_scale(layer_id) 92 | else: 93 | scale = 1. 94 | 95 | parameter_group_names[group_name] = { 96 | "weight_decay": this_weight_decay, 97 | "params": [], 98 | "lr_scale": scale 99 | } 100 | parameter_group_vars[group_name] = { 101 | "weight_decay": this_weight_decay, 102 | "params": [], 103 | "lr_scale": scale 104 | } 105 | 106 | parameter_group_vars[group_name]["params"].append(param) 107 | parameter_group_names[group_name]["params"].append(name) 108 | print("Param groups = %s" % json.dumps(parameter_group_names, indent=2)) 109 | return list(parameter_group_vars.values()) 110 | 111 | 112 | def create_optimizer(model, weight_decay, lr, opt, get_num_layer=None, opt_eps=None, opt_betas=None, get_layer_scale=None, filter_bias_and_bn=True, skip_list=None, momentum = 0.9): 113 | opt_lower = opt.lower() 114 | weight_decay = weight_decay 115 | # if weight_decay and filter_bias_and_bn: 116 | if filter_bias_and_bn: 117 | skip = {} 118 | if skip_list is not None: 119 | skip = skip_list 120 | elif hasattr(model, 'no_weight_decay'): 121 | skip = model.no_weight_decay() 122 | parameters = get_parameter_groups(model, weight_decay, skip, get_num_layer, get_layer_scale) 123 | weight_decay = 0. 124 | else: 125 | parameters = model.parameters() 126 | 127 | if 'fused' in opt_lower: 128 | assert has_apex and torch.cuda.is_available(), 'APEX and CUDA required for fused optimizers' 129 | 130 | opt_args = dict(lr=lr, weight_decay=weight_decay) 131 | if opt_eps is not None: 132 | opt_args['eps'] = opt_eps 133 | if opt_betas is not None: 134 | opt_args['betas'] = opt_betas 135 | 136 | opt_split = opt_lower.split('_') 137 | opt_lower = opt_split[-1] 138 | if opt_lower == 'sgd' or opt_lower == 'nesterov': 139 | opt_args.pop('eps', None) 140 | optimizer = optim.SGD(parameters, momentum=momentum, nesterov=True, **opt_args) 141 | elif opt_lower == 'momentum': 142 | opt_args.pop('eps', None) 143 | optimizer = optim.SGD(parameters, momentum=momentum, nesterov=False, **opt_args) 144 | elif opt_lower == 'adam': 145 | optimizer = optim.Adam(parameters, **opt_args) 146 | if opt_lower == 'adamw': 147 | optimizer = optim.AdamW(parameters, **opt_args) 148 | elif opt_lower == 'nadam': 149 | optimizer = Nadam(parameters, **opt_args) 150 | elif opt_lower == 'radam': 151 | optimizer = RAdam(parameters, **opt_args) 152 | elif opt_lower == 'adamp': 153 | optimizer = AdamP(parameters, wd_ratio=0.01, nesterov=True, **opt_args) 154 | elif opt_lower == 'sgdp': 155 | optimizer = SGDP(parameters, momentum=momentum, nesterov=True, **opt_args) 156 | elif opt_lower == 'adadelta': 157 | optimizer = optim.Adadelta(parameters, **opt_args) 158 | 159 | elif opt_lower == 'adahessian': 160 | optimizer = Adahessian(parameters, **opt_args) 161 | elif opt_lower == 'rmsprop': 162 | optimizer = optim.RMSprop(parameters, alpha=0.9, momentum=momentum, **opt_args) 163 | elif opt_lower == 'rmsproptf': 164 | optimizer = RMSpropTF(parameters, alpha=0.9, momentum=momentum, **opt_args) 165 | elif opt_lower == 'fusedsgd': 166 | opt_args.pop('eps', None) 167 | optimizer = FusedSGD(parameters, momentum=momentum, nesterov=True, **opt_args) 168 | elif opt_lower == 'fusedmomentum': 169 | opt_args.pop('eps', None) 170 | optimizer = FusedSGD(parameters, momentum=momentum, nesterov=False, **opt_args) 171 | elif opt_lower == 'fusedadam': 172 | optimizer = FusedAdam(parameters, adam_w_mode=False, **opt_args) 173 | elif opt_lower == 'fusedadamw': 174 | optimizer = FusedAdam(parameters, adam_w_mode=True, **opt_args) 175 | elif opt_lower == 'fusedlamb': 176 | optimizer = FusedLAMB(parameters, **opt_args) 177 | elif opt_lower == 'fusednovograd': 178 | opt_args.setdefault('betas', (0.95, 0.98)) 179 | optimizer = FusedNovoGrad(parameters, **opt_args) 180 | else: 181 | assert False and "Invalid optimizer" 182 | 183 | if len(opt_split) > 1: 184 | if opt_split[0] == 'lookahead': 185 | optimizer = Lookahead(optimizer) 186 | 187 | return optimizer 188 | -------------------------------------------------------------------------------- /base_ml/unireplknet_layer_decay_optimizer_constructor.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # UniRepLKNet 3 | # https://github.com/AILab-CVC/UniRepLKNet 4 | # Licensed under The Apache 2.0 License [see LICENSE for details] 5 | # -------------------------------------------------------- 6 | import json 7 | from mmcv.runner import OPTIMIZER_BUILDERS, DefaultOptimizerConstructor 8 | from mmcv.runner import get_dist_info 9 | from mmdet.utils import get_root_logger 10 | 11 | def get_layer_id(var_name, max_layer_id,): 12 | """Get the layer id to set the different learning rates in ``layer_wise`` 13 | decay_type. 14 | 15 | Args: 16 | var_name (str): The key of the model. 17 | max_layer_id (int): Maximum layer id. 18 | 19 | Returns: 20 | int: The id number corresponding to different learning rate in 21 | ``LearningRateDecayOptimizerConstructor``. 22 | """ 23 | 24 | if var_name in ('backbone.cls_token', 'backbone.mask_token', 25 | 'backbone.pos_embed'): 26 | return 0 27 | 28 | elif var_name.startswith('backbone.downsample_layers'): 29 | stage_id = int(var_name.split('.')[2]) 30 | if stage_id == 0: 31 | layer_id = 0 32 | elif stage_id == 1: 33 | layer_id = 2 34 | elif stage_id == 2: 35 | layer_id = 3 36 | elif stage_id == 3: 37 | layer_id = max_layer_id 38 | return layer_id 39 | 40 | elif var_name.startswith('backbone.stages'): 41 | stage_id = int(var_name.split('.')[2]) 42 | block_id = int(var_name.split('.')[3]) 43 | if stage_id == 0: 44 | layer_id = 1 45 | elif stage_id == 1: 46 | layer_id = 2 47 | elif stage_id == 2: 48 | layer_id = 3 + block_id // 3 49 | elif stage_id == 3: 50 | layer_id = max_layer_id 51 | return layer_id 52 | 53 | else: 54 | return max_layer_id + 1 55 | 56 | 57 | 58 | def get_stage_id(var_name, max_stage_id): 59 | """Get the stage id to set the different learning rates in ``stage_wise`` 60 | decay_type. 61 | 62 | Args: 63 | var_name (str): The key of the model. 64 | max_stage_id (int): Maximum stage id. 65 | 66 | Returns: 67 | int: The id number corresponding to different learning rate in 68 | ``LearningRateDecayOptimizerConstructor``. 69 | """ 70 | 71 | if var_name in ('backbone.cls_token', 'backbone.mask_token', 72 | 'backbone.pos_embed'): 73 | return 0 74 | elif var_name.startswith('backbone.downsample_layers'): 75 | return 0 76 | elif var_name.startswith('backbone.stages'): 77 | stage_id = int(var_name.split('.')[2]) 78 | return stage_id + 1 79 | else: 80 | return max_stage_id - 1 81 | 82 | 83 | @OPTIMIZER_BUILDERS.register_module() 84 | class UniRepLKNetLearningRateDecayOptimizerConstructor(DefaultOptimizerConstructor): 85 | # Different learning rates are set for different layers of backbone. 86 | # The design is inspired by and adapted from ConvNeXt. 87 | 88 | def add_params(self, params, module, **kwargs): 89 | """Add all parameters of module to the params list. 90 | 91 | The parameters of the given module will be added to the list of param 92 | groups, with specific rules defined by paramwise_cfg. 93 | 94 | Args: 95 | params (list[dict]): A list of param groups, it will be modified 96 | in place. 97 | module (nn.Module): The module to be added. 98 | """ 99 | logger = get_root_logger() 100 | 101 | parameter_groups = {} 102 | logger.info(f'self.paramwise_cfg is {self.paramwise_cfg}') 103 | num_layers = self.paramwise_cfg.get('num_layers') + 2 104 | decay_rate = self.paramwise_cfg.get('decay_rate') 105 | decay_type = self.paramwise_cfg.get('decay_type', 'layer_wise') 106 | dw_scale = self.paramwise_cfg.get('dw_scale', 1) 107 | logger.info('Build UniRepLKNetLearningRateDecayOptimizerConstructor ' 108 | f'{decay_type} {decay_rate} - {num_layers}') 109 | weight_decay = self.base_wd 110 | for name, param in module.named_parameters(): 111 | if not param.requires_grad: 112 | continue # frozen weights 113 | if len(param.shape) == 1 or name.endswith('.bias') or name in ( 114 | 'pos_embed', 'cls_token'): 115 | group_name = 'no_decay' 116 | this_weight_decay = 0. 117 | else: 118 | group_name = 'decay' 119 | this_weight_decay = weight_decay 120 | if 'layer_wise' in decay_type: 121 | layer_id = get_layer_id(name, self.paramwise_cfg.get('num_layers')) 122 | logger.info(f'set param {name} as id {layer_id}') 123 | elif decay_type == 'stage_wise': 124 | layer_id = get_stage_id(name, num_layers) 125 | logger.info(f'set param {name} as id {layer_id}') 126 | 127 | if dw_scale == 1 or 'dwconv' not in name: 128 | group_name = f'layer_{layer_id}_{group_name}' 129 | if group_name not in parameter_groups: 130 | scale = decay_rate ** (num_layers - layer_id - 1) 131 | parameter_groups[group_name] = { 132 | 'weight_decay': this_weight_decay, 133 | 'params': [], 134 | 'param_names': [], 135 | 'lr_scale': scale, 136 | 'group_name': group_name, 137 | 'lr': scale * self.base_lr, 138 | } 139 | 140 | parameter_groups[group_name]['params'].append(param) 141 | parameter_groups[group_name]['param_names'].append(name) 142 | else: 143 | group_name = f'layer_{layer_id}_{group_name}_dwconv' 144 | if group_name not in parameter_groups: 145 | scale = decay_rate ** (num_layers - layer_id - 1) * dw_scale 146 | parameter_groups[group_name] = { 147 | 'weight_decay': this_weight_decay, 148 | 'params': [], 149 | 'param_names': [], 150 | 'lr_scale': scale, 151 | 'group_name': group_name, 152 | 'lr': scale * self.base_lr, 153 | } 154 | 155 | parameter_groups[group_name]['params'].append(param) 156 | parameter_groups[group_name]['param_names'].append(name) 157 | 158 | rank, _ = get_dist_info() 159 | if rank == 0: 160 | to_display = {} 161 | for key in parameter_groups: 162 | to_display[key] = { 163 | 'param_names': parameter_groups[key]['param_names'], 164 | 'lr_scale': parameter_groups[key]['lr_scale'], 165 | 'lr': parameter_groups[key]['lr'], 166 | 'weight_decay': parameter_groups[key]['weight_decay'], 167 | } 168 | logger.info(f'Param groups = {json.dumps(to_display, indent=2)}') 169 | params.extend(parameter_groups.values()) 170 | -------------------------------------------------------------------------------- /cell_segmentation/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Cell Segmentation and detection using our cellvit model 3 | # 4 | # @ Fabian Hörst, fabian.hoerst@uk-essen.de 5 | # Institute for Artifical Intelligence in Medicine, 6 | # University Medicine Essen 7 | -------------------------------------------------------------------------------- /cell_segmentation/datasets/base_cell.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Base cell segmentation dataset, based on torch Dataset implementation 3 | # 4 | # @ Fabian Hörst, fabian.hoerst@uk-essen.de 5 | # Institute for Artifical Intelligence in Medicine, 6 | # University Medicine Essen 7 | 8 | import logging 9 | from typing import Callable 10 | 11 | import torch 12 | from torch.utils.data import Dataset 13 | 14 | logger = logging.getLogger() 15 | logger.addHandler(logging.NullHandler()) 16 | 17 | from abc import abstractmethod 18 | 19 | 20 | class CellDataset(Dataset): 21 | def set_transforms(self, transforms: Callable) -> None: 22 | self.transforms = transforms 23 | 24 | @abstractmethod 25 | def load_cell_count(self): 26 | """Load Cell count from cell_count.csv file. File must be located inside the fold folder 27 | 28 | Example file beginning: 29 | Image,Neoplastic,Inflammatory,Connective,Dead,Epithelial 30 | 0_0.png,4,2,2,0,0 31 | 0_1.png,8,1,1,0,0 32 | 0_10.png,17,0,1,0,0 33 | 0_100.png,10,0,11,0,0 34 | ... 35 | """ 36 | pass 37 | 38 | @abstractmethod 39 | def get_sampling_weights_tissue(self, gamma: float = 1) -> torch.Tensor: 40 | """Get sampling weights calculated by tissue type statistics 41 | 42 | For this, a file named "weight_config.yaml" with the content: 43 | tissue: 44 | tissue_1: xxx 45 | tissue_2: xxx (name of tissue: count) 46 | ... 47 | Must exists in the dataset main folder (parent path, not inside the folds) 48 | 49 | Args: 50 | gamma (float, optional): Gamma scaling factor, between 0 and 1. 51 | 1 means total balancing, 0 means original weights. Defaults to 1. 52 | 53 | Returns: 54 | torch.Tensor: Weights for each sample 55 | """ 56 | 57 | @abstractmethod 58 | def get_sampling_weights_cell(self, gamma: float = 1) -> torch.Tensor: 59 | """Get sampling weights calculated by cell type statistics 60 | 61 | Args: 62 | gamma (float, optional): Gamma scaling factor, between 0 and 1. 63 | 1 means total balancing, 0 means original weights. Defaults to 1. 64 | 65 | Returns: 66 | torch.Tensor: Weights for each sample 67 | """ 68 | 69 | def get_sampling_weights_cell_tissue(self, gamma: float = 1) -> torch.Tensor: 70 | """Get combined sampling weights by calculating tissue and cell sampling weights, 71 | normalizing them and adding them up to yield one score. 72 | 73 | Args: 74 | gamma (float, optional): Gamma scaling factor, between 0 and 1. 75 | 1 means total balancing, 0 means original weights. Defaults to 1. 76 | 77 | Returns: 78 | torch.Tensor: Weights for each sample 79 | """ 80 | assert 0 <= gamma <= 1, "Gamma must be between 0 and 1" 81 | tw = self.get_sampling_weights_tissue(gamma) 82 | cw = self.get_sampling_weights_cell(gamma) 83 | weights = tw / torch.max(tw) + cw / torch.max(cw) 84 | 85 | return weights 86 | -------------------------------------------------------------------------------- /cell_segmentation/datasets/cell_graph_datamodel.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Graph Data model 3 | # 4 | # For more information, please check out docs/readmes/graphs.md 5 | # 6 | # @ Fabian Hörst, fabian.hoerst@uk-essen.de 7 | # Institute for Artifical Intelligence in Medicine, 8 | # University Medicine Essen 9 | 10 | from dataclasses import dataclass 11 | from typing import List 12 | 13 | import torch 14 | 15 | from datamodel.graph_datamodel import GraphDataWSI 16 | 17 | 18 | @dataclass 19 | class CellGraphDataWSI(GraphDataWSI): 20 | """Dataclass for Graph Data 21 | 22 | Args: 23 | contours (List[torch.Tensor]): Contour Data for each object. 24 | """ 25 | 26 | contours: List[torch.Tensor] 27 | -------------------------------------------------------------------------------- /cell_segmentation/datasets/conic.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # PanNuke Dataset 3 | # 4 | # Dataset information: https://arxiv.org/abs/2108.11195 5 | # Please Prepare Dataset as described here: docs/readmes/pannuke.md # TODO: write own documentation 6 | # 7 | # @ Fabian Hörst, fabian.hoerst@uk-essen.de 8 | # Institute for Artifical Intelligence in Medicine, 9 | # University Medicine Essen 10 | 11 | 12 | import logging 13 | from pathlib import Path 14 | from typing import Callable, Tuple, Union, List 15 | 16 | import numpy as np 17 | import pandas as pd 18 | import torch 19 | from PIL import Image 20 | 21 | from cell_segmentation.datasets.base_cell import CellDataset 22 | from cell_segmentation.datasets.pannuke import PanNukeDataset 23 | 24 | logger = logging.getLogger() 25 | logger.addHandler(logging.NullHandler()) 26 | 27 | 28 | class CoNicDataset(CellDataset): 29 | """Lizzard dataset 30 | 31 | This dataset is always cached 32 | 33 | Args: 34 | dataset_path (Union[Path, str]): Path to Lizzard dataset. Structure is described under ./docs/readmes/cell_segmentation.md 35 | folds (Union[int, list[int]]): Folds to use for this dataset 36 | transforms (Callable, optional): PyTorch transformations. Defaults to None. 37 | stardist (bool, optional): Return StarDist labels. Defaults to False 38 | regression (bool, optional): Return Regression of cells in x and y direction. Defaults to False 39 | **kwargs are irgnored 40 | """ 41 | 42 | def __init__( 43 | self, 44 | dataset_path: Union[Path, str], 45 | folds: Union[int, List[int]], 46 | transforms: Callable = None, 47 | stardist: bool = False, 48 | regression: bool = False, 49 | **kwargs, 50 | ) -> None: 51 | if isinstance(folds, int): 52 | folds = [folds] 53 | 54 | self.dataset = Path(dataset_path).resolve() 55 | self.transforms = transforms 56 | self.images = [] 57 | self.masks = [] 58 | self.img_names = [] 59 | self.folds = folds 60 | self.stardist = stardist 61 | self.regression = regression 62 | for fold in folds: 63 | image_path = self.dataset / f"fold{fold}" / "images" 64 | fold_images = [f for f in sorted(image_path.glob("*.png")) if f.is_file()] 65 | 66 | # sanity_check: mask must exist for image 67 | for fold_image in fold_images: 68 | mask_path = ( 69 | self.dataset / f"fold{fold}" / "labels" / f"{fold_image.stem}.npy" 70 | ) 71 | if mask_path.is_file(): 72 | self.images.append(fold_image) 73 | self.masks.append(mask_path) 74 | self.img_names.append(fold_image.name) 75 | 76 | else: 77 | logger.debug( 78 | "Found image {fold_image}, but no corresponding annotation file!" 79 | ) 80 | 81 | # load everything in advance to speedup, as the dataset is rather small 82 | self.loaded_imgs = [] 83 | self.loaded_masks = [] 84 | for idx in range(len(self.images)): 85 | img_path = self.images[idx] 86 | img = np.array(Image.open(img_path)).astype(np.uint8) 87 | 88 | mask_path = self.masks[idx] 89 | mask = np.load(mask_path, allow_pickle=True) 90 | inst_map = mask[()]["inst_map"].astype(np.int32) 91 | type_map = mask[()]["type_map"].astype(np.int32) 92 | mask = np.stack([inst_map, type_map], axis=-1) 93 | self.loaded_imgs.append(img) 94 | self.loaded_masks.append(mask) 95 | 96 | logger.info(f"Created Pannuke Dataset by using fold(s) {self.folds}") 97 | logger.info(f"Resulting dataset length: {self.__len__()}") 98 | 99 | def __getitem__(self, index: int) -> Tuple[torch.Tensor, dict, str, str]: 100 | """Get one dataset item consisting of transformed image, 101 | masks (instance_map, nuclei_type_map, nuclei_binary_map, hv_map) and tissue type as string 102 | 103 | Args: 104 | index (int): Index of element to retrieve 105 | 106 | Returns: 107 | Tuple[torch.Tensor, dict, str, str]: 108 | torch.Tensor: Image, with shape (3, H, W), shape is arbitrary for Lizzard (H and W approx. between 500 and 2000) 109 | dict: 110 | "instance_map": Instance-Map, each instance is has one integer starting by 1 (zero is background), Shape (256, 256) 111 | "nuclei_type_map": Nuclei-Type-Map, for each nucleus (instance) the class is indicated by an integer. Shape (256, 256) 112 | "nuclei_binary_map": Binary Nuclei-Mask, Shape (256, 256) 113 | "hv_map": Horizontal and vertical instance map. 114 | Shape: (H, W, 2). First dimension is horizontal (horizontal gradient (-1 to 1)), 115 | last is vertical (vertical gradient (-1 to 1)) Shape (256, 256, 2) 116 | "dist_map": Probability distance map. Shape (256, 256) 117 | "stardist_map": Stardist vector map. Shape (n_rays, 256, 256) 118 | [Optional if regression] 119 | "regression_map": Regression map. Shape (2, 256, 256). First is vertical, second horizontal. 120 | str: Tissue type 121 | str: Image Name 122 | """ 123 | img_path = self.images[index] 124 | img = self.loaded_imgs[index] 125 | mask = self.loaded_masks[index] 126 | 127 | if self.transforms is not None: 128 | transformed = self.transforms(image=img, mask=mask) 129 | img = transformed["image"] 130 | mask = transformed["mask"] 131 | 132 | inst_map = mask[:, :, 0].copy() 133 | type_map = mask[:, :, 1].copy() 134 | np_map = mask[:, :, 0].copy() 135 | np_map[np_map > 0] = 1 136 | hv_map = PanNukeDataset.gen_instance_hv_map(inst_map) 137 | 138 | # torch convert 139 | img = torch.Tensor(img).type(torch.float32) 140 | img = img.permute(2, 0, 1) 141 | if torch.max(img) >= 5: 142 | img = img / 255 143 | 144 | masks = { 145 | "instance_map": torch.Tensor(inst_map).type(torch.int64), 146 | "nuclei_type_map": torch.Tensor(type_map).type(torch.int64), 147 | "nuclei_binary_map": torch.Tensor(np_map).type(torch.int64), 148 | "hv_map": torch.Tensor(hv_map).type(torch.float32), 149 | } 150 | if self.stardist: 151 | dist_map = PanNukeDataset.gen_distance_prob_maps(inst_map) 152 | stardist_map = PanNukeDataset.gen_stardist_maps(inst_map) 153 | masks["dist_map"] = torch.Tensor(dist_map).type(torch.float32) 154 | masks["stardist_map"] = torch.Tensor(stardist_map).type(torch.float32) 155 | if self.regression: 156 | masks["regression_map"] = PanNukeDataset.gen_regression_map(inst_map) 157 | 158 | return img, masks, "Colon", Path(img_path).name 159 | 160 | def __len__(self) -> int: 161 | """Length of Dataset 162 | 163 | Returns: 164 | int: Length of Dataset 165 | """ 166 | return len(self.images) 167 | 168 | def set_transforms(self, transforms: Callable) -> None: 169 | """Set the transformations, can be used tp exchange transformations 170 | 171 | Args: 172 | transforms (Callable): PyTorch transformations 173 | """ 174 | self.transforms = transforms 175 | 176 | def load_cell_count(self): 177 | """Load Cell count from cell_count.csv file. File must be located inside the fold folder 178 | and named "cell_count.csv" 179 | 180 | Example file beginning: 181 | Image,Neutrophil,Epithelial,Lymphocyte,Plasma,Eosinophil,Connective 182 | consep_1_0000.png,0,117,0,0,0,0 183 | consep_1_0001.png,0,95,1,0,0,8 184 | consep_1_0002.png,0,172,3,0,0,2 185 | ... 186 | """ 187 | df_placeholder = [] 188 | for fold in self.folds: 189 | csv_path = self.dataset / f"fold{fold}" / "cell_count.csv" 190 | cell_count = pd.read_csv(csv_path, index_col=0) 191 | df_placeholder.append(cell_count) 192 | self.cell_count = pd.concat(df_placeholder) 193 | self.cell_count = self.cell_count.reindex(self.img_names) 194 | 195 | def get_sampling_weights_cell(self, gamma: float = 1) -> torch.Tensor: 196 | """Get sampling weights calculated by cell type statistics 197 | 198 | Args: 199 | gamma (float, optional): Gamma scaling factor, between 0 and 1. 200 | 1 means total balancing, 0 means original weights. Defaults to 1. 201 | 202 | Returns: 203 | torch.Tensor: Weights for each sample 204 | """ 205 | assert 0 <= gamma <= 1, "Gamma must be between 0 and 1" 206 | assert hasattr(self, "cell_count"), "Please run .load_cell_count() in advance!" 207 | binary_weight_factors = np.array([1069, 4189, 4356, 3103, 1025, 4527]) 208 | k = np.sum(binary_weight_factors) 209 | cell_counts_imgs = np.clip(self.cell_count.to_numpy(), 0, 1) 210 | weight_vector = k / (gamma * binary_weight_factors + (1 - gamma) * k) 211 | img_weight = (1 - gamma) * np.max(cell_counts_imgs, axis=-1) + gamma * np.sum( 212 | cell_counts_imgs * weight_vector, axis=-1 213 | ) 214 | img_weight[np.where(img_weight == 0)] = np.min( 215 | img_weight[np.nonzero(img_weight)] 216 | ) 217 | 218 | return torch.Tensor(img_weight) 219 | 220 | # def get_sampling_weights_cell(self, gamma: float = 1) -> torch.Tensor: 221 | # """Get sampling weights calculated by cell type statistics 222 | 223 | # Args: 224 | # gamma (float, optional): Gamma scaling factor, between 0 and 1. 225 | # 1 means total balancing, 0 means original weights. Defaults to 1. 226 | 227 | # Returns: 228 | # torch.Tensor: Weights for each sample 229 | # """ 230 | # assert 0 <= gamma <= 1, "Gamma must be between 0 and 1" 231 | # assert hasattr(self, "cell_count"), "Please run .load_cell_count() in advance!" 232 | # binary_weight_factors = np.array([4012, 222017, 93612, 24793, 2999, 98783]) 233 | # k = np.sum(binary_weight_factors) 234 | # cell_counts_imgs = self.cell_count.to_numpy() 235 | # weight_vector = k / (gamma * binary_weight_factors + (1 - gamma) * k) 236 | # img_weight = (1 - gamma) * np.max(cell_counts_imgs, axis=-1) + gamma * np.sum( 237 | # cell_counts_imgs * weight_vector, axis=-1 238 | # ) 239 | # img_weight[np.where(img_weight == 0)] = np.min( 240 | # img_weight[np.nonzero(img_weight)] 241 | # ) 242 | 243 | # return torch.Tensor(img_weight) 244 | -------------------------------------------------------------------------------- /cell_segmentation/datasets/consep.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # MoNuSeg Dataset 3 | # 4 | # Dataset information: https://monuseg.grand-challenge.org/Home/ 5 | # Please Prepare Dataset as described here: docs/readmes/monuseg.md 6 | # 7 | # @ Fabian Hörst, fabian.hoerst@uk-essen.de 8 | # Institute for Artifical Intelligence in Medicine, 9 | # University Medicine Essen 10 | 11 | import logging 12 | from pathlib import Path 13 | from typing import Callable, Union, Tuple 14 | 15 | import numpy as np 16 | import torch 17 | from PIL import Image 18 | from torch.utils.data import Dataset 19 | 20 | from cell_segmentation.datasets.pannuke import PanNukeDataset 21 | 22 | logger = logging.getLogger() 23 | logger.addHandler(logging.NullHandler()) 24 | 25 | 26 | class CoNSePDataset(Dataset): 27 | def __init__( 28 | self, 29 | dataset_path: Union[Path, str], 30 | transforms: Callable = None, 31 | ) -> None: 32 | """MoNuSeg Dataset 33 | 34 | Args: 35 | dataset_path (Union[Path, str]): Path to dataset 36 | transforms (Callable, optional): Transformations to apply on images. Defaults to None. 37 | Raises: 38 | FileNotFoundError: If no ground-truth annotation file was found in path 39 | """ 40 | self.dataset = Path(dataset_path).resolve() 41 | self.transforms = transforms 42 | self.masks = [] 43 | self.img_names = [] 44 | 45 | image_path = self.dataset / "images" 46 | label_path = self.dataset / "labels" 47 | self.images = [f for f in sorted(image_path.glob("*.png")) if f.is_file()] 48 | self.masks = [f for f in sorted(label_path.glob("*.npy")) if f.is_file()] 49 | 50 | # sanity_check 51 | for idx, image in enumerate(self.images): 52 | image_name = image.stem 53 | mask_name = self.masks[idx].stem 54 | if image_name != mask_name: 55 | raise FileNotFoundError(f"Annotation for file {image_name} is missing") 56 | 57 | def __getitem__(self, index: int) -> Tuple[torch.Tensor, dict, str]: 58 | """Get one item from dataset 59 | 60 | Args: 61 | index (int): Item to get 62 | 63 | Returns: 64 | Tuple[torch.Tensor, dict, str]: Trainings-Batch 65 | * torch.Tensor: Image 66 | * dict: Ground-Truth values: keys are "instance map", "nuclei_binary_map" and "hv_map" 67 | * str: filename 68 | """ 69 | img_path = self.images[index] 70 | img = np.array(Image.open(img_path)).astype(np.uint8) 71 | 72 | mask_path = self.masks[index] 73 | mask = np.load(mask_path, allow_pickle=True) 74 | inst_map = mask[()]["inst_map"].astype(np.int32) 75 | type_map = mask[()]["type_map"].astype(np.int32) 76 | mask = np.stack([inst_map, type_map], axis=-1) 77 | 78 | if self.transforms is not None: 79 | transformed = self.transforms(image=img, mask=mask) 80 | img = transformed["image"] 81 | mask = transformed["mask"] 82 | 83 | inst_map = mask[:, :, 0].copy() 84 | type_map = mask[:, :, 1].copy() 85 | np_map = mask[:, :, 0].copy() 86 | np_map[np_map > 0] = 1 87 | hv_map = PanNukeDataset.gen_instance_hv_map(inst_map) 88 | 89 | # torch convert 90 | img = torch.Tensor(img).type(torch.float32) 91 | img = img.permute(2, 0, 1) 92 | if torch.max(img) >= 5: 93 | img = img / 255 94 | 95 | masks = { 96 | "instance_map": torch.Tensor(inst_map).type(torch.int64), 97 | "nuclei_type_map": torch.Tensor(type_map).type(torch.int64), 98 | "nuclei_binary_map": torch.Tensor(np_map).type(torch.int64), 99 | "hv_map": torch.Tensor(hv_map).type(torch.float32), 100 | } 101 | 102 | return img, masks, Path(img_path).name 103 | 104 | def __len__(self) -> int: 105 | """Length of Dataset 106 | 107 | Returns: 108 | int: Length of Dataset 109 | """ 110 | return len(self.images) 111 | 112 | def set_transforms(self, transforms: Callable) -> None: 113 | """Set the transformations, can be used tp exchange transformations 114 | 115 | Args: 116 | transforms (Callable): PyTorch transformations 117 | """ 118 | self.transforms = transforms 119 | -------------------------------------------------------------------------------- /cell_segmentation/datasets/dataset_coordinator.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Coordinate the datasets, used to select the right dataset with corresponding setting 3 | # 4 | # @ Fabian Hörst, fabian.hoerst@uk-essen.de 5 | # Institute for Artifical Intelligence in Medicine, 6 | # University Medicine Essen 7 | 8 | from typing import Callable 9 | 10 | from torch.utils.data import Dataset 11 | from cell_segmentation.datasets.conic import CoNicDataset 12 | 13 | from cell_segmentation.datasets.pannuke import PanNukeDataset 14 | 15 | 16 | def select_dataset( 17 | dataset_name: str, split: str, dataset_config: dict, transforms: Callable = None 18 | ) -> Dataset: 19 | """Select a cell segmentation dataset from the provided ones, currently just PanNuke is implemented here 20 | 21 | Args: 22 | dataset_name (str): Name of dataset to use. 23 | Must be one of: [pannuke, lizzard] 24 | split (str): Split to use. 25 | Must be one of: ["train", "val", "validation", "test"] 26 | dataset_config (dict): Dictionary with dataset configuration settings 27 | transforms (Callable, optional): PyTorch Image and Mask transformations. Defaults to None. 28 | 29 | Raises: 30 | NotImplementedError: Unknown dataset 31 | 32 | Returns: 33 | Dataset: Cell segmentation dataset 34 | """ 35 | assert split.lower() in [ 36 | "train", 37 | "val", 38 | "validation", 39 | "test", 40 | ], "Unknown split type!" 41 | 42 | if dataset_name.lower() == "pannuke": 43 | if split == "train": 44 | folds = dataset_config["train_folds"] 45 | if split == "val" or split == "validation": 46 | folds = dataset_config["val_folds"] 47 | if split == "test": 48 | folds = dataset_config["test_folds"] 49 | dataset = PanNukeDataset( 50 | dataset_path=dataset_config["dataset_path"], 51 | folds=folds, 52 | transforms=transforms, 53 | stardist=dataset_config.get("stardist", False), 54 | regression=dataset_config.get("regression_loss", False), 55 | ) 56 | elif dataset_name.lower() == "conic": 57 | if split == "train": 58 | folds = dataset_config["train_folds"] 59 | if split == "val" or split == "validation": 60 | folds = dataset_config["val_folds"] 61 | if split == "test": 62 | folds = dataset_config["test_folds"] 63 | dataset = CoNicDataset( 64 | dataset_path=dataset_config["dataset_path"], 65 | folds=folds, 66 | transforms=transforms, 67 | stardist=dataset_config.get("stardist", False), 68 | regression=dataset_config.get("regression_loss", False), 69 | # TODO: Stardist and regression loss 70 | ) 71 | else: 72 | raise NotImplementedError(f"Unknown dataset: {dataset_name}") 73 | return dataset 74 | -------------------------------------------------------------------------------- /cell_segmentation/datasets/monuseg.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # MoNuSeg Dataset 3 | # 4 | # Dataset information: https://monuseg.grand-challenge.org/Home/ 5 | # Please Prepare Dataset as described here: docs/readmes/monuseg.md 6 | # 7 | # @ Fabian Hörst, fabian.hoerst@uk-essen.de 8 | # Institute for Artifical Intelligence in Medicine, 9 | # University Medicine Essen 10 | 11 | import logging 12 | from pathlib import Path 13 | from typing import Callable, Union, Tuple 14 | 15 | import numpy as np 16 | import torch 17 | from PIL import Image 18 | from torch.utils.data import Dataset 19 | 20 | from cell_segmentation.datasets.pannuke import PanNukeDataset 21 | from einops import rearrange 22 | 23 | logger = logging.getLogger() 24 | logger.addHandler(logging.NullHandler()) 25 | 26 | 27 | class MoNuSegDataset(Dataset): 28 | def __init__( 29 | self, 30 | dataset_path: Union[Path, str], 31 | transforms: Callable = None, 32 | patching: bool = False, 33 | overlap: int = 0, 34 | ) -> None: 35 | """MoNuSeg Dataset 36 | 37 | Args: 38 | dataset_path (Union[Path, str]): Path to dataset 39 | transforms (Callable, optional): Transformations to apply on images. Defaults to None. 40 | patching (bool, optional): If patches with size 256px should be used Otherwise, the entire MoNuSeg images are loaded. Defaults to False. 41 | overlap: (bool, optional): If overlap should be used for patch sampling. Overlap in pixels. 42 | Recommended value other than 0 is 64. Defaults to 0. 43 | Raises: 44 | FileNotFoundError: If no ground-truth annotation file was found in path 45 | """ 46 | self.dataset = Path(dataset_path).resolve() 47 | self.transforms = transforms 48 | self.masks = [] 49 | self.img_names = [] 50 | self.patching = patching 51 | self.overlap = overlap 52 | 53 | image_path = self.dataset / "images" 54 | label_path = self.dataset / "labels" 55 | self.images = [f for f in sorted(image_path.glob("*.png")) if f.is_file()] 56 | self.masks = [f for f in sorted(label_path.glob("*.npy")) if f.is_file()] 57 | 58 | # sanity_check 59 | for idx, image in enumerate(self.images): 60 | image_name = image.stem 61 | mask_name = self.masks[idx].stem 62 | if image_name != mask_name: 63 | raise FileNotFoundError(f"Annotation for file {image_name} is missing") 64 | 65 | def __getitem__(self, index: int) -> Tuple[torch.Tensor, dict, str]: 66 | """Get one item from dataset 67 | 68 | Args: 69 | index (int): Item to get 70 | 71 | Returns: 72 | Tuple[torch.Tensor, dict, str]: Trainings-Batch 73 | * torch.Tensor: Image 74 | * dict: Ground-Truth values: keys are "instance map", "nuclei_binary_map" and "hv_map" 75 | * str: filename 76 | """ 77 | img_path = self.images[index] 78 | img = np.array(Image.open(img_path)).astype(np.uint8) 79 | 80 | mask_path = self.masks[index] 81 | mask = np.load(mask_path, allow_pickle=True) 82 | mask = mask.astype(np.int64) 83 | 84 | if self.transforms is not None: 85 | transformed = self.transforms(image=img, mask=mask) 86 | img = transformed["image"] 87 | mask = transformed["mask"] 88 | 89 | hv_map = PanNukeDataset.gen_instance_hv_map(mask) 90 | np_map = mask.copy() 91 | np_map[np_map > 0] = 1 92 | 93 | # torch convert 94 | img = torch.Tensor(img).type(torch.float32) 95 | img = img.permute(2, 0, 1) 96 | if torch.max(img) >= 5: 97 | img = img / 255 98 | 99 | if self.patching and self.overlap == 0: 100 | img = rearrange(img, "c (h i) (w j) -> c h w i j", i=256, j=256) 101 | if self.patching and self.overlap != 0: 102 | img = img.unfold(1, 256, 256 - self.overlap).unfold( 103 | 2, 256, 256 - self.overlap 104 | ) 105 | 106 | masks = { 107 | "instance_map": torch.Tensor(mask).type(torch.int64), 108 | "nuclei_binary_map": torch.Tensor(np_map).type(torch.int64), 109 | "hv_map": torch.Tensor(hv_map).type(torch.float32), 110 | } 111 | 112 | return img, masks, Path(img_path).name 113 | 114 | def __len__(self) -> int: 115 | """Length of Dataset 116 | 117 | Returns: 118 | int: Length of Dataset 119 | """ 120 | return len(self.images) 121 | 122 | def set_transforms(self, transforms: Callable) -> None: 123 | """Set the transformations, can be used tp exchange transformations 124 | 125 | Args: 126 | transforms (Callable): PyTorch transformations 127 | """ 128 | self.transforms = transforms 129 | -------------------------------------------------------------------------------- /cell_segmentation/datasets/prepare_monuseg.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Prepare MoNuSeg Dataset By converting and resorting files 3 | # 4 | # @ Fabian Hörst, fabian.hoerst@uk-essen.de 5 | # Institute for Artifical Intelligence in Medicine, 6 | # University Medicine Essen 7 | 8 | from PIL import Image 9 | import xml.etree.ElementTree as ET 10 | from skimage import draw 11 | import numpy as np 12 | from pathlib import Path 13 | from typing import Union 14 | import argparse 15 | 16 | 17 | def convert_monuseg( 18 | input_path: Union[Path, str], output_path: Union[Path, str] 19 | ) -> None: 20 | """Convert the MoNuSeg dataset to a new format (1000 -> 1024, tiff to png and xml to npy) 21 | 22 | Args: 23 | input_path (Union[Path, str]): Input dataset 24 | output_path (Union[Path, str]): Output path 25 | """ 26 | input_path = Path(input_path) 27 | output_path = Path(output_path) 28 | output_path.mkdir(exist_ok=True, parents=True) 29 | 30 | # testing and training 31 | parts = ["testing", "training"] 32 | for part in parts: 33 | print(f"Prepare: {part}") 34 | input_path_part = input_path / part 35 | output_path_part = output_path / part 36 | output_path_part.mkdir(exist_ok=True, parents=True) 37 | (output_path_part / "images").mkdir(exist_ok=True, parents=True) 38 | (output_path_part / "labels").mkdir(exist_ok=True, parents=True) 39 | 40 | # images 41 | images = [f for f in sorted((input_path_part / "images").glob("*.tif"))] 42 | for img_path in images: 43 | loaded_image = Image.open(img_path) 44 | resized = loaded_image.resize( 45 | (1024, 1024), resample=Image.Resampling.LANCZOS 46 | ) 47 | new_img_path = output_path_part / "images" / f"{img_path.stem}.png" 48 | resized.save(new_img_path) 49 | # masks 50 | annotations = [f for f in sorted((input_path_part / "labels").glob("*.xml"))] 51 | for annot_path in annotations: 52 | binary_mask = np.transpose(np.zeros((1000, 1000))) 53 | 54 | # extract xml file 55 | tree = ET.parse(annot_path) 56 | root = tree.getroot() 57 | child = root[0] 58 | 59 | for x in child: 60 | r = x.tag 61 | if r == "Regions": 62 | element_idx = 1 63 | for y in x: 64 | y_tag = y.tag 65 | 66 | if y_tag == "Region": 67 | regions = [] 68 | vertices = y[1] 69 | coords = np.zeros((len(vertices), 2)) 70 | for i, vertex in enumerate(vertices): 71 | coords[i][0] = vertex.attrib["X"] 72 | coords[i][1] = vertex.attrib["Y"] 73 | regions.append(coords) 74 | vertex_row_coords = regions[0][:, 0] 75 | vertex_col_coords = regions[0][:, 1] 76 | fill_row_coords, fill_col_coords = draw.polygon( 77 | vertex_col_coords, vertex_row_coords, binary_mask.shape 78 | ) 79 | binary_mask[fill_row_coords, fill_col_coords] = element_idx 80 | 81 | element_idx = element_idx + 1 82 | inst_image = Image.fromarray(binary_mask) 83 | resized_mask = np.array( 84 | inst_image.resize((1024, 1024), resample=Image.Resampling.NEAREST) 85 | ) 86 | new_mask_path = output_path_part / "labels" / f"{annot_path.stem}.npy" 87 | np.save(new_mask_path, resized_mask) 88 | print("Finished") 89 | 90 | 91 | parser = argparse.ArgumentParser( 92 | formatter_class=argparse.ArgumentDefaultsHelpFormatter, 93 | description="Convert the MoNuSeg dataset", 94 | ) 95 | parser.add_argument( 96 | "--input_path", 97 | type=str, 98 | help="Input path of the original MoNuSeg dataset", 99 | required=True, 100 | ) 101 | parser.add_argument( 102 | "--output_path", 103 | type=str, 104 | help="Output path to store the processed MoNuSeg dataset", 105 | required=True, 106 | ) 107 | 108 | if __name__ == "__main__": 109 | opt = parser.parse_args() 110 | configuration = vars(opt) 111 | 112 | input_path = Path(configuration["input_path"]) 113 | output_path = Path(configuration["output_path"]) 114 | 115 | convert_monuseg(input_path=input_path, output_path=output_path) 116 | -------------------------------------------------------------------------------- /cell_segmentation/datasets/prepare_pannuke.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Prepare MoNuSeg Dataset By converting and resorting files 3 | # 4 | # @ Fabian Hörst, fabian.hoerst@uk-essen.de 5 | # Institute for Artifical Intelligence in Medicine, 6 | # University Medicine Essen 7 | 8 | import inspect 9 | import os 10 | import sys 11 | 12 | currentdir = os.path.dirname(os.path.abspath(inspect.getfile(inspect.currentframe()))) 13 | parentdir = os.path.dirname(currentdir) 14 | sys.path.insert(0, parentdir) 15 | parentdir = os.path.dirname(parentdir) 16 | sys.path.insert(0, parentdir) 17 | 18 | import numpy as np 19 | from pathlib import Path 20 | from PIL import Image 21 | from tqdm import tqdm 22 | import argparse 23 | from cell_segmentation.utils.metrics import remap_label 24 | 25 | 26 | def process_fold(fold, input_path, output_path) -> None: 27 | fold_path = Path(input_path) / f"fold{fold}" 28 | output_fold_path = Path(output_path) / f"fold{fold}" 29 | output_fold_path.mkdir(exist_ok=True, parents=True) 30 | (output_fold_path / "images").mkdir(exist_ok=True, parents=True) 31 | (output_fold_path / "labels").mkdir(exist_ok=True, parents=True) 32 | 33 | print(f"Fold: {fold}") 34 | print("Loading large numpy files, this may take a while") 35 | images = np.load(fold_path / "images.npy") 36 | masks = np.load(fold_path / "masks.npy") 37 | 38 | print("Process images") 39 | for i in tqdm(range(len(images)), total=len(images)): 40 | outname = f"{fold}_{i}.png" 41 | out_img = images[i] 42 | im = Image.fromarray(out_img.astype(np.uint8)) 43 | im.save(output_fold_path / "images" / outname) 44 | 45 | print("Process masks") 46 | for i in tqdm(range(len(images)), total=len(images)): 47 | outname = f"{fold}_{i}.npy" 48 | 49 | # need to create instance map and type map with shape 256x256 50 | mask = masks[i] 51 | inst_map = np.zeros((256, 256)) 52 | num_nuc = 0 53 | for j in range(5): 54 | # copy value from new array if value is not equal 0 55 | layer_res = remap_label(mask[:, :, j]) 56 | # inst_map = np.where(mask[:,:,j] != 0, mask[:,:,j], inst_map) 57 | inst_map = np.where(layer_res != 0, layer_res + num_nuc, inst_map) 58 | num_nuc = num_nuc + np.max(layer_res) 59 | inst_map = remap_label(inst_map) 60 | 61 | type_map = np.zeros((256, 256)).astype(np.int32) 62 | for j in range(5): 63 | layer_res = ((j + 1) * np.clip(mask[:, :, j], 0, 1)).astype(np.int32) 64 | type_map = np.where(layer_res != 0, layer_res, type_map) 65 | 66 | outdict = {"inst_map": inst_map, "type_map": type_map} 67 | np.save(output_fold_path / "labels" / outname, outdict) 68 | 69 | 70 | parser = argparse.ArgumentParser( 71 | formatter_class=argparse.ArgumentDefaultsHelpFormatter, 72 | description="Perform CellViT inference for given run-directory with model checkpoints and logs", 73 | ) 74 | parser.add_argument( 75 | "--input_path", 76 | type=str, 77 | help="Input path of the original PanNuke dataset", 78 | required=True, 79 | ) 80 | parser.add_argument( 81 | "--output_path", 82 | type=str, 83 | help="Output path to store the processed PanNuke dataset", 84 | required=True, 85 | ) 86 | 87 | if __name__ == "__main__": 88 | opt = parser.parse_args() 89 | configuration = vars(opt) 90 | 91 | input_path = Path(configuration["input_path"]) 92 | output_path = Path(configuration["output_path"]) 93 | 94 | for fold in [0, 1, 2]: 95 | process_fold(fold, input_path, output_path) 96 | -------------------------------------------------------------------------------- /cell_segmentation/experiments/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Experiment related methods for each network type 3 | # 4 | # @ Fabian Hörst, fabian.hoerst@uk-essen.de 5 | # Institute for Artifical Intelligence in Medicine, 6 | # University Medicine Essen 7 | -------------------------------------------------------------------------------- /cell_segmentation/inference/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Inference related methods for each network type 3 | # 4 | # @ Fabian Hörst, fabian.hoerst@uk-essen.de 5 | # Institute for Artifical Intelligence in Medicine, 6 | # University Medicine Essen 7 | -------------------------------------------------------------------------------- /cell_segmentation/run_cellvit.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Running an Experiment Using CellViT cell segmentation network 3 | # 4 | # @ Fabian Hörst, fabian.hoerst@uk-essen.de 5 | # Institute for Artifical Intelligence in Medicine, 6 | # University Medicine Essen 7 | 8 | import inspect 9 | import os 10 | import sys 11 | 12 | currentdir = os.path.dirname(os.path.abspath(inspect.getfile(inspect.currentframe()))) 13 | parentdir = os.path.dirname(currentdir) 14 | sys.path.insert(0, parentdir) 15 | 16 | import wandb 17 | 18 | from base_ml.base_cli import ExperimentBaseParser 19 | from cell_segmentation.experiments.experiment_cellvit_pannuke import ( 20 | ExperimentCellVitPanNuke, 21 | ) 22 | from cell_segmentation.experiments.experiment_cellvit_conic import ( 23 | ExperimentCellViTCoNic, 24 | ) 25 | 26 | from cell_segmentation.inference.inference_cellvit_experiment_pannuke import ( 27 | InferenceCellViT, 28 | ) 29 | 30 | if __name__ == "__main__": 31 | # Parse arguments 32 | configuration_parser = ExperimentBaseParser() 33 | configuration = configuration_parser.parse_arguments() 34 | 35 | if configuration["data"]["dataset"].lower() == "pannuke": 36 | experiment_class = ExperimentCellVitPanNuke 37 | elif configuration["data"]["dataset"].lower() == "conic": 38 | experiment_class = ExperimentCellViTCoNic 39 | # Setup experiment 40 | if "checkpoint" in configuration: 41 | # continue checkpoint 42 | experiment = experiment_class( 43 | default_conf=configuration, checkpoint=configuration["checkpoint"] 44 | ) 45 | outdir = experiment.run_experiment() 46 | inference = InferenceCellViT( 47 | run_dir=outdir, 48 | gpu=configuration["gpu"], 49 | checkpoint_name=configuration["eval_checkpoint"], 50 | magnification=configuration["data"].get("magnification", 40), 51 | ) 52 | ( 53 | trained_model, 54 | inference_dataloader, 55 | dataset_config, 56 | ) = inference.setup_patch_inference() 57 | inference.run_patch_inference( 58 | trained_model, inference_dataloader, dataset_config, generate_plots=False 59 | ) 60 | else: 61 | experiment = experiment_class(default_conf=configuration) 62 | if configuration["run_sweep"] is True: 63 | # run new sweep 64 | sweep_configuration = experiment_class.extract_sweep_arguments( 65 | configuration 66 | ) 67 | os.environ["WANDB_DIR"] = os.path.abspath( 68 | configuration["logging"]["wandb_dir"] 69 | ) 70 | sweep_id = wandb.sweep( 71 | sweep=sweep_configuration, project=configuration["logging"]["project"] 72 | ) 73 | wandb.agent(sweep_id=sweep_id, function=experiment.run_experiment) 74 | elif "agent" in configuration and configuration["agent"] is not None: 75 | # add agent to already existing sweep, not run sweep must be set to true 76 | configuration["run_sweep"] = True 77 | os.environ["WANDB_DIR"] = os.path.abspath( 78 | configuration["logging"]["wandb_dir"] 79 | ) 80 | wandb.agent( 81 | sweep_id=configuration["agent"], function=experiment.run_experiment 82 | ) 83 | else: 84 | # casual run 85 | outdir = experiment.run_experiment() 86 | inference = InferenceCellViT( 87 | run_dir=outdir, 88 | gpu=configuration["gpu"], 89 | checkpoint_name=configuration["eval_checkpoint"], 90 | magnification=configuration["data"].get("magnification", 40), 91 | ) 92 | ( 93 | trained_model, 94 | inference_dataloader, 95 | dataset_config, 96 | ) = inference.setup_patch_inference() 97 | inference.run_patch_inference( 98 | trained_model, 99 | inference_dataloader, 100 | dataset_config, 101 | generate_plots=False, 102 | ) 103 | wandb.finish() 104 | -------------------------------------------------------------------------------- /cell_segmentation/trainer/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Trainer for each network type 3 | # 4 | # @ Fabian Hörst, fabian.hoerst@uk-essen.de 5 | # Institute for Artifical Intelligence in Medicine, 6 | # University Medicine Essen 7 | -------------------------------------------------------------------------------- /cell_segmentation/utils/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Utils 3 | # 4 | # @ Fabian Hörst, fabian.hoerst@uk-essen.de 5 | # Institute for Artifical Intelligence in Medicine, 6 | # University Medicine Essen 7 | -------------------------------------------------------------------------------- /cell_segmentation/utils/metrics.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Implemented Metrics for Cell detection 3 | # 4 | # This code is based on the following repository: https://github.com/TissueImageAnalytics/PanNuke-metrics 5 | # 6 | # Implemented metrics are: 7 | # 8 | # Instance Segmentation Metrics 9 | # Binary PQ 10 | # Multiclass PQ 11 | # Neoplastic PQ 12 | # Non-Neoplastic PQ 13 | # Inflammatory PQ 14 | # Dead PQ 15 | # Inflammatory PQ 16 | # Dead PQ 17 | # 18 | # Detection and Classification Metrics 19 | # Precision, Recall, F1 20 | # 21 | # Other 22 | # dice1, dice2, aji, aji_plus 23 | # 24 | # Binary PQ (bPQ): Assumes all nuclei belong to same class and reports the average PQ across tissue types. 25 | # Multi-Class PQ (mPQ): Reports the average PQ across the classes and tissue types. 26 | # Neoplastic PQ: Reports the PQ for the neoplastic class on all tissues. 27 | # Non-Neoplastic PQ: Reports the PQ for the non-neoplastic class on all tissues. 28 | # Inflammatory PQ: Reports the PQ for the inflammatory class on all tissues. 29 | # Connective PQ: Reports the PQ for the connective class on all tissues. 30 | # Dead PQ: Reports the PQ for the dead class on all tissues. 31 | 32 | 33 | from typing import List 34 | import numpy as np 35 | from scipy.optimize import linear_sum_assignment 36 | 37 | 38 | def get_fast_pq(true, pred, match_iou=0.5): 39 | """ 40 | `match_iou` is the IoU threshold level to determine the pairing between 41 | GT instances `p` and prediction instances `g`. `p` and `g` is a pair 42 | if IoU > `match_iou`. However, pair of `p` and `g` must be unique 43 | (1 prediction instance to 1 GT instance mapping). 44 | 45 | If `match_iou` < 0.5, Munkres assignment (solving minimum weight matching 46 | in bipartite graphs) is caculated to find the maximal amount of unique pairing. 47 | 48 | If `match_iou` >= 0.5, all IoU(p,g) > 0.5 pairing is proven to be unique and 49 | the number of pairs is also maximal. 50 | 51 | Fast computation requires instance IDs are in contiguous orderding 52 | i.e [1, 2, 3, 4] not [2, 3, 6, 10]. Please call `remap_label` beforehand 53 | and `by_size` flag has no effect on the result. 54 | 55 | Returns: 56 | [dq, sq, pq]: measurement statistic 57 | 58 | [paired_true, paired_pred, unpaired_true, unpaired_pred]: 59 | pairing information to perform measurement 60 | 61 | """ 62 | assert match_iou >= 0.0, "Cant' be negative" 63 | 64 | true = np.copy(true) #[256,256] 65 | pred = np.copy(pred) #(256,256) 66 | true_id_list = list(np.unique(true)) 67 | pred_id_list = list(np.unique(pred)) 68 | 69 | # if there is no background, fixing by adding it 70 | if 0 not in pred_id_list: 71 | pred_id_list = [0] + pred_id_list 72 | 73 | true_masks = [ 74 | None, 75 | ] 76 | for t in true_id_list[1:]: 77 | t_mask = np.array(true == t, np.uint8) 78 | true_masks.append(t_mask) 79 | 80 | pred_masks = [ 81 | None, 82 | ] 83 | for p in pred_id_list[1:]: 84 | p_mask = np.array(pred == p, np.uint8) 85 | pred_masks.append(p_mask) 86 | 87 | # prefill with value 88 | pairwise_iou = np.zeros( 89 | [len(true_id_list) - 1, len(pred_id_list) - 1], dtype=np.float64 90 | ) 91 | 92 | # caching pairwise iou for all instances 93 | for true_id in true_id_list[1:]: # 0-th is background 94 | t_mask = true_masks[true_id] 95 | pred_true_overlap = pred[t_mask > 0] 96 | pred_true_overlap_id = np.unique(pred_true_overlap) 97 | pred_true_overlap_id = list(pred_true_overlap_id) 98 | for pred_id in pred_true_overlap_id: 99 | if pred_id == 0: # ignore 100 | continue # overlaping background 101 | p_mask = pred_masks[pred_id] 102 | total = (t_mask + p_mask).sum() 103 | inter = (t_mask * p_mask).sum() 104 | iou = inter / (total - inter) 105 | pairwise_iou[true_id - 1, pred_id - 1] = iou 106 | # 107 | if match_iou >= 0.5: 108 | paired_iou = pairwise_iou[pairwise_iou > match_iou] 109 | pairwise_iou[pairwise_iou <= match_iou] = 0.0 110 | paired_true, paired_pred = np.nonzero(pairwise_iou) 111 | paired_iou = pairwise_iou[paired_true, paired_pred] 112 | paired_true += 1 # index is instance id - 1 113 | paired_pred += 1 # hence return back to original 114 | else: # * Exhaustive maximal unique pairing 115 | #### Munkres pairing with scipy library 116 | # the algorithm return (row indices, matched column indices) 117 | # if there is multiple same cost in a row, index of first occurence 118 | # is return, thus the unique pairing is ensure 119 | # inverse pair to get high IoU as minimum 120 | paired_true, paired_pred = linear_sum_assignment(-pairwise_iou) 121 | ### extract the paired cost and remove invalid pair 122 | paired_iou = pairwise_iou[paired_true, paired_pred] 123 | 124 | # now select those above threshold level 125 | # paired with iou = 0.0 i.e no intersection => FP or FN 126 | paired_true = list(paired_true[paired_iou > match_iou] + 1) 127 | paired_pred = list(paired_pred[paired_iou > match_iou] + 1) 128 | paired_iou = paired_iou[paired_iou > match_iou] 129 | 130 | # get the actual FP and FN 131 | unpaired_true = [idx for idx in true_id_list[1:] if idx not in paired_true] 132 | unpaired_pred = [idx for idx in pred_id_list[1:] if idx not in paired_pred] 133 | # print(paired_iou.shape, paired_true.shape, len(unpaired_true), len(unpaired_pred)) 134 | 135 | # 136 | tp = len(paired_true) 137 | fp = len(unpaired_pred) 138 | fn = len(unpaired_true) 139 | # get the F1-score i.e DQ 140 | dq = tp / (tp + 0.5 * fp + 0.5 * fn + 1.0e-6) # good practice? 141 | # get the SQ, no paired has 0 iou so not impact 142 | sq = paired_iou.sum() / (tp + 1.0e-6) 143 | 144 | return [dq, sq, dq * sq], [paired_true, paired_pred, unpaired_true, unpaired_pred] 145 | 146 | 147 | ##### 148 | 149 | 150 | def remap_label(pred, by_size=False): 151 | """ 152 | Rename all instance id so that the id is contiguous i.e [0, 1, 2, 3] 153 | not [0, 2, 4, 6]. The ordering of instances (which one comes first) 154 | is preserved unless by_size=True, then the instances will be reordered 155 | so that bigger nucler has smaller ID 156 | 157 | Args: 158 | pred : the 2d array contain instances where each instances is marked 159 | by non-zero integer 160 | by_size : renaming with larger nuclei has smaller id (on-top) 161 | """ 162 | pred_id = list(np.unique(pred)) 163 | if 0 in pred_id: 164 | pred_id.remove(0) 165 | if len(pred_id) == 0: 166 | return pred # no label 167 | if by_size: 168 | pred_size = [] 169 | for inst_id in pred_id: 170 | size = (pred == inst_id).sum() 171 | pred_size.append(size) 172 | # sort the id by size in descending order 173 | pair_list = zip(pred_id, pred_size) 174 | pair_list = sorted(pair_list, key=lambda x: x[1], reverse=True) 175 | pred_id, pred_size = zip(*pair_list) 176 | 177 | new_pred = np.zeros(pred.shape, np.int32) 178 | for idx, inst_id in enumerate(pred_id): 179 | new_pred[pred == inst_id] = idx + 1 180 | return new_pred 181 | 182 | 183 | #### 184 | 185 | 186 | def binarize(x): 187 | """ 188 | convert multichannel (multiclass) instance segmetation tensor 189 | to binary instance segmentation (bg and nuclei), 190 | 191 | :param x: B*B*C (for PanNuke 256*256*5 ) 192 | :return: Instance segmentation 193 | """ 194 | #x = np.transpose(x, (1, 2, 0)) #[256,256,5] 195 | 196 | out = np.zeros([x.shape[0], x.shape[1]]) 197 | count = 1 198 | for i in range(x.shape[2]): 199 | x_ch = x[:, :, i] #[256,256] 200 | unique_vals = np.unique(x_ch) 201 | unique_vals = unique_vals.tolist() 202 | unique_vals.remove(0) 203 | for j in unique_vals: 204 | x_tmp = x_ch == j 205 | x_tmp_c = 1 - x_tmp 206 | out *= x_tmp_c 207 | out += count * x_tmp 208 | count += 1 209 | out = out.astype("int32") 210 | return out 211 | 212 | 213 | def get_tissue_idx(tissue_indices, idx): 214 | for i in range(len(tissue_indices)): 215 | if tissue_indices[i].count(idx) == 1: 216 | tiss_idx = i 217 | return tiss_idx 218 | 219 | 220 | def cell_detection_scores( 221 | paired_true, paired_pred, unpaired_true, unpaired_pred, w: List = [1, 1] 222 | ): 223 | tp_d = paired_pred.shape[0] 224 | fp_d = unpaired_pred.shape[0] 225 | fn_d = unpaired_true.shape[0] 226 | 227 | # tp_tn_dt = (paired_pred == paired_true).sum() 228 | # fp_fn_dt = (paired_pred != paired_true).sum() 229 | prec_d = tp_d / (tp_d + fp_d) 230 | rec_d = tp_d / (tp_d + fn_d) 231 | 232 | f1_d = 2 * tp_d / (2 * tp_d + w[0] * fp_d + w[1] * fn_d) 233 | 234 | return f1_d, prec_d, rec_d 235 | 236 | 237 | def cell_type_detection_scores( 238 | paired_true, 239 | paired_pred, 240 | unpaired_true, 241 | unpaired_pred, 242 | type_id, 243 | w: List = [2, 2, 1, 1], 244 | exhaustive: bool = True, 245 | ): 246 | type_samples = (paired_true == type_id) | (paired_pred == type_id) 247 | 248 | paired_true = paired_true[type_samples] 249 | paired_pred = paired_pred[type_samples] 250 | 251 | tp_dt = ((paired_true == type_id) & (paired_pred == type_id)).sum() 252 | tn_dt = ((paired_true != type_id) & (paired_pred != type_id)).sum() 253 | fp_dt = ((paired_true != type_id) & (paired_pred == type_id)).sum() 254 | fn_dt = ((paired_true == type_id) & (paired_pred != type_id)).sum() 255 | 256 | if not exhaustive: 257 | ignore = (paired_true == -1).sum() 258 | fp_dt -= ignore 259 | 260 | fp_d = (unpaired_pred == type_id).sum() # 261 | fn_d = (unpaired_true == type_id).sum() 262 | 263 | prec_type = (tp_dt + tn_dt) / (tp_dt + tn_dt + w[0] * fp_dt + w[2] * fp_d) 264 | rec_type = (tp_dt + tn_dt) / (tp_dt + tn_dt + w[1] * fn_dt + w[3] * fn_d) 265 | 266 | f1_type = (2 * (tp_dt + tn_dt)) / ( 267 | 2 * (tp_dt + tn_dt) + w[0] * fp_dt + w[1] * fn_dt + w[2] * fp_d + w[3] * fn_d 268 | ) 269 | return f1_type, prec_type, rec_type 270 | -------------------------------------------------------------------------------- /cell_segmentation/utils/post_proc_cellvit.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # PostProcessing Pipeline 3 | # 4 | # Adapted from HoverNet 5 | # HoverNet Network (https://doi.org/10.1016/j.media.2019.101563) 6 | # Code Snippet adapted from HoverNet implementation (https://github.com/vqdang/hover_net) 7 | 8 | import warnings 9 | from typing import Tuple, Literal,List 10 | 11 | import cv2 12 | import numpy as np 13 | from scipy.ndimage import measurements 14 | from scipy.ndimage.morphology import binary_fill_holes 15 | from skimage.segmentation import watershed 16 | import torch 17 | 18 | from .tools import get_bounding_box, remove_small_objects 19 | 20 | 21 | def noop(*args, **kargs): 22 | pass 23 | 24 | 25 | warnings.warn = noop 26 | 27 | 28 | class DetectionCellPostProcessor: 29 | def __init__( 30 | self, 31 | nr_types: int = None, 32 | magnification: Literal[20, 40] = 40, 33 | gt: bool = False, 34 | ) -> None: 35 | """DetectionCellPostProcessor for postprocessing prediction maps and get detected cells 36 | 37 | Args: 38 | nr_types (int, optional): Number of cell types, including background (background = 0). Defaults to None. 39 | magnification (Literal[20, 40], optional): Which magnification the data has. Defaults to 40. 40 | gt (bool, optional): If this is gt data (used that we do not suppress tiny cells that may be noise in a prediction map). 41 | Defaults to False. 42 | 43 | Raises: 44 | NotImplementedError: Unknown magnification 45 | """ 46 | self.nr_types = nr_types 47 | self.magnification = magnification 48 | self.gt = gt 49 | 50 | if magnification == 40: 51 | self.object_size = 10 52 | self.k_size = 21 53 | elif magnification == 20: 54 | self.object_size = 3 # 3 or 40, we used 5 55 | self.k_size = 11 # 11 or 41, we used 13 56 | else: 57 | raise NotImplementedError("Unknown magnification") 58 | if gt: # to not supress something in gt! 59 | self.object_size = 100 60 | self.k_size = 21 61 | 62 | def post_process_cell_segmentation( 63 | self, 64 | pred_map: np.ndarray, 65 | ) -> Tuple[np.ndarray, dict]: 66 | """Post processing of one image tile 67 | 68 | Args: 69 | pred_map (np.ndarray): Combined output of tp, np and hv branches, in the same order. Shape: (H, W, 4) 70 | 71 | Returns: 72 | Tuple[np.ndarray, dict]: 73 | np.ndarray: Instance map for one image. Each nuclei has own integer. Shape: (H, W) 74 | dict: Instance dictionary. Main Key is the nuclei instance number (int), with a dict as value. 75 | For each instance, the dictionary contains the keys: bbox (bounding box), centroid (centroid coordinates), 76 | contour, type_prob (probability), type (nuclei type) 77 | """ 78 | if self.nr_types is not None: 79 | pred_type = pred_map[..., :1] 80 | pred_inst = pred_map[..., 1:] 81 | pred_type = pred_type.astype(np.int32) 82 | else: 83 | pred_inst = pred_map 84 | 85 | pred_inst = np.squeeze(pred_inst) 86 | pred_inst = self.__proc_np_hv( 87 | pred_inst, object_size=self.object_size, ksize=self.k_size 88 | ) 89 | 90 | inst_id_list = np.unique(pred_inst)[1:] # exlcude background 91 | inst_info_dict = {} 92 | for inst_id in inst_id_list: 93 | inst_map = pred_inst == inst_id 94 | rmin, rmax, cmin, cmax = get_bounding_box(inst_map) 95 | inst_bbox = np.array([[rmin, cmin], [rmax, cmax]]) 96 | inst_map = inst_map[ 97 | inst_bbox[0][0] : inst_bbox[1][0], inst_bbox[0][1] : inst_bbox[1][1] 98 | ] 99 | inst_map = inst_map.astype(np.uint8) 100 | inst_moment = cv2.moments(inst_map) 101 | inst_contour = cv2.findContours( 102 | inst_map, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE 103 | ) 104 | # * opencv protocol format may break 105 | inst_contour = np.squeeze(inst_contour[0][0].astype("int32")) 106 | # < 3 points dont make a contour, so skip, likely artifact too 107 | # as the contours obtained via approximation => too small or sthg 108 | if inst_contour.shape[0] < 3: 109 | continue 110 | if len(inst_contour.shape) != 2: 111 | continue # ! check for trickery shape 112 | inst_centroid = [ 113 | (inst_moment["m10"] / inst_moment["m00"]), 114 | (inst_moment["m01"] / inst_moment["m00"]), 115 | ] 116 | inst_centroid = np.array(inst_centroid) 117 | inst_contour[:, 0] += inst_bbox[0][1] # X 118 | inst_contour[:, 1] += inst_bbox[0][0] # Y 119 | inst_centroid[0] += inst_bbox[0][1] # X 120 | inst_centroid[1] += inst_bbox[0][0] # Y 121 | inst_info_dict[inst_id] = { # inst_id should start at 1 122 | "bbox": inst_bbox, 123 | "centroid": inst_centroid, 124 | "contour": inst_contour, 125 | "type_prob": None, 126 | "type": None, 127 | } 128 | 129 | #### * Get class of each instance id, stored at index id-1 (inst_id = number of deteced nucleus) 130 | for inst_id in list(inst_info_dict.keys()): 131 | rmin, cmin, rmax, cmax = (inst_info_dict[inst_id]["bbox"]).flatten() 132 | inst_map_crop = pred_inst[rmin:rmax, cmin:cmax] 133 | inst_type_crop = pred_type[rmin:rmax, cmin:cmax] 134 | inst_map_crop = inst_map_crop == inst_id 135 | inst_type = inst_type_crop[inst_map_crop] 136 | type_list, type_pixels = np.unique(inst_type, return_counts=True) 137 | type_list = list(zip(type_list, type_pixels)) 138 | type_list = sorted(type_list, key=lambda x: x[1], reverse=True) 139 | inst_type = type_list[0][0] 140 | if inst_type == 0: # ! pick the 2nd most dominant if exist 141 | if len(type_list) > 1: 142 | inst_type = type_list[1][0] 143 | type_dict = {v[0]: v[1] for v in type_list} 144 | type_prob = type_dict[inst_type] / (np.sum(inst_map_crop) + 1.0e-6) 145 | inst_info_dict[inst_id]["type"] = int(inst_type) 146 | inst_info_dict[inst_id]["type_prob"] = float(type_prob) 147 | 148 | return pred_inst, inst_info_dict 149 | 150 | def __proc_np_hv( 151 | self, pred: np.ndarray, object_size: int = 10, ksize: int = 21 152 | ) -> np.ndarray: 153 | """Process Nuclei Prediction with XY Coordinate Map and generate instance map (each instance has unique integer) 154 | 155 | Separate Instances (also overlapping ones) from binary nuclei map and hv map by using morphological operations and watershed 156 | 157 | Args: 158 | pred (np.ndarray): Prediction output, assuming. Shape: (H, W, 3) 159 | * channel 0 contain probability map of nuclei 160 | * channel 1 containing the regressed X-map 161 | * channel 2 containing the regressed Y-map 162 | object_size (int, optional): Smallest oject size for filtering. Defaults to 10 163 | k_size (int, optional): Sobel Kernel size. Defaults to 21 164 | Returns: 165 | np.ndarray: Instance map for one image. Each nuclei has own integer. Shape: (H, W) 166 | """ 167 | pred = np.array(pred, dtype=np.float32) 168 | 169 | blb_raw = pred[..., 0] 170 | h_dir_raw = pred[..., 1] 171 | v_dir_raw = pred[..., 2] 172 | 173 | # processing 174 | blb = np.array(blb_raw >= 0.5, dtype=np.int32) 175 | 176 | blb = measurements.label(blb)[0] # ndimage.label(blb)[0] 177 | blb = remove_small_objects(blb, min_size=10) # 10 178 | blb[blb > 0] = 1 # background is 0 already 179 | 180 | h_dir = cv2.normalize( 181 | h_dir_raw, 182 | None, 183 | alpha=0, 184 | beta=1, 185 | norm_type=cv2.NORM_MINMAX, 186 | dtype=cv2.CV_32F, 187 | ) 188 | v_dir = cv2.normalize( 189 | v_dir_raw, 190 | None, 191 | alpha=0, 192 | beta=1, 193 | norm_type=cv2.NORM_MINMAX, 194 | dtype=cv2.CV_32F, 195 | ) 196 | 197 | # ksize = int((20 * scale_factor) + 1) # 21 vs 41 198 | # obj_size = math.ceil(10 * (scale_factor**2)) #10 vs 40 199 | 200 | sobelh = cv2.Sobel(h_dir, cv2.CV_64F, 1, 0, ksize=ksize) 201 | sobelv = cv2.Sobel(v_dir, cv2.CV_64F, 0, 1, ksize=ksize) 202 | 203 | sobelh = 1 - ( 204 | cv2.normalize( 205 | sobelh, 206 | None, 207 | alpha=0, 208 | beta=1, 209 | norm_type=cv2.NORM_MINMAX, 210 | dtype=cv2.CV_32F, 211 | ) 212 | ) 213 | sobelv = 1 - ( 214 | cv2.normalize( 215 | sobelv, 216 | None, 217 | alpha=0, 218 | beta=1, 219 | norm_type=cv2.NORM_MINMAX, 220 | dtype=cv2.CV_32F, 221 | ) 222 | ) 223 | 224 | overall = np.maximum(sobelh, sobelv) 225 | overall = overall - (1 - blb) 226 | overall[overall < 0] = 0 227 | 228 | dist = (1.0 - overall) * blb 229 | ## nuclei values form mountains so inverse to get basins 230 | dist = -cv2.GaussianBlur(dist, (3, 3), 0) 231 | 232 | overall = np.array(overall >= 0.4, dtype=np.int32) 233 | 234 | marker = blb - overall 235 | marker[marker < 0] = 0 236 | marker = binary_fill_holes(marker).astype("uint8") 237 | kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (5, 5)) 238 | marker = cv2.morphologyEx(marker, cv2.MORPH_OPEN, kernel) 239 | marker = measurements.label(marker)[0] 240 | marker = remove_small_objects(marker, min_size=object_size) 241 | 242 | proced_pred = watershed(dist, markers=marker, mask=blb) 243 | 244 | return proced_pred 245 | 246 | 247 | def calculate_instances( 248 | pred_types: torch.Tensor, pred_insts: torch.Tensor 249 | ) -> List[dict]: 250 | """Best used for GT 251 | 252 | Args: 253 | pred_types (torch.Tensor): Binary or type map ground-truth. 254 | Shape must be (B, C, H, W) with C=1 for binary or num_nuclei_types for multi-class. 255 | pred_insts (torch.Tensor): Ground-Truth instance map with shape (B, H, W) 256 | 257 | Returns: 258 | list[dict]: Dictionary with nuclei informations, output similar to post_process_cell_segmentation 259 | """ 260 | type_preds = [] 261 | pred_types = pred_types.permute(0, 2, 3, 1) 262 | for i in range(pred_types.shape[0]): 263 | pred_type = torch.argmax(pred_types, dim=-1)[i].detach().cpu().numpy() 264 | pred_inst = pred_insts[i].detach().cpu().numpy() 265 | inst_id_list = np.unique(pred_inst)[1:] # exlcude background 266 | inst_info_dict = {} 267 | for inst_id in inst_id_list: 268 | inst_map = pred_inst == inst_id 269 | rmin, rmax, cmin, cmax = get_bounding_box(inst_map) 270 | inst_bbox = np.array([[rmin, cmin], [rmax, cmax]]) 271 | inst_map = inst_map[ 272 | inst_bbox[0][0] : inst_bbox[1][0], inst_bbox[0][1] : inst_bbox[1][1] 273 | ] 274 | inst_map = inst_map.astype(np.uint8) 275 | inst_moment = cv2.moments(inst_map) 276 | inst_contour = cv2.findContours( 277 | inst_map, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE 278 | ) 279 | # * opencv protocol format may break 280 | inst_contour = np.squeeze(inst_contour[0][0].astype("int32")) 281 | # < 3 points dont make a contour, so skip, likely artifact too 282 | # as the contours obtained via approximation => too small or sthg 283 | if inst_contour.shape[0] < 3: 284 | continue 285 | if len(inst_contour.shape) != 2: 286 | continue # ! check for trickery shape 287 | inst_centroid = [ 288 | (inst_moment["m10"] / inst_moment["m00"]), 289 | (inst_moment["m01"] / inst_moment["m00"]), 290 | ] 291 | inst_centroid = np.array(inst_centroid) 292 | inst_contour[:, 0] += inst_bbox[0][1] # X 293 | inst_contour[:, 1] += inst_bbox[0][0] # Y 294 | inst_centroid[0] += inst_bbox[0][1] # X 295 | inst_centroid[1] += inst_bbox[0][0] # Y 296 | inst_info_dict[inst_id] = { # inst_id should start at 1 297 | "bbox": inst_bbox, 298 | "centroid": inst_centroid, 299 | "contour": inst_contour, 300 | "type_prob": None, 301 | "type": None, 302 | } 303 | #### * Get class of each instance id, stored at index id-1 (inst_id = number of deteced nucleus) 304 | for inst_id in list(inst_info_dict.keys()): 305 | rmin, cmin, rmax, cmax = (inst_info_dict[inst_id]["bbox"]).flatten() 306 | inst_map_crop = pred_inst[rmin:rmax, cmin:cmax] 307 | inst_type_crop = pred_type[rmin:rmax, cmin:cmax] 308 | inst_map_crop = inst_map_crop == inst_id 309 | inst_type = inst_type_crop[inst_map_crop] 310 | type_list, type_pixels = np.unique(inst_type, return_counts=True) 311 | type_list = list(zip(type_list, type_pixels)) 312 | type_list = sorted(type_list, key=lambda x: x[1], reverse=True) 313 | inst_type = type_list[0][0] 314 | if inst_type == 0: # ! pick the 2nd most dominant if exist 315 | if len(type_list) > 1: 316 | inst_type = type_list[1][0] 317 | type_dict = {v[0]: v[1] for v in type_list} 318 | type_prob = type_dict[inst_type] / (np.sum(inst_map_crop) + 1.0e-6) 319 | inst_info_dict[inst_id]["type"] = int(inst_type) 320 | inst_info_dict[inst_id]["type_prob"] = float(type_prob) 321 | type_preds.append(inst_info_dict) 322 | 323 | return type_preds 324 | -------------------------------------------------------------------------------- /cell_segmentation/utils/template_geojson.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # GeoJson templates 3 | 4 | 5 | 6 | def get_template_point() -> dict: 7 | """Return a template for a Point geojson object 8 | 9 | Returns: 10 | dict: Template 11 | """ 12 | template_point = { 13 | "type": "Feature", 14 | "id": "TODO", 15 | "geometry": { 16 | "type": "MultiPoint", 17 | "coordinates": [ 18 | [], 19 | ], 20 | }, 21 | "properties": { 22 | "objectType": "annotation", 23 | "classification": {"name": "TODO", "color": []}, 24 | }, 25 | } 26 | return template_point 27 | 28 | 29 | def get_template_segmentation() -> dict: 30 | """Return a template for a MultiPolygon geojson object 31 | 32 | Returns: 33 | dict: Template 34 | """ 35 | template_multipolygon = { 36 | "type": "Feature", 37 | "id": "TODO", 38 | "geometry": { 39 | "type": "MultiPolygon", 40 | "coordinates": [ 41 | [], 42 | ], 43 | }, 44 | "properties": { 45 | "objectType": "annotation", 46 | "classification": {"name": "TODO", "color": []}, 47 | }, 48 | } 49 | return template_multipolygon 50 | -------------------------------------------------------------------------------- /cell_segmentation/utils/tools.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Helpful functions Pipeline 3 | # 4 | # Adapted from HoverNet 5 | # HoverNet Network (https://doi.org/10.1016/j.media.2019.101563) 6 | # Code Snippet adapted from HoverNet implementation (https://github.com/vqdang/hover_net) 7 | 8 | import math 9 | from typing import Tuple 10 | 11 | import numpy as np 12 | import scipy 13 | from numba import njit, prange 14 | from scipy import ndimage 15 | from scipy.optimize import linear_sum_assignment 16 | from skimage.draw import polygon 17 | 18 | 19 | def get_bounding_box(img): 20 | """Get bounding box coordinate information.""" 21 | rows = np.any(img, axis=1) 22 | cols = np.any(img, axis=0) 23 | rmin, rmax = np.where(rows)[0][[0, -1]] 24 | cmin, cmax = np.where(cols)[0][[0, -1]] 25 | # due to python indexing, need to add 1 to max 26 | # else accessing will be 1px in the box, not out 27 | rmax += 1 28 | cmax += 1 29 | return [rmin, rmax, cmin, cmax] 30 | 31 | 32 | @njit 33 | def cropping_center(x, crop_shape, batch=False): 34 | """Crop an input image at the centre. 35 | 36 | Args: 37 | x: input array 38 | crop_shape: dimensions of cropped array 39 | 40 | Returns: 41 | x: cropped array 42 | 43 | """ 44 | orig_shape = x.shape 45 | if not batch: 46 | h0 = int((orig_shape[0] - crop_shape[0]) * 0.5) 47 | w0 = int((orig_shape[1] - crop_shape[1]) * 0.5) 48 | x = x[h0 : h0 + crop_shape[0], w0 : w0 + crop_shape[1], ...] 49 | else: 50 | h0 = int((orig_shape[1] - crop_shape[0]) * 0.5) 51 | w0 = int((orig_shape[2] - crop_shape[1]) * 0.5) 52 | x = x[:, h0 : h0 + crop_shape[0], w0 : w0 + crop_shape[1], ...] 53 | return x 54 | 55 | 56 | def remove_small_objects(pred, min_size=64, connectivity=1): 57 | """Remove connected components smaller than the specified size. 58 | 59 | This function is taken from skimage.morphology.remove_small_objects, but the warning 60 | is removed when a single label is provided. 61 | 62 | Args: 63 | pred: input labelled array 64 | min_size: minimum size of instance in output array 65 | connectivity: The connectivity defining the neighborhood of a pixel. 66 | 67 | Returns: 68 | out: output array with instances removed under min_size 69 | 70 | """ 71 | out = pred 72 | 73 | if min_size == 0: # shortcut for efficiency 74 | return out 75 | 76 | if out.dtype == bool: 77 | selem = ndimage.generate_binary_structure(pred.ndim, connectivity) 78 | ccs = np.zeros_like(pred, dtype=np.int32) 79 | ndimage.label(pred, selem, output=ccs) 80 | else: 81 | ccs = out 82 | 83 | try: 84 | component_sizes = np.bincount(ccs.ravel()) 85 | except ValueError: 86 | raise ValueError( 87 | "Negative value labels are not supported. Try " 88 | "relabeling the input with `scipy.ndimage.label` or " 89 | "`skimage.morphology.label`." 90 | ) 91 | 92 | too_small = component_sizes < min_size 93 | too_small_mask = too_small[ccs] 94 | out[too_small_mask] = 0 95 | 96 | return out 97 | 98 | 99 | def pair_coordinates( 100 | setA: np.ndarray, setB: np.ndarray, radius: float 101 | ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: 102 | """Use the Munkres or Kuhn-Munkres algorithm to find the most optimal 103 | unique pairing (largest possible match) when pairing points in set B 104 | against points in set A, using distance as cost function. 105 | 106 | Args: 107 | setA (np.ndarray): np.array (float32) of size Nx2 contains the of XY coordinate 108 | of N different points 109 | setB (np.ndarray): np.array (float32) of size Nx2 contains the of XY coordinate 110 | of N different points 111 | radius (float): valid area around a point in setA to consider 112 | a given coordinate in setB a candidate for match 113 | 114 | Returns: 115 | Tuple[np.ndarray, np.ndarray, np.ndarray]: 116 | pairing: pairing is an array of indices 117 | where point at index pairing[0] in set A paired with point 118 | in set B at index pairing[1] 119 | unparedA: remaining point in set A unpaired 120 | unparedB: remaining point in set B unpaired 121 | """ 122 | # * Euclidean distance as the cost matrix 123 | pair_distance = scipy.spatial.distance.cdist(setA, setB, metric="euclidean") 124 | 125 | # * Munkres pairing with scipy library 126 | # the algorithm return (row indices, matched column indices) 127 | # if there is multiple same cost in a row, index of first occurence 128 | # is return, thus the unique pairing is ensured 129 | indicesA, paired_indicesB = linear_sum_assignment(pair_distance) 130 | 131 | # extract the paired cost and remove instances 132 | # outside of designated radius 133 | pair_cost = pair_distance[indicesA, paired_indicesB] 134 | 135 | pairedA = indicesA[pair_cost <= radius] 136 | pairedB = paired_indicesB[pair_cost <= radius] 137 | 138 | pairing = np.concatenate([pairedA[:, None], pairedB[:, None]], axis=-1) 139 | unpairedA = np.delete(np.arange(setA.shape[0]), pairedA) 140 | unpairedB = np.delete(np.arange(setB.shape[0]), pairedB) 141 | 142 | return pairing, unpairedA, unpairedB 143 | 144 | 145 | def fix_duplicates(inst_map: np.ndarray) -> np.ndarray: 146 | """Re-label duplicated instances in an instance labelled mask. 147 | 148 | Parameters 149 | ---------- 150 | inst_map : np.ndarray 151 | Instance labelled mask. Shape (H, W). 152 | 153 | Returns 154 | ------- 155 | np.ndarray: 156 | The instance labelled mask without duplicated indices. 157 | Shape (H, W). 158 | """ 159 | current_max_id = np.amax(inst_map) 160 | inst_list = list(np.unique(inst_map)) 161 | if 0 in inst_list: 162 | inst_list.remove(0) 163 | 164 | for inst_id in inst_list: 165 | inst = np.array(inst_map == inst_id, np.uint8) 166 | remapped_ids = ndimage.label(inst)[0] 167 | remapped_ids[remapped_ids > 1] += current_max_id 168 | inst_map[remapped_ids > 1] = remapped_ids[remapped_ids > 1] 169 | current_max_id = np.amax(inst_map) 170 | 171 | return inst_map 172 | 173 | 174 | def polygons_to_label_coord( 175 | coord: np.ndarray, shape: Tuple[int, int], labels: np.ndarray = None 176 | ) -> np.ndarray: 177 | """Render polygons to image given a shape. 178 | 179 | Parameters 180 | ---------- 181 | coord.shape : np.ndarray 182 | Shape: (n_polys, n_rays) 183 | shape : Tuple[int, int] 184 | Shape of the output mask. 185 | labels : np.ndarray, optional 186 | Sorted indices of the centroids. 187 | 188 | Returns 189 | ------- 190 | np.ndarray: 191 | Instance labelled mask. Shape: (H, W). 192 | """ 193 | coord = np.asarray(coord) 194 | if labels is None: 195 | labels = np.arange(len(coord)) 196 | 197 | assert coord.ndim == 3 and coord.shape[1] == 2 and len(coord) == len(labels) 198 | 199 | lbl = np.zeros(shape, np.int32) 200 | 201 | for i, c in zip(labels, coord): 202 | rr, cc = polygon(*c, shape) 203 | lbl[rr, cc] = i + 1 204 | 205 | return lbl 206 | 207 | 208 | def ray_angles(n_rays: int = 32): 209 | """Get linearly spaced angles for rays.""" 210 | return np.linspace(0, 2 * np.pi, n_rays, endpoint=False) 211 | 212 | 213 | def dist_to_coord( 214 | dist: np.ndarray, points: np.ndarray, scale_dist: Tuple[int, int] = (1, 1) 215 | ) -> np.ndarray: 216 | """Convert list of distances and centroids from polar to cartesian coordinates. 217 | 218 | Parameters 219 | ---------- 220 | dist : np.ndarray 221 | The centerpoint pixels of the radial distance map. Shape (n_polys, n_rays). 222 | points : np.ndarray 223 | The centroids of the instances. Shape: (n_polys, 2). 224 | scale_dist : Tuple[int, int], default=(1, 1) 225 | Scaling factor. 226 | 227 | Returns 228 | ------- 229 | np.ndarray: 230 | Cartesian cooridnates of the polygons. Shape (n_polys, 2, n_rays). 231 | """ 232 | dist = np.asarray(dist) 233 | points = np.asarray(points) 234 | assert ( 235 | dist.ndim == 2 236 | and points.ndim == 2 237 | and len(dist) == len(points) 238 | and points.shape[1] == 2 239 | and len(scale_dist) == 2 240 | ) 241 | n_rays = dist.shape[1] 242 | phis = ray_angles(n_rays) 243 | coord = (dist[:, np.newaxis] * np.array([np.sin(phis), np.cos(phis)])).astype( 244 | np.float32 245 | ) 246 | coord *= np.asarray(scale_dist).reshape(1, 2, 1) 247 | coord += points[..., np.newaxis] 248 | return coord 249 | 250 | 251 | def polygons_to_label( 252 | dist: np.ndarray, 253 | points: np.ndarray, 254 | shape: Tuple[int, int], 255 | prob: np.ndarray = None, 256 | thresh: float = -np.inf, 257 | scale_dist: Tuple[int, int] = (1, 1), 258 | ) -> np.ndarray: 259 | """Convert distances and center points to instance labelled mask. 260 | 261 | Parameters 262 | ---------- 263 | dist : np.ndarray 264 | The centerpoint pixels of the radial distance map. Shape (n_polys, n_rays). 265 | points : np.ndarray 266 | The centroids of the instances. Shape: (n_polys, 2). 267 | shape : Tuple[int, int]: 268 | Shape of the output mask. 269 | prob : np.ndarray, optional 270 | The centerpoint pixels of the regressed distance transform. 271 | Shape: (n_polys, n_rays). 272 | thresh : float, default=-np.inf 273 | Threshold for the regressed distance transform. 274 | scale_dist : Tuple[int, int], default=(1, 1) 275 | Scaling factor. 276 | 277 | Returns 278 | ------- 279 | np.ndarray: 280 | Instance labelled mask. Shape (H, W). 281 | """ 282 | dist = np.asarray(dist) 283 | points = np.asarray(points) 284 | prob = np.inf * np.ones(len(points)) if prob is None else np.asarray(prob) 285 | 286 | assert dist.ndim == 2 and points.ndim == 2 and len(dist) == len(points) 287 | assert len(points) == len(prob) and points.shape[1] == 2 and prob.ndim == 1 288 | 289 | ind = prob > thresh 290 | points = points[ind] 291 | dist = dist[ind] 292 | prob = prob[ind] 293 | 294 | ind = np.argsort(prob, kind="stable") 295 | points = points[ind] 296 | dist = dist[ind] 297 | 298 | coord = dist_to_coord(dist, points, scale_dist=scale_dist) 299 | 300 | return polygons_to_label_coord(coord, shape=shape, labels=ind) 301 | 302 | 303 | @njit(cache=True, fastmath=True) 304 | def intersection(boxA: np.ndarray, boxB: np.ndarray): 305 | """Compute area of intersection of two boxes. 306 | 307 | Parameters 308 | ---------- 309 | boxA : np.ndarray 310 | First boxes 311 | boxB : np.ndarray 312 | Second box 313 | 314 | Returns 315 | ------- 316 | float64: 317 | Area of intersection 318 | """ 319 | xA = max(boxA[..., 0], boxB[..., 0]) 320 | xB = min(boxA[..., 2], boxB[..., 2]) 321 | dx = xB - xA 322 | if dx <= 0: 323 | return 0.0 324 | 325 | yA = max(boxA[..., 1], boxB[..., 1]) 326 | yB = min(boxA[..., 3], boxB[..., 3]) 327 | dy = yB - yA 328 | if dy <= 0.0: 329 | return 0.0 330 | 331 | return dx * dy 332 | 333 | 334 | @njit(parallel=True) 335 | def get_bboxes( 336 | dist: np.ndarray, points: np.ndarray 337 | ) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray, int]: 338 | """Get bounding boxes from the non-zero pixels of the radial distance maps. 339 | 340 | This is basically a translation from the stardist repo cpp code to python 341 | 342 | NOTE: jit compiled and parallelized with numba. 343 | 344 | Parameters 345 | ---------- 346 | dist : np.ndarray 347 | The non-zero values of the radial distance maps. Shape: (n_nonzero, n_rays). 348 | points : np.ndarray 349 | The yx-coordinates of the non-zero points. Shape (n_nonzero, 2). 350 | 351 | Returns 352 | ------- 353 | Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray, int]: 354 | Returns the x0, y0, x1, y1 bbox coordinates, bbox areas and the maximum 355 | radial distance in the image. 356 | """ 357 | n_polys = dist.shape[0] 358 | n_rays = dist.shape[1] 359 | 360 | bbox_x1 = np.zeros(n_polys) 361 | bbox_x2 = np.zeros(n_polys) 362 | bbox_y1 = np.zeros(n_polys) 363 | bbox_y2 = np.zeros(n_polys) 364 | 365 | areas = np.zeros(n_polys) 366 | angle_pi = 2 * math.pi / n_rays 367 | max_dist = 0 368 | 369 | for i in prange(n_polys): 370 | max_radius_outer = 0 371 | py = points[i, 0] 372 | px = points[i, 1] 373 | 374 | for k in range(n_rays): 375 | d = dist[i, k] 376 | y = py + d * np.sin(angle_pi * k) 377 | x = px + d * np.cos(angle_pi * k) 378 | 379 | if k == 0: 380 | bbox_x1[i] = x 381 | bbox_x2[i] = x 382 | bbox_y1[i] = y 383 | bbox_y2[i] = y 384 | else: 385 | bbox_x1[i] = min(x, bbox_x1[i]) 386 | bbox_x2[i] = max(x, bbox_x2[i]) 387 | bbox_y1[i] = min(y, bbox_y1[i]) 388 | bbox_y2[i] = max(y, bbox_y2[i]) 389 | 390 | max_radius_outer = max(d, max_radius_outer) 391 | 392 | areas[i] = (bbox_x2[i] - bbox_x1[i]) * (bbox_y2[i] - bbox_y1[i]) 393 | max_dist = max(max_dist, max_radius_outer) 394 | 395 | return bbox_x1, bbox_y1, bbox_x2, bbox_y2, areas, max_dist 396 | -------------------------------------------------------------------------------- /config.yaml: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES: 3 2 | logging: 3 | log_dir: /data5/ziweicui/cellvit256-unireplknet-n 4 | mode: online 5 | project: Cell-Segmentation 6 | notes: CellViT-256 7 | log_comment: CellViT-256-resnet50-tiny 8 | tags: 9 | - Fold-1 10 | - ViT256 11 | wandb_dir: /data5/ziweicui/UniRepLKNet-optimizerconfig-unetdecoder-inputconv/results 12 | level: Debug 13 | group: CellViT256 14 | run_id: anifw9ux 15 | wandb_file: anifw9ux 16 | random_seed: 19 17 | gpu: 0 18 | data: 19 | dataset: PanNuke 20 | dataset_path: /data5/ziweicui/cellvit-png 21 | train_folds: 22 | - 0 23 | val_folds: 24 | - 1 25 | test_folds: 26 | - 2 27 | num_nuclei_classes: 6 28 | num_tissue_classes: 19 29 | model: 30 | backbone: default 31 | pretrained_encoder: /data5/ziweicui/semi_supervised_resnet50-08389792.pth 32 | shared_skip_connections: true 33 | loss: 34 | nuclei_binary_map: 35 | focaltverskyloss: 36 | loss_fn: FocalTverskyLoss 37 | weight: 1 38 | dice: 39 | loss_fn: dice_loss 40 | weight: 1 41 | hv_map: 42 | mse: 43 | loss_fn: mse_loss_maps 44 | weight: 2.5 45 | msge: 46 | loss_fn: msge_loss_maps 47 | weight: 8 48 | nuclei_type_map: 49 | bce: 50 | loss_fn: xentropy_loss 51 | weight: 0.5 52 | dice: 53 | loss_fn: dice_loss 54 | weight: 0.2 55 | mcfocaltverskyloss: 56 | loss_fn: MCFocalTverskyLoss 57 | weight: 0.5 58 | args: 59 | num_classes: 6 60 | tissue_types: 61 | ce: 62 | loss_fn: CrossEntropyLoss 63 | weight: 0.1 64 | training: 65 | drop_rate: 0 66 | attn_drop_rate: 0.1 67 | drop_path_rate: 0.1 68 | batch_size: 32 69 | epochs: 130 70 | optimizer: AdamW 71 | early_stopping_patience: 130 72 | scheduler: 73 | scheduler_type: cosine 74 | hyperparameters: 75 | #gamma: 0.85 76 | eta_min: 1e-5 77 | optimizer_hyperparameter: 78 | # betas: 79 | # - 0.85 80 | # - 0.95 81 | #lr: 0.004 82 | opt_lower: 'AdamW' 83 | lr: 0.0008 84 | opt_betas: [0.85,0.95] 85 | weight_decay: 0.05 86 | opt_eps: 0.00000008 87 | unfreeze_epoch: 25 88 | sampling_gamma: 0.85 89 | sampling_strategy: cell+tissue 90 | mixed_precision: true 91 | transformations: 92 | randomrotate90: 93 | p: 0.5 94 | horizontalflip: 95 | p: 0.5 96 | verticalflip: 97 | p: 0.5 98 | downscale: 99 | p: 0.15 100 | scale: 0.5 101 | blur: 102 | p: 0.2 103 | blur_limit: 10 104 | gaussnoise: 105 | p: 0.25 106 | var_limit: 50 107 | colorjitter: 108 | p: 0.2 109 | scale_setting: 0.25 110 | scale_color: 0.1 111 | superpixels: 112 | p: 0.1 113 | zoomblur: 114 | p: 0.1 115 | randomsizedcrop: 116 | p: 0.1 117 | elastictransform: 118 | p: 0.2 119 | normalize: 120 | mean: 121 | - 0.5 122 | - 0.5 123 | - 0.5 124 | std: 125 | - 0.5 126 | - 0.5 127 | - 0.5 128 | eval_checkpoint: latest_checkpoint.pth 129 | dataset_config: 130 | tissue_types: 131 | Adrenal_gland: 0 132 | Bile-duct: 1 133 | Bladder: 2 134 | Breast: 3 135 | Cervix: 4 136 | Colon: 5 137 | Esophagus: 6 138 | HeadNeck: 7 139 | Kidney: 8 140 | Liver: 9 141 | Lung: 10 142 | Ovarian: 11 143 | Pancreatic: 12 144 | Prostate: 13 145 | Skin: 14 146 | Stomach: 15 147 | Testis: 16 148 | Thyroid: 17 149 | Uterus: 18 150 | nuclei_types: 151 | Background: 0 152 | Neoplastic: 1 153 | Inflammatory: 2 154 | Connective: 3 155 | Dead: 4 156 | Epithelial: 5 157 | run_sweep: false 158 | agent: null 159 | -------------------------------------------------------------------------------- /datamodel/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Data models 3 | # 4 | # @ Fabian Hörst, fabian.hoerst@uk-essen.de 5 | # Institute for Artifical Intelligence in Medicine, 6 | # University Medicine Essen 7 | -------------------------------------------------------------------------------- /datamodel/graph_datamodel.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Graph Data model 3 | # 4 | # For more information, please check out docs/readmes/graphs.md 5 | # 6 | # @ Fabian Hörst, fabian.hoerst@uk-essen.de 7 | # Institute for Artifical Intelligence in Medicine, 8 | # University Medicine Essen 9 | 10 | from dataclasses import dataclass 11 | 12 | import torch 13 | 14 | 15 | @dataclass 16 | class GraphDataWSI: 17 | """Dataclass for Graph Data 18 | 19 | Args: 20 | x (torch.Tensor): Node feature matrix with shape (num_nodes, num_nodes_features) 21 | positions(torch.Tensor): Each of the objects defined in x has a physical position in a Cartesian coordinate system, 22 | be it detected cells or extracted patches. That's why we store the 2D position here, globally for the WSI. 23 | Shape (num_nodes, 2) 24 | metadata (dict, optional): Metadata about the object is stored here. Defaults to None 25 | """ 26 | 27 | x: torch.Tensor 28 | positions: torch.Tensor 29 | metadata: dict 30 | -------------------------------------------------------------------------------- /datamodel/wsi_datamodel.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # WSI Model 3 | # 4 | # @ Fabian Hörst, fabian.hoerst@uk-essen.de 5 | # Institute for Artifical Intelligence in Medicine, 6 | # University Medicine Essen 7 | 8 | 9 | import json 10 | from pathlib import Path 11 | from typing import Union, List, Callable, Tuple 12 | 13 | from dataclasses import dataclass, field 14 | import numpy as np 15 | import yaml 16 | import logging 17 | import torch 18 | from PIL import Image 19 | 20 | 21 | @dataclass 22 | class WSI: 23 | """WSI object 24 | 25 | Args: 26 | name (str): WSI name 27 | patient (str): Patient name 28 | slide_path (Union[str, Path]): Full path to the WSI file. 29 | patched_slide_path (Union[str, Path], optional): Full path to preprocessed WSI files (patches). Defaults to None. 30 | embedding_name (Union[str, Path], optional): Defaults to None. 31 | label (Union[str, int, float, np.ndarray], optional): Label of the WSI. Defaults to None. 32 | logger (logging.logger, optional): Logger module for logging information. Defaults to None. 33 | """ 34 | 35 | name: str 36 | patient: str 37 | slide_path: Union[str, Path] 38 | patched_slide_path: Union[str, Path] = None 39 | embedding_name: Union[str, Path] = None 40 | label: Union[str, int, float, np.ndarray] = None 41 | logger: logging.Logger = None 42 | 43 | # unset attributes used in this class 44 | metadata: dict = field(init=False, repr=False) 45 | all_patch_metadata: List[dict] = field(init=False, repr=False) 46 | patches_list: List = field(init=False, repr=False) 47 | patch_transform: Callable = field(init=False, repr=False) 48 | 49 | # name without ending (e.g. slide1 instead of slide1.svs) 50 | def __post_init__(self): 51 | """Post-Processing object""" 52 | super().__init__() 53 | # define paramaters that are used, but not defined at startup 54 | 55 | # convert string to path 56 | self.slide_path = Path(self.slide_path).resolve() 57 | if self.patched_slide_path is not None: 58 | self.patched_slide_path = Path(self.patched_slide_path).resolve() 59 | # load metadata 60 | self._get_metadata() 61 | self._get_wsi_patch_metadata() 62 | self.patch_transform = None # hardcode to None (should not be a parameter, but should be defined) 63 | 64 | if self.logger is not None: 65 | self.logger.debug(self.__repr__()) 66 | 67 | def _get_metadata(self) -> None: 68 | """Load metadata yaml file""" 69 | self.metadata_path = self.patched_slide_path / "metadata.yaml" 70 | with open(self.metadata_path.resolve(), "r") as metadata_yaml: 71 | try: 72 | self.metadata = yaml.safe_load(metadata_yaml) 73 | except yaml.YAMLError as exc: 74 | print(exc) 75 | self.metadata["label_map_inverse"] = { 76 | v: k for k, v in self.metadata["label_map"].items() 77 | } 78 | 79 | def _get_wsi_patch_metadata(self) -> None: 80 | """Load patch_metadata json file and convert to dict and lists""" 81 | with open(self.patched_slide_path / "patch_metadata.json", "r") as json_file: 82 | metadata = json.load(json_file) 83 | self.patches_list = [str(list(elem.keys())[0]) for elem in metadata] 84 | self.all_patch_metadata = { 85 | str(list(elem.keys())[0]): elem[str(list(elem.keys())[0])] 86 | for elem in metadata 87 | } 88 | 89 | def load_patch_metadata(self, patch_name: str) -> dict: 90 | """Return the metadata of a patch with given name (including patch suffix, e.g., wsi_1_1.png) 91 | 92 | This function assumes that metadata path is a subpath of the patches dataset path 93 | 94 | Args: 95 | patch_name (str): Name of patch 96 | 97 | Returns: 98 | dict: metadata 99 | """ 100 | patch_metadata_path = self.all_patch_metadata[patch_name]["metadata_path"] 101 | patch_metadata_path = self.patched_slide_path / patch_metadata_path 102 | 103 | # open 104 | with open(patch_metadata_path, "r") as metadata_yaml: 105 | patch_metadata = yaml.safe_load(metadata_yaml) 106 | patch_metadata["name"] = patch_name 107 | 108 | return patch_metadata 109 | 110 | def set_patch_transform(self, transform: Callable) -> None: 111 | """Set the transformation function to process a patch 112 | 113 | Args: 114 | transform (Callable): Transformation function 115 | """ 116 | self.patch_transform = transform 117 | 118 | # patch processing 119 | def process_patch_image( 120 | self, patch_name: str, transform: Callable = None 121 | ) -> Tuple[torch.Tensor, dict]: 122 | """Process one patch: Load from disk, apply transformation if needed. ToTensor is applied automatically 123 | 124 | Args: 125 | patch_name (Path): Name of patch to load, including patch suffix, e.g., wsi_1_1.png 126 | transform (Callable, optional): Optional Patch-Transformation 127 | Returns: 128 | Tuple[torch.Tensor, dict]: 129 | 130 | * torch.Tensor: patch as torch.tensor (:,:,3) 131 | * dict: patch metadata as dictionary 132 | """ 133 | patch = Image.open(self.patched_slide_path / "patches" / patch_name) 134 | if transform: 135 | patch = transform(patch) 136 | 137 | metadata = self.load_patch_metadata(patch_name) 138 | return patch, metadata 139 | 140 | def get_number_patches(self) -> int: 141 | """Return the number of patches for this WSI 142 | 143 | Returns: 144 | int: number of patches 145 | """ 146 | return int(len(self.patches_list)) 147 | 148 | def get_patches( 149 | self, transform: Callable = None 150 | ) -> Tuple[torch.Tensor, list, list]: 151 | """Get all patches for one image 152 | 153 | Args: 154 | transform (Callable, optional): Optional Patch-Transformation 155 | 156 | Returns: 157 | Tuple[torch.Tensor, list]: 158 | 159 | * patched image: Shape of torch.Tensor(num_patches, 3, :, :) 160 | * coordinates as list metadata_dictionary 161 | 162 | """ 163 | if self.logger is not None: 164 | self.logger.warning(f"Loading {self.get_number_patches()} patches!") 165 | patches = [] 166 | metadata = [] 167 | for patch in self.patches_list: 168 | transformed_patch, meta = self.process_patch_image(patch, transform) 169 | patches.append(transformed_patch) 170 | metadata.append(meta) 171 | patches = torch.stack(patches) 172 | 173 | return patches, metadata 174 | 175 | def load_embedding(self) -> torch.Tensor: 176 | """Load embedding from subfolder patched_slide_path/embedding/ 177 | 178 | Raises: 179 | FileNotFoundError: If embedding is not given 180 | 181 | Returns: 182 | torch.Tensor: WSI embedding 183 | """ 184 | embedding_path = ( 185 | self.patched_slide_path / "embeddings" / f"{self.embedding_name}.pt" 186 | ) 187 | if embedding_path.is_file(): 188 | embedding = torch.load(embedding_path) 189 | return embedding 190 | else: 191 | raise FileNotFoundError( 192 | f"Embeddings for WSI {self.slide_path} cannot be found in path {embedding_path}" 193 | ) 194 | -------------------------------------------------------------------------------- /docs/datasets/PanNuke/dataset_config.yaml: -------------------------------------------------------------------------------- 1 | tissue_types: 2 | "Adrenal_gland": 0 3 | "Bile-duct": 1 4 | "Bladder": 2 5 | "Breast": 3 6 | "Cervix": 4 7 | "Colon": 5 8 | "Esophagus": 6 9 | "HeadNeck": 7 10 | "Kidney": 8 11 | "Liver": 9 12 | "Lung": 10 13 | "Ovarian": 11 14 | "Pancreatic": 12 15 | "Prostate": 13 16 | "Skin": 14 17 | "Stomach": 15 18 | "Testis": 16 19 | "Thyroid": 17 20 | "Uterus": 18 21 | 22 | nuclei_types: 23 | "Background": 0 24 | "Neoplastic": 1 25 | "Inflammatory": 2 26 | "Connective": 3 27 | "Dead": 4 28 | "Epithelial": 5 29 | -------------------------------------------------------------------------------- /docs/datasets/PanNuke/weight_config.yaml: -------------------------------------------------------------------------------- 1 | tissue: 2 | "Adrenal_gland": 437 3 | "Bile-duct": 420 4 | "Bladder": 146 5 | "Breast": 2351 6 | "Cervix": 293 7 | "Colon": 1440 8 | "Esophagus": 424 9 | "HeadNeck": 384 10 | "Kidney": 134 11 | "Liver": 224 12 | "Lung": 184 13 | "Ovarian": 146 14 | "Pancreatic": 195 15 | "Prostate": 182 16 | "Skin": 187 17 | "Stomach": 146 18 | "Testis": 196 19 | "Thyroid": 226 20 | "Uterus": 186 21 | -------------------------------------------------------------------------------- /docs/model.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hustvl/LKCell/ae2bf75994c9c93aecb661ec89256bf52d28f09f/docs/model.png -------------------------------------------------------------------------------- /docs/readmes/cell_segmentation.md: -------------------------------------------------------------------------------- 1 | # Cell Segmentation 2 | 3 | ## Training 4 | 5 | The data structure used to train cell segmentation networks is different than to train classification networks on WSI/Patient level. Cureently, due to the massive amount of cells inside a WSI, all famous cell segmentation datasets (such like [PanNuke](https://warwick.ac.uk/fac/cross_fac/tia/data/pannuke), https://doi.org/10.48550/arXiv.2003.10778) provide just patches with cell annotations. Therefore, we use the following dataset structure (with k folds): 6 | 7 | ```bash 8 | dataset 9 | ├── dataset_config.yaml 10 | ├── fold0 11 | │ ├── images 12 | | | ├── 0_imgname0.png 13 | | | ├── 0_imgname1.png 14 | | | ├── 0_imgname2.png 15 | ... 16 | | | └── 0_imgnameN.png 17 | │ ├── labels 18 | | | ├── 0_imgname0.npy 19 | | | ├── 0_imgname1.npy 20 | | | ├── 0_imgname2.npy 21 | ... 22 | | | └── 0_imgnameN.npy 23 | | └── types.csv 24 | ├── fold1 25 | │ ├── images 26 | | | ├── 1_imgname0.png 27 | | | ├── 1_imgname1.png 28 | ... 29 | │ ├── labels 30 | | | ├── 1_imgname0.npy 31 | | | ├── 1_imgname1.npy 32 | ... 33 | | └── types.csv 34 | ... 35 | └── foldk 36 | │ ├── images 37 | | ├── k_imgname0.png 38 | | ├── k_imgname1.png 39 | ... 40 | ├── labels 41 | | ├── k_imgname0.npy 42 | | ├── k_imgname1.npy 43 | └── types.csv 44 | ``` 45 | 46 | Each type csv should have the following header: 47 | ```csv 48 | img,type # Header 49 | foldnum_imgname0.png,SetTypeHeare # Each row is one patch with tissue type 50 | ``` 51 | 52 | The labels are numpy masks with the following structure: 53 | TBD 54 | 55 | ## Add a new dataset 56 | add to dataset coordnator. 57 | 58 | All settings of the dataset must be performed in the correspondinng yaml file, under the data section 59 | 60 | dataset name is **not** case sensitive! 61 | -------------------------------------------------------------------------------- /docs/readmes/monuseg.md: -------------------------------------------------------------------------------- 1 | ## MoNuSeg Preparation 2 | The original PanNuke dataset has the following style using .xml annotations and .tiff files with a size of $1000 \times 1000$ pixels: 3 | 4 | ```bash 5 | ├── testing 6 | │ ├── images 7 | │ │ ├── TCGA-2Z-A9J9-01A-01-TS1.tif 8 | │ │ ├── TCGA-44-2665-01B-06-BS6.tif 9 | ... 10 | │ └── labels 11 | │ ├── TCGA-2Z-A9J9-01A-01-TS1.xml 12 | │ ├── TCGA-44-2665-01B-06-BS6.xml 13 | ... 14 | └── training 15 | ├── images 16 | └── labels 17 | ``` 18 | For our experiments, we resized the dataset images to $1024 \times 1024$ pixels and convert the .xml annotations to binary masks: 19 | ```bash 20 | ├── testing 21 | │ ├── images 22 | │ │ ├── TCGA-2Z-A9J9-01A-01-TS1.png 23 | │ │ ├── TCGA-44-2665-01B-06-BS6.png 24 | ... 25 | │ └── labels 26 | │ │ ├── TCGA-2Z-A9J9-01A-01-TS1.npy 27 | │ │ ├── TCGA-44-2665-01B-06-BS6.npy 28 | ... 29 | └── training 30 | ├── images 31 | └── labels 32 | ``` 33 | 34 | Everythin can be extracted using the [`cell_segmentation/datasets/prepare_monuseg.py`](cell_segmentation/datasets/prepare_monuseg.py) script. 35 | -------------------------------------------------------------------------------- /docs/readmes/pannuke.md: -------------------------------------------------------------------------------- 1 | ## PanNuke Preparation 2 | The original PanNuke dataset has the following style using just one big array for each dataset split: 3 | 4 | ```bash 5 | ├── fold0 6 | │ ├── images.npy 7 | │ ├── masks.npy 8 | │ └── types.npy 9 | ├── fold1 10 | │ ├── images.npy 11 | │ ├── masks.npy 12 | │ └── types.npy 13 | └── fold2 14 | ├── images.npy 15 | ├── masks.npy 16 | └── types.npy 17 | ``` 18 | 19 | For memory efficieny and to make us of multi-threading dataloading with our augmentation pipeline, we reassemble the dataset to the following structure: 20 | ```bash 21 | ├── fold0 22 | │ ├── cell_count.csv # cell-count for each image to be used in sampling 23 | │ ├── images # H&E Image for each sample as .png files 24 | │   ├── images 25 | │   │   ├── 0_0.png 26 | │   │   ├── 0_1.png 27 | │   │   ├── 0_2.png 28 | ... 29 | │ ├── labels # label as .npy arrays for each sample 30 | │   │   ├── 0_0.npy 31 | │   │   ├── 0_1.npy 32 | │   │   ├── 0_2.npy 33 | ... 34 | │ └── types.csv # csv file with type for each image 35 | ├── fold1 36 | │ ├── cell_count.csv 37 | │ ├── images 38 | │   │   ├── 1_0.png 39 | ... 40 | │ ├── labels 41 | │   │   ├── 1_0.npy 42 | ... 43 | │ └── types.csv 44 | ├── fold2 45 | │ ├── cell_count.csv 46 | │ ├── images 47 | │   │   ├── 2_0.png 48 | ... 49 | │ ├── labels 50 | │   │   ├── 2_0.npy 51 | ... 52 | │ └── types.csv 53 | ├── dataset_config.yaml # dataset config with dataset information 54 | └── weight_config.yaml # config file for our sampling 55 | ``` 56 | 57 | We provide all configuration files for the PanNuke dataset in the [`configs/datasets/PanNuke`](configs/datasets/PanNuke) folder. Please copy them in your dataset folder. Images and masks have to be extracted using the [`cell_segmentation/datasets/prepare_pannuke.py`](cell_segmentation/datasets/prepare_pannuke.py) script. 58 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Model implementations and pretrained models 3 | # 4 | # @ Fabian Hörst, fabian.hoerst@uk-essen.de 5 | # Institute for Artifical Intelligence in Medicine, 6 | # University Medicine Essen 7 | -------------------------------------------------------------------------------- /models/segmentation/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hustvl/LKCell/ae2bf75994c9c93aecb661ec89256bf52d28f09f/models/segmentation/__init__.py -------------------------------------------------------------------------------- /models/segmentation/cell_segmentation/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hustvl/LKCell/ae2bf75994c9c93aecb661ec89256bf52d28f09f/models/segmentation/cell_segmentation/__init__.py -------------------------------------------------------------------------------- /models/segmentation/cell_segmentation/__pycache__/UNet_v2.py: -------------------------------------------------------------------------------- 1 | import os.path 2 | import warnings 3 | 4 | import torch 5 | from torch import nn 6 | from unet_v2.pvtv2 import * 7 | import torch.nn.functional as F 8 | 9 | 10 | class ChannelAttention(nn.Module): 11 | def __init__(self, in_planes, ratio=16): 12 | super(ChannelAttention, self).__init__() 13 | self.avg_pool = nn.AdaptiveAvgPool2d(1) 14 | self.max_pool = nn.AdaptiveMaxPool2d(1) 15 | 16 | self.fc1 = nn.Conv2d(in_planes, in_planes // 16, 1, bias=False) 17 | self.relu1 = nn.ReLU() 18 | self.fc2 = nn.Conv2d(in_planes // 16, in_planes, 1, bias=False) 19 | 20 | self.sigmoid = nn.Sigmoid() 21 | 22 | def forward(self, x): 23 | avg_out = self.fc2(self.relu1(self.fc1(self.avg_pool(x)))) 24 | max_out = self.fc2(self.relu1(self.fc1(self.max_pool(x)))) 25 | out = avg_out + max_out 26 | return self.sigmoid(out) 27 | 28 | 29 | class SpatialAttention(nn.Module): 30 | def __init__(self, kernel_size=7): 31 | super(SpatialAttention, self).__init__() 32 | 33 | assert kernel_size in (3, 7), 'kernel size must be 3 or 7' 34 | padding = 3 if kernel_size == 7 else 1 35 | 36 | self.conv1 = nn.Conv2d(2, 1, kernel_size, padding=padding, bias=False) 37 | self.sigmoid = nn.Sigmoid() 38 | 39 | def forward(self, x): 40 | avg_out = torch.mean(x, dim=1, keepdim=True) 41 | max_out, _ = torch.max(x, dim=1, keepdim=True) 42 | x = torch.cat([avg_out, max_out], dim=1) 43 | x = self.conv1(x) 44 | return self.sigmoid(x) 45 | 46 | 47 | class BasicConv2d(nn.Module): 48 | def __init__(self, in_planes, out_planes, kernel_size, stride=1, padding=0, dilation=1): 49 | super(BasicConv2d, self).__init__() 50 | 51 | self.conv = nn.Conv2d(in_planes, out_planes, 52 | kernel_size=kernel_size, stride=stride, 53 | padding=padding, dilation=dilation, bias=False) 54 | self.bn = nn.BatchNorm2d(out_planes) 55 | self.relu = nn.ReLU(inplace=True) 56 | 57 | def forward(self, x): 58 | x = self.conv(x) 59 | x = self.bn(x) 60 | return x 61 | 62 | 63 | class Encoder(nn.Module): 64 | def __init__(self, pretrain_path): 65 | super().__init__() 66 | self.backbone = pvt_v2_b2() 67 | 68 | if pretrain_path is None: 69 | warnings.warn('please provide the pretrained pvt model. Not using pretrained model.') 70 | elif not os.path.isfile(pretrain_path): 71 | warnings.warn(f'path: {pretrain_path} does not exists. Not using pretrained model.') 72 | else: 73 | print(f"using pretrained file: {pretrain_path}") 74 | save_model = torch.load(pretrain_path) 75 | model_dict = self.backbone.state_dict() 76 | state_dict = {k: v for k, v in save_model.items() if k in model_dict.keys()} 77 | model_dict.update(state_dict) 78 | 79 | self.backbone.load_state_dict(model_dict) 80 | 81 | def forward(self, x): 82 | f1, f2, f3, f4 = self.backbone(x) # (x: 3, 352, 352) 83 | return f1, f2, f3, f4 84 | 85 | 86 | class SDI(nn.Module): 87 | def __init__(self, channel): 88 | super().__init__() 89 | 90 | self.convs = nn.ModuleList( 91 | [nn.Conv2d(channel, channel, kernel_size=3, stride=1, padding=1) for _ in range(4)]) 92 | 93 | def forward(self, xs, anchor): 94 | ans = torch.ones_like(anchor) 95 | target_size = anchor.shape[-1] 96 | 97 | for i, x in enumerate(xs): 98 | if x.shape[-1] > target_size: 99 | x = F.adaptive_avg_pool2d(x, (target_size, target_size)) 100 | elif x.shape[-1] < target_size: 101 | x = F.interpolate(x, size=(target_size, target_size), 102 | mode='bilinear', align_corners=True) 103 | 104 | ans = ans * self.convs[i](x) 105 | 106 | return ans 107 | 108 | 109 | class UNetV2(nn.Module): 110 | """ 111 | use SpatialAtt + ChannelAtt 112 | """ 113 | def __init__(self, channel=32, n_classes=1, deep_supervision=True, pretrained_path=None): 114 | super().__init__() 115 | self.deep_supervision = deep_supervision 116 | 117 | self.encoder = Encoder(pretrained_path) 118 | 119 | self.ca_1 = ChannelAttention(64) 120 | self.sa_1 = SpatialAttention() 121 | 122 | self.ca_2 = ChannelAttention(128) 123 | self.sa_2 = SpatialAttention() 124 | 125 | self.ca_3 = ChannelAttention(320) 126 | self.sa_3 = SpatialAttention() 127 | 128 | self.ca_4 = ChannelAttention(512) 129 | self.sa_4 = SpatialAttention() 130 | 131 | self.Translayer_1 = BasicConv2d(64, channel, 1) 132 | self.Translayer_2 = BasicConv2d(128, channel, 1) 133 | self.Translayer_3 = BasicConv2d(320, channel, 1) 134 | self.Translayer_4 = BasicConv2d(512, channel, 1) 135 | 136 | self.sdi_1 = SDI(channel) 137 | self.sdi_2 = SDI(channel) 138 | self.sdi_3 = SDI(channel) 139 | self.sdi_4 = SDI(channel) 140 | 141 | self.seg_outs = nn.ModuleList([ 142 | nn.Conv2d(channel, n_classes, 1, 1) for _ in range(4)]) 143 | 144 | self.deconv2 = nn.ConvTranspose2d(channel, channel, kernel_size=4, stride=2, padding=1, 145 | bias=False) 146 | self.deconv3 = nn.ConvTranspose2d(channel, channel, kernel_size=4, stride=2, 147 | padding=1, bias=False) 148 | self.deconv4 = nn.ConvTranspose2d(channel, channel, kernel_size=4, stride=2, 149 | padding=1, bias=False) 150 | self.deconv5 = nn.ConvTranspose2d(channel, channel, kernel_size=4, stride=2, 151 | padding=1, bias=False) 152 | 153 | def forward(self, x): 154 | seg_outs = [] 155 | f1, f2, f3, f4 = self.encoder(x) 156 | 157 | f1 = self.ca_1(f1) * f1 158 | f1 = self.sa_1(f1) * f1 159 | f1 = self.Translayer_1(f1) 160 | 161 | f2 = self.ca_2(f2) * f2 162 | f2 = self.sa_2(f2) * f2 163 | f2 = self.Translayer_2(f2) 164 | 165 | f3 = self.ca_3(f3) * f3 166 | f3 = self.sa_3(f3) * f3 167 | f3 = self.Translayer_3(f3) 168 | 169 | f4 = self.ca_4(f4) * f4 170 | f4 = self.sa_4(f4) * f4 171 | f4 = self.Translayer_4(f4) 172 | 173 | f41 = self.sdi_4([f1, f2, f3, f4], f4) 174 | f31 = self.sdi_3([f1, f2, f3, f4], f3) 175 | f21 = self.sdi_2([f1, f2, f3, f4], f2) 176 | f11 = self.sdi_1([f1, f2, f3, f4], f1) 177 | 178 | seg_outs.append(self.seg_outs[0](f41)) 179 | 180 | y = self.deconv2(f41) + f31 181 | seg_outs.append(self.seg_outs[1](y)) 182 | 183 | y = self.deconv3(y) + f21 184 | seg_outs.append(self.seg_outs[2](y)) 185 | 186 | y = self.deconv4(y) + f11 187 | seg_outs.append(self.seg_outs[3](y)) 188 | 189 | for i, o in enumerate(seg_outs): 190 | seg_outs[i] = F.interpolate(o, scale_factor=4, mode='bilinear') 191 | 192 | if self.deep_supervision: 193 | return seg_outs[::-1] 194 | else: 195 | return seg_outs[-1] 196 | 197 | 198 | if __name__ == "__main__": 199 | pretrained_path = "/afs/crc.nd.edu/user/y/ypeng4/Polyp-PVT_2/pvt_pth/pvt_v2_b2.pth" 200 | model = UNetV2(n_classes=2, deep_supervision=True, pretrained_path=None) 201 | x = torch.rand((2, 3, 256, 256)) 202 | ys = model(x) 203 | for y in ys: 204 | print(y.shape) 205 | -------------------------------------------------------------------------------- /models/segmentation/cell_segmentation/__pycache__/__init__.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hustvl/LKCell/ae2bf75994c9c93aecb661ec89256bf52d28f09f/models/segmentation/cell_segmentation/__pycache__/__init__.cpython-310.pyc -------------------------------------------------------------------------------- /models/segmentation/cell_segmentation/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hustvl/LKCell/ae2bf75994c9c93aecb661ec89256bf52d28f09f/models/segmentation/cell_segmentation/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /models/segmentation/cell_segmentation/__pycache__/cellvit.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hustvl/LKCell/ae2bf75994c9c93aecb661ec89256bf52d28f09f/models/segmentation/cell_segmentation/__pycache__/cellvit.cpython-310.pyc -------------------------------------------------------------------------------- /models/segmentation/cell_segmentation/__pycache__/cellvit.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hustvl/LKCell/ae2bf75994c9c93aecb661ec89256bf52d28f09f/models/segmentation/cell_segmentation/__pycache__/cellvit.cpython-38.pyc -------------------------------------------------------------------------------- /models/segmentation/cell_segmentation/__pycache__/cellvit_shared.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hustvl/LKCell/ae2bf75994c9c93aecb661ec89256bf52d28f09f/models/segmentation/cell_segmentation/__pycache__/cellvit_shared.cpython-38.pyc -------------------------------------------------------------------------------- /models/segmentation/cell_segmentation/__pycache__/cellvit_unirepLKnet.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hustvl/LKCell/ae2bf75994c9c93aecb661ec89256bf52d28f09f/models/segmentation/cell_segmentation/__pycache__/cellvit_unirepLKnet.cpython-310.pyc -------------------------------------------------------------------------------- /models/segmentation/cell_segmentation/__pycache__/cellvit_unirepLKnet.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hustvl/LKCell/ae2bf75994c9c93aecb661ec89256bf52d28f09f/models/segmentation/cell_segmentation/__pycache__/cellvit_unirepLKnet.cpython-38.pyc -------------------------------------------------------------------------------- /models/segmentation/cell_segmentation/__pycache__/replknet.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hustvl/LKCell/ae2bf75994c9c93aecb661ec89256bf52d28f09f/models/segmentation/cell_segmentation/__pycache__/replknet.cpython-38.pyc -------------------------------------------------------------------------------- /models/segmentation/cell_segmentation/__pycache__/utils.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hustvl/LKCell/ae2bf75994c9c93aecb661ec89256bf52d28f09f/models/segmentation/cell_segmentation/__pycache__/utils.cpython-38.pyc -------------------------------------------------------------------------------- /models/segmentation/cell_segmentation/utils.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from einops import rearrange 3 | from models.encoders.VIT.SAM.image_encoder import ImageEncoderViT 4 | from models.encoders.VIT.vits_histo import VisionTransformer 5 | 6 | import torch 7 | import torch.nn as nn 8 | from typing import Callable, Tuple, Type, List 9 | 10 | 11 | class Conv2DBlock(nn.Module): 12 | """Conv2DBlock with convolution followed by batch-normalisation, ReLU activation and dropout 13 | 14 | Args: 15 | in_channels (int): Number of input channels for convolution 16 | out_channels (int): Number of output channels for convolution 17 | kernel_size (int, optional): Kernel size for convolution. Defaults to 3. 18 | dropout (float, optional): Dropout. Defaults to 0. 19 | """ 20 | 21 | def __init__( 22 | self, 23 | in_channels: int, 24 | out_channels: int, 25 | kernel_size: int = 3, 26 | dropout: float = 0, 27 | ) -> None: 28 | super().__init__() 29 | self.block = nn.Sequential( 30 | nn.Conv2d( 31 | in_channels=in_channels, 32 | out_channels=out_channels, 33 | kernel_size=kernel_size, 34 | stride=1, 35 | padding=((kernel_size - 1) // 2), 36 | ), 37 | nn.BatchNorm2d(out_channels), 38 | nn.ReLU(True), 39 | nn.Dropout(dropout), 40 | ) 41 | 42 | def forward(self, x): 43 | return self.block(x) 44 | 45 | 46 | class Deconv2DBlock(nn.Module): 47 | """Deconvolution block with ConvTranspose2d followed by Conv2d, batch-normalisation, ReLU activation and dropout 48 | 49 | Args: 50 | in_channels (int): Number of input channels for deconv block 51 | out_channels (int): Number of output channels for deconv and convolution. 52 | kernel_size (int, optional): Kernel size for convolution. Defaults to 3. 53 | dropout (float, optional): Dropout. Defaults to 0. 54 | """ 55 | 56 | def __init__( 57 | self, 58 | in_channels: int, 59 | out_channels: int, 60 | kernel_size: int = 3, 61 | dropout: float = 0, 62 | ) -> None: 63 | super().__init__() 64 | self.block = nn.Sequential( 65 | nn.ConvTranspose2d( 66 | in_channels=in_channels, 67 | out_channels=out_channels, 68 | kernel_size=2, 69 | stride=2, 70 | padding=0, 71 | output_padding=0, 72 | ), 73 | nn.Conv2d( 74 | in_channels=out_channels, 75 | out_channels=out_channels, 76 | kernel_size=kernel_size, 77 | stride=1, 78 | padding=((kernel_size - 1) // 2), 79 | ), 80 | nn.BatchNorm2d(out_channels), 81 | nn.ReLU(True), 82 | nn.Dropout(dropout), 83 | ) 84 | 85 | def forward(self, x): 86 | return self.block(x) 87 | 88 | 89 | class ViTCellViT(VisionTransformer): 90 | def __init__( 91 | self, 92 | extract_layers: List[int], 93 | img_size: List[int] = [224], 94 | patch_size: int = 16, 95 | in_chans: int = 3, 96 | num_classes: int = 0, 97 | embed_dim: int = 768, 98 | depth: int = 12, 99 | num_heads: int = 12, 100 | mlp_ratio: float = 4, 101 | qkv_bias: bool = False, 102 | qk_scale: float = None, 103 | drop_rate: float = 0, 104 | attn_drop_rate: float = 0, 105 | drop_path_rate: float = 0, 106 | norm_layer: Callable = nn.LayerNorm, 107 | **kwargs 108 | ): 109 | """Vision Transformer with 1D positional embedding 110 | 111 | Args: 112 | extract_layers: (List[int]): List of Transformer Blocks whose outputs should be returned in addition to the tokens. First blocks starts with 1, and maximum is N=depth. 113 | img_size (int, optional): Input image size. Defaults to 224. 114 | patch_size (int, optional): Patch Token size (one dimension only, cause tokens are squared). Defaults to 16. 115 | in_chans (int, optional): Number of input channels. Defaults to 3. 116 | num_classes (int, optional): Number of output classes. if num classes = 0, raw tokens are returned (nn.Identity). 117 | Default to 0. 118 | embed_dim (int, optional): Embedding dimension. Defaults to 768. 119 | depth(int, optional): Number of Transformer Blocks. Defaults to 12. 120 | num_heads (int, optional): Number of attention heads per Transformer Block. Defaults to 12. 121 | mlp_ratio (float, optional): MLP ratio for hidden MLP dimension (Bottleneck = dim*mlp_ratio). 122 | Defaults to 4.0. 123 | qkv_bias (bool, optional): If bias should be used for query (q), key (k), and value (v). Defaults to False. 124 | qk_scale (float, optional): Scaling parameter. Defaults to None. 125 | drop_rate (float, optional): Dropout in MLP. Defaults to 0.0. 126 | attn_drop_rate (float, optional): Dropout for attention layer. Defaults to 0.0. 127 | drop_path_rate (float, optional): Dropout for skip connection. Defaults to 0.0. 128 | norm_layer (Callable, optional): Normalization layer. Defaults to nn.LayerNorm. 129 | 130 | """ 131 | super().__init__( 132 | img_size=img_size, 133 | patch_size=patch_size, 134 | in_chans=in_chans, 135 | num_classes=num_classes, 136 | embed_dim=embed_dim, 137 | depth=depth, 138 | num_heads=num_heads, 139 | mlp_ratio=mlp_ratio, 140 | qkv_bias=qkv_bias, 141 | qk_scale=qk_scale, 142 | drop_rate=drop_rate, 143 | attn_drop_rate=attn_drop_rate, 144 | drop_path_rate=drop_path_rate, 145 | norm_layer=norm_layer, 146 | ) 147 | self.extract_layers = extract_layers 148 | 149 | def forward( 150 | self, x: torch.Tensor 151 | ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: 152 | """Forward pass with returning intermediate outputs for skip connections 153 | 154 | Args: 155 | x (torch.Tensor): Input batch 156 | 157 | Returns: 158 | Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: 159 | torch.Tensor: Output of last layers (all tokens, without classification) 160 | torch.Tensor: Classification output 161 | torch.Tensor: Skip connection outputs from extract_layer selection 162 | """ 163 | extracted_layers = [] 164 | x = self.prepare_tokens(x) 165 | 166 | for depth, blk in enumerate(self.blocks): 167 | x = blk(x) 168 | if depth + 1 in self.extract_layers: 169 | extracted_layers.append(x) 170 | 171 | x = self.norm(x) 172 | output = self.head(x[:, 0]) 173 | 174 | return output, x[:, 0], extracted_layers 175 | 176 | 177 | class ViTCellViTDeit(ImageEncoderViT): 178 | def __init__( 179 | self, 180 | extract_layers: List[int], 181 | img_size: int = 1024, 182 | patch_size: int = 16, 183 | in_chans: int = 3, 184 | embed_dim: int = 768, 185 | depth: int = 12, 186 | num_heads: int = 12, 187 | mlp_ratio: float = 4, 188 | out_chans: int = 256, 189 | qkv_bias: bool = True, 190 | norm_layer: Type[nn.Module] = nn.LayerNorm, 191 | act_layer: Type[nn.Module] = nn.GELU, 192 | use_abs_pos: bool = True, 193 | use_rel_pos: bool = False, 194 | rel_pos_zero_init: bool = True, 195 | window_size: int = 0, 196 | global_attn_indexes: Tuple[int, ...] = (), 197 | ) -> None: 198 | super().__init__( 199 | img_size, 200 | patch_size, 201 | in_chans, 202 | embed_dim, 203 | depth, 204 | num_heads, 205 | mlp_ratio, 206 | out_chans, 207 | qkv_bias, 208 | norm_layer, 209 | act_layer, 210 | use_abs_pos, 211 | use_rel_pos, 212 | rel_pos_zero_init, 213 | window_size, 214 | global_attn_indexes, 215 | ) 216 | self.extract_layers = extract_layers 217 | 218 | def forward(self, x: torch.Tensor) -> torch.Tensor: 219 | extracted_layers = [] 220 | x = self.patch_embed(x) 221 | 222 | if self.pos_embed is not None: 223 | token_size = x.shape[1] 224 | x = x + self.pos_embed[:, :token_size, :token_size, :] 225 | 226 | for depth, blk in enumerate(self.blocks): 227 | x = blk(x) 228 | if depth + 1 in self.extract_layers: 229 | extracted_layers.append(x) 230 | output = self.neck(x.permute(0, 3, 1, 2)) 231 | _output = rearrange(output, "b c h w -> b c (h w)") 232 | 233 | return torch.mean(_output, axis=-1), output, extracted_layers 234 | -------------------------------------------------------------------------------- /models/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hustvl/LKCell/ae2bf75994c9c93aecb661ec89256bf52d28f09f/models/utils/__init__.py -------------------------------------------------------------------------------- /models/utils/attention.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # PyTorch Implementation of Attention Modules 3 | # 4 | # Implementation based on: https://github.com/mahmoodlab/CLAM 5 | # @ Fabian Hörst, fabian.hoerst@uk-essen.de 6 | # Institute for Artifical Intelligence in Medicine, 7 | # University Medicine Essen 8 | 9 | from typing import Tuple 10 | import torch 11 | import torch.nn as nn 12 | 13 | 14 | class Attention(nn.Module): 15 | """Basic Attention module. Compare https://github.com/AMLab-Amsterdam/AttentionDeepMIL 16 | 17 | Args: 18 | in_features (int, optional): Input shape of attention module. Defaults to 1024. 19 | attention_features (int, optional): Number of attention features. Defaults to 128. 20 | num_classes (int, optional): Number of output classes. Defaults to 2. 21 | dropout (bool, optional): If True, dropout is used. Defaults to False. 22 | dropout_rate (float, optional): Dropout rate, just applies if dropout parameter is true. 23 | Needs to be between 0.0 and 1.0. Defaults to 0.25. 24 | """ 25 | 26 | def __init__( 27 | self, 28 | in_features: int = 1024, 29 | attention_features: int = 128, 30 | num_classes: int = 2, 31 | dropout: bool = False, 32 | dropout_rate: float = 0.25, 33 | ): 34 | super(Attention, self).__init__() 35 | # naming 36 | self.model_name = "AttentionModule" 37 | 38 | # set parameter dimensions for attention 39 | self.attention_features = attention_features 40 | self.in_features = in_features 41 | self.num_classes = num_classes 42 | self.dropout = dropout 43 | self.d_rate = dropout_rate 44 | 45 | if self.dropout: 46 | assert self.d_rate < 1 47 | self.attention = nn.Sequential( 48 | nn.Linear(self.in_features, self.attention_features), 49 | nn.Tanh(), 50 | nn.Dropout(self.d_rate), 51 | nn.Linear(self.attention_features, self.num_classes), 52 | ) 53 | else: 54 | self.attention = nn.Sequential( 55 | nn.Linear(self.in_features, self.attention_features), 56 | nn.Tanh(), 57 | nn.Linear(self.attention_features, self.num_classes), 58 | ) 59 | 60 | def forward(self, H: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: 61 | """Forward pass, calculating attention scores for given input vector 62 | 63 | Args: 64 | H (torch.Tensor): Bag of instances. Shape: (Number of instances, Feature-dimensions) 65 | 66 | Returns: 67 | Tuple[torch.Tensor, torch.Tensor]: 68 | 69 | * Attention-Scores 70 | * H. Shape: Bag of instances. Shape: (Number of instances, Feature-dimensions) 71 | """ 72 | A = self.attention(H) 73 | return A, H 74 | 75 | 76 | class AttentionGated(nn.Module): 77 | """Gated Attention module. Compare https://github.com/AMLab-Amsterdam/AttentionDeepMIL 78 | 79 | Args: 80 | in_features (int, optional): Input shape of attention module. Defaults to 1024. 81 | attention_features (int, optional): Number of attention features. Defaults to 128. 82 | num_classes (int, optional): Number of output classes. Defaults to 2. 83 | dropout (bool, optional): If True, dropout is used. Defaults to False. 84 | dropout_rate (float, optional): Dropout rate, just applies if dropout parameter is true. 85 | needs to be between 0.0 and 1.0. Defaults to 0.25. 86 | """ 87 | 88 | def __init__( 89 | self, 90 | in_features: int = 1024, 91 | attention_features: int = 128, 92 | num_classes: int = 2, 93 | dropout: bool = False, 94 | dropout_rate: float = 0.25, 95 | ): 96 | super(AttentionGated, self).__init__() 97 | # naming 98 | self.model_name = "AttentionModuleGated" 99 | 100 | # set Parameter dimensions for attention 101 | self.attention_features = attention_features 102 | self.in_features = in_features 103 | self.num_classes = num_classes 104 | self.dropout = dropout 105 | self.d_rate = dropout_rate 106 | 107 | if self.dropout: 108 | assert self.d_rate < 1 109 | self.attention_V = nn.Sequential( 110 | nn.Linear(self.in_features, self.attention_features), 111 | nn.Tanh(), 112 | nn.Dropout(self.d_rate), 113 | ) 114 | self.attention_U = nn.Sequential( 115 | nn.Linear(self.in_features, self.attention_features), 116 | nn.Sigmoid(), 117 | nn.Dropout(self.d_rate), 118 | ) 119 | self.attention_W = nn.Sequential( 120 | nn.Linear(self.attention_features, self.num_classes) 121 | ) 122 | 123 | else: 124 | self.attention_V = nn.Sequential( 125 | nn.Linear(self.in_features, self.attention_features), nn.Tanh() 126 | ) 127 | self.attention_U = nn.Sequential( 128 | nn.Linear(self.in_features, self.attention_features), nn.Sigmoid() 129 | ) 130 | self.attention_W = nn.Sequential( 131 | nn.Linear(self.attention_features, self.num_classes) 132 | ) 133 | 134 | def forward(self, H: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: 135 | """Forward pass, calculating attention scores for given input vector 136 | 137 | Args: 138 | H (torch.Tensor): Bag of instances. Shape: (Number of instances, Feature-dimensions) 139 | 140 | Returns: 141 | Tuple[torch.Tensor, torch.Tensor]: 142 | 143 | * Attention-Scores. Shape: (Number of instances) 144 | * H. Shape: Bag of instances. Shape: (Number of instances, Feature-dimensions) 145 | """ 146 | v = self.attention_V(H) 147 | u = self.attention_U(H) 148 | A = self.attention_W(v * u) 149 | return A, H 150 | -------------------------------------------------------------------------------- /models/utils/dense.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Dense Block as defined in: 3 | # Huang, Gao, Zhuang Liu, Laurens Van Der Maaten, and Kilian Q. Weinberger. 4 | # "Densely connected convolutional networks." In Proceedings of the IEEE conference 5 | # on computer vision and pattern recognition, pp. 4700-4708. 2017. 6 | # 7 | # Code Snippet adapted from HoverNet implementation (https://github.com/vqdang/hover_net) 8 | # 9 | # @ Fabian Hörst, fabian.hoerst@uk-essen.de 10 | # Institute for Artifical Intelligence in Medicine, 11 | # University Medicine Essen 12 | 13 | 14 | import torch 15 | import torch.nn as nn 16 | 17 | from collections import OrderedDict 18 | 19 | 20 | class DenseBlock(nn.Module): 21 | """Dense Block as defined in: 22 | 23 | Huang, Gao, Zhuang Liu, Laurens Van Der Maaten, and Kilian Q. Weinberger. 24 | "Densely connected convolutional networks." In Proceedings of the IEEE conference 25 | on computer vision and pattern recognition, pp. 4700-4708. 2017. 26 | 27 | Only performs `valid` convolution. 28 | 29 | """ 30 | 31 | def __init__(self, in_ch, unit_ksize, unit_ch, unit_count, split=1): 32 | super(DenseBlock, self).__init__() 33 | assert len(unit_ksize) == len(unit_ch), "Unbalance Unit Info" 34 | 35 | self.nr_unit = unit_count 36 | self.in_ch = in_ch 37 | self.unit_ch = unit_ch 38 | 39 | # ! For inference only so init values for batchnorm may not match tensorflow 40 | unit_in_ch = in_ch 41 | self.units = nn.ModuleList() 42 | for idx in range(unit_count): 43 | self.units.append( 44 | nn.Sequential( 45 | OrderedDict( 46 | [ 47 | ("preact_bna/bn", nn.BatchNorm2d(unit_in_ch, eps=1e-5)), 48 | ("preact_bna/relu", nn.ReLU(inplace=True)), 49 | ( 50 | "conv1", 51 | nn.Conv2d( 52 | unit_in_ch, 53 | unit_ch[0], 54 | unit_ksize[0], 55 | stride=1, 56 | padding=0, 57 | bias=False, 58 | ), 59 | ), 60 | ("conv1/bn", nn.BatchNorm2d(unit_ch[0], eps=1e-5)), 61 | ("conv1/relu", nn.ReLU(inplace=True)), 62 | # ('conv2/pool', TFSamepaddingLayer(ksize=unit_ksize[1], stride=1)), 63 | ( 64 | "conv2", 65 | nn.Conv2d( 66 | unit_ch[0], 67 | unit_ch[1], 68 | unit_ksize[1], 69 | groups=split, 70 | stride=1, 71 | padding=0, 72 | bias=False, 73 | ), 74 | ), 75 | ] 76 | ) 77 | ) 78 | ) 79 | unit_in_ch += unit_ch[1] 80 | 81 | self.blk_bna = nn.Sequential( 82 | OrderedDict( 83 | [ 84 | ("bn", nn.BatchNorm2d(unit_in_ch, eps=1e-5)), 85 | ("relu", nn.ReLU(inplace=True)), 86 | ] 87 | ) 88 | ) 89 | 90 | def out_ch(self): 91 | return self.in_ch + self.nr_unit * self.unit_ch[-1] 92 | 93 | def init_weights(self): 94 | """Kaiming (HE) initialization for convolutional layers and constant initialization for normalization and linear layers""" 95 | for m in self.modules(): 96 | classname = m.__class__.__name__ 97 | 98 | if isinstance(m, nn.Conv2d): 99 | nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu") 100 | 101 | if "norm" in classname.lower(): 102 | nn.init.constant_(m.weight, 1) 103 | nn.init.constant_(m.bias, 0) 104 | 105 | if "linear" in classname.lower(): 106 | if m.bias is not None: 107 | nn.init.constant_(m.bias, 0) 108 | 109 | def forward(self, prev_feat): 110 | for idx in range(self.nr_unit): 111 | new_feat = self.units[idx](prev_feat) 112 | prev_feat = crop_to_shape(prev_feat, new_feat) 113 | prev_feat = torch.cat([prev_feat, new_feat], dim=1) 114 | prev_feat = self.blk_bna(prev_feat) 115 | 116 | return prev_feat 117 | 118 | 119 | # helper functions for cropping 120 | def crop_op(x, cropping, data_format="NCHW"): 121 | """Center crop image. 122 | 123 | Args: 124 | x: input image 125 | cropping: the substracted amount 126 | data_format: choose either `NCHW` or `NHWC` 127 | 128 | """ 129 | crop_t = cropping[0] // 2 130 | crop_b = cropping[0] - crop_t 131 | crop_l = cropping[1] // 2 132 | crop_r = cropping[1] - crop_l 133 | if data_format == "NCHW": 134 | x = x[:, :, crop_t:-crop_b, crop_l:-crop_r] 135 | else: 136 | x = x[:, crop_t:-crop_b, crop_l:-crop_r, :] 137 | return x 138 | 139 | 140 | def crop_to_shape(x, y, data_format="NCHW"): 141 | """Centre crop x so that x has shape of y. y dims must be smaller than x dims. 142 | 143 | Args: 144 | x: input array 145 | y: array with desired shape. 146 | 147 | """ 148 | assert ( 149 | y.shape[0] <= x.shape[0] and y.shape[1] <= x.shape[1] 150 | ), "Ensure that y dimensions are smaller than x dimensions!" 151 | 152 | x_shape = x.size() 153 | y_shape = y.size() 154 | if data_format == "NCHW": 155 | crop_shape = (x_shape[2] - y_shape[2], x_shape[3] - y_shape[3]) 156 | else: 157 | crop_shape = (x_shape[1] - y_shape[1], x_shape[2] - y_shape[2]) 158 | return crop_op(x, crop_shape, data_format) 159 | -------------------------------------------------------------------------------- /models/utils/residual.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Residual block as defined in: 3 | # He, Kaiming, Xiangyu Zhang, Shaoqing Ren, and Jian Sun. "Deep residual learning 4 | # for image recognition." In Proceedings of the IEEE conference on computer vision 5 | # and pattern recognition, pp. 770-778. 2016. 6 | # 7 | # Code Snippet adapted from HoverNet implementation (https://github.com/vqdang/hover_net) 8 | # 9 | # @ Fabian Hörst, fabian.hoerst@uk-essen.de 10 | # Institute for Artifical Intelligence in Medicine, 11 | # University Medicine Essen 12 | 13 | 14 | import torch 15 | import torch.nn as nn 16 | 17 | from collections import OrderedDict 18 | 19 | from models.utils.tf_utils import TFSamepaddingLayer 20 | 21 | 22 | class ResidualBlock(nn.Module): 23 | """Residual block as defined in: 24 | 25 | He, Kaiming, Xiangyu Zhang, Shaoqing Ren, and Jian Sun. "Deep residual learning 26 | for image recognition." In Proceedings of the IEEE conference on computer vision 27 | and pattern recognition, pp. 770-778. 2016. 28 | 29 | """ 30 | 31 | def __init__(self, in_ch, unit_ksize, unit_ch, unit_count, stride=1): 32 | super(ResidualBlock, self).__init__() 33 | assert len(unit_ksize) == len(unit_ch), "Unbalance Unit Info" 34 | 35 | self.nr_unit = unit_count 36 | self.in_ch = in_ch 37 | self.unit_ch = unit_ch 38 | 39 | # ! For inference only so init values for batchnorm may not match tensorflow 40 | unit_in_ch = in_ch 41 | self.units = nn.ModuleList() 42 | for idx in range(unit_count): 43 | unit_layer = [ 44 | ("preact/bn", nn.BatchNorm2d(unit_in_ch, eps=1e-5)), 45 | ("preact/relu", nn.ReLU(inplace=True)), 46 | ( 47 | "conv1", 48 | nn.Conv2d( 49 | unit_in_ch, 50 | unit_ch[0], 51 | unit_ksize[0], 52 | stride=1, 53 | padding=0, 54 | bias=False, 55 | ), 56 | ), 57 | ("conv1/bn", nn.BatchNorm2d(unit_ch[0], eps=1e-5)), 58 | ("conv1/relu", nn.ReLU(inplace=True)), 59 | ( 60 | "conv2/pad", 61 | TFSamepaddingLayer( 62 | ksize=unit_ksize[1], stride=stride if idx == 0 else 1 63 | ), 64 | ), 65 | ( 66 | "conv2", 67 | nn.Conv2d( 68 | unit_ch[0], 69 | unit_ch[1], 70 | unit_ksize[1], 71 | stride=stride if idx == 0 else 1, 72 | padding=0, 73 | bias=False, 74 | ), 75 | ), 76 | ("conv2/bn", nn.BatchNorm2d(unit_ch[1], eps=1e-5)), 77 | ("conv2/relu", nn.ReLU(inplace=True)), 78 | ( 79 | "conv3", 80 | nn.Conv2d( 81 | unit_ch[1], 82 | unit_ch[2], 83 | unit_ksize[2], 84 | stride=1, 85 | padding=0, 86 | bias=False, 87 | ), 88 | ), 89 | ] 90 | # * has bna to conclude each previous block so 91 | # * must not put preact for the first unit of this block 92 | unit_layer = unit_layer if idx != 0 else unit_layer[2:] 93 | self.units.append(nn.Sequential(OrderedDict(unit_layer))) 94 | unit_in_ch = unit_ch[-1] 95 | 96 | if in_ch != unit_ch[-1] or stride != 1: 97 | self.shortcut = nn.Conv2d(in_ch, unit_ch[-1], 1, stride=stride, bias=False) 98 | else: 99 | self.shortcut = None 100 | 101 | self.blk_bna = nn.Sequential( 102 | OrderedDict( 103 | [ 104 | ("bn", nn.BatchNorm2d(unit_in_ch, eps=1e-5)), 105 | ("relu", nn.ReLU(inplace=True)), 106 | ] 107 | ) 108 | ) 109 | 110 | def out_ch(self): 111 | return self.unit_ch[-1] 112 | 113 | def init_weights(self): 114 | """Kaiming (HE) initialization for convolutional layers and constant initialization for normalization and linear layers""" 115 | for m in self.modules(): 116 | classname = m.__class__.__name__ 117 | 118 | if isinstance(m, nn.Conv2d): 119 | nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu") 120 | 121 | if "norm" in classname.lower(): 122 | nn.init.constant_(m.weight, 1) 123 | nn.init.constant_(m.bias, 0) 124 | 125 | if "linear" in classname.lower(): 126 | if m.bias is not None: 127 | nn.init.constant_(m.bias, 0) 128 | 129 | def forward(self, prev_feat, freeze=False): 130 | if self.shortcut is None: 131 | shortcut = prev_feat 132 | else: 133 | shortcut = self.shortcut(prev_feat) 134 | 135 | for idx in range(0, len(self.units)): 136 | new_feat = prev_feat 137 | if self.training: 138 | with torch.set_grad_enabled(not freeze): 139 | new_feat = self.units[idx](new_feat) 140 | else: 141 | new_feat = self.units[idx](new_feat) 142 | prev_feat = new_feat + shortcut 143 | shortcut = prev_feat 144 | feat = self.blk_bna(prev_feat) 145 | return feat 146 | -------------------------------------------------------------------------------- /models/utils/tf_utils.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class TFSamepaddingLayer(nn.Module): 7 | """To align with tf `same` padding. 8 | 9 | Putting this before any conv layer that need padding 10 | Assuming kernel has Height == Width for simplicity 11 | """ 12 | 13 | def __init__(self, ksize, stride): 14 | super(TFSamepaddingLayer, self).__init__() 15 | self.ksize = ksize 16 | self.stride = stride 17 | 18 | def forward(self, x): 19 | if x.shape[2] % self.stride == 0: 20 | pad = max(self.ksize - self.stride, 0) 21 | else: 22 | pad = max(self.ksize - (x.shape[2] % self.stride), 0) 23 | 24 | if pad % 2 == 0: 25 | pad_val = pad // 2 26 | padding = (pad_val, pad_val, pad_val, pad_val) 27 | else: 28 | pad_val_start = pad // 2 29 | pad_val_end = pad - pad_val_start 30 | padding = (pad_val_start, pad_val_end, pad_val_start, pad_val_end) 31 | # print(x.shape, padding) 32 | x = F.pad(x, padding, "constant", 0) 33 | # print(x.shape) 34 | return x 35 | -------------------------------------------------------------------------------- /models/utils/tools.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Helper functions for models 3 | # 4 | # @ Fabian Hörst, fabian.hoerst@uk-essen.de 5 | # Institute for Artifical Intelligence in Medicine, 6 | # University Medicine Essen 7 | 8 | from torch import nn 9 | 10 | 11 | def reset_weights(model: nn.Module) -> None: 12 | """Reset the parameters of the model to avaid weight leakage 13 | 14 | Args: 15 | model (nn.Module): PyTorch Model 16 | """ 17 | for layer in model.children(): 18 | if hasattr(layer, "reset_parameters"): 19 | layer.reset_parameters() 20 | 21 | 22 | def initialize_weights(module: nn.Module) -> None: 23 | """Initialize Module weights according to xavier 24 | 25 | Args: 26 | module (nn.Module): Model 27 | """ 28 | for m in module.modules(): 29 | if isinstance(m, nn.Linear): 30 | nn.init.xavier_normal_(m.weight) 31 | m.bias.data.zero_() 32 | 33 | elif isinstance(m, nn.BatchNorm1d): 34 | nn.init.constant_(m.weight, 1) 35 | nn.init.constant_(m.bias, 0) 36 | -------------------------------------------------------------------------------- /preprocessing/encoding/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import logging 3 | 4 | logger = logging.getLogger(__name__) 5 | logger.addHandler(logging.NullHandler()) 6 | -------------------------------------------------------------------------------- /preprocessing/encoding/datasets/patched_wsi_inference.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Patched WSI Dataset used for inference, mainly for calculating embeddings 3 | # 4 | # @ Fabian Hörst, fabian.hoerst@uk-essen.de 5 | # Institute for Artifical Intelligence in Medicine, 6 | # University Medicine Essen 7 | 8 | from typing import Callable, Tuple, List 9 | 10 | import torch 11 | from torch.utils.data import Dataset 12 | from datamodel.wsi_datamodel import WSI 13 | 14 | 15 | class PatchedWSIInference(Dataset): 16 | """Inference Dataset, used for calculating embeddings of *one* WSI. Wrapped around a WSI object 17 | 18 | Args: 19 | wsi_object ( 20 | filelist (list[str]): List with filenames as entries. Filenames should match the key pattern in wsi_objects dictionary 21 | transform (Callable): Inference Transformations 22 | """ 23 | 24 | def __init__( 25 | self, 26 | wsi_object: WSI, 27 | transform: Callable, 28 | ) -> None: 29 | # set all configurations 30 | assert isinstance(wsi_object, WSI), "Must be a WSI-object" 31 | assert ( 32 | wsi_object.patched_slide_path is not None 33 | ), "Please provide a WSI that already has been patched into slices" 34 | 35 | self.transform = transform 36 | self.wsi_object = wsi_object 37 | 38 | def __getitem__( 39 | self, idx: int 40 | ) -> Tuple[torch.Tensor, list[list[str, str]], list[str], int, str]: 41 | """Returns one WSI with patches, coords, filenames, labels and wsi name for given idx 42 | 43 | Args: 44 | idx (int): Index of WSI to retrieve 45 | 46 | Returns: 47 | Tuple[torch.Tensor, list[list[str,str]], list[str], int, str]: 48 | 49 | * torch.Tensor: Tensor with shape [num_patches, 3, height, width], includes all patches for one WSI 50 | * list[list[str,str]]: List with coordinates as list entries, e.g., [['1', '1'], ['2', '1'], ..., ['row', 'col']] 51 | * list[str]: List with patch filenames 52 | * int: Patient label as integer 53 | * str: String with WSI name 54 | """ 55 | patch_name = self.wsi_object.patches_list[idx] 56 | 57 | patch, metadata = self.wsi_object.process_patch_image( 58 | patch_name=patch_name, transform=self.transform 59 | ) 60 | 61 | return patch, metadata 62 | 63 | def __len__(self) -> int: 64 | """Return len of dataset 65 | 66 | Returns: 67 | int: Len of dataset 68 | """ 69 | return int(self.wsi_object.get_number_patches()) 70 | 71 | @staticmethod 72 | def collate_batch(batch: List[Tuple]) -> Tuple[torch.Tensor, list[dict]]: 73 | """Create a custom batch 74 | 75 | Needed to unpack List of tuples with dictionaries and array 76 | 77 | Args: 78 | batch (List[Tuple]): Input batch consisting of a list of tuples (patch, patch-metadata) 79 | 80 | Returns: 81 | Tuple[torch.Tensor, list[dict]]: 82 | New batch: patches with shape [batch_size, 3, patch_size, patch_size], list of metadata dicts 83 | """ 84 | patches, metadata = zip(*batch) 85 | patches = torch.stack(patches) 86 | metadata = list(metadata) 87 | return patches, metadata 88 | -------------------------------------------------------------------------------- /preprocessing/patch_extraction/main_extraction.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Main entry point for patch-preprocessing 3 | # 4 | # @ Fabian Hörst, fabian.hoerst@uk-essen.de 5 | # Institute for Artifical Intelligence in Medicine, 6 | # University Medicine Essen 7 | 8 | import inspect 9 | import logging 10 | import os 11 | import sys 12 | 13 | 14 | logger = logging.getLogger() 15 | logger.addHandler(logging.NullHandler()) 16 | 17 | currentdir = os.path.dirname(os.path.abspath(inspect.getfile(inspect.currentframe()))) 18 | parentdir = os.path.dirname(currentdir) 19 | sys.path.insert(0, parentdir) 20 | parentdir = os.path.dirname(parentdir) 21 | sys.path.insert(0, parentdir) 22 | 23 | from preprocessing.patch_extraction.src.cli import PreProcessingParser 24 | from preprocessing.patch_extraction.src.patch_extraction import PreProcessor 25 | from utils.tools import close_logger 26 | 27 | if __name__ == "__main__": 28 | configuration_parser = PreProcessingParser() 29 | configuration, logger = configuration_parser.get_config() 30 | configuration_parser.store_config() 31 | 32 | slide_processor = PreProcessor(slide_processor_config=configuration) 33 | slide_processor.sample_patches_dataset() 34 | 35 | logger.info("Finished Preprocessing.") 36 | close_logger(logger) 37 | -------------------------------------------------------------------------------- /preprocessing/patch_extraction/scripts/macenko.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import inspect 3 | import logging 4 | import os 5 | import sys 6 | 7 | logger = logging.getLogger() 8 | logger.addHandler(logging.NullHandler()) 9 | 10 | currentdir = os.path.dirname(os.path.abspath(inspect.getfile(inspect.currentframe()))) 11 | parentdir = os.path.dirname(currentdir) 12 | sys.path.insert(0, parentdir) 13 | parentdir = os.path.dirname(parentdir) 14 | sys.path.insert(0, parentdir) 15 | parentdir = os.path.dirname(parentdir) 16 | sys.path.insert(0, parentdir) 17 | 18 | from preprocessing.patch_extraction.src.cli import MacenkoParser 19 | from preprocessing.patch_extraction.src.patch_extraction import PreProcessor 20 | 21 | if __name__ == "__main__": 22 | configuration_parser = MacenkoParser() 23 | configuration, logger = configuration_parser.get_config() 24 | 25 | slide_processor = PreProcessor(slide_processor_config=configuration) 26 | slide_processor.save_normalization_vector( 27 | wsi_file=configuration.wsi_paths, save_json_path=configuration.save_json_path 28 | ) 29 | 30 | logger.info("Finished Macenko Vector Calculation!") 31 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | absl-py==2.1.0 2 | addict==2.4.0 3 | albumentations==1.4.1 4 | aliyun-python-sdk-core==2.15.0 5 | aliyun-python-sdk-kms==2.16.2 6 | annotated-types==0.6.0 7 | apex 8 | appdirs==1.4.4 9 | Brotli 10 | cached-property==1.5.2 11 | certifi==2022.12.7 12 | cffi==1.16.0 13 | charset-normalizer==2.1.1 14 | chex==0.1.7 15 | click==8.1.7 16 | colorama==0.4.6 17 | contextlib2==21.6.0 18 | contourpy==1.1.1 19 | crcmod==1.7 20 | cryptography==42.0.5 21 | cycler==0.12.1 22 | depthwise-conv2d-implicit-gemm==0.0.0 23 | dill==0.3.8 24 | dm-tree==0.1.8 25 | docker-pycreds==0.4.0 26 | einops==0.7.0 27 | etils==1.3.0 28 | filelock==3.9.0 29 | flax==0.7.2 30 | fonttools==4.49.0 31 | fsspec==2024.2.0 32 | gitdb==4.0.11 33 | GitPython==3.1.42 34 | huggingface-hub==0.21.4 35 | idna==3.4 36 | imageio==2.34.0 37 | importlib_metadata==7.0.2 38 | importlib_resources==6.1.3 39 | jax==0.4.13 40 | jaxlib==0.4.13 41 | Jinja2==3.1.2 42 | jmespath==0.10.0 43 | joblib==1.3.2 44 | kiwisolver==1.4.5 45 | lazy_loader==0.3 46 | lightning-utilities==0.10.1 47 | llvmlite==0.41.1 48 | Markdown==3.5.2 49 | markdown-it-py==3.0.0 50 | MarkupSafe==2.1.3 51 | matplotlib==3.7.5 52 | mdurl==0.1.2 53 | ml-dtypes==0.2.0 54 | mmcv==2.1.0 55 | mmcv-full==1.7.2 56 | mmdet==3.3.0 57 | mmengine==0.10.3 58 | mmsegmentation==1.2.2 59 | model-index==0.1.11 60 | monai==1.3.0 61 | mpmath==1.3.0 62 | msgpack==1.0.8 63 | munch==4.0.0 64 | natsort==8.4.0 65 | nest-asyncio==1.6.0 66 | networkx==3.0 67 | ninja==1.11.1.1 68 | numba==0.58.1 69 | numpy 70 | opencv-python==4.9.0.80 71 | opencv-python-headless==4.9.0.80 72 | opendatalab==0.0.10 73 | openmim==0.3.9 74 | openxlab==0.0.35 75 | opt-einsum==3.3.0 76 | optax==0.1.8 77 | orbax-checkpoint==0.2.3 78 | ordered-set==4.1.0 79 | oss2==2.17.0 80 | packaging 81 | pandarallel==1.6.5 82 | pandas==2.0.3 83 | pillow==10.2.0 84 | platformdirs 85 | pretrainedmodels==0.7.4 86 | prettytable==3.10.0 87 | protobuf==4.25.3 88 | psutil==5.9.8 89 | pycocotools==2.0.7 90 | pycparser==2.21 91 | pycryptodome==3.20.0 92 | pydantic==2.6.4 93 | pydantic_core==2.16.3 94 | Pygments==2.17.2 95 | pyparsing==3.1.2 96 | PySocks 97 | python-dateutil==2.9.0.post0 98 | pytz==2023.4 99 | PyWavelets==1.4.1 100 | PyYAML 101 | requests==2.28.2 102 | rich==13.4.2 103 | safetensors==0.4.2 104 | schema==0.7.5 105 | scikit-image==0.21.0 106 | scikit-learn==1.3.2 107 | scipy 108 | sentry-sdk==1.41.0 109 | setproctitle==1.3.3 110 | shapely==2.0.3 111 | six==1.16.0 112 | smmap==5.0.1 113 | sympy==1.12 114 | tabulate==0.9.0 115 | tensorstore==0.1.45 116 | termcolor 117 | terminaltables==3.1.10 118 | thop==0.1.1.post2209072238 119 | threadpoolctl==3.3.0 120 | tifffile==2023.7.10 121 | timm==0.9.16 122 | tomli==2.0.1 123 | toolz==0.12.1 124 | torch==2.1.2+cu118 125 | torchaudio==2.1.2+cu118 126 | torchinfo==1.8.0 127 | torchmetrics==1.3.1 128 | torchstat==0.0.7 129 | torchsummary==1.5.1 130 | torchvision==0.16.2+cu118 131 | tqdm==4.65.2 132 | triton==2.1.0 133 | typing_extensions==4.10.0 134 | tzdata==2024.1 135 | ujson==5.9.0 136 | urllib3==1.26.13 137 | wandb==0.16.4 138 | wcwidth==0.2.13 139 | yacs 140 | yapf==0.40.2 141 | zipp==3.17.0 -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import logging 3 | 4 | logger = logging.getLogger("__main__") 5 | logger.addHandler(logging.NullHandler()) 6 | -------------------------------------------------------------------------------- /utils/file_handling.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from pathlib import Path 3 | from typing import List, Union 4 | import pandas as pd 5 | 6 | 7 | def load_wsi_files_from_csv(csv_path: Union[Path, str], wsi_extension: str) -> List: 8 | """Load filenames from csv file with column name "Filename" 9 | 10 | Args: 11 | csv_path (Union[Path, str]): Path to csv file 12 | wsi_extension (str): WSI file ending (suffix) 13 | 14 | Returns: 15 | List: _description_ 16 | """ 17 | wsi_filelist = pd.read_csv(csv_path) 18 | wsi_filelist = wsi_filelist["Filename"].to_list() 19 | wsi_filelist = [f for f in wsi_filelist if Path(f).suffix == f".{wsi_extension}"] 20 | 21 | return wsi_filelist 22 | -------------------------------------------------------------------------------- /utils/logger.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Logging Class 3 | # 4 | 5 | import datetime 6 | from typing import Literal, Union 7 | from pathlib import Path 8 | import logging 9 | import logging.handlers 10 | import os 11 | import sys 12 | 13 | 14 | class Logger: 15 | """Initialize a Logger for sys-logging and RotatingFileHandler-logging by using python logging module. 16 | The logger can be used out of the box without any changes, but is also adaptable for specific use cases. 17 | In basic configuration, just the log level must be provided. If log_dir is provided, another handler object is created 18 | logging into a file into the log_dir directory. The filename can be changes by using comment, which basically is the filename. 19 | To create different log files with specific timestamp set 'use_timestamp' = True. This adds an additional timestamp to the filename. 20 | 21 | Args: 22 | level (Literal["CRITICAL", "ERROR", "WARNING", "INFO", "DEBUG"]): Logger.level 23 | log_dir (Union[Path, str], optional): Path to save logfile in. Defaults to None. 24 | comment (str, optional): additional comment for save file. Defaults to 'logs'. 25 | formatter (str, optional): Custom formatter. Defaults to None. 26 | use_timestamp (bool, optional): Using timestamp for time-logging. Defaults to False. 27 | file_level (Literal["CRITICAL", "ERROR", "WARNING", "INFO", "DEBUG"], optional): Set Logger.level. for output file. 28 | Can be useful if a different logging level should be used for terminal output and logging file. 29 | If no level is selected, file level logging is the same as for console. Defaults to None. 30 | """ 31 | 32 | def __init__( 33 | self, 34 | level: Literal["CRITICAL", "ERROR", "WARNING", "INFO", "DEBUG"], 35 | log_dir: Union[Path, str] = None, 36 | comment: str = "logs", 37 | formatter: str = None, 38 | use_timestamp: bool = False, 39 | file_level: Literal["CRITICAL", "ERROR", "WARNING", "INFO", "DEBUG"] = None, 40 | ) -> None: 41 | self.level = level 42 | self.comment = comment 43 | self.log_parent_dir = log_dir 44 | self.use_timestamp = use_timestamp 45 | if formatter is None: 46 | self.formatter = "%(asctime)s [%(levelname)s] - %(message)s" 47 | else: 48 | self.formatter = formatter 49 | if file_level is None: 50 | self.file_level = level 51 | else: 52 | self.file_level = file_level 53 | 54 | def create_handler(self, logger: logging.Logger) -> None: 55 | """Create logging handler for sys output and rotating files in parent_dir. 56 | 57 | Args: 58 | logger (logging.Logger): The Logger 59 | """ 60 | log_handlers = {"StreamHandler": logging.StreamHandler(stream=sys.stdout)} 61 | fh_formatter = logging.Formatter(f"{self.formatter}") 62 | log_handlers["StreamHandler"].setLevel(self.level) 63 | 64 | if self.log_parent_dir is not None: 65 | log_parent_dir = Path(self.log_parent_dir) 66 | if self.use_timestamp: 67 | log_name = f'{datetime.datetime.now().strftime("%Y-%m-%dT%H%M%S")}_{self.comment}.log' 68 | else: 69 | log_name = f"{self.comment}.log" 70 | log_parent_dir.mkdir(parents=True, exist_ok=True) 71 | 72 | should_roll_over = os.path.isfile(log_parent_dir / log_name) 73 | 74 | log_handlers["FileHandler"] = logging.handlers.RotatingFileHandler( 75 | log_parent_dir / log_name, backupCount=5 76 | ) 77 | 78 | if should_roll_over: # log already exists, roll over! 79 | log_handlers["FileHandler"].doRollover() 80 | log_handlers["FileHandler"].setLevel(self.file_level) 81 | 82 | for handler in log_handlers.values(): 83 | handler.setFormatter(fh_formatter) 84 | logger.addHandler(handler) 85 | 86 | def create_logger(self) -> logging.Logger: 87 | """Create the logger 88 | 89 | Returns: 90 | Logger: The logger to be used. 91 | """ 92 | logger = logging.getLogger("__main__") 93 | logger.addHandler(logging.NullHandler()) 94 | 95 | logger.setLevel( 96 | "DEBUG" 97 | ) # set to debug because each handler level must be equal or lower 98 | self.create_handler(logger) 99 | 100 | return logger 101 | -------------------------------------------------------------------------------- /utils/tools.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Utility functions 3 | 4 | 5 | 6 | import importlib 7 | import logging 8 | import sys 9 | 10 | import types 11 | from datetime import timedelta 12 | from timeit import default_timer as timer 13 | from typing import Dict, List, Optional, Tuple, Union 14 | 15 | from utils.__init__ import logger 16 | 17 | 18 | # Helper timing functions 19 | def start_timer() -> float: 20 | """Returns the number of seconds passed since epoch. The epoch is the point where the time starts, 21 | and is platform dependent. 22 | 23 | Returns: 24 | float: The number of seconds passed since epoch 25 | """ 26 | return timer() 27 | 28 | 29 | def end_timer(start_time: float, timed_event: str = "Time usage") -> None: 30 | """Prints the time passed from start_time. 31 | 32 | 33 | Args: 34 | start_time (float): The number of seconds passed since epoch when the timer started 35 | timed_event (str, optional): A string describing the activity being monitored. Defaults to "Time usage". 36 | """ 37 | logger.info(f"{timed_event}: {timedelta(seconds=timer() - start_time)}") 38 | 39 | 40 | def module_exists( 41 | *names: Union[List[str], str], 42 | error: str = "ignore", 43 | warn_every_time: bool = False, 44 | __INSTALLED_OPTIONAL_MODULES: Dict[str, bool] = {}, 45 | ) -> Optional[Union[Tuple[types.ModuleType, ...], types.ModuleType]]: 46 | """Try to import optional dependencies. 47 | Ref: https://stackoverflow.com/a/73838546/4900327 48 | 49 | Args: 50 | names (Union(List(str), str)): The module name(s) to import. Str or list of strings. 51 | error (str, optional): What to do when a dependency is not found: 52 | * raise : Raise an ImportError. 53 | * warn: print a warning. 54 | * ignore: If any module is not installed, return None, otherwise, return the module(s). 55 | Defaults to "ignore". 56 | warn_every_time (bool, optional): Whether to warn every time an import is tried. Only applies when error="warn". 57 | Setting this to True will result in multiple warnings if you try to import the same library multiple times. 58 | Defaults to False. 59 | Raises: 60 | ImportError: ImportError of Module 61 | 62 | Returns: 63 | Optional[ModuleType, Tuple[ModuleType...]]: The imported module(s), if all are found. 64 | None is returned if any module is not found and `error!="raise"`. 65 | """ 66 | assert error in {"raise", "warn", "ignore"} 67 | if isinstance(names, (list, tuple, set)): 68 | names: List[str] = list(names) 69 | else: 70 | assert isinstance(names, str) 71 | names: List[str] = [names] 72 | modules = [] 73 | for name in names: 74 | try: 75 | module = importlib.import_module(name) 76 | modules.append(module) 77 | __INSTALLED_OPTIONAL_MODULES[name] = True 78 | except ImportError: 79 | modules.append(None) 80 | 81 | def error_msg(missing: Union[str, List[str]]): 82 | if not isinstance(missing, (list, tuple)): 83 | missing = [missing] 84 | missing_str: str = " ".join([f'"{name}"' for name in missing]) 85 | dep_str = "dependencies" 86 | if len(missing) == 1: 87 | dep_str = "dependency" 88 | msg = f"Missing optional {dep_str} {missing_str}. Use pip or conda to install." 89 | return msg 90 | 91 | missing_modules: List[str] = [ 92 | name for name, module in zip(names, modules) if module is None 93 | ] 94 | if len(missing_modules) > 0: 95 | if error == "raise": 96 | raise ImportError(error_msg(missing_modules)) 97 | if error == "warn": 98 | for name in missing_modules: 99 | # Ensures warning is printed only once 100 | if warn_every_time is True or name not in __INSTALLED_OPTIONAL_MODULES: 101 | logger.warning(f"Warning: {error_msg(name)}") 102 | __INSTALLED_OPTIONAL_MODULES[name] = False 103 | return None 104 | if len(modules) == 1: 105 | return modules[0] 106 | return tuple(modules) 107 | 108 | 109 | def close_logger(logger: logging.Logger) -> None: 110 | """Closing a logger savely 111 | 112 | Args: 113 | logger (logging.Logger): Logger to close 114 | """ 115 | handlers = logger.handlers[:] 116 | for handler in handlers: 117 | logger.removeHandler(handler) 118 | handler.close() 119 | 120 | logger.handlers.clear() 121 | logging.shutdown() 122 | 123 | 124 | class AverageMeter(object): 125 | """Computes and stores the average and current value 126 | 127 | Original-Code: https://github.com/facebookresearch/simsiam 128 | """ 129 | 130 | def __init__(self, name, fmt=":f"): 131 | self.name = name 132 | self.fmt = fmt 133 | self.reset() 134 | 135 | def reset(self): 136 | self.val = 0 137 | self.avg = 0 138 | self.sum = 0 139 | self.count = 0 140 | 141 | def update(self, val, n=1): 142 | self.val = val 143 | self.sum += val * n 144 | self.count += n 145 | self.avg = self.sum / self.count 146 | 147 | def __str__(self): 148 | fmtstr = "{name} {val" + self.fmt + "} ({avg" + self.fmt + "})" 149 | return fmtstr.format(**self.__dict__) 150 | 151 | 152 | def flatten_dict(d: dict, parent_key: str = "", sep: str = ".") -> dict: 153 | """Flatten a nested dictionary and insert the sep to seperate keys 154 | 155 | Args: 156 | d (dict): dict to flatten 157 | parent_key (str, optional): parent key name. Defaults to ''. 158 | sep (str, optional): Seperator. Defaults to '.'. 159 | 160 | Returns: 161 | dict: Flattened dict 162 | """ 163 | items = [] 164 | for k, v in d.items(): 165 | new_key = parent_key + sep + k if parent_key else k 166 | if isinstance(v, dict): 167 | items.extend(flatten_dict(v, new_key, sep=sep).items()) 168 | else: 169 | items.append((new_key, v)) 170 | return dict(items) 171 | 172 | 173 | def unflatten_dict(d: dict, sep: str = ".") -> dict: 174 | """Unflatten a flattened dictionary (created a nested dictionary) 175 | 176 | Args: 177 | d (dict): Dict to be nested 178 | sep (str, optional): Seperator of flattened keys. Defaults to '.'. 179 | 180 | Returns: 181 | dict: Nested dict 182 | """ 183 | output_dict = {} 184 | for key, value in d.items(): 185 | keys = key.split(sep) 186 | d = output_dict 187 | for k in keys[:-1]: 188 | d = d.setdefault(k, {}) 189 | d[keys[-1]] = value 190 | 191 | return output_dict 192 | 193 | 194 | def remove_parameter_tag(d: dict, sep: str = ".") -> dict: 195 | """Remove all paramter tags from dictionary 196 | 197 | Args: 198 | d (dict): Dict must be flattened with defined seperator 199 | sep (str, optional): Seperator used during flattening. Defaults to ".". 200 | 201 | Returns: 202 | dict: Dict with parameter tag removed 203 | """ 204 | param_dict = {} 205 | for k, _ in d.items(): 206 | unflattened_keys = k.split(sep) 207 | new_keys = [] 208 | max_num_insert = len(unflattened_keys) - 1 209 | for i, k in enumerate(unflattened_keys): 210 | if i < max_num_insert and k != "parameters": 211 | new_keys.append(k) 212 | joined_key = sep.join(new_keys) 213 | param_dict[joined_key] = {} 214 | print(param_dict) 215 | for k, v in d.items(): 216 | unflattened_keys = k.split(sep) 217 | new_keys = [] 218 | max_num_insert = len(unflattened_keys) - 1 219 | for i, k in enumerate(unflattened_keys): 220 | if i < max_num_insert and k != "parameters": 221 | new_keys.append(k) 222 | joined_key = sep.join(new_keys) 223 | param_dict[joined_key][unflattened_keys[-1]] = v 224 | 225 | return param_dict 226 | 227 | def get_size_of_dict(d: dict) -> int: 228 | size = sys.getsizeof(d) 229 | for key, value in d.items(): 230 | size += sys.getsizeof(key) 231 | size += sys.getsizeof(value) 232 | return size 233 | --------------------------------------------------------------------------------