4 |
5 |
6 |
7 | [](https://arxiv.org/abs/2407.18054)
8 | [](https://www.python.org/downloads/release/python-360/)
9 | [](LICENSE)
10 | [](https://github.com/hustvl)
11 | [](https://huggingface.co/spaces/xiazhi/LKCell)
12 |
13 | [](https://paperswithcode.com/sota/panoptic-segmentation-on-pannuke?p=lkcell-efficient-cell-nuclei-instance)
14 |
15 | [Ziwei Cui](https://github.com/ziwei-cui)
School of Electronic Information and Communications, Huazhong University of Science and Technology \
18 |
Department of Cardiology, Huanggang Central Hospital
19 |
20 | (\* equal contribution, 📧 corresponding author)
21 |
22 | [Key Features](#key-features) • [Installation](#installation) • [Usage](#usage) • [Training](#training) • [Inference](#inference) • [Citation](#Citation)
23 |
24 |
25 |
26 | ---
27 |
28 |
29 |
30 | ## Key Features
31 |
32 | **Click and try LKCell on our [🤗 Hugging Face Space](https://huggingface.co/spaces/xiazhi/LKCell)!**
33 |
34 | This repository contains the code implementation of LKCell, a deep learning-based method for automated instance segmentation of cell nuclei in digitized tissue samples. LKCell utilizes an architecture based on large convolutional kernels and achieves state-of-the-art performance on the [PanNuke](https://warwick.ac.uk/fac/cross_fac/tia/data/pannuke) dataset, a challenging nuclei instance segmentation benchmark.
35 |
36 |
37 |
38 |
39 |
40 |
41 |
42 |
43 |
44 |
45 |
46 | ## Installation
47 |
48 | ```
49 | git clone https://github.com/hustvl/LKCell.git
50 | conda create -n lkcell
51 | conda activate lkcell
52 | pip install -r requirements.txt
53 | ```
54 |
55 | Note: (1) preferred torch version is 2.0; (2) If you find problem in installing `depthwise-conv2d-implicit-gemm==0.0.0` , please follow the instruction in [here](https://github.com/AILab-CVC/UniRepLKNet#use-an-efficient-large-kernel-convolution-with-pytorch).
56 |
57 |
65 |
66 |
67 |
68 | ## Project Structure
69 |
70 | We are currently using the following folder structure:
71 |
72 | ```bash
73 | ├── base_ml # Basic Machine Learning Code: Trainer, Experiment, ...
74 | ├── cell_segmentation # Cell Segmentation training and inference files
75 | │ ├── datasets # Datasets (PyTorch)
76 | │ ├── experiments # Specific Experiment Code for different experiments
77 | │ ├── inference # Inference code for experiment statistics and plots
78 | │ ├── trainer # Trainer functions to train networks
79 | │ ├── utils # Utils code
80 | │ └── run_cellvit.py # Run file to start an experiment
81 | ├── config # Python configuration file for global Python settings
82 | ├── docs # Documentation files (in addition to this main README.md
83 | ├── models # Machine Learning Models (PyTorch implementations)
84 | │ └── segmentation # LKCell Code
85 | ├── datamodel # Code of dataclass :Graph Data , WSI object , ...
86 | ├── preprocessing # Code of preprocessing : Encoding , Patch Extraction , ...
87 | ```
88 |
89 |
90 | ## Training
91 |
92 | ### Dataset preparation
93 | We use a customized dataset structure for the PanNuke and the MoNuSeg dataset.
94 | The dataset structures are explained in [pannuke.md](docs/readmes/pannuke.md) and [monuseg.md](docs/readmes/monuseg.md) documentation files.
95 | We also provide preparation scripts in the [`cell_segmentation/datasets/`](cell_segmentation/datasets/) folder.
96 |
97 | ### Training script
98 | The CLI for a ML-experiment to train the LKCell-Network is as follows (here the [```run_cellvit.py```](cell_segmentation/run_cellvit.py) script is used):
99 | ```bash
100 | usage: run_cellvit.py [-h] --config CONFIG [--gpu GPU] [--sweep | --agent AGENT | --checkpoint CHECKPOINT]
101 | Start an experiment with given configuration file.
102 |
103 | python ./cell_segmentation/run_cellvit.py --config ./config.yaml
104 | ```
105 |
106 | The important file is the configuration file, in which all paths are set, the model configuration is given and the hyperparameters or sweeps are defined.
107 |
108 |
109 |
110 | **Pre-trained UnirepLKNet models** for training initialization can be downloaded from Google Drive: [UnirepLKNet-Models](https://drive.google.com/drive/folders/1pqjCBZIv4WwEsE5raUPz5AUM7I-UPtMJ).
111 |
112 |
113 | ### Evaluation
114 | In our paper, we did not (!) use early stopping, but rather train all models for 100 to eliminate selection bias but have the largest possible database for training. Therefore, evaluation neeeds to be performed with the `latest_checkpoint.pth` model and not the best early stopping model.
115 | We provide a script to create evaluation results: [`inference_cellvit_experiment.py`](cell_segmentation/inference/inference_cellvit_experiment.py) for PanNuke and [`inference_cellvit_monuseg.py`](cell_segmentation/inference/inference_cellvit_monuseg.py) for MoNuSeg.
116 |
117 | ### Inference
118 |
119 | Model checkpoints can be downloaded here, You can choose to download from Google Drive or HuggingFace :
120 |
121 | - Google Drive
122 | - [LKCell-L](https://drive.google.com/drive/folders/1r4vCwcyHgLtMJkr2rhFLox6SDldB2p7F?usp=drive_link) 🚀
123 | - [LKCell-B](https://drive.google.com/drive/folders/1i7SrHloSsGZSbesDZ9hBxbOG4RnaPhQU?usp=drive_link)
124 | - HuggingFace
125 | - [LKCell-L](https://huggingface.co/xiazhi/LKCell-L) 🚀
126 | - [LKCell-B](https://huggingface.co/xiazhi/LKCell-B)
127 |
128 |
129 | You can click [🤗 Hugging Face Space](https://huggingface.co/spaces/xiazhi/LKCell) to quickly perform model inference.
130 |
131 |
132 | ## Acknowledgement
133 |
134 | This project is built upon [CellViT](https://github.com/TIO-IKIM/CellViT) and [UniRepLKNet](https://github.com/AILab-CVC/UniRepLKNet). Thanks for these awesome repos!
135 |
136 | ## Citation
137 | ```latex
138 | @misc{cui2024lkcellefficientcellnuclei,
139 | title={LKCell: Efficient Cell Nuclei Instance Segmentation with Large Convolution Kernels},
140 | author={Ziwei Cui and Jingfeng Yao and Lunbin Zeng and Juan Yang and Wenyu Liu and Xinggang Wang},
141 | year={2024},
142 | eprint={2407.18054},
143 | archivePrefix={arXiv},
144 | primaryClass={eess.IV},
145 | url={https://arxiv.org/abs/2407.18054},
146 | }
147 | ```
148 |
--------------------------------------------------------------------------------
/base_ml/base_cli.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # Base CLI to parse Arguments
3 |
4 | import argparse
5 | import logging
6 | from abc import ABC, abstractmethod
7 | from typing import Tuple, Union
8 |
9 | import yaml
10 | from pydantic import BaseModel
11 |
12 |
13 | class ABCParser(ABC):
14 | """Blueprint for Argument Parser"""
15 |
16 | @abstractmethod
17 | def __init__(self) -> None:
18 | pass
19 |
20 | @abstractmethod
21 | def get_config(self) -> Tuple[Union[BaseModel, dict], logging.Logger]:
22 | """Load configuration and create a logger
23 |
24 | Returns:
25 | Tuple[PreProcessingConfig, logging.Logger]: Configuration and Logger
26 | """
27 | pass
28 |
29 | @abstractmethod
30 | def store_config(self) -> None:
31 | """Store the config file in the logging directory to keep track of the configuration."""
32 | pass
33 |
34 |
35 | class ExperimentBaseParser:
36 | """Configuration Parser for Machine Learning Experiments"""
37 |
38 | def __init__(self) -> None:
39 | parser = argparse.ArgumentParser(
40 | formatter_class=argparse.ArgumentDefaultsHelpFormatter,
41 | description="Start an experiment with given configuration file.",
42 | )
43 | requiredNamed = parser.add_argument_group("required named arguments")
44 | requiredNamed.add_argument(
45 | "--config", type=str, help="Path to a config file", required=True
46 | )
47 | parser.add_argument("--gpu", type=int, help="Cuda-GPU ID")
48 | group = parser.add_mutually_exclusive_group(required=False)
49 | group.add_argument(
50 | "--sweep",
51 | action="store_true",
52 | help="Starting a sweep. For this the configuration file must be structured according to WandB sweeping. "
53 | "Compare https://docs.wandb.ai/guides/sweeps and https://community.wandb.ai/t/nested-sweep-configuration/3369/3 "
54 | "for further information. This parameter cannot be set in the config file!",
55 | )
56 | group.add_argument(
57 | "--agent",
58 | type=str,
59 | help="Add a new agent to the sweep. "
60 | "Please pass the sweep ID as argument in the way entity/project/sweep_id, e.g., user1/test_project/v4hwbijh. "
61 | "The agent configuration can be found in the WandB dashboard for the running sweep in the sweep overview tab "
62 | "under launch agent. Just paste the entity/project/sweep_id given there. The provided config file must be a sweep config file."
63 | "This parameter cannot be set in the config file!",
64 | )
65 | group.add_argument(
66 | "--checkpoint",
67 | type=str,
68 | help="Path to a PyTorch checkpoint file. "
69 | "The file is loaded and continued to train with the provided settings. "
70 | "If this is passed, no sweeps are possible. "
71 | "This parameter cannot be set in the config file!",
72 | )
73 |
74 | self.parser = parser
75 |
76 | def parse_arguments(self) -> Tuple[Union[BaseModel, dict]]:
77 | """Parse the arguments from CLI and load yaml config
78 |
79 | Returns:
80 | Tuple[Union[BaseModel, dict]]: Parsed arguments
81 | """
82 | # parse the arguments
83 | opt = self.parser.parse_args()
84 | with open(opt.config, "r") as config_file:
85 | yaml_config = yaml.safe_load(config_file)
86 | yaml_config_dict = dict(yaml_config)
87 |
88 | opt_dict = vars(opt)
89 | # check for gpu to overwrite with cli argument
90 | if "gpu" in opt_dict:
91 | if opt_dict["gpu"] is not None:
92 | yaml_config_dict["gpu"] = opt_dict["gpu"]
93 |
94 | # check if either training, sweep, checkpoint or start agent should be called
95 | # first step: remove such keys from the config file
96 | if "run_sweep" in yaml_config_dict:
97 | yaml_config_dict.pop("run_sweep")
98 | if "agent" in yaml_config_dict:
99 | yaml_config_dict.pop("agent")
100 | if "checkpoint" in yaml_config_dict:
101 | yaml_config_dict.pop("checkpoint")
102 |
103 | # select one of the options
104 | if "sweep" in opt_dict and opt_dict["sweep"] is True:
105 | yaml_config_dict["run_sweep"] = True
106 | else:
107 | yaml_config_dict["run_sweep"] = False
108 | if "agent" in opt_dict:
109 | yaml_config_dict["agent"] = opt_dict["agent"]
110 | if "checkpoint" in opt_dict:
111 | if opt_dict["checkpoint"] is not None:
112 | yaml_config_dict["checkpoint"] = opt_dict["checkpoint"]
113 |
114 | self.config = yaml_config_dict #将yaml_config_dict赋给self.config
115 |
116 | return self.config
117 |
--------------------------------------------------------------------------------
/base_ml/base_early_stopping.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # Base Machine Learning Experiment
3 |
4 | import logging
5 |
6 | logger = logging.getLogger("__main__")
7 | logger.addHandler(logging.NullHandler())
8 |
9 | import wandb
10 |
11 |
12 | class EarlyStopping:
13 | """Early Stopping Class
14 |
15 | Args:
16 | patience (int): Patience to wait before stopping
17 | strategy (str, optional): Optimization strategy.
18 | Please select 'minimize' or 'maximize' for strategy. Defaults to "minimize".
19 | """
20 |
21 | def __init__(self, patience: int, strategy: str = "minimize"):
22 | assert strategy.lower() in [
23 | "minimize",
24 | "maximize",
25 | ], "Please select 'minimize' or 'maximize' for strategy"
26 |
27 | self.patience = patience
28 | self.counter = 0
29 | self.strategy = strategy.lower()
30 | self.best_metric = None
31 | self.best_epoch = None
32 | self.early_stop = False
33 |
34 | logger.info(
35 | f"Using early stopping with a range of {self.patience} and {self.strategy} strategy"
36 | )
37 |
38 | def __call__(self, metric: float, epoch: int) -> bool:
39 | """Early stopping update call
40 |
41 | Args:
42 | metric (float): Metric for early stopping
43 | epoch (int): Current epoch
44 |
45 | Returns:
46 | bool: Returns true if the model is performing better than the current best model,
47 | otherwise false
48 | """
49 | if self.best_metric is None:
50 | self.best_metric = metric
51 | self.best_epoch = epoch
52 | return True
53 | else:
54 | if self.strategy == "minimize":
55 | if self.best_metric >= metric:
56 | self.best_metric = metric
57 | self.best_epoch = epoch
58 | self.counter = 0
59 | wandb.run.summary["Best-Epoch"] = epoch
60 | wandb.run.summary["Best-Metric"] = metric
61 | return True
62 | else:
63 | self.counter += 1
64 | if self.counter >= self.patience:
65 | self.early_stop = True
66 | return False
67 | elif self.strategy == "maximize":
68 | if self.best_metric <= metric:
69 | self.best_metric = metric
70 | self.best_epoch = epoch
71 | self.counter = 0
72 | wandb.run.summary["Best-Epoch"] = epoch
73 | wandb.run.summary["Best-Metric"] = metric
74 | return True
75 | else:
76 | self.counter += 1
77 | if self.counter >= self.patience:
78 | self.early_stop = True
79 | return False
80 |
--------------------------------------------------------------------------------
/base_ml/base_optim.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # Wrappping all available PyTorch Optimizer
3 |
4 |
5 | from torch.optim import (
6 | ASGD,
7 | LBFGS,
8 | SGD,
9 | Adadelta,
10 | Adagrad,
11 | Adam,
12 | Adamax,
13 | AdamW,
14 | RAdam,
15 | RMSprop,
16 | Rprop,
17 | SparseAdam,
18 | )
19 |
20 | OPTI_DICT = {
21 | "Adadelta": Adadelta,
22 | "Adagrad": Adagrad,
23 | "Adam": Adam,
24 | "AdamW": AdamW,
25 | "SparseAdam": SparseAdam,
26 | "Adamax": Adamax,
27 | "ASGD": ASGD,
28 | "LBFGS": LBFGS,
29 | "RAdam": RAdam,
30 | "RMSprop": RMSprop,
31 | "Rprop": Rprop,
32 | "SGD": SGD,
33 | }
34 |
--------------------------------------------------------------------------------
/base_ml/base_trainer.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # Base Trainer Class
3 |
4 |
5 | import logging
6 | from abc import abstractmethod
7 | from typing import Tuple, Union
8 |
9 | import torch
10 | import torch.nn as nn
11 | import wandb
12 | from base_ml.base_early_stopping import EarlyStopping
13 | from pathlib import Path
14 | from torch.nn.modules.loss import _Loss
15 | from torch.optim import Optimizer
16 | from torch.optim.lr_scheduler import _LRScheduler
17 | from torch.utils.data import DataLoader
18 | from utils.tools import flatten_dict
19 |
20 |
21 | class BaseTrainer:
22 | """
23 | Base class for all trainers with important ML components
24 |
25 | Args:
26 | model (nn.Module): Model that should be trained
27 | loss_fn (_Loss): Loss function
28 | optimizer (Optimizer): Optimizer
29 | scheduler (_LRScheduler): Learning rate scheduler
30 | device (str): Cuda device to use, e.g., cuda:0.
31 | logger (logging.Logger): Logger module
32 | logdir (Union[Path, str]): Logging directory
33 | experiment_config (dict): Configuration of this experiment
34 | early_stopping (EarlyStopping, optional): Early Stopping Class. Defaults to None.
35 | accum_iter (int, optional): Accumulation steps for gradient accumulation.
36 | Provide a number greater than 1 for activating gradient accumulation. Defaults to 1.
37 | mixed_precision (bool, optional): If mixed-precision should be used. Defaults to False.
38 | log_images (bool, optional): If images should be logged to WandB. Defaults to False.
39 | """
40 |
41 | def __init__(
42 | self,
43 | model: nn.Module,
44 | loss_fn: _Loss,
45 | optimizer: Optimizer,
46 | scheduler: _LRScheduler,
47 | device: str,
48 | logger: logging.Logger,
49 | logdir: Union[Path, str],
50 | experiment_config: dict,
51 | early_stopping: EarlyStopping = None,
52 | accum_iter: int = 1,
53 | mixed_precision: bool = False,
54 | log_images: bool = False,
55 | #model_ema: bool = True,
56 | ) -> None:
57 | self.model = model
58 |
59 | self.loss_fn = loss_fn
60 | self.optimizer = optimizer
61 | self.scheduler = scheduler
62 | self.device = device
63 | self.logger = logger
64 | self.logdir = Path(logdir)
65 | self.early_stopping = early_stopping
66 | self.accum_iter = accum_iter
67 | self.start_epoch = 0
68 | self.experiment_config = experiment_config
69 | self.log_images = log_images
70 | self.mixed_precision = mixed_precision
71 | if self.mixed_precision:
72 | self.scaler = torch.cuda.amp.GradScaler(enabled=True)
73 | else:
74 | self.scaler = None
75 |
76 | @abstractmethod
77 | def train_epoch(
78 | self, epoch: int, train_loader: DataLoader, **kwargs
79 | ) -> Tuple[dict, dict]:
80 | """Training logic for a training epoch
81 |
82 | Args:
83 | epoch (int): Current epoch number
84 | train_loader (DataLoader): Train dataloader
85 |
86 | Raises:
87 | NotImplementedError: Needs to be implemented
88 |
89 | Returns:
90 | Tuple[dict, dict]: wandb logging dictionaries
91 | * Scalar metrics
92 | * Image metrics
93 | """
94 | raise NotImplementedError
95 |
96 | @abstractmethod
97 | def validation_epoch(
98 | self, epoch: int, val_dataloader: DataLoader
99 | ) -> Tuple[dict, dict, float]:
100 | """Training logic for an validation epoch
101 |
102 | Args:
103 | epoch (int): Current epoch number
104 | val_dataloader (DataLoader): Validation dataloader
105 |
106 | Raises:
107 | NotImplementedError: Needs to be implemented
108 |
109 | Returns:
110 | Tuple[dict, dict, float]: wandb logging dictionaries and early_stopping_metric
111 | * Scalar metrics
112 | * Image metrics
113 | * Early Stopping metric as float
114 | """
115 | raise NotImplementedError
116 |
117 | @abstractmethod
118 | def train_step(self, batch: object, batch_idx: int, num_batches: int):
119 | """Training logic for one training batch
120 |
121 | Args:
122 | batch (object): A training batch
123 | batch_idx (int): Current batch index
124 | num_batches (int): Maximum number of batches
125 |
126 | Raises:
127 | NotImplementedError: Needs to be implemented
128 | """
129 |
130 | raise NotImplementedError
131 |
132 | @abstractmethod
133 | def validation_step(self, batch, batch_idx: int):
134 | """Training logic for one validation batch
135 |
136 | Args:
137 | batch (object): A training batch
138 | batch_idx (int): Current batch index
139 |
140 | Raises:
141 | NotImplementedError: Needs to be implemented
142 | """
143 |
144 | def fit(
145 | self,
146 | epochs: int,
147 | train_dataloader: DataLoader,
148 | val_dataloader: DataLoader,
149 | metric_init: dict = None,
150 | eval_every: int = 1,
151 | **kwargs,
152 | ):
153 | """Fitting function to start training and validation of the trainer
154 |
155 | Args:
156 | epochs (int): Number of epochs the network should be training
157 | train_dataloader (DataLoader): Dataloader with training data
158 | val_dataloader (DataLoader): Dataloader with validation data
159 | metric_init (dict, optional): Initialization dictionary with scalar metrics that should be initialized for startup.
160 | This is just import for logging with wandb if you want to have the plots properly scaled.
161 | The data in the the metric dictionary is used as values for epoch 0 (before training has startetd).
162 | If not provided, step 0 (epoch 0) is not logged. Should have the same scalar keys as training and validation epochs report.
163 | For more information, you should have a look into the train_epoch and val_epoch methods where the wandb logging dicts are assembled.
164 | Defaults to None.
165 | eval_every (int, optional): How often the network should be evaluated (after how many epochs). Defaults to 1.
166 | **kwargs
167 | """
168 |
169 | self.logger.info(f"Starting training, total number of epochs: {epochs}")
170 | if metric_init is not None and self.start_epoch == 0:
171 | wandb.log(metric_init, step=0)
172 | for epoch in range(self.start_epoch, epochs):
173 | # training epoch
174 | #train_sampler.set_epoch(epoch) # for distributed training
175 | self.logger.info(f"Epoch: {epoch+1}/{epochs}")
176 | train_scalar_metrics, train_image_metrics = self.train_epoch(
177 | epoch, train_dataloader, **kwargs
178 | )
179 | wandb.log(train_scalar_metrics, step=epoch + 1)
180 | if self.log_images:
181 | wandb.log(train_image_metrics, step=epoch + 1)
182 | if epoch >=95 and ((epoch + 1)) % eval_every == 0:
183 | # validation epoch
184 | (
185 | val_scalar_metrics,
186 | val_image_metrics,
187 | early_stopping_metric,
188 | ) = self.validation_epoch(epoch, val_dataloader)
189 | wandb.log(val_scalar_metrics, step=epoch + 1)
190 | if self.log_images:
191 | wandb.log(val_image_metrics, step=epoch + 1)
192 |
193 | #self.save_checkpoint(epoch, f"checkpoint_{epoch}.pth")
194 |
195 | # log learning rate
196 | curr_lr = self.optimizer.param_groups[0]["lr"]
197 | wandb.log(
198 | {
199 | "Learning-Rate/Learning-Rate": curr_lr,
200 | },
201 | step=epoch + 1,
202 | )
203 | if epoch >=95 and ((epoch + 1)) % eval_every == 0:
204 | # early stopping
205 | if self.early_stopping is not None:
206 | best_model = self.early_stopping(early_stopping_metric, epoch)
207 | if best_model:
208 | self.logger.info("New best model - save checkpoint")
209 | self.save_checkpoint(epoch, "model_best.pth")
210 | elif self.early_stopping.early_stop:
211 | self.logger.info("Performing early stopping!")
212 | break
213 | self.save_checkpoint(epoch, "latest_checkpoint.pth")
214 |
215 | # scheduling
216 | if type(self.scheduler) == torch.optim.lr_scheduler.ReduceLROnPlateau:
217 | self.scheduler.step(float(val_scalar_metrics["Loss/Validation"]))
218 | else:
219 | self.scheduler.step()
220 | new_lr = self.optimizer.param_groups[0]["lr"]
221 | self.logger.debug(f"Old lr: {curr_lr:.6f} - New lr: {new_lr:.6f}")
222 |
223 | def save_checkpoint(self, epoch: int, checkpoint_name: str):
224 | if self.early_stopping is None:
225 | best_metric = None
226 | best_epoch = None
227 | else:
228 | best_metric = self.early_stopping.best_metric
229 | best_epoch = self.early_stopping.best_epoch
230 |
231 | arch = type(self.model).__name__
232 | state = {
233 | "arch": arch,
234 | "epoch": epoch,
235 | "model_state_dict": self.model.state_dict(),
236 | "optimizer_state_dict": self.optimizer.state_dict(),
237 | "scheduler_state_dict": self.scheduler.state_dict(),
238 | "best_metric": best_metric,
239 | "best_epoch": best_epoch,
240 | "config": flatten_dict(wandb.config),
241 | "wandb_id": wandb.run.id,
242 | "logdir": str(self.logdir.resolve()),
243 | "run_name": str(Path(self.logdir).name),
244 | "scaler_state_dict": self.scaler.state_dict()
245 | if self.scaler is not None
246 | else None,
247 | }
248 |
249 | checkpoint_dir = self.logdir / "checkpoints"
250 | checkpoint_dir.mkdir(exist_ok=True, parents=True)
251 |
252 | filename = str(checkpoint_dir / checkpoint_name)
253 | torch.save(state, filename)
254 |
255 | def resume_checkpoint(self, checkpoint):
256 | self.logger.info("Loading checkpoint")
257 | self.logger.info("Loading Model")
258 | self.model.load_state_dict(checkpoint["model_state_dict"])
259 | self.logger.info("Loading Optimizer state dict")
260 | self.optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
261 | self.scheduler.load_state_dict(checkpoint["scheduler_state_dict"])
262 |
263 | if self.early_stopping is not None:
264 | self.early_stopping.best_metric = checkpoint["best_metric"]
265 | self.early_stopping.best_epoch = checkpoint["best_epoch"]
266 | if self.scaler is not None:
267 | self.scaler.load_state_dict(checkpoint["scaler_state_dict"])
268 |
269 | self.logger.info(f"Checkpoint epoch: {int(checkpoint['epoch'])}")
270 | self.start_epoch = int(checkpoint["epoch"])
271 | self.logger.info(f"Next epoch is: {self.start_epoch + 1}")
272 |
--------------------------------------------------------------------------------
/base_ml/base_utils.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | import torch
3 | import torch.nn.functional as F
4 |
5 | __all__ = ["filter2D", "gaussian", "gaussian_kernel2d", "sobel_hv"]
6 |
7 |
8 | def filter2D(input_tensor: torch.Tensor, kernel: torch.Tensor) -> torch.Tensor:
9 | """Convolves a given kernel on input tensor without losing dimensional shape.
10 |
11 | Parameters
12 | ----------
13 | input_tensor : torch.Tensor
14 | Input image/tensor.
15 | kernel : torch.Tensor
16 | Convolution kernel/window.
17 |
18 | Returns
19 | -------
20 | torch.Tensor:
21 | The convolved tensor of same shape as the input.
22 | """
23 | (_, channel, _, _) = input_tensor.size()
24 |
25 | # "SAME" padding to avoid losing height and width
26 | pad = [
27 | kernel.size(2) // 2,
28 | kernel.size(2) // 2,
29 | kernel.size(3) // 2,
30 | kernel.size(3) // 2,
31 | ]
32 | pad_tensor = F.pad(input_tensor, pad, "replicate")
33 |
34 | out = F.conv2d(pad_tensor, kernel, groups=channel)
35 | return out
36 |
37 |
38 | def gaussian(
39 | window_size: int, sigma: float, device: torch.device = None
40 | ) -> torch.Tensor:
41 | """Create a gaussian 1D tensor.
42 |
43 | Parameters
44 | ----------
45 | window_size : int
46 | Number of elements for the output tensor.
47 | sigma : float
48 | Std of the gaussian distribution.
49 | device : torch.device
50 | Device for the tensor.
51 |
52 | Returns
53 | -------
54 | torch.Tensor:
55 | A gaussian 1D tensor. Shape: (window_size, ).
56 | """
57 | x = torch.arange(window_size, device=device).float() - window_size // 2
58 | if window_size % 2 == 0:
59 | x = x + 0.5
60 |
61 | gauss = torch.exp((-x.pow(2.0) / float(2 * sigma**2)))
62 |
63 | return gauss / gauss.sum()
64 |
65 |
66 | def gaussian_kernel2d(
67 | window_size: int, sigma: float, n_channels: int = 1, device: torch.device = None
68 | ) -> torch.Tensor:
69 | """Create 2D window_size**2 sized kernel a gaussial kernel.
70 |
71 | Parameters
72 | ----------
73 | window_size : int
74 | Number of rows and columns for the output tensor.
75 | sigma : float
76 | Std of the gaussian distribution.
77 | n_channel : int
78 | Number of channels in the image that will be convolved with
79 | this kernel.
80 | device : torch.device
81 | Device for the kernel.
82 |
83 | Returns:
84 | -----------
85 | torch.Tensor:
86 | A tensor of shape (1, 1, window_size, window_size)
87 | """
88 | kernel_x = gaussian(window_size, sigma, device=device)
89 | kernel_y = gaussian(window_size, sigma, device=device)
90 |
91 | kernel_2d = torch.matmul(kernel_x.unsqueeze(-1), kernel_y.unsqueeze(-1).t())
92 | kernel_2d = kernel_2d.expand(n_channels, 1, window_size, window_size)
93 |
94 | return kernel_2d
95 |
96 |
97 | def sobel_hv(window_size: int = 5, device: torch.device = None):
98 | """Create a kernel that is used to compute 1st order derivatives.
99 |
100 | Parameters
101 | ----------
102 | window_size : int
103 | Size of the convolution kernel.
104 | device : torch.device:
105 | Device for the kernel.
106 |
107 | Returns
108 | -------
109 | torch.Tensor:
110 | the computed 1st order derivatives of the input tensor.
111 | Shape (B, 2, H, W)
112 |
113 | Raises
114 | ------
115 | ValueError:
116 | If `window_size` is not an odd number.
117 | """
118 | if not window_size % 2 == 1:
119 | raise ValueError(f"window_size must be odd. Got: {window_size}")
120 |
121 | # Generate the sobel kernels
122 | range_h = torch.arange(
123 | -window_size // 2 + 1, window_size // 2 + 1, dtype=torch.float32, device=device
124 | )
125 | range_v = torch.arange(
126 | -window_size // 2 + 1, window_size // 2 + 1, dtype=torch.float32, device=device
127 | )
128 | h, v = torch.meshgrid(range_h, range_v)
129 |
130 | kernel_h = h / (h * h + v * v + 1e-6)
131 | kernel_h = kernel_h.unsqueeze(0).unsqueeze(0)
132 |
133 | kernel_v = v / (h * h + v * v + 1e-6)
134 | kernel_v = kernel_v.unsqueeze(0).unsqueeze(0)
135 |
136 | return torch.cat([kernel_h, kernel_v], dim=0)
137 |
--------------------------------------------------------------------------------
/base_ml/base_validator.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # Validators
3 |
4 |
5 | from schema import Schema, Or
6 |
7 | sweep_schema = Schema(
8 | {
9 | "method": Or("grid", "random", "bayes"),
10 | "name": str,
11 | "metric": {"name": str, "goal": Or("maximize", "minimize")},
12 | "run_cap": int,
13 | },
14 | ignore_extra_keys=True,
15 | )
16 |
--------------------------------------------------------------------------------
/base_ml/optim_factory.py:
--------------------------------------------------------------------------------
1 | # Based on RepLKNet, ConvNeXt, timm, DINO and DeiT code bases
2 | # https://github.com/DingXiaoH/RepLKNet-pytorch
3 | # https://github.com/facebookresearch/ConvNeXt
4 | # https://github.com/rwightman/pytorch-image-models/tree/master/timm
5 | # https://github.com/facebookresearch/deit/
6 | # https://github.com/facebookresearch/dino
7 | # --------------------------------------------------------'
8 | import torch
9 | from torch import optim as optim
10 |
11 | from timm.optim.adafactor import Adafactor
12 | from timm.optim.adahessian import Adahessian
13 | from timm.optim.adamp import AdamP
14 | from timm.optim.lookahead import Lookahead
15 | from timm.optim.nadam import Nadam
16 | from timm.optim.radam import RAdam
17 | from timm.optim.rmsprop_tf import RMSpropTF
18 | from timm.optim.sgdp import SGDP
19 |
20 | import json
21 |
22 | try:
23 | from apex.optimizers import FusedNovoGrad, FusedAdam, FusedLAMB, FusedSGD
24 | has_apex = True
25 | except ImportError:
26 | has_apex = False
27 |
28 |
29 | def get_num_layer_for_convnext(var_name):
30 | """
31 | Divide [3, 3, 27, 3] layers into 12 groups; each group is three
32 | consecutive blocks, including possible neighboring downsample layers;
33 | adapted from https://github.com/microsoft/unilm/blob/master/beit/optim_factory.py
34 | """
35 | num_max_layer = 12
36 | if var_name.startswith("downsample_layers"):
37 | stage_id = int(var_name.split('.')[1])
38 | if stage_id == 0:
39 | layer_id = 0
40 | elif stage_id == 1 or stage_id == 2:
41 | layer_id = stage_id + 1
42 | elif stage_id == 3:
43 | layer_id = 12
44 | return layer_id
45 |
46 | elif var_name.startswith("stages"):
47 | stage_id = int(var_name.split('.')[1])
48 | block_id = int(var_name.split('.')[2])
49 | if stage_id == 0 or stage_id == 1:
50 | layer_id = stage_id + 1
51 | elif stage_id == 2:
52 | layer_id = 3 + block_id // 3
53 | elif stage_id == 3:
54 | layer_id = 12
55 | return layer_id
56 | else:
57 | return num_max_layer + 1
58 |
59 | class LayerDecayValueAssigner(object):
60 | def __init__(self, values):
61 | self.values = values
62 |
63 | def get_scale(self, layer_id):
64 | return self.values[layer_id]
65 |
66 | def get_layer_id(self, var_name):
67 | return get_num_layer_for_convnext(var_name)
68 |
69 |
70 | def get_parameter_groups(model, weight_decay=1e-5, skip_list=(), get_num_layer=None, get_layer_scale=None):
71 | parameter_group_names = {}
72 | parameter_group_vars = {}
73 |
74 | for name, param in model.named_parameters():
75 | if not param.requires_grad:
76 | continue # frozen weights
77 | if len(param.shape) == 1 or name.endswith(".bias") or name in skip_list:
78 | group_name = "no_decay"
79 | this_weight_decay = 0.
80 | else:
81 | group_name = "decay"
82 | this_weight_decay = weight_decay
83 | if get_num_layer is not None:
84 | layer_id = get_num_layer(name)
85 | group_name = "layer_%d_%s" % (layer_id, group_name)
86 | else:
87 | layer_id = None
88 |
89 | if group_name not in parameter_group_names:
90 | if get_layer_scale is not None:
91 | scale = get_layer_scale(layer_id)
92 | else:
93 | scale = 1.
94 |
95 | parameter_group_names[group_name] = {
96 | "weight_decay": this_weight_decay,
97 | "params": [],
98 | "lr_scale": scale
99 | }
100 | parameter_group_vars[group_name] = {
101 | "weight_decay": this_weight_decay,
102 | "params": [],
103 | "lr_scale": scale
104 | }
105 |
106 | parameter_group_vars[group_name]["params"].append(param)
107 | parameter_group_names[group_name]["params"].append(name)
108 | print("Param groups = %s" % json.dumps(parameter_group_names, indent=2))
109 | return list(parameter_group_vars.values())
110 |
111 |
112 | def create_optimizer(model, weight_decay, lr, opt, get_num_layer=None, opt_eps=None, opt_betas=None, get_layer_scale=None, filter_bias_and_bn=True, skip_list=None, momentum = 0.9):
113 | opt_lower = opt.lower()
114 | weight_decay = weight_decay
115 | # if weight_decay and filter_bias_and_bn:
116 | if filter_bias_and_bn:
117 | skip = {}
118 | if skip_list is not None:
119 | skip = skip_list
120 | elif hasattr(model, 'no_weight_decay'):
121 | skip = model.no_weight_decay()
122 | parameters = get_parameter_groups(model, weight_decay, skip, get_num_layer, get_layer_scale)
123 | weight_decay = 0.
124 | else:
125 | parameters = model.parameters()
126 |
127 | if 'fused' in opt_lower:
128 | assert has_apex and torch.cuda.is_available(), 'APEX and CUDA required for fused optimizers'
129 |
130 | opt_args = dict(lr=lr, weight_decay=weight_decay)
131 | if opt_eps is not None:
132 | opt_args['eps'] = opt_eps
133 | if opt_betas is not None:
134 | opt_args['betas'] = opt_betas
135 |
136 | opt_split = opt_lower.split('_')
137 | opt_lower = opt_split[-1]
138 | if opt_lower == 'sgd' or opt_lower == 'nesterov':
139 | opt_args.pop('eps', None)
140 | optimizer = optim.SGD(parameters, momentum=momentum, nesterov=True, **opt_args)
141 | elif opt_lower == 'momentum':
142 | opt_args.pop('eps', None)
143 | optimizer = optim.SGD(parameters, momentum=momentum, nesterov=False, **opt_args)
144 | elif opt_lower == 'adam':
145 | optimizer = optim.Adam(parameters, **opt_args)
146 | if opt_lower == 'adamw':
147 | optimizer = optim.AdamW(parameters, **opt_args)
148 | elif opt_lower == 'nadam':
149 | optimizer = Nadam(parameters, **opt_args)
150 | elif opt_lower == 'radam':
151 | optimizer = RAdam(parameters, **opt_args)
152 | elif opt_lower == 'adamp':
153 | optimizer = AdamP(parameters, wd_ratio=0.01, nesterov=True, **opt_args)
154 | elif opt_lower == 'sgdp':
155 | optimizer = SGDP(parameters, momentum=momentum, nesterov=True, **opt_args)
156 | elif opt_lower == 'adadelta':
157 | optimizer = optim.Adadelta(parameters, **opt_args)
158 |
159 | elif opt_lower == 'adahessian':
160 | optimizer = Adahessian(parameters, **opt_args)
161 | elif opt_lower == 'rmsprop':
162 | optimizer = optim.RMSprop(parameters, alpha=0.9, momentum=momentum, **opt_args)
163 | elif opt_lower == 'rmsproptf':
164 | optimizer = RMSpropTF(parameters, alpha=0.9, momentum=momentum, **opt_args)
165 | elif opt_lower == 'fusedsgd':
166 | opt_args.pop('eps', None)
167 | optimizer = FusedSGD(parameters, momentum=momentum, nesterov=True, **opt_args)
168 | elif opt_lower == 'fusedmomentum':
169 | opt_args.pop('eps', None)
170 | optimizer = FusedSGD(parameters, momentum=momentum, nesterov=False, **opt_args)
171 | elif opt_lower == 'fusedadam':
172 | optimizer = FusedAdam(parameters, adam_w_mode=False, **opt_args)
173 | elif opt_lower == 'fusedadamw':
174 | optimizer = FusedAdam(parameters, adam_w_mode=True, **opt_args)
175 | elif opt_lower == 'fusedlamb':
176 | optimizer = FusedLAMB(parameters, **opt_args)
177 | elif opt_lower == 'fusednovograd':
178 | opt_args.setdefault('betas', (0.95, 0.98))
179 | optimizer = FusedNovoGrad(parameters, **opt_args)
180 | else:
181 | assert False and "Invalid optimizer"
182 |
183 | if len(opt_split) > 1:
184 | if opt_split[0] == 'lookahead':
185 | optimizer = Lookahead(optimizer)
186 |
187 | return optimizer
188 |
--------------------------------------------------------------------------------
/base_ml/unireplknet_layer_decay_optimizer_constructor.py:
--------------------------------------------------------------------------------
1 | # --------------------------------------------------------
2 | # UniRepLKNet
3 | # https://github.com/AILab-CVC/UniRepLKNet
4 | # Licensed under The Apache 2.0 License [see LICENSE for details]
5 | # --------------------------------------------------------
6 | import json
7 | from mmcv.runner import OPTIMIZER_BUILDERS, DefaultOptimizerConstructor
8 | from mmcv.runner import get_dist_info
9 | from mmdet.utils import get_root_logger
10 |
11 | def get_layer_id(var_name, max_layer_id,):
12 | """Get the layer id to set the different learning rates in ``layer_wise``
13 | decay_type.
14 |
15 | Args:
16 | var_name (str): The key of the model.
17 | max_layer_id (int): Maximum layer id.
18 |
19 | Returns:
20 | int: The id number corresponding to different learning rate in
21 | ``LearningRateDecayOptimizerConstructor``.
22 | """
23 |
24 | if var_name in ('backbone.cls_token', 'backbone.mask_token',
25 | 'backbone.pos_embed'):
26 | return 0
27 |
28 | elif var_name.startswith('backbone.downsample_layers'):
29 | stage_id = int(var_name.split('.')[2])
30 | if stage_id == 0:
31 | layer_id = 0
32 | elif stage_id == 1:
33 | layer_id = 2
34 | elif stage_id == 2:
35 | layer_id = 3
36 | elif stage_id == 3:
37 | layer_id = max_layer_id
38 | return layer_id
39 |
40 | elif var_name.startswith('backbone.stages'):
41 | stage_id = int(var_name.split('.')[2])
42 | block_id = int(var_name.split('.')[3])
43 | if stage_id == 0:
44 | layer_id = 1
45 | elif stage_id == 1:
46 | layer_id = 2
47 | elif stage_id == 2:
48 | layer_id = 3 + block_id // 3
49 | elif stage_id == 3:
50 | layer_id = max_layer_id
51 | return layer_id
52 |
53 | else:
54 | return max_layer_id + 1
55 |
56 |
57 |
58 | def get_stage_id(var_name, max_stage_id):
59 | """Get the stage id to set the different learning rates in ``stage_wise``
60 | decay_type.
61 |
62 | Args:
63 | var_name (str): The key of the model.
64 | max_stage_id (int): Maximum stage id.
65 |
66 | Returns:
67 | int: The id number corresponding to different learning rate in
68 | ``LearningRateDecayOptimizerConstructor``.
69 | """
70 |
71 | if var_name in ('backbone.cls_token', 'backbone.mask_token',
72 | 'backbone.pos_embed'):
73 | return 0
74 | elif var_name.startswith('backbone.downsample_layers'):
75 | return 0
76 | elif var_name.startswith('backbone.stages'):
77 | stage_id = int(var_name.split('.')[2])
78 | return stage_id + 1
79 | else:
80 | return max_stage_id - 1
81 |
82 |
83 | @OPTIMIZER_BUILDERS.register_module()
84 | class UniRepLKNetLearningRateDecayOptimizerConstructor(DefaultOptimizerConstructor):
85 | # Different learning rates are set for different layers of backbone.
86 | # The design is inspired by and adapted from ConvNeXt.
87 |
88 | def add_params(self, params, module, **kwargs):
89 | """Add all parameters of module to the params list.
90 |
91 | The parameters of the given module will be added to the list of param
92 | groups, with specific rules defined by paramwise_cfg.
93 |
94 | Args:
95 | params (list[dict]): A list of param groups, it will be modified
96 | in place.
97 | module (nn.Module): The module to be added.
98 | """
99 | logger = get_root_logger()
100 |
101 | parameter_groups = {}
102 | logger.info(f'self.paramwise_cfg is {self.paramwise_cfg}')
103 | num_layers = self.paramwise_cfg.get('num_layers') + 2
104 | decay_rate = self.paramwise_cfg.get('decay_rate')
105 | decay_type = self.paramwise_cfg.get('decay_type', 'layer_wise')
106 | dw_scale = self.paramwise_cfg.get('dw_scale', 1)
107 | logger.info('Build UniRepLKNetLearningRateDecayOptimizerConstructor '
108 | f'{decay_type} {decay_rate} - {num_layers}')
109 | weight_decay = self.base_wd
110 | for name, param in module.named_parameters():
111 | if not param.requires_grad:
112 | continue # frozen weights
113 | if len(param.shape) == 1 or name.endswith('.bias') or name in (
114 | 'pos_embed', 'cls_token'):
115 | group_name = 'no_decay'
116 | this_weight_decay = 0.
117 | else:
118 | group_name = 'decay'
119 | this_weight_decay = weight_decay
120 | if 'layer_wise' in decay_type:
121 | layer_id = get_layer_id(name, self.paramwise_cfg.get('num_layers'))
122 | logger.info(f'set param {name} as id {layer_id}')
123 | elif decay_type == 'stage_wise':
124 | layer_id = get_stage_id(name, num_layers)
125 | logger.info(f'set param {name} as id {layer_id}')
126 |
127 | if dw_scale == 1 or 'dwconv' not in name:
128 | group_name = f'layer_{layer_id}_{group_name}'
129 | if group_name not in parameter_groups:
130 | scale = decay_rate ** (num_layers - layer_id - 1)
131 | parameter_groups[group_name] = {
132 | 'weight_decay': this_weight_decay,
133 | 'params': [],
134 | 'param_names': [],
135 | 'lr_scale': scale,
136 | 'group_name': group_name,
137 | 'lr': scale * self.base_lr,
138 | }
139 |
140 | parameter_groups[group_name]['params'].append(param)
141 | parameter_groups[group_name]['param_names'].append(name)
142 | else:
143 | group_name = f'layer_{layer_id}_{group_name}_dwconv'
144 | if group_name not in parameter_groups:
145 | scale = decay_rate ** (num_layers - layer_id - 1) * dw_scale
146 | parameter_groups[group_name] = {
147 | 'weight_decay': this_weight_decay,
148 | 'params': [],
149 | 'param_names': [],
150 | 'lr_scale': scale,
151 | 'group_name': group_name,
152 | 'lr': scale * self.base_lr,
153 | }
154 |
155 | parameter_groups[group_name]['params'].append(param)
156 | parameter_groups[group_name]['param_names'].append(name)
157 |
158 | rank, _ = get_dist_info()
159 | if rank == 0:
160 | to_display = {}
161 | for key in parameter_groups:
162 | to_display[key] = {
163 | 'param_names': parameter_groups[key]['param_names'],
164 | 'lr_scale': parameter_groups[key]['lr_scale'],
165 | 'lr': parameter_groups[key]['lr'],
166 | 'weight_decay': parameter_groups[key]['weight_decay'],
167 | }
168 | logger.info(f'Param groups = {json.dumps(to_display, indent=2)}')
169 | params.extend(parameter_groups.values())
170 |
--------------------------------------------------------------------------------
/cell_segmentation/__init__.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # Cell Segmentation and detection using our cellvit model
3 | #
4 | # @ Fabian Hörst, fabian.hoerst@uk-essen.de
5 | # Institute for Artifical Intelligence in Medicine,
6 | # University Medicine Essen
7 |
--------------------------------------------------------------------------------
/cell_segmentation/datasets/base_cell.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # Base cell segmentation dataset, based on torch Dataset implementation
3 | #
4 | # @ Fabian Hörst, fabian.hoerst@uk-essen.de
5 | # Institute for Artifical Intelligence in Medicine,
6 | # University Medicine Essen
7 |
8 | import logging
9 | from typing import Callable
10 |
11 | import torch
12 | from torch.utils.data import Dataset
13 |
14 | logger = logging.getLogger()
15 | logger.addHandler(logging.NullHandler())
16 |
17 | from abc import abstractmethod
18 |
19 |
20 | class CellDataset(Dataset):
21 | def set_transforms(self, transforms: Callable) -> None:
22 | self.transforms = transforms
23 |
24 | @abstractmethod
25 | def load_cell_count(self):
26 | """Load Cell count from cell_count.csv file. File must be located inside the fold folder
27 |
28 | Example file beginning:
29 | Image,Neoplastic,Inflammatory,Connective,Dead,Epithelial
30 | 0_0.png,4,2,2,0,0
31 | 0_1.png,8,1,1,0,0
32 | 0_10.png,17,0,1,0,0
33 | 0_100.png,10,0,11,0,0
34 | ...
35 | """
36 | pass
37 |
38 | @abstractmethod
39 | def get_sampling_weights_tissue(self, gamma: float = 1) -> torch.Tensor:
40 | """Get sampling weights calculated by tissue type statistics
41 |
42 | For this, a file named "weight_config.yaml" with the content:
43 | tissue:
44 | tissue_1: xxx
45 | tissue_2: xxx (name of tissue: count)
46 | ...
47 | Must exists in the dataset main folder (parent path, not inside the folds)
48 |
49 | Args:
50 | gamma (float, optional): Gamma scaling factor, between 0 and 1.
51 | 1 means total balancing, 0 means original weights. Defaults to 1.
52 |
53 | Returns:
54 | torch.Tensor: Weights for each sample
55 | """
56 |
57 | @abstractmethod
58 | def get_sampling_weights_cell(self, gamma: float = 1) -> torch.Tensor:
59 | """Get sampling weights calculated by cell type statistics
60 |
61 | Args:
62 | gamma (float, optional): Gamma scaling factor, between 0 and 1.
63 | 1 means total balancing, 0 means original weights. Defaults to 1.
64 |
65 | Returns:
66 | torch.Tensor: Weights for each sample
67 | """
68 |
69 | def get_sampling_weights_cell_tissue(self, gamma: float = 1) -> torch.Tensor:
70 | """Get combined sampling weights by calculating tissue and cell sampling weights,
71 | normalizing them and adding them up to yield one score.
72 |
73 | Args:
74 | gamma (float, optional): Gamma scaling factor, between 0 and 1.
75 | 1 means total balancing, 0 means original weights. Defaults to 1.
76 |
77 | Returns:
78 | torch.Tensor: Weights for each sample
79 | """
80 | assert 0 <= gamma <= 1, "Gamma must be between 0 and 1"
81 | tw = self.get_sampling_weights_tissue(gamma)
82 | cw = self.get_sampling_weights_cell(gamma)
83 | weights = tw / torch.max(tw) + cw / torch.max(cw)
84 |
85 | return weights
86 |
--------------------------------------------------------------------------------
/cell_segmentation/datasets/cell_graph_datamodel.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # Graph Data model
3 | #
4 | # For more information, please check out docs/readmes/graphs.md
5 | #
6 | # @ Fabian Hörst, fabian.hoerst@uk-essen.de
7 | # Institute for Artifical Intelligence in Medicine,
8 | # University Medicine Essen
9 |
10 | from dataclasses import dataclass
11 | from typing import List
12 |
13 | import torch
14 |
15 | from datamodel.graph_datamodel import GraphDataWSI
16 |
17 |
18 | @dataclass
19 | class CellGraphDataWSI(GraphDataWSI):
20 | """Dataclass for Graph Data
21 |
22 | Args:
23 | contours (List[torch.Tensor]): Contour Data for each object.
24 | """
25 |
26 | contours: List[torch.Tensor]
27 |
--------------------------------------------------------------------------------
/cell_segmentation/datasets/conic.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # PanNuke Dataset
3 | #
4 | # Dataset information: https://arxiv.org/abs/2108.11195
5 | # Please Prepare Dataset as described here: docs/readmes/pannuke.md # TODO: write own documentation
6 | #
7 | # @ Fabian Hörst, fabian.hoerst@uk-essen.de
8 | # Institute for Artifical Intelligence in Medicine,
9 | # University Medicine Essen
10 |
11 |
12 | import logging
13 | from pathlib import Path
14 | from typing import Callable, Tuple, Union, List
15 |
16 | import numpy as np
17 | import pandas as pd
18 | import torch
19 | from PIL import Image
20 |
21 | from cell_segmentation.datasets.base_cell import CellDataset
22 | from cell_segmentation.datasets.pannuke import PanNukeDataset
23 |
24 | logger = logging.getLogger()
25 | logger.addHandler(logging.NullHandler())
26 |
27 |
28 | class CoNicDataset(CellDataset):
29 | """Lizzard dataset
30 |
31 | This dataset is always cached
32 |
33 | Args:
34 | dataset_path (Union[Path, str]): Path to Lizzard dataset. Structure is described under ./docs/readmes/cell_segmentation.md
35 | folds (Union[int, list[int]]): Folds to use for this dataset
36 | transforms (Callable, optional): PyTorch transformations. Defaults to None.
37 | stardist (bool, optional): Return StarDist labels. Defaults to False
38 | regression (bool, optional): Return Regression of cells in x and y direction. Defaults to False
39 | **kwargs are irgnored
40 | """
41 |
42 | def __init__(
43 | self,
44 | dataset_path: Union[Path, str],
45 | folds: Union[int, List[int]],
46 | transforms: Callable = None,
47 | stardist: bool = False,
48 | regression: bool = False,
49 | **kwargs,
50 | ) -> None:
51 | if isinstance(folds, int):
52 | folds = [folds]
53 |
54 | self.dataset = Path(dataset_path).resolve()
55 | self.transforms = transforms
56 | self.images = []
57 | self.masks = []
58 | self.img_names = []
59 | self.folds = folds
60 | self.stardist = stardist
61 | self.regression = regression
62 | for fold in folds:
63 | image_path = self.dataset / f"fold{fold}" / "images"
64 | fold_images = [f for f in sorted(image_path.glob("*.png")) if f.is_file()]
65 |
66 | # sanity_check: mask must exist for image
67 | for fold_image in fold_images:
68 | mask_path = (
69 | self.dataset / f"fold{fold}" / "labels" / f"{fold_image.stem}.npy"
70 | )
71 | if mask_path.is_file():
72 | self.images.append(fold_image)
73 | self.masks.append(mask_path)
74 | self.img_names.append(fold_image.name)
75 |
76 | else:
77 | logger.debug(
78 | "Found image {fold_image}, but no corresponding annotation file!"
79 | )
80 |
81 | # load everything in advance to speedup, as the dataset is rather small
82 | self.loaded_imgs = []
83 | self.loaded_masks = []
84 | for idx in range(len(self.images)):
85 | img_path = self.images[idx]
86 | img = np.array(Image.open(img_path)).astype(np.uint8)
87 |
88 | mask_path = self.masks[idx]
89 | mask = np.load(mask_path, allow_pickle=True)
90 | inst_map = mask[()]["inst_map"].astype(np.int32)
91 | type_map = mask[()]["type_map"].astype(np.int32)
92 | mask = np.stack([inst_map, type_map], axis=-1)
93 | self.loaded_imgs.append(img)
94 | self.loaded_masks.append(mask)
95 |
96 | logger.info(f"Created Pannuke Dataset by using fold(s) {self.folds}")
97 | logger.info(f"Resulting dataset length: {self.__len__()}")
98 |
99 | def __getitem__(self, index: int) -> Tuple[torch.Tensor, dict, str, str]:
100 | """Get one dataset item consisting of transformed image,
101 | masks (instance_map, nuclei_type_map, nuclei_binary_map, hv_map) and tissue type as string
102 |
103 | Args:
104 | index (int): Index of element to retrieve
105 |
106 | Returns:
107 | Tuple[torch.Tensor, dict, str, str]:
108 | torch.Tensor: Image, with shape (3, H, W), shape is arbitrary for Lizzard (H and W approx. between 500 and 2000)
109 | dict:
110 | "instance_map": Instance-Map, each instance is has one integer starting by 1 (zero is background), Shape (256, 256)
111 | "nuclei_type_map": Nuclei-Type-Map, for each nucleus (instance) the class is indicated by an integer. Shape (256, 256)
112 | "nuclei_binary_map": Binary Nuclei-Mask, Shape (256, 256)
113 | "hv_map": Horizontal and vertical instance map.
114 | Shape: (H, W, 2). First dimension is horizontal (horizontal gradient (-1 to 1)),
115 | last is vertical (vertical gradient (-1 to 1)) Shape (256, 256, 2)
116 | "dist_map": Probability distance map. Shape (256, 256)
117 | "stardist_map": Stardist vector map. Shape (n_rays, 256, 256)
118 | [Optional if regression]
119 | "regression_map": Regression map. Shape (2, 256, 256). First is vertical, second horizontal.
120 | str: Tissue type
121 | str: Image Name
122 | """
123 | img_path = self.images[index]
124 | img = self.loaded_imgs[index]
125 | mask = self.loaded_masks[index]
126 |
127 | if self.transforms is not None:
128 | transformed = self.transforms(image=img, mask=mask)
129 | img = transformed["image"]
130 | mask = transformed["mask"]
131 |
132 | inst_map = mask[:, :, 0].copy()
133 | type_map = mask[:, :, 1].copy()
134 | np_map = mask[:, :, 0].copy()
135 | np_map[np_map > 0] = 1
136 | hv_map = PanNukeDataset.gen_instance_hv_map(inst_map)
137 |
138 | # torch convert
139 | img = torch.Tensor(img).type(torch.float32)
140 | img = img.permute(2, 0, 1)
141 | if torch.max(img) >= 5:
142 | img = img / 255
143 |
144 | masks = {
145 | "instance_map": torch.Tensor(inst_map).type(torch.int64),
146 | "nuclei_type_map": torch.Tensor(type_map).type(torch.int64),
147 | "nuclei_binary_map": torch.Tensor(np_map).type(torch.int64),
148 | "hv_map": torch.Tensor(hv_map).type(torch.float32),
149 | }
150 | if self.stardist:
151 | dist_map = PanNukeDataset.gen_distance_prob_maps(inst_map)
152 | stardist_map = PanNukeDataset.gen_stardist_maps(inst_map)
153 | masks["dist_map"] = torch.Tensor(dist_map).type(torch.float32)
154 | masks["stardist_map"] = torch.Tensor(stardist_map).type(torch.float32)
155 | if self.regression:
156 | masks["regression_map"] = PanNukeDataset.gen_regression_map(inst_map)
157 |
158 | return img, masks, "Colon", Path(img_path).name
159 |
160 | def __len__(self) -> int:
161 | """Length of Dataset
162 |
163 | Returns:
164 | int: Length of Dataset
165 | """
166 | return len(self.images)
167 |
168 | def set_transforms(self, transforms: Callable) -> None:
169 | """Set the transformations, can be used tp exchange transformations
170 |
171 | Args:
172 | transforms (Callable): PyTorch transformations
173 | """
174 | self.transforms = transforms
175 |
176 | def load_cell_count(self):
177 | """Load Cell count from cell_count.csv file. File must be located inside the fold folder
178 | and named "cell_count.csv"
179 |
180 | Example file beginning:
181 | Image,Neutrophil,Epithelial,Lymphocyte,Plasma,Eosinophil,Connective
182 | consep_1_0000.png,0,117,0,0,0,0
183 | consep_1_0001.png,0,95,1,0,0,8
184 | consep_1_0002.png,0,172,3,0,0,2
185 | ...
186 | """
187 | df_placeholder = []
188 | for fold in self.folds:
189 | csv_path = self.dataset / f"fold{fold}" / "cell_count.csv"
190 | cell_count = pd.read_csv(csv_path, index_col=0)
191 | df_placeholder.append(cell_count)
192 | self.cell_count = pd.concat(df_placeholder)
193 | self.cell_count = self.cell_count.reindex(self.img_names)
194 |
195 | def get_sampling_weights_cell(self, gamma: float = 1) -> torch.Tensor:
196 | """Get sampling weights calculated by cell type statistics
197 |
198 | Args:
199 | gamma (float, optional): Gamma scaling factor, between 0 and 1.
200 | 1 means total balancing, 0 means original weights. Defaults to 1.
201 |
202 | Returns:
203 | torch.Tensor: Weights for each sample
204 | """
205 | assert 0 <= gamma <= 1, "Gamma must be between 0 and 1"
206 | assert hasattr(self, "cell_count"), "Please run .load_cell_count() in advance!"
207 | binary_weight_factors = np.array([1069, 4189, 4356, 3103, 1025, 4527])
208 | k = np.sum(binary_weight_factors)
209 | cell_counts_imgs = np.clip(self.cell_count.to_numpy(), 0, 1)
210 | weight_vector = k / (gamma * binary_weight_factors + (1 - gamma) * k)
211 | img_weight = (1 - gamma) * np.max(cell_counts_imgs, axis=-1) + gamma * np.sum(
212 | cell_counts_imgs * weight_vector, axis=-1
213 | )
214 | img_weight[np.where(img_weight == 0)] = np.min(
215 | img_weight[np.nonzero(img_weight)]
216 | )
217 |
218 | return torch.Tensor(img_weight)
219 |
220 | # def get_sampling_weights_cell(self, gamma: float = 1) -> torch.Tensor:
221 | # """Get sampling weights calculated by cell type statistics
222 |
223 | # Args:
224 | # gamma (float, optional): Gamma scaling factor, between 0 and 1.
225 | # 1 means total balancing, 0 means original weights. Defaults to 1.
226 |
227 | # Returns:
228 | # torch.Tensor: Weights for each sample
229 | # """
230 | # assert 0 <= gamma <= 1, "Gamma must be between 0 and 1"
231 | # assert hasattr(self, "cell_count"), "Please run .load_cell_count() in advance!"
232 | # binary_weight_factors = np.array([4012, 222017, 93612, 24793, 2999, 98783])
233 | # k = np.sum(binary_weight_factors)
234 | # cell_counts_imgs = self.cell_count.to_numpy()
235 | # weight_vector = k / (gamma * binary_weight_factors + (1 - gamma) * k)
236 | # img_weight = (1 - gamma) * np.max(cell_counts_imgs, axis=-1) + gamma * np.sum(
237 | # cell_counts_imgs * weight_vector, axis=-1
238 | # )
239 | # img_weight[np.where(img_weight == 0)] = np.min(
240 | # img_weight[np.nonzero(img_weight)]
241 | # )
242 |
243 | # return torch.Tensor(img_weight)
244 |
--------------------------------------------------------------------------------
/cell_segmentation/datasets/consep.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # MoNuSeg Dataset
3 | #
4 | # Dataset information: https://monuseg.grand-challenge.org/Home/
5 | # Please Prepare Dataset as described here: docs/readmes/monuseg.md
6 | #
7 | # @ Fabian Hörst, fabian.hoerst@uk-essen.de
8 | # Institute for Artifical Intelligence in Medicine,
9 | # University Medicine Essen
10 |
11 | import logging
12 | from pathlib import Path
13 | from typing import Callable, Union, Tuple
14 |
15 | import numpy as np
16 | import torch
17 | from PIL import Image
18 | from torch.utils.data import Dataset
19 |
20 | from cell_segmentation.datasets.pannuke import PanNukeDataset
21 |
22 | logger = logging.getLogger()
23 | logger.addHandler(logging.NullHandler())
24 |
25 |
26 | class CoNSePDataset(Dataset):
27 | def __init__(
28 | self,
29 | dataset_path: Union[Path, str],
30 | transforms: Callable = None,
31 | ) -> None:
32 | """MoNuSeg Dataset
33 |
34 | Args:
35 | dataset_path (Union[Path, str]): Path to dataset
36 | transforms (Callable, optional): Transformations to apply on images. Defaults to None.
37 | Raises:
38 | FileNotFoundError: If no ground-truth annotation file was found in path
39 | """
40 | self.dataset = Path(dataset_path).resolve()
41 | self.transforms = transforms
42 | self.masks = []
43 | self.img_names = []
44 |
45 | image_path = self.dataset / "images"
46 | label_path = self.dataset / "labels"
47 | self.images = [f for f in sorted(image_path.glob("*.png")) if f.is_file()]
48 | self.masks = [f for f in sorted(label_path.glob("*.npy")) if f.is_file()]
49 |
50 | # sanity_check
51 | for idx, image in enumerate(self.images):
52 | image_name = image.stem
53 | mask_name = self.masks[idx].stem
54 | if image_name != mask_name:
55 | raise FileNotFoundError(f"Annotation for file {image_name} is missing")
56 |
57 | def __getitem__(self, index: int) -> Tuple[torch.Tensor, dict, str]:
58 | """Get one item from dataset
59 |
60 | Args:
61 | index (int): Item to get
62 |
63 | Returns:
64 | Tuple[torch.Tensor, dict, str]: Trainings-Batch
65 | * torch.Tensor: Image
66 | * dict: Ground-Truth values: keys are "instance map", "nuclei_binary_map" and "hv_map"
67 | * str: filename
68 | """
69 | img_path = self.images[index]
70 | img = np.array(Image.open(img_path)).astype(np.uint8)
71 |
72 | mask_path = self.masks[index]
73 | mask = np.load(mask_path, allow_pickle=True)
74 | inst_map = mask[()]["inst_map"].astype(np.int32)
75 | type_map = mask[()]["type_map"].astype(np.int32)
76 | mask = np.stack([inst_map, type_map], axis=-1)
77 |
78 | if self.transforms is not None:
79 | transformed = self.transforms(image=img, mask=mask)
80 | img = transformed["image"]
81 | mask = transformed["mask"]
82 |
83 | inst_map = mask[:, :, 0].copy()
84 | type_map = mask[:, :, 1].copy()
85 | np_map = mask[:, :, 0].copy()
86 | np_map[np_map > 0] = 1
87 | hv_map = PanNukeDataset.gen_instance_hv_map(inst_map)
88 |
89 | # torch convert
90 | img = torch.Tensor(img).type(torch.float32)
91 | img = img.permute(2, 0, 1)
92 | if torch.max(img) >= 5:
93 | img = img / 255
94 |
95 | masks = {
96 | "instance_map": torch.Tensor(inst_map).type(torch.int64),
97 | "nuclei_type_map": torch.Tensor(type_map).type(torch.int64),
98 | "nuclei_binary_map": torch.Tensor(np_map).type(torch.int64),
99 | "hv_map": torch.Tensor(hv_map).type(torch.float32),
100 | }
101 |
102 | return img, masks, Path(img_path).name
103 |
104 | def __len__(self) -> int:
105 | """Length of Dataset
106 |
107 | Returns:
108 | int: Length of Dataset
109 | """
110 | return len(self.images)
111 |
112 | def set_transforms(self, transforms: Callable) -> None:
113 | """Set the transformations, can be used tp exchange transformations
114 |
115 | Args:
116 | transforms (Callable): PyTorch transformations
117 | """
118 | self.transforms = transforms
119 |
--------------------------------------------------------------------------------
/cell_segmentation/datasets/dataset_coordinator.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # Coordinate the datasets, used to select the right dataset with corresponding setting
3 | #
4 | # @ Fabian Hörst, fabian.hoerst@uk-essen.de
5 | # Institute for Artifical Intelligence in Medicine,
6 | # University Medicine Essen
7 |
8 | from typing import Callable
9 |
10 | from torch.utils.data import Dataset
11 | from cell_segmentation.datasets.conic import CoNicDataset
12 |
13 | from cell_segmentation.datasets.pannuke import PanNukeDataset
14 |
15 |
16 | def select_dataset(
17 | dataset_name: str, split: str, dataset_config: dict, transforms: Callable = None
18 | ) -> Dataset:
19 | """Select a cell segmentation dataset from the provided ones, currently just PanNuke is implemented here
20 |
21 | Args:
22 | dataset_name (str): Name of dataset to use.
23 | Must be one of: [pannuke, lizzard]
24 | split (str): Split to use.
25 | Must be one of: ["train", "val", "validation", "test"]
26 | dataset_config (dict): Dictionary with dataset configuration settings
27 | transforms (Callable, optional): PyTorch Image and Mask transformations. Defaults to None.
28 |
29 | Raises:
30 | NotImplementedError: Unknown dataset
31 |
32 | Returns:
33 | Dataset: Cell segmentation dataset
34 | """
35 | assert split.lower() in [
36 | "train",
37 | "val",
38 | "validation",
39 | "test",
40 | ], "Unknown split type!"
41 |
42 | if dataset_name.lower() == "pannuke":
43 | if split == "train":
44 | folds = dataset_config["train_folds"]
45 | if split == "val" or split == "validation":
46 | folds = dataset_config["val_folds"]
47 | if split == "test":
48 | folds = dataset_config["test_folds"]
49 | dataset = PanNukeDataset(
50 | dataset_path=dataset_config["dataset_path"],
51 | folds=folds,
52 | transforms=transforms,
53 | stardist=dataset_config.get("stardist", False),
54 | regression=dataset_config.get("regression_loss", False),
55 | )
56 | elif dataset_name.lower() == "conic":
57 | if split == "train":
58 | folds = dataset_config["train_folds"]
59 | if split == "val" or split == "validation":
60 | folds = dataset_config["val_folds"]
61 | if split == "test":
62 | folds = dataset_config["test_folds"]
63 | dataset = CoNicDataset(
64 | dataset_path=dataset_config["dataset_path"],
65 | folds=folds,
66 | transforms=transforms,
67 | stardist=dataset_config.get("stardist", False),
68 | regression=dataset_config.get("regression_loss", False),
69 | # TODO: Stardist and regression loss
70 | )
71 | else:
72 | raise NotImplementedError(f"Unknown dataset: {dataset_name}")
73 | return dataset
74 |
--------------------------------------------------------------------------------
/cell_segmentation/datasets/monuseg.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # MoNuSeg Dataset
3 | #
4 | # Dataset information: https://monuseg.grand-challenge.org/Home/
5 | # Please Prepare Dataset as described here: docs/readmes/monuseg.md
6 | #
7 | # @ Fabian Hörst, fabian.hoerst@uk-essen.de
8 | # Institute for Artifical Intelligence in Medicine,
9 | # University Medicine Essen
10 |
11 | import logging
12 | from pathlib import Path
13 | from typing import Callable, Union, Tuple
14 |
15 | import numpy as np
16 | import torch
17 | from PIL import Image
18 | from torch.utils.data import Dataset
19 |
20 | from cell_segmentation.datasets.pannuke import PanNukeDataset
21 | from einops import rearrange
22 |
23 | logger = logging.getLogger()
24 | logger.addHandler(logging.NullHandler())
25 |
26 |
27 | class MoNuSegDataset(Dataset):
28 | def __init__(
29 | self,
30 | dataset_path: Union[Path, str],
31 | transforms: Callable = None,
32 | patching: bool = False,
33 | overlap: int = 0,
34 | ) -> None:
35 | """MoNuSeg Dataset
36 |
37 | Args:
38 | dataset_path (Union[Path, str]): Path to dataset
39 | transforms (Callable, optional): Transformations to apply on images. Defaults to None.
40 | patching (bool, optional): If patches with size 256px should be used Otherwise, the entire MoNuSeg images are loaded. Defaults to False.
41 | overlap: (bool, optional): If overlap should be used for patch sampling. Overlap in pixels.
42 | Recommended value other than 0 is 64. Defaults to 0.
43 | Raises:
44 | FileNotFoundError: If no ground-truth annotation file was found in path
45 | """
46 | self.dataset = Path(dataset_path).resolve()
47 | self.transforms = transforms
48 | self.masks = []
49 | self.img_names = []
50 | self.patching = patching
51 | self.overlap = overlap
52 |
53 | image_path = self.dataset / "images"
54 | label_path = self.dataset / "labels"
55 | self.images = [f for f in sorted(image_path.glob("*.png")) if f.is_file()]
56 | self.masks = [f for f in sorted(label_path.glob("*.npy")) if f.is_file()]
57 |
58 | # sanity_check
59 | for idx, image in enumerate(self.images):
60 | image_name = image.stem
61 | mask_name = self.masks[idx].stem
62 | if image_name != mask_name:
63 | raise FileNotFoundError(f"Annotation for file {image_name} is missing")
64 |
65 | def __getitem__(self, index: int) -> Tuple[torch.Tensor, dict, str]:
66 | """Get one item from dataset
67 |
68 | Args:
69 | index (int): Item to get
70 |
71 | Returns:
72 | Tuple[torch.Tensor, dict, str]: Trainings-Batch
73 | * torch.Tensor: Image
74 | * dict: Ground-Truth values: keys are "instance map", "nuclei_binary_map" and "hv_map"
75 | * str: filename
76 | """
77 | img_path = self.images[index]
78 | img = np.array(Image.open(img_path)).astype(np.uint8)
79 |
80 | mask_path = self.masks[index]
81 | mask = np.load(mask_path, allow_pickle=True)
82 | mask = mask.astype(np.int64)
83 |
84 | if self.transforms is not None:
85 | transformed = self.transforms(image=img, mask=mask)
86 | img = transformed["image"]
87 | mask = transformed["mask"]
88 |
89 | hv_map = PanNukeDataset.gen_instance_hv_map(mask)
90 | np_map = mask.copy()
91 | np_map[np_map > 0] = 1
92 |
93 | # torch convert
94 | img = torch.Tensor(img).type(torch.float32)
95 | img = img.permute(2, 0, 1)
96 | if torch.max(img) >= 5:
97 | img = img / 255
98 |
99 | if self.patching and self.overlap == 0:
100 | img = rearrange(img, "c (h i) (w j) -> c h w i j", i=256, j=256)
101 | if self.patching and self.overlap != 0:
102 | img = img.unfold(1, 256, 256 - self.overlap).unfold(
103 | 2, 256, 256 - self.overlap
104 | )
105 |
106 | masks = {
107 | "instance_map": torch.Tensor(mask).type(torch.int64),
108 | "nuclei_binary_map": torch.Tensor(np_map).type(torch.int64),
109 | "hv_map": torch.Tensor(hv_map).type(torch.float32),
110 | }
111 |
112 | return img, masks, Path(img_path).name
113 |
114 | def __len__(self) -> int:
115 | """Length of Dataset
116 |
117 | Returns:
118 | int: Length of Dataset
119 | """
120 | return len(self.images)
121 |
122 | def set_transforms(self, transforms: Callable) -> None:
123 | """Set the transformations, can be used tp exchange transformations
124 |
125 | Args:
126 | transforms (Callable): PyTorch transformations
127 | """
128 | self.transforms = transforms
129 |
--------------------------------------------------------------------------------
/cell_segmentation/datasets/prepare_monuseg.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # Prepare MoNuSeg Dataset By converting and resorting files
3 | #
4 | # @ Fabian Hörst, fabian.hoerst@uk-essen.de
5 | # Institute for Artifical Intelligence in Medicine,
6 | # University Medicine Essen
7 |
8 | from PIL import Image
9 | import xml.etree.ElementTree as ET
10 | from skimage import draw
11 | import numpy as np
12 | from pathlib import Path
13 | from typing import Union
14 | import argparse
15 |
16 |
17 | def convert_monuseg(
18 | input_path: Union[Path, str], output_path: Union[Path, str]
19 | ) -> None:
20 | """Convert the MoNuSeg dataset to a new format (1000 -> 1024, tiff to png and xml to npy)
21 |
22 | Args:
23 | input_path (Union[Path, str]): Input dataset
24 | output_path (Union[Path, str]): Output path
25 | """
26 | input_path = Path(input_path)
27 | output_path = Path(output_path)
28 | output_path.mkdir(exist_ok=True, parents=True)
29 |
30 | # testing and training
31 | parts = ["testing", "training"]
32 | for part in parts:
33 | print(f"Prepare: {part}")
34 | input_path_part = input_path / part
35 | output_path_part = output_path / part
36 | output_path_part.mkdir(exist_ok=True, parents=True)
37 | (output_path_part / "images").mkdir(exist_ok=True, parents=True)
38 | (output_path_part / "labels").mkdir(exist_ok=True, parents=True)
39 |
40 | # images
41 | images = [f for f in sorted((input_path_part / "images").glob("*.tif"))]
42 | for img_path in images:
43 | loaded_image = Image.open(img_path)
44 | resized = loaded_image.resize(
45 | (1024, 1024), resample=Image.Resampling.LANCZOS
46 | )
47 | new_img_path = output_path_part / "images" / f"{img_path.stem}.png"
48 | resized.save(new_img_path)
49 | # masks
50 | annotations = [f for f in sorted((input_path_part / "labels").glob("*.xml"))]
51 | for annot_path in annotations:
52 | binary_mask = np.transpose(np.zeros((1000, 1000)))
53 |
54 | # extract xml file
55 | tree = ET.parse(annot_path)
56 | root = tree.getroot()
57 | child = root[0]
58 |
59 | for x in child:
60 | r = x.tag
61 | if r == "Regions":
62 | element_idx = 1
63 | for y in x:
64 | y_tag = y.tag
65 |
66 | if y_tag == "Region":
67 | regions = []
68 | vertices = y[1]
69 | coords = np.zeros((len(vertices), 2))
70 | for i, vertex in enumerate(vertices):
71 | coords[i][0] = vertex.attrib["X"]
72 | coords[i][1] = vertex.attrib["Y"]
73 | regions.append(coords)
74 | vertex_row_coords = regions[0][:, 0]
75 | vertex_col_coords = regions[0][:, 1]
76 | fill_row_coords, fill_col_coords = draw.polygon(
77 | vertex_col_coords, vertex_row_coords, binary_mask.shape
78 | )
79 | binary_mask[fill_row_coords, fill_col_coords] = element_idx
80 |
81 | element_idx = element_idx + 1
82 | inst_image = Image.fromarray(binary_mask)
83 | resized_mask = np.array(
84 | inst_image.resize((1024, 1024), resample=Image.Resampling.NEAREST)
85 | )
86 | new_mask_path = output_path_part / "labels" / f"{annot_path.stem}.npy"
87 | np.save(new_mask_path, resized_mask)
88 | print("Finished")
89 |
90 |
91 | parser = argparse.ArgumentParser(
92 | formatter_class=argparse.ArgumentDefaultsHelpFormatter,
93 | description="Convert the MoNuSeg dataset",
94 | )
95 | parser.add_argument(
96 | "--input_path",
97 | type=str,
98 | help="Input path of the original MoNuSeg dataset",
99 | required=True,
100 | )
101 | parser.add_argument(
102 | "--output_path",
103 | type=str,
104 | help="Output path to store the processed MoNuSeg dataset",
105 | required=True,
106 | )
107 |
108 | if __name__ == "__main__":
109 | opt = parser.parse_args()
110 | configuration = vars(opt)
111 |
112 | input_path = Path(configuration["input_path"])
113 | output_path = Path(configuration["output_path"])
114 |
115 | convert_monuseg(input_path=input_path, output_path=output_path)
116 |
--------------------------------------------------------------------------------
/cell_segmentation/datasets/prepare_pannuke.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # Prepare MoNuSeg Dataset By converting and resorting files
3 | #
4 | # @ Fabian Hörst, fabian.hoerst@uk-essen.de
5 | # Institute for Artifical Intelligence in Medicine,
6 | # University Medicine Essen
7 |
8 | import inspect
9 | import os
10 | import sys
11 |
12 | currentdir = os.path.dirname(os.path.abspath(inspect.getfile(inspect.currentframe())))
13 | parentdir = os.path.dirname(currentdir)
14 | sys.path.insert(0, parentdir)
15 | parentdir = os.path.dirname(parentdir)
16 | sys.path.insert(0, parentdir)
17 |
18 | import numpy as np
19 | from pathlib import Path
20 | from PIL import Image
21 | from tqdm import tqdm
22 | import argparse
23 | from cell_segmentation.utils.metrics import remap_label
24 |
25 |
26 | def process_fold(fold, input_path, output_path) -> None:
27 | fold_path = Path(input_path) / f"fold{fold}"
28 | output_fold_path = Path(output_path) / f"fold{fold}"
29 | output_fold_path.mkdir(exist_ok=True, parents=True)
30 | (output_fold_path / "images").mkdir(exist_ok=True, parents=True)
31 | (output_fold_path / "labels").mkdir(exist_ok=True, parents=True)
32 |
33 | print(f"Fold: {fold}")
34 | print("Loading large numpy files, this may take a while")
35 | images = np.load(fold_path / "images.npy")
36 | masks = np.load(fold_path / "masks.npy")
37 |
38 | print("Process images")
39 | for i in tqdm(range(len(images)), total=len(images)):
40 | outname = f"{fold}_{i}.png"
41 | out_img = images[i]
42 | im = Image.fromarray(out_img.astype(np.uint8))
43 | im.save(output_fold_path / "images" / outname)
44 |
45 | print("Process masks")
46 | for i in tqdm(range(len(images)), total=len(images)):
47 | outname = f"{fold}_{i}.npy"
48 |
49 | # need to create instance map and type map with shape 256x256
50 | mask = masks[i]
51 | inst_map = np.zeros((256, 256))
52 | num_nuc = 0
53 | for j in range(5):
54 | # copy value from new array if value is not equal 0
55 | layer_res = remap_label(mask[:, :, j])
56 | # inst_map = np.where(mask[:,:,j] != 0, mask[:,:,j], inst_map)
57 | inst_map = np.where(layer_res != 0, layer_res + num_nuc, inst_map)
58 | num_nuc = num_nuc + np.max(layer_res)
59 | inst_map = remap_label(inst_map)
60 |
61 | type_map = np.zeros((256, 256)).astype(np.int32)
62 | for j in range(5):
63 | layer_res = ((j + 1) * np.clip(mask[:, :, j], 0, 1)).astype(np.int32)
64 | type_map = np.where(layer_res != 0, layer_res, type_map)
65 |
66 | outdict = {"inst_map": inst_map, "type_map": type_map}
67 | np.save(output_fold_path / "labels" / outname, outdict)
68 |
69 |
70 | parser = argparse.ArgumentParser(
71 | formatter_class=argparse.ArgumentDefaultsHelpFormatter,
72 | description="Perform CellViT inference for given run-directory with model checkpoints and logs",
73 | )
74 | parser.add_argument(
75 | "--input_path",
76 | type=str,
77 | help="Input path of the original PanNuke dataset",
78 | required=True,
79 | )
80 | parser.add_argument(
81 | "--output_path",
82 | type=str,
83 | help="Output path to store the processed PanNuke dataset",
84 | required=True,
85 | )
86 |
87 | if __name__ == "__main__":
88 | opt = parser.parse_args()
89 | configuration = vars(opt)
90 |
91 | input_path = Path(configuration["input_path"])
92 | output_path = Path(configuration["output_path"])
93 |
94 | for fold in [0, 1, 2]:
95 | process_fold(fold, input_path, output_path)
96 |
--------------------------------------------------------------------------------
/cell_segmentation/experiments/__init__.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # Experiment related methods for each network type
3 | #
4 | # @ Fabian Hörst, fabian.hoerst@uk-essen.de
5 | # Institute for Artifical Intelligence in Medicine,
6 | # University Medicine Essen
7 |
--------------------------------------------------------------------------------
/cell_segmentation/inference/__init__.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # Inference related methods for each network type
3 | #
4 | # @ Fabian Hörst, fabian.hoerst@uk-essen.de
5 | # Institute for Artifical Intelligence in Medicine,
6 | # University Medicine Essen
7 |
--------------------------------------------------------------------------------
/cell_segmentation/run_cellvit.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # Running an Experiment Using CellViT cell segmentation network
3 | #
4 | # @ Fabian Hörst, fabian.hoerst@uk-essen.de
5 | # Institute for Artifical Intelligence in Medicine,
6 | # University Medicine Essen
7 |
8 | import inspect
9 | import os
10 | import sys
11 |
12 | currentdir = os.path.dirname(os.path.abspath(inspect.getfile(inspect.currentframe())))
13 | parentdir = os.path.dirname(currentdir)
14 | sys.path.insert(0, parentdir)
15 |
16 | import wandb
17 |
18 | from base_ml.base_cli import ExperimentBaseParser
19 | from cell_segmentation.experiments.experiment_cellvit_pannuke import (
20 | ExperimentCellVitPanNuke,
21 | )
22 | from cell_segmentation.experiments.experiment_cellvit_conic import (
23 | ExperimentCellViTCoNic,
24 | )
25 |
26 | from cell_segmentation.inference.inference_cellvit_experiment_pannuke import (
27 | InferenceCellViT,
28 | )
29 |
30 | if __name__ == "__main__":
31 | # Parse arguments
32 | configuration_parser = ExperimentBaseParser()
33 | configuration = configuration_parser.parse_arguments()
34 |
35 | if configuration["data"]["dataset"].lower() == "pannuke":
36 | experiment_class = ExperimentCellVitPanNuke
37 | elif configuration["data"]["dataset"].lower() == "conic":
38 | experiment_class = ExperimentCellViTCoNic
39 | # Setup experiment
40 | if "checkpoint" in configuration:
41 | # continue checkpoint
42 | experiment = experiment_class(
43 | default_conf=configuration, checkpoint=configuration["checkpoint"]
44 | )
45 | outdir = experiment.run_experiment()
46 | inference = InferenceCellViT(
47 | run_dir=outdir,
48 | gpu=configuration["gpu"],
49 | checkpoint_name=configuration["eval_checkpoint"],
50 | magnification=configuration["data"].get("magnification", 40),
51 | )
52 | (
53 | trained_model,
54 | inference_dataloader,
55 | dataset_config,
56 | ) = inference.setup_patch_inference()
57 | inference.run_patch_inference(
58 | trained_model, inference_dataloader, dataset_config, generate_plots=False
59 | )
60 | else:
61 | experiment = experiment_class(default_conf=configuration)
62 | if configuration["run_sweep"] is True:
63 | # run new sweep
64 | sweep_configuration = experiment_class.extract_sweep_arguments(
65 | configuration
66 | )
67 | os.environ["WANDB_DIR"] = os.path.abspath(
68 | configuration["logging"]["wandb_dir"]
69 | )
70 | sweep_id = wandb.sweep(
71 | sweep=sweep_configuration, project=configuration["logging"]["project"]
72 | )
73 | wandb.agent(sweep_id=sweep_id, function=experiment.run_experiment)
74 | elif "agent" in configuration and configuration["agent"] is not None:
75 | # add agent to already existing sweep, not run sweep must be set to true
76 | configuration["run_sweep"] = True
77 | os.environ["WANDB_DIR"] = os.path.abspath(
78 | configuration["logging"]["wandb_dir"]
79 | )
80 | wandb.agent(
81 | sweep_id=configuration["agent"], function=experiment.run_experiment
82 | )
83 | else:
84 | # casual run
85 | outdir = experiment.run_experiment()
86 | inference = InferenceCellViT(
87 | run_dir=outdir,
88 | gpu=configuration["gpu"],
89 | checkpoint_name=configuration["eval_checkpoint"],
90 | magnification=configuration["data"].get("magnification", 40),
91 | )
92 | (
93 | trained_model,
94 | inference_dataloader,
95 | dataset_config,
96 | ) = inference.setup_patch_inference()
97 | inference.run_patch_inference(
98 | trained_model,
99 | inference_dataloader,
100 | dataset_config,
101 | generate_plots=False,
102 | )
103 | wandb.finish()
104 |
--------------------------------------------------------------------------------
/cell_segmentation/trainer/__init__.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # Trainer for each network type
3 | #
4 | # @ Fabian Hörst, fabian.hoerst@uk-essen.de
5 | # Institute for Artifical Intelligence in Medicine,
6 | # University Medicine Essen
7 |
--------------------------------------------------------------------------------
/cell_segmentation/utils/__init__.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # Utils
3 | #
4 | # @ Fabian Hörst, fabian.hoerst@uk-essen.de
5 | # Institute for Artifical Intelligence in Medicine,
6 | # University Medicine Essen
7 |
--------------------------------------------------------------------------------
/cell_segmentation/utils/metrics.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # Implemented Metrics for Cell detection
3 | #
4 | # This code is based on the following repository: https://github.com/TissueImageAnalytics/PanNuke-metrics
5 | #
6 | # Implemented metrics are:
7 | #
8 | # Instance Segmentation Metrics
9 | # Binary PQ
10 | # Multiclass PQ
11 | # Neoplastic PQ
12 | # Non-Neoplastic PQ
13 | # Inflammatory PQ
14 | # Dead PQ
15 | # Inflammatory PQ
16 | # Dead PQ
17 | #
18 | # Detection and Classification Metrics
19 | # Precision, Recall, F1
20 | #
21 | # Other
22 | # dice1, dice2, aji, aji_plus
23 | #
24 | # Binary PQ (bPQ): Assumes all nuclei belong to same class and reports the average PQ across tissue types.
25 | # Multi-Class PQ (mPQ): Reports the average PQ across the classes and tissue types.
26 | # Neoplastic PQ: Reports the PQ for the neoplastic class on all tissues.
27 | # Non-Neoplastic PQ: Reports the PQ for the non-neoplastic class on all tissues.
28 | # Inflammatory PQ: Reports the PQ for the inflammatory class on all tissues.
29 | # Connective PQ: Reports the PQ for the connective class on all tissues.
30 | # Dead PQ: Reports the PQ for the dead class on all tissues.
31 |
32 |
33 | from typing import List
34 | import numpy as np
35 | from scipy.optimize import linear_sum_assignment
36 |
37 |
38 | def get_fast_pq(true, pred, match_iou=0.5):
39 | """
40 | `match_iou` is the IoU threshold level to determine the pairing between
41 | GT instances `p` and prediction instances `g`. `p` and `g` is a pair
42 | if IoU > `match_iou`. However, pair of `p` and `g` must be unique
43 | (1 prediction instance to 1 GT instance mapping).
44 |
45 | If `match_iou` < 0.5, Munkres assignment (solving minimum weight matching
46 | in bipartite graphs) is caculated to find the maximal amount of unique pairing.
47 |
48 | If `match_iou` >= 0.5, all IoU(p,g) > 0.5 pairing is proven to be unique and
49 | the number of pairs is also maximal.
50 |
51 | Fast computation requires instance IDs are in contiguous orderding
52 | i.e [1, 2, 3, 4] not [2, 3, 6, 10]. Please call `remap_label` beforehand
53 | and `by_size` flag has no effect on the result.
54 |
55 | Returns:
56 | [dq, sq, pq]: measurement statistic
57 |
58 | [paired_true, paired_pred, unpaired_true, unpaired_pred]:
59 | pairing information to perform measurement
60 |
61 | """
62 | assert match_iou >= 0.0, "Cant' be negative"
63 |
64 | true = np.copy(true) #[256,256]
65 | pred = np.copy(pred) #(256,256)
66 | true_id_list = list(np.unique(true))
67 | pred_id_list = list(np.unique(pred))
68 |
69 | # if there is no background, fixing by adding it
70 | if 0 not in pred_id_list:
71 | pred_id_list = [0] + pred_id_list
72 |
73 | true_masks = [
74 | None,
75 | ]
76 | for t in true_id_list[1:]:
77 | t_mask = np.array(true == t, np.uint8)
78 | true_masks.append(t_mask)
79 |
80 | pred_masks = [
81 | None,
82 | ]
83 | for p in pred_id_list[1:]:
84 | p_mask = np.array(pred == p, np.uint8)
85 | pred_masks.append(p_mask)
86 |
87 | # prefill with value
88 | pairwise_iou = np.zeros(
89 | [len(true_id_list) - 1, len(pred_id_list) - 1], dtype=np.float64
90 | )
91 |
92 | # caching pairwise iou for all instances
93 | for true_id in true_id_list[1:]: # 0-th is background
94 | t_mask = true_masks[true_id]
95 | pred_true_overlap = pred[t_mask > 0]
96 | pred_true_overlap_id = np.unique(pred_true_overlap)
97 | pred_true_overlap_id = list(pred_true_overlap_id)
98 | for pred_id in pred_true_overlap_id:
99 | if pred_id == 0: # ignore
100 | continue # overlaping background
101 | p_mask = pred_masks[pred_id]
102 | total = (t_mask + p_mask).sum()
103 | inter = (t_mask * p_mask).sum()
104 | iou = inter / (total - inter)
105 | pairwise_iou[true_id - 1, pred_id - 1] = iou
106 | #
107 | if match_iou >= 0.5:
108 | paired_iou = pairwise_iou[pairwise_iou > match_iou]
109 | pairwise_iou[pairwise_iou <= match_iou] = 0.0
110 | paired_true, paired_pred = np.nonzero(pairwise_iou)
111 | paired_iou = pairwise_iou[paired_true, paired_pred]
112 | paired_true += 1 # index is instance id - 1
113 | paired_pred += 1 # hence return back to original
114 | else: # * Exhaustive maximal unique pairing
115 | #### Munkres pairing with scipy library
116 | # the algorithm return (row indices, matched column indices)
117 | # if there is multiple same cost in a row, index of first occurence
118 | # is return, thus the unique pairing is ensure
119 | # inverse pair to get high IoU as minimum
120 | paired_true, paired_pred = linear_sum_assignment(-pairwise_iou)
121 | ### extract the paired cost and remove invalid pair
122 | paired_iou = pairwise_iou[paired_true, paired_pred]
123 |
124 | # now select those above threshold level
125 | # paired with iou = 0.0 i.e no intersection => FP or FN
126 | paired_true = list(paired_true[paired_iou > match_iou] + 1)
127 | paired_pred = list(paired_pred[paired_iou > match_iou] + 1)
128 | paired_iou = paired_iou[paired_iou > match_iou]
129 |
130 | # get the actual FP and FN
131 | unpaired_true = [idx for idx in true_id_list[1:] if idx not in paired_true]
132 | unpaired_pred = [idx for idx in pred_id_list[1:] if idx not in paired_pred]
133 | # print(paired_iou.shape, paired_true.shape, len(unpaired_true), len(unpaired_pred))
134 |
135 | #
136 | tp = len(paired_true)
137 | fp = len(unpaired_pred)
138 | fn = len(unpaired_true)
139 | # get the F1-score i.e DQ
140 | dq = tp / (tp + 0.5 * fp + 0.5 * fn + 1.0e-6) # good practice?
141 | # get the SQ, no paired has 0 iou so not impact
142 | sq = paired_iou.sum() / (tp + 1.0e-6)
143 |
144 | return [dq, sq, dq * sq], [paired_true, paired_pred, unpaired_true, unpaired_pred]
145 |
146 |
147 | #####
148 |
149 |
150 | def remap_label(pred, by_size=False):
151 | """
152 | Rename all instance id so that the id is contiguous i.e [0, 1, 2, 3]
153 | not [0, 2, 4, 6]. The ordering of instances (which one comes first)
154 | is preserved unless by_size=True, then the instances will be reordered
155 | so that bigger nucler has smaller ID
156 |
157 | Args:
158 | pred : the 2d array contain instances where each instances is marked
159 | by non-zero integer
160 | by_size : renaming with larger nuclei has smaller id (on-top)
161 | """
162 | pred_id = list(np.unique(pred))
163 | if 0 in pred_id:
164 | pred_id.remove(0)
165 | if len(pred_id) == 0:
166 | return pred # no label
167 | if by_size:
168 | pred_size = []
169 | for inst_id in pred_id:
170 | size = (pred == inst_id).sum()
171 | pred_size.append(size)
172 | # sort the id by size in descending order
173 | pair_list = zip(pred_id, pred_size)
174 | pair_list = sorted(pair_list, key=lambda x: x[1], reverse=True)
175 | pred_id, pred_size = zip(*pair_list)
176 |
177 | new_pred = np.zeros(pred.shape, np.int32)
178 | for idx, inst_id in enumerate(pred_id):
179 | new_pred[pred == inst_id] = idx + 1
180 | return new_pred
181 |
182 |
183 | ####
184 |
185 |
186 | def binarize(x):
187 | """
188 | convert multichannel (multiclass) instance segmetation tensor
189 | to binary instance segmentation (bg and nuclei),
190 |
191 | :param x: B*B*C (for PanNuke 256*256*5 )
192 | :return: Instance segmentation
193 | """
194 | #x = np.transpose(x, (1, 2, 0)) #[256,256,5]
195 |
196 | out = np.zeros([x.shape[0], x.shape[1]])
197 | count = 1
198 | for i in range(x.shape[2]):
199 | x_ch = x[:, :, i] #[256,256]
200 | unique_vals = np.unique(x_ch)
201 | unique_vals = unique_vals.tolist()
202 | unique_vals.remove(0)
203 | for j in unique_vals:
204 | x_tmp = x_ch == j
205 | x_tmp_c = 1 - x_tmp
206 | out *= x_tmp_c
207 | out += count * x_tmp
208 | count += 1
209 | out = out.astype("int32")
210 | return out
211 |
212 |
213 | def get_tissue_idx(tissue_indices, idx):
214 | for i in range(len(tissue_indices)):
215 | if tissue_indices[i].count(idx) == 1:
216 | tiss_idx = i
217 | return tiss_idx
218 |
219 |
220 | def cell_detection_scores(
221 | paired_true, paired_pred, unpaired_true, unpaired_pred, w: List = [1, 1]
222 | ):
223 | tp_d = paired_pred.shape[0]
224 | fp_d = unpaired_pred.shape[0]
225 | fn_d = unpaired_true.shape[0]
226 |
227 | # tp_tn_dt = (paired_pred == paired_true).sum()
228 | # fp_fn_dt = (paired_pred != paired_true).sum()
229 | prec_d = tp_d / (tp_d + fp_d)
230 | rec_d = tp_d / (tp_d + fn_d)
231 |
232 | f1_d = 2 * tp_d / (2 * tp_d + w[0] * fp_d + w[1] * fn_d)
233 |
234 | return f1_d, prec_d, rec_d
235 |
236 |
237 | def cell_type_detection_scores(
238 | paired_true,
239 | paired_pred,
240 | unpaired_true,
241 | unpaired_pred,
242 | type_id,
243 | w: List = [2, 2, 1, 1],
244 | exhaustive: bool = True,
245 | ):
246 | type_samples = (paired_true == type_id) | (paired_pred == type_id)
247 |
248 | paired_true = paired_true[type_samples]
249 | paired_pred = paired_pred[type_samples]
250 |
251 | tp_dt = ((paired_true == type_id) & (paired_pred == type_id)).sum()
252 | tn_dt = ((paired_true != type_id) & (paired_pred != type_id)).sum()
253 | fp_dt = ((paired_true != type_id) & (paired_pred == type_id)).sum()
254 | fn_dt = ((paired_true == type_id) & (paired_pred != type_id)).sum()
255 |
256 | if not exhaustive:
257 | ignore = (paired_true == -1).sum()
258 | fp_dt -= ignore
259 |
260 | fp_d = (unpaired_pred == type_id).sum() #
261 | fn_d = (unpaired_true == type_id).sum()
262 |
263 | prec_type = (tp_dt + tn_dt) / (tp_dt + tn_dt + w[0] * fp_dt + w[2] * fp_d)
264 | rec_type = (tp_dt + tn_dt) / (tp_dt + tn_dt + w[1] * fn_dt + w[3] * fn_d)
265 |
266 | f1_type = (2 * (tp_dt + tn_dt)) / (
267 | 2 * (tp_dt + tn_dt) + w[0] * fp_dt + w[1] * fn_dt + w[2] * fp_d + w[3] * fn_d
268 | )
269 | return f1_type, prec_type, rec_type
270 |
--------------------------------------------------------------------------------
/cell_segmentation/utils/post_proc_cellvit.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # PostProcessing Pipeline
3 | #
4 | # Adapted from HoverNet
5 | # HoverNet Network (https://doi.org/10.1016/j.media.2019.101563)
6 | # Code Snippet adapted from HoverNet implementation (https://github.com/vqdang/hover_net)
7 |
8 | import warnings
9 | from typing import Tuple, Literal,List
10 |
11 | import cv2
12 | import numpy as np
13 | from scipy.ndimage import measurements
14 | from scipy.ndimage.morphology import binary_fill_holes
15 | from skimage.segmentation import watershed
16 | import torch
17 |
18 | from .tools import get_bounding_box, remove_small_objects
19 |
20 |
21 | def noop(*args, **kargs):
22 | pass
23 |
24 |
25 | warnings.warn = noop
26 |
27 |
28 | class DetectionCellPostProcessor:
29 | def __init__(
30 | self,
31 | nr_types: int = None,
32 | magnification: Literal[20, 40] = 40,
33 | gt: bool = False,
34 | ) -> None:
35 | """DetectionCellPostProcessor for postprocessing prediction maps and get detected cells
36 |
37 | Args:
38 | nr_types (int, optional): Number of cell types, including background (background = 0). Defaults to None.
39 | magnification (Literal[20, 40], optional): Which magnification the data has. Defaults to 40.
40 | gt (bool, optional): If this is gt data (used that we do not suppress tiny cells that may be noise in a prediction map).
41 | Defaults to False.
42 |
43 | Raises:
44 | NotImplementedError: Unknown magnification
45 | """
46 | self.nr_types = nr_types
47 | self.magnification = magnification
48 | self.gt = gt
49 |
50 | if magnification == 40:
51 | self.object_size = 10
52 | self.k_size = 21
53 | elif magnification == 20:
54 | self.object_size = 3 # 3 or 40, we used 5
55 | self.k_size = 11 # 11 or 41, we used 13
56 | else:
57 | raise NotImplementedError("Unknown magnification")
58 | if gt: # to not supress something in gt!
59 | self.object_size = 100
60 | self.k_size = 21
61 |
62 | def post_process_cell_segmentation(
63 | self,
64 | pred_map: np.ndarray,
65 | ) -> Tuple[np.ndarray, dict]:
66 | """Post processing of one image tile
67 |
68 | Args:
69 | pred_map (np.ndarray): Combined output of tp, np and hv branches, in the same order. Shape: (H, W, 4)
70 |
71 | Returns:
72 | Tuple[np.ndarray, dict]:
73 | np.ndarray: Instance map for one image. Each nuclei has own integer. Shape: (H, W)
74 | dict: Instance dictionary. Main Key is the nuclei instance number (int), with a dict as value.
75 | For each instance, the dictionary contains the keys: bbox (bounding box), centroid (centroid coordinates),
76 | contour, type_prob (probability), type (nuclei type)
77 | """
78 | if self.nr_types is not None:
79 | pred_type = pred_map[..., :1]
80 | pred_inst = pred_map[..., 1:]
81 | pred_type = pred_type.astype(np.int32)
82 | else:
83 | pred_inst = pred_map
84 |
85 | pred_inst = np.squeeze(pred_inst)
86 | pred_inst = self.__proc_np_hv(
87 | pred_inst, object_size=self.object_size, ksize=self.k_size
88 | )
89 |
90 | inst_id_list = np.unique(pred_inst)[1:] # exlcude background
91 | inst_info_dict = {}
92 | for inst_id in inst_id_list:
93 | inst_map = pred_inst == inst_id
94 | rmin, rmax, cmin, cmax = get_bounding_box(inst_map)
95 | inst_bbox = np.array([[rmin, cmin], [rmax, cmax]])
96 | inst_map = inst_map[
97 | inst_bbox[0][0] : inst_bbox[1][0], inst_bbox[0][1] : inst_bbox[1][1]
98 | ]
99 | inst_map = inst_map.astype(np.uint8)
100 | inst_moment = cv2.moments(inst_map)
101 | inst_contour = cv2.findContours(
102 | inst_map, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE
103 | )
104 | # * opencv protocol format may break
105 | inst_contour = np.squeeze(inst_contour[0][0].astype("int32"))
106 | # < 3 points dont make a contour, so skip, likely artifact too
107 | # as the contours obtained via approximation => too small or sthg
108 | if inst_contour.shape[0] < 3:
109 | continue
110 | if len(inst_contour.shape) != 2:
111 | continue # ! check for trickery shape
112 | inst_centroid = [
113 | (inst_moment["m10"] / inst_moment["m00"]),
114 | (inst_moment["m01"] / inst_moment["m00"]),
115 | ]
116 | inst_centroid = np.array(inst_centroid)
117 | inst_contour[:, 0] += inst_bbox[0][1] # X
118 | inst_contour[:, 1] += inst_bbox[0][0] # Y
119 | inst_centroid[0] += inst_bbox[0][1] # X
120 | inst_centroid[1] += inst_bbox[0][0] # Y
121 | inst_info_dict[inst_id] = { # inst_id should start at 1
122 | "bbox": inst_bbox,
123 | "centroid": inst_centroid,
124 | "contour": inst_contour,
125 | "type_prob": None,
126 | "type": None,
127 | }
128 |
129 | #### * Get class of each instance id, stored at index id-1 (inst_id = number of deteced nucleus)
130 | for inst_id in list(inst_info_dict.keys()):
131 | rmin, cmin, rmax, cmax = (inst_info_dict[inst_id]["bbox"]).flatten()
132 | inst_map_crop = pred_inst[rmin:rmax, cmin:cmax]
133 | inst_type_crop = pred_type[rmin:rmax, cmin:cmax]
134 | inst_map_crop = inst_map_crop == inst_id
135 | inst_type = inst_type_crop[inst_map_crop]
136 | type_list, type_pixels = np.unique(inst_type, return_counts=True)
137 | type_list = list(zip(type_list, type_pixels))
138 | type_list = sorted(type_list, key=lambda x: x[1], reverse=True)
139 | inst_type = type_list[0][0]
140 | if inst_type == 0: # ! pick the 2nd most dominant if exist
141 | if len(type_list) > 1:
142 | inst_type = type_list[1][0]
143 | type_dict = {v[0]: v[1] for v in type_list}
144 | type_prob = type_dict[inst_type] / (np.sum(inst_map_crop) + 1.0e-6)
145 | inst_info_dict[inst_id]["type"] = int(inst_type)
146 | inst_info_dict[inst_id]["type_prob"] = float(type_prob)
147 |
148 | return pred_inst, inst_info_dict
149 |
150 | def __proc_np_hv(
151 | self, pred: np.ndarray, object_size: int = 10, ksize: int = 21
152 | ) -> np.ndarray:
153 | """Process Nuclei Prediction with XY Coordinate Map and generate instance map (each instance has unique integer)
154 |
155 | Separate Instances (also overlapping ones) from binary nuclei map and hv map by using morphological operations and watershed
156 |
157 | Args:
158 | pred (np.ndarray): Prediction output, assuming. Shape: (H, W, 3)
159 | * channel 0 contain probability map of nuclei
160 | * channel 1 containing the regressed X-map
161 | * channel 2 containing the regressed Y-map
162 | object_size (int, optional): Smallest oject size for filtering. Defaults to 10
163 | k_size (int, optional): Sobel Kernel size. Defaults to 21
164 | Returns:
165 | np.ndarray: Instance map for one image. Each nuclei has own integer. Shape: (H, W)
166 | """
167 | pred = np.array(pred, dtype=np.float32)
168 |
169 | blb_raw = pred[..., 0]
170 | h_dir_raw = pred[..., 1]
171 | v_dir_raw = pred[..., 2]
172 |
173 | # processing
174 | blb = np.array(blb_raw >= 0.5, dtype=np.int32)
175 |
176 | blb = measurements.label(blb)[0] # ndimage.label(blb)[0]
177 | blb = remove_small_objects(blb, min_size=10) # 10
178 | blb[blb > 0] = 1 # background is 0 already
179 |
180 | h_dir = cv2.normalize(
181 | h_dir_raw,
182 | None,
183 | alpha=0,
184 | beta=1,
185 | norm_type=cv2.NORM_MINMAX,
186 | dtype=cv2.CV_32F,
187 | )
188 | v_dir = cv2.normalize(
189 | v_dir_raw,
190 | None,
191 | alpha=0,
192 | beta=1,
193 | norm_type=cv2.NORM_MINMAX,
194 | dtype=cv2.CV_32F,
195 | )
196 |
197 | # ksize = int((20 * scale_factor) + 1) # 21 vs 41
198 | # obj_size = math.ceil(10 * (scale_factor**2)) #10 vs 40
199 |
200 | sobelh = cv2.Sobel(h_dir, cv2.CV_64F, 1, 0, ksize=ksize)
201 | sobelv = cv2.Sobel(v_dir, cv2.CV_64F, 0, 1, ksize=ksize)
202 |
203 | sobelh = 1 - (
204 | cv2.normalize(
205 | sobelh,
206 | None,
207 | alpha=0,
208 | beta=1,
209 | norm_type=cv2.NORM_MINMAX,
210 | dtype=cv2.CV_32F,
211 | )
212 | )
213 | sobelv = 1 - (
214 | cv2.normalize(
215 | sobelv,
216 | None,
217 | alpha=0,
218 | beta=1,
219 | norm_type=cv2.NORM_MINMAX,
220 | dtype=cv2.CV_32F,
221 | )
222 | )
223 |
224 | overall = np.maximum(sobelh, sobelv)
225 | overall = overall - (1 - blb)
226 | overall[overall < 0] = 0
227 |
228 | dist = (1.0 - overall) * blb
229 | ## nuclei values form mountains so inverse to get basins
230 | dist = -cv2.GaussianBlur(dist, (3, 3), 0)
231 |
232 | overall = np.array(overall >= 0.4, dtype=np.int32)
233 |
234 | marker = blb - overall
235 | marker[marker < 0] = 0
236 | marker = binary_fill_holes(marker).astype("uint8")
237 | kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (5, 5))
238 | marker = cv2.morphologyEx(marker, cv2.MORPH_OPEN, kernel)
239 | marker = measurements.label(marker)[0]
240 | marker = remove_small_objects(marker, min_size=object_size)
241 |
242 | proced_pred = watershed(dist, markers=marker, mask=blb)
243 |
244 | return proced_pred
245 |
246 |
247 | def calculate_instances(
248 | pred_types: torch.Tensor, pred_insts: torch.Tensor
249 | ) -> List[dict]:
250 | """Best used for GT
251 |
252 | Args:
253 | pred_types (torch.Tensor): Binary or type map ground-truth.
254 | Shape must be (B, C, H, W) with C=1 for binary or num_nuclei_types for multi-class.
255 | pred_insts (torch.Tensor): Ground-Truth instance map with shape (B, H, W)
256 |
257 | Returns:
258 | list[dict]: Dictionary with nuclei informations, output similar to post_process_cell_segmentation
259 | """
260 | type_preds = []
261 | pred_types = pred_types.permute(0, 2, 3, 1)
262 | for i in range(pred_types.shape[0]):
263 | pred_type = torch.argmax(pred_types, dim=-1)[i].detach().cpu().numpy()
264 | pred_inst = pred_insts[i].detach().cpu().numpy()
265 | inst_id_list = np.unique(pred_inst)[1:] # exlcude background
266 | inst_info_dict = {}
267 | for inst_id in inst_id_list:
268 | inst_map = pred_inst == inst_id
269 | rmin, rmax, cmin, cmax = get_bounding_box(inst_map)
270 | inst_bbox = np.array([[rmin, cmin], [rmax, cmax]])
271 | inst_map = inst_map[
272 | inst_bbox[0][0] : inst_bbox[1][0], inst_bbox[0][1] : inst_bbox[1][1]
273 | ]
274 | inst_map = inst_map.astype(np.uint8)
275 | inst_moment = cv2.moments(inst_map)
276 | inst_contour = cv2.findContours(
277 | inst_map, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE
278 | )
279 | # * opencv protocol format may break
280 | inst_contour = np.squeeze(inst_contour[0][0].astype("int32"))
281 | # < 3 points dont make a contour, so skip, likely artifact too
282 | # as the contours obtained via approximation => too small or sthg
283 | if inst_contour.shape[0] < 3:
284 | continue
285 | if len(inst_contour.shape) != 2:
286 | continue # ! check for trickery shape
287 | inst_centroid = [
288 | (inst_moment["m10"] / inst_moment["m00"]),
289 | (inst_moment["m01"] / inst_moment["m00"]),
290 | ]
291 | inst_centroid = np.array(inst_centroid)
292 | inst_contour[:, 0] += inst_bbox[0][1] # X
293 | inst_contour[:, 1] += inst_bbox[0][0] # Y
294 | inst_centroid[0] += inst_bbox[0][1] # X
295 | inst_centroid[1] += inst_bbox[0][0] # Y
296 | inst_info_dict[inst_id] = { # inst_id should start at 1
297 | "bbox": inst_bbox,
298 | "centroid": inst_centroid,
299 | "contour": inst_contour,
300 | "type_prob": None,
301 | "type": None,
302 | }
303 | #### * Get class of each instance id, stored at index id-1 (inst_id = number of deteced nucleus)
304 | for inst_id in list(inst_info_dict.keys()):
305 | rmin, cmin, rmax, cmax = (inst_info_dict[inst_id]["bbox"]).flatten()
306 | inst_map_crop = pred_inst[rmin:rmax, cmin:cmax]
307 | inst_type_crop = pred_type[rmin:rmax, cmin:cmax]
308 | inst_map_crop = inst_map_crop == inst_id
309 | inst_type = inst_type_crop[inst_map_crop]
310 | type_list, type_pixels = np.unique(inst_type, return_counts=True)
311 | type_list = list(zip(type_list, type_pixels))
312 | type_list = sorted(type_list, key=lambda x: x[1], reverse=True)
313 | inst_type = type_list[0][0]
314 | if inst_type == 0: # ! pick the 2nd most dominant if exist
315 | if len(type_list) > 1:
316 | inst_type = type_list[1][0]
317 | type_dict = {v[0]: v[1] for v in type_list}
318 | type_prob = type_dict[inst_type] / (np.sum(inst_map_crop) + 1.0e-6)
319 | inst_info_dict[inst_id]["type"] = int(inst_type)
320 | inst_info_dict[inst_id]["type_prob"] = float(type_prob)
321 | type_preds.append(inst_info_dict)
322 |
323 | return type_preds
324 |
--------------------------------------------------------------------------------
/cell_segmentation/utils/template_geojson.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # GeoJson templates
3 |
4 |
5 |
6 | def get_template_point() -> dict:
7 | """Return a template for a Point geojson object
8 |
9 | Returns:
10 | dict: Template
11 | """
12 | template_point = {
13 | "type": "Feature",
14 | "id": "TODO",
15 | "geometry": {
16 | "type": "MultiPoint",
17 | "coordinates": [
18 | [],
19 | ],
20 | },
21 | "properties": {
22 | "objectType": "annotation",
23 | "classification": {"name": "TODO", "color": []},
24 | },
25 | }
26 | return template_point
27 |
28 |
29 | def get_template_segmentation() -> dict:
30 | """Return a template for a MultiPolygon geojson object
31 |
32 | Returns:
33 | dict: Template
34 | """
35 | template_multipolygon = {
36 | "type": "Feature",
37 | "id": "TODO",
38 | "geometry": {
39 | "type": "MultiPolygon",
40 | "coordinates": [
41 | [],
42 | ],
43 | },
44 | "properties": {
45 | "objectType": "annotation",
46 | "classification": {"name": "TODO", "color": []},
47 | },
48 | }
49 | return template_multipolygon
50 |
--------------------------------------------------------------------------------
/cell_segmentation/utils/tools.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # Helpful functions Pipeline
3 | #
4 | # Adapted from HoverNet
5 | # HoverNet Network (https://doi.org/10.1016/j.media.2019.101563)
6 | # Code Snippet adapted from HoverNet implementation (https://github.com/vqdang/hover_net)
7 |
8 | import math
9 | from typing import Tuple
10 |
11 | import numpy as np
12 | import scipy
13 | from numba import njit, prange
14 | from scipy import ndimage
15 | from scipy.optimize import linear_sum_assignment
16 | from skimage.draw import polygon
17 |
18 |
19 | def get_bounding_box(img):
20 | """Get bounding box coordinate information."""
21 | rows = np.any(img, axis=1)
22 | cols = np.any(img, axis=0)
23 | rmin, rmax = np.where(rows)[0][[0, -1]]
24 | cmin, cmax = np.where(cols)[0][[0, -1]]
25 | # due to python indexing, need to add 1 to max
26 | # else accessing will be 1px in the box, not out
27 | rmax += 1
28 | cmax += 1
29 | return [rmin, rmax, cmin, cmax]
30 |
31 |
32 | @njit
33 | def cropping_center(x, crop_shape, batch=False):
34 | """Crop an input image at the centre.
35 |
36 | Args:
37 | x: input array
38 | crop_shape: dimensions of cropped array
39 |
40 | Returns:
41 | x: cropped array
42 |
43 | """
44 | orig_shape = x.shape
45 | if not batch:
46 | h0 = int((orig_shape[0] - crop_shape[0]) * 0.5)
47 | w0 = int((orig_shape[1] - crop_shape[1]) * 0.5)
48 | x = x[h0 : h0 + crop_shape[0], w0 : w0 + crop_shape[1], ...]
49 | else:
50 | h0 = int((orig_shape[1] - crop_shape[0]) * 0.5)
51 | w0 = int((orig_shape[2] - crop_shape[1]) * 0.5)
52 | x = x[:, h0 : h0 + crop_shape[0], w0 : w0 + crop_shape[1], ...]
53 | return x
54 |
55 |
56 | def remove_small_objects(pred, min_size=64, connectivity=1):
57 | """Remove connected components smaller than the specified size.
58 |
59 | This function is taken from skimage.morphology.remove_small_objects, but the warning
60 | is removed when a single label is provided.
61 |
62 | Args:
63 | pred: input labelled array
64 | min_size: minimum size of instance in output array
65 | connectivity: The connectivity defining the neighborhood of a pixel.
66 |
67 | Returns:
68 | out: output array with instances removed under min_size
69 |
70 | """
71 | out = pred
72 |
73 | if min_size == 0: # shortcut for efficiency
74 | return out
75 |
76 | if out.dtype == bool:
77 | selem = ndimage.generate_binary_structure(pred.ndim, connectivity)
78 | ccs = np.zeros_like(pred, dtype=np.int32)
79 | ndimage.label(pred, selem, output=ccs)
80 | else:
81 | ccs = out
82 |
83 | try:
84 | component_sizes = np.bincount(ccs.ravel())
85 | except ValueError:
86 | raise ValueError(
87 | "Negative value labels are not supported. Try "
88 | "relabeling the input with `scipy.ndimage.label` or "
89 | "`skimage.morphology.label`."
90 | )
91 |
92 | too_small = component_sizes < min_size
93 | too_small_mask = too_small[ccs]
94 | out[too_small_mask] = 0
95 |
96 | return out
97 |
98 |
99 | def pair_coordinates(
100 | setA: np.ndarray, setB: np.ndarray, radius: float
101 | ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
102 | """Use the Munkres or Kuhn-Munkres algorithm to find the most optimal
103 | unique pairing (largest possible match) when pairing points in set B
104 | against points in set A, using distance as cost function.
105 |
106 | Args:
107 | setA (np.ndarray): np.array (float32) of size Nx2 contains the of XY coordinate
108 | of N different points
109 | setB (np.ndarray): np.array (float32) of size Nx2 contains the of XY coordinate
110 | of N different points
111 | radius (float): valid area around a point in setA to consider
112 | a given coordinate in setB a candidate for match
113 |
114 | Returns:
115 | Tuple[np.ndarray, np.ndarray, np.ndarray]:
116 | pairing: pairing is an array of indices
117 | where point at index pairing[0] in set A paired with point
118 | in set B at index pairing[1]
119 | unparedA: remaining point in set A unpaired
120 | unparedB: remaining point in set B unpaired
121 | """
122 | # * Euclidean distance as the cost matrix
123 | pair_distance = scipy.spatial.distance.cdist(setA, setB, metric="euclidean")
124 |
125 | # * Munkres pairing with scipy library
126 | # the algorithm return (row indices, matched column indices)
127 | # if there is multiple same cost in a row, index of first occurence
128 | # is return, thus the unique pairing is ensured
129 | indicesA, paired_indicesB = linear_sum_assignment(pair_distance)
130 |
131 | # extract the paired cost and remove instances
132 | # outside of designated radius
133 | pair_cost = pair_distance[indicesA, paired_indicesB]
134 |
135 | pairedA = indicesA[pair_cost <= radius]
136 | pairedB = paired_indicesB[pair_cost <= radius]
137 |
138 | pairing = np.concatenate([pairedA[:, None], pairedB[:, None]], axis=-1)
139 | unpairedA = np.delete(np.arange(setA.shape[0]), pairedA)
140 | unpairedB = np.delete(np.arange(setB.shape[0]), pairedB)
141 |
142 | return pairing, unpairedA, unpairedB
143 |
144 |
145 | def fix_duplicates(inst_map: np.ndarray) -> np.ndarray:
146 | """Re-label duplicated instances in an instance labelled mask.
147 |
148 | Parameters
149 | ----------
150 | inst_map : np.ndarray
151 | Instance labelled mask. Shape (H, W).
152 |
153 | Returns
154 | -------
155 | np.ndarray:
156 | The instance labelled mask without duplicated indices.
157 | Shape (H, W).
158 | """
159 | current_max_id = np.amax(inst_map)
160 | inst_list = list(np.unique(inst_map))
161 | if 0 in inst_list:
162 | inst_list.remove(0)
163 |
164 | for inst_id in inst_list:
165 | inst = np.array(inst_map == inst_id, np.uint8)
166 | remapped_ids = ndimage.label(inst)[0]
167 | remapped_ids[remapped_ids > 1] += current_max_id
168 | inst_map[remapped_ids > 1] = remapped_ids[remapped_ids > 1]
169 | current_max_id = np.amax(inst_map)
170 |
171 | return inst_map
172 |
173 |
174 | def polygons_to_label_coord(
175 | coord: np.ndarray, shape: Tuple[int, int], labels: np.ndarray = None
176 | ) -> np.ndarray:
177 | """Render polygons to image given a shape.
178 |
179 | Parameters
180 | ----------
181 | coord.shape : np.ndarray
182 | Shape: (n_polys, n_rays)
183 | shape : Tuple[int, int]
184 | Shape of the output mask.
185 | labels : np.ndarray, optional
186 | Sorted indices of the centroids.
187 |
188 | Returns
189 | -------
190 | np.ndarray:
191 | Instance labelled mask. Shape: (H, W).
192 | """
193 | coord = np.asarray(coord)
194 | if labels is None:
195 | labels = np.arange(len(coord))
196 |
197 | assert coord.ndim == 3 and coord.shape[1] == 2 and len(coord) == len(labels)
198 |
199 | lbl = np.zeros(shape, np.int32)
200 |
201 | for i, c in zip(labels, coord):
202 | rr, cc = polygon(*c, shape)
203 | lbl[rr, cc] = i + 1
204 |
205 | return lbl
206 |
207 |
208 | def ray_angles(n_rays: int = 32):
209 | """Get linearly spaced angles for rays."""
210 | return np.linspace(0, 2 * np.pi, n_rays, endpoint=False)
211 |
212 |
213 | def dist_to_coord(
214 | dist: np.ndarray, points: np.ndarray, scale_dist: Tuple[int, int] = (1, 1)
215 | ) -> np.ndarray:
216 | """Convert list of distances and centroids from polar to cartesian coordinates.
217 |
218 | Parameters
219 | ----------
220 | dist : np.ndarray
221 | The centerpoint pixels of the radial distance map. Shape (n_polys, n_rays).
222 | points : np.ndarray
223 | The centroids of the instances. Shape: (n_polys, 2).
224 | scale_dist : Tuple[int, int], default=(1, 1)
225 | Scaling factor.
226 |
227 | Returns
228 | -------
229 | np.ndarray:
230 | Cartesian cooridnates of the polygons. Shape (n_polys, 2, n_rays).
231 | """
232 | dist = np.asarray(dist)
233 | points = np.asarray(points)
234 | assert (
235 | dist.ndim == 2
236 | and points.ndim == 2
237 | and len(dist) == len(points)
238 | and points.shape[1] == 2
239 | and len(scale_dist) == 2
240 | )
241 | n_rays = dist.shape[1]
242 | phis = ray_angles(n_rays)
243 | coord = (dist[:, np.newaxis] * np.array([np.sin(phis), np.cos(phis)])).astype(
244 | np.float32
245 | )
246 | coord *= np.asarray(scale_dist).reshape(1, 2, 1)
247 | coord += points[..., np.newaxis]
248 | return coord
249 |
250 |
251 | def polygons_to_label(
252 | dist: np.ndarray,
253 | points: np.ndarray,
254 | shape: Tuple[int, int],
255 | prob: np.ndarray = None,
256 | thresh: float = -np.inf,
257 | scale_dist: Tuple[int, int] = (1, 1),
258 | ) -> np.ndarray:
259 | """Convert distances and center points to instance labelled mask.
260 |
261 | Parameters
262 | ----------
263 | dist : np.ndarray
264 | The centerpoint pixels of the radial distance map. Shape (n_polys, n_rays).
265 | points : np.ndarray
266 | The centroids of the instances. Shape: (n_polys, 2).
267 | shape : Tuple[int, int]:
268 | Shape of the output mask.
269 | prob : np.ndarray, optional
270 | The centerpoint pixels of the regressed distance transform.
271 | Shape: (n_polys, n_rays).
272 | thresh : float, default=-np.inf
273 | Threshold for the regressed distance transform.
274 | scale_dist : Tuple[int, int], default=(1, 1)
275 | Scaling factor.
276 |
277 | Returns
278 | -------
279 | np.ndarray:
280 | Instance labelled mask. Shape (H, W).
281 | """
282 | dist = np.asarray(dist)
283 | points = np.asarray(points)
284 | prob = np.inf * np.ones(len(points)) if prob is None else np.asarray(prob)
285 |
286 | assert dist.ndim == 2 and points.ndim == 2 and len(dist) == len(points)
287 | assert len(points) == len(prob) and points.shape[1] == 2 and prob.ndim == 1
288 |
289 | ind = prob > thresh
290 | points = points[ind]
291 | dist = dist[ind]
292 | prob = prob[ind]
293 |
294 | ind = np.argsort(prob, kind="stable")
295 | points = points[ind]
296 | dist = dist[ind]
297 |
298 | coord = dist_to_coord(dist, points, scale_dist=scale_dist)
299 |
300 | return polygons_to_label_coord(coord, shape=shape, labels=ind)
301 |
302 |
303 | @njit(cache=True, fastmath=True)
304 | def intersection(boxA: np.ndarray, boxB: np.ndarray):
305 | """Compute area of intersection of two boxes.
306 |
307 | Parameters
308 | ----------
309 | boxA : np.ndarray
310 | First boxes
311 | boxB : np.ndarray
312 | Second box
313 |
314 | Returns
315 | -------
316 | float64:
317 | Area of intersection
318 | """
319 | xA = max(boxA[..., 0], boxB[..., 0])
320 | xB = min(boxA[..., 2], boxB[..., 2])
321 | dx = xB - xA
322 | if dx <= 0:
323 | return 0.0
324 |
325 | yA = max(boxA[..., 1], boxB[..., 1])
326 | yB = min(boxA[..., 3], boxB[..., 3])
327 | dy = yB - yA
328 | if dy <= 0.0:
329 | return 0.0
330 |
331 | return dx * dy
332 |
333 |
334 | @njit(parallel=True)
335 | def get_bboxes(
336 | dist: np.ndarray, points: np.ndarray
337 | ) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray, int]:
338 | """Get bounding boxes from the non-zero pixels of the radial distance maps.
339 |
340 | This is basically a translation from the stardist repo cpp code to python
341 |
342 | NOTE: jit compiled and parallelized with numba.
343 |
344 | Parameters
345 | ----------
346 | dist : np.ndarray
347 | The non-zero values of the radial distance maps. Shape: (n_nonzero, n_rays).
348 | points : np.ndarray
349 | The yx-coordinates of the non-zero points. Shape (n_nonzero, 2).
350 |
351 | Returns
352 | -------
353 | Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray, int]:
354 | Returns the x0, y0, x1, y1 bbox coordinates, bbox areas and the maximum
355 | radial distance in the image.
356 | """
357 | n_polys = dist.shape[0]
358 | n_rays = dist.shape[1]
359 |
360 | bbox_x1 = np.zeros(n_polys)
361 | bbox_x2 = np.zeros(n_polys)
362 | bbox_y1 = np.zeros(n_polys)
363 | bbox_y2 = np.zeros(n_polys)
364 |
365 | areas = np.zeros(n_polys)
366 | angle_pi = 2 * math.pi / n_rays
367 | max_dist = 0
368 |
369 | for i in prange(n_polys):
370 | max_radius_outer = 0
371 | py = points[i, 0]
372 | px = points[i, 1]
373 |
374 | for k in range(n_rays):
375 | d = dist[i, k]
376 | y = py + d * np.sin(angle_pi * k)
377 | x = px + d * np.cos(angle_pi * k)
378 |
379 | if k == 0:
380 | bbox_x1[i] = x
381 | bbox_x2[i] = x
382 | bbox_y1[i] = y
383 | bbox_y2[i] = y
384 | else:
385 | bbox_x1[i] = min(x, bbox_x1[i])
386 | bbox_x2[i] = max(x, bbox_x2[i])
387 | bbox_y1[i] = min(y, bbox_y1[i])
388 | bbox_y2[i] = max(y, bbox_y2[i])
389 |
390 | max_radius_outer = max(d, max_radius_outer)
391 |
392 | areas[i] = (bbox_x2[i] - bbox_x1[i]) * (bbox_y2[i] - bbox_y1[i])
393 | max_dist = max(max_dist, max_radius_outer)
394 |
395 | return bbox_x1, bbox_y1, bbox_x2, bbox_y2, areas, max_dist
396 |
--------------------------------------------------------------------------------
/config.yaml:
--------------------------------------------------------------------------------
1 | CUDA_VISIBLE_DEVICES: 3
2 | logging:
3 | log_dir: /data5/ziweicui/cellvit256-unireplknet-n
4 | mode: online
5 | project: Cell-Segmentation
6 | notes: CellViT-256
7 | log_comment: CellViT-256-resnet50-tiny
8 | tags:
9 | - Fold-1
10 | - ViT256
11 | wandb_dir: /data5/ziweicui/UniRepLKNet-optimizerconfig-unetdecoder-inputconv/results
12 | level: Debug
13 | group: CellViT256
14 | run_id: anifw9ux
15 | wandb_file: anifw9ux
16 | random_seed: 19
17 | gpu: 0
18 | data:
19 | dataset: PanNuke
20 | dataset_path: /data5/ziweicui/cellvit-png
21 | train_folds:
22 | - 0
23 | val_folds:
24 | - 1
25 | test_folds:
26 | - 2
27 | num_nuclei_classes: 6
28 | num_tissue_classes: 19
29 | model:
30 | backbone: default
31 | pretrained_encoder: /data5/ziweicui/semi_supervised_resnet50-08389792.pth
32 | shared_skip_connections: true
33 | loss:
34 | nuclei_binary_map:
35 | focaltverskyloss:
36 | loss_fn: FocalTverskyLoss
37 | weight: 1
38 | dice:
39 | loss_fn: dice_loss
40 | weight: 1
41 | hv_map:
42 | mse:
43 | loss_fn: mse_loss_maps
44 | weight: 2.5
45 | msge:
46 | loss_fn: msge_loss_maps
47 | weight: 8
48 | nuclei_type_map:
49 | bce:
50 | loss_fn: xentropy_loss
51 | weight: 0.5
52 | dice:
53 | loss_fn: dice_loss
54 | weight: 0.2
55 | mcfocaltverskyloss:
56 | loss_fn: MCFocalTverskyLoss
57 | weight: 0.5
58 | args:
59 | num_classes: 6
60 | tissue_types:
61 | ce:
62 | loss_fn: CrossEntropyLoss
63 | weight: 0.1
64 | training:
65 | drop_rate: 0
66 | attn_drop_rate: 0.1
67 | drop_path_rate: 0.1
68 | batch_size: 32
69 | epochs: 130
70 | optimizer: AdamW
71 | early_stopping_patience: 130
72 | scheduler:
73 | scheduler_type: cosine
74 | hyperparameters:
75 | #gamma: 0.85
76 | eta_min: 1e-5
77 | optimizer_hyperparameter:
78 | # betas:
79 | # - 0.85
80 | # - 0.95
81 | #lr: 0.004
82 | opt_lower: 'AdamW'
83 | lr: 0.0008
84 | opt_betas: [0.85,0.95]
85 | weight_decay: 0.05
86 | opt_eps: 0.00000008
87 | unfreeze_epoch: 25
88 | sampling_gamma: 0.85
89 | sampling_strategy: cell+tissue
90 | mixed_precision: true
91 | transformations:
92 | randomrotate90:
93 | p: 0.5
94 | horizontalflip:
95 | p: 0.5
96 | verticalflip:
97 | p: 0.5
98 | downscale:
99 | p: 0.15
100 | scale: 0.5
101 | blur:
102 | p: 0.2
103 | blur_limit: 10
104 | gaussnoise:
105 | p: 0.25
106 | var_limit: 50
107 | colorjitter:
108 | p: 0.2
109 | scale_setting: 0.25
110 | scale_color: 0.1
111 | superpixels:
112 | p: 0.1
113 | zoomblur:
114 | p: 0.1
115 | randomsizedcrop:
116 | p: 0.1
117 | elastictransform:
118 | p: 0.2
119 | normalize:
120 | mean:
121 | - 0.5
122 | - 0.5
123 | - 0.5
124 | std:
125 | - 0.5
126 | - 0.5
127 | - 0.5
128 | eval_checkpoint: latest_checkpoint.pth
129 | dataset_config:
130 | tissue_types:
131 | Adrenal_gland: 0
132 | Bile-duct: 1
133 | Bladder: 2
134 | Breast: 3
135 | Cervix: 4
136 | Colon: 5
137 | Esophagus: 6
138 | HeadNeck: 7
139 | Kidney: 8
140 | Liver: 9
141 | Lung: 10
142 | Ovarian: 11
143 | Pancreatic: 12
144 | Prostate: 13
145 | Skin: 14
146 | Stomach: 15
147 | Testis: 16
148 | Thyroid: 17
149 | Uterus: 18
150 | nuclei_types:
151 | Background: 0
152 | Neoplastic: 1
153 | Inflammatory: 2
154 | Connective: 3
155 | Dead: 4
156 | Epithelial: 5
157 | run_sweep: false
158 | agent: null
159 |
--------------------------------------------------------------------------------
/datamodel/__init__.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # Data models
3 | #
4 | # @ Fabian Hörst, fabian.hoerst@uk-essen.de
5 | # Institute for Artifical Intelligence in Medicine,
6 | # University Medicine Essen
7 |
--------------------------------------------------------------------------------
/datamodel/graph_datamodel.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # Graph Data model
3 | #
4 | # For more information, please check out docs/readmes/graphs.md
5 | #
6 | # @ Fabian Hörst, fabian.hoerst@uk-essen.de
7 | # Institute for Artifical Intelligence in Medicine,
8 | # University Medicine Essen
9 |
10 | from dataclasses import dataclass
11 |
12 | import torch
13 |
14 |
15 | @dataclass
16 | class GraphDataWSI:
17 | """Dataclass for Graph Data
18 |
19 | Args:
20 | x (torch.Tensor): Node feature matrix with shape (num_nodes, num_nodes_features)
21 | positions(torch.Tensor): Each of the objects defined in x has a physical position in a Cartesian coordinate system,
22 | be it detected cells or extracted patches. That's why we store the 2D position here, globally for the WSI.
23 | Shape (num_nodes, 2)
24 | metadata (dict, optional): Metadata about the object is stored here. Defaults to None
25 | """
26 |
27 | x: torch.Tensor
28 | positions: torch.Tensor
29 | metadata: dict
30 |
--------------------------------------------------------------------------------
/datamodel/wsi_datamodel.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # WSI Model
3 | #
4 | # @ Fabian Hörst, fabian.hoerst@uk-essen.de
5 | # Institute for Artifical Intelligence in Medicine,
6 | # University Medicine Essen
7 |
8 |
9 | import json
10 | from pathlib import Path
11 | from typing import Union, List, Callable, Tuple
12 |
13 | from dataclasses import dataclass, field
14 | import numpy as np
15 | import yaml
16 | import logging
17 | import torch
18 | from PIL import Image
19 |
20 |
21 | @dataclass
22 | class WSI:
23 | """WSI object
24 |
25 | Args:
26 | name (str): WSI name
27 | patient (str): Patient name
28 | slide_path (Union[str, Path]): Full path to the WSI file.
29 | patched_slide_path (Union[str, Path], optional): Full path to preprocessed WSI files (patches). Defaults to None.
30 | embedding_name (Union[str, Path], optional): Defaults to None.
31 | label (Union[str, int, float, np.ndarray], optional): Label of the WSI. Defaults to None.
32 | logger (logging.logger, optional): Logger module for logging information. Defaults to None.
33 | """
34 |
35 | name: str
36 | patient: str
37 | slide_path: Union[str, Path]
38 | patched_slide_path: Union[str, Path] = None
39 | embedding_name: Union[str, Path] = None
40 | label: Union[str, int, float, np.ndarray] = None
41 | logger: logging.Logger = None
42 |
43 | # unset attributes used in this class
44 | metadata: dict = field(init=False, repr=False)
45 | all_patch_metadata: List[dict] = field(init=False, repr=False)
46 | patches_list: List = field(init=False, repr=False)
47 | patch_transform: Callable = field(init=False, repr=False)
48 |
49 | # name without ending (e.g. slide1 instead of slide1.svs)
50 | def __post_init__(self):
51 | """Post-Processing object"""
52 | super().__init__()
53 | # define paramaters that are used, but not defined at startup
54 |
55 | # convert string to path
56 | self.slide_path = Path(self.slide_path).resolve()
57 | if self.patched_slide_path is not None:
58 | self.patched_slide_path = Path(self.patched_slide_path).resolve()
59 | # load metadata
60 | self._get_metadata()
61 | self._get_wsi_patch_metadata()
62 | self.patch_transform = None # hardcode to None (should not be a parameter, but should be defined)
63 |
64 | if self.logger is not None:
65 | self.logger.debug(self.__repr__())
66 |
67 | def _get_metadata(self) -> None:
68 | """Load metadata yaml file"""
69 | self.metadata_path = self.patched_slide_path / "metadata.yaml"
70 | with open(self.metadata_path.resolve(), "r") as metadata_yaml:
71 | try:
72 | self.metadata = yaml.safe_load(metadata_yaml)
73 | except yaml.YAMLError as exc:
74 | print(exc)
75 | self.metadata["label_map_inverse"] = {
76 | v: k for k, v in self.metadata["label_map"].items()
77 | }
78 |
79 | def _get_wsi_patch_metadata(self) -> None:
80 | """Load patch_metadata json file and convert to dict and lists"""
81 | with open(self.patched_slide_path / "patch_metadata.json", "r") as json_file:
82 | metadata = json.load(json_file)
83 | self.patches_list = [str(list(elem.keys())[0]) for elem in metadata]
84 | self.all_patch_metadata = {
85 | str(list(elem.keys())[0]): elem[str(list(elem.keys())[0])]
86 | for elem in metadata
87 | }
88 |
89 | def load_patch_metadata(self, patch_name: str) -> dict:
90 | """Return the metadata of a patch with given name (including patch suffix, e.g., wsi_1_1.png)
91 |
92 | This function assumes that metadata path is a subpath of the patches dataset path
93 |
94 | Args:
95 | patch_name (str): Name of patch
96 |
97 | Returns:
98 | dict: metadata
99 | """
100 | patch_metadata_path = self.all_patch_metadata[patch_name]["metadata_path"]
101 | patch_metadata_path = self.patched_slide_path / patch_metadata_path
102 |
103 | # open
104 | with open(patch_metadata_path, "r") as metadata_yaml:
105 | patch_metadata = yaml.safe_load(metadata_yaml)
106 | patch_metadata["name"] = patch_name
107 |
108 | return patch_metadata
109 |
110 | def set_patch_transform(self, transform: Callable) -> None:
111 | """Set the transformation function to process a patch
112 |
113 | Args:
114 | transform (Callable): Transformation function
115 | """
116 | self.patch_transform = transform
117 |
118 | # patch processing
119 | def process_patch_image(
120 | self, patch_name: str, transform: Callable = None
121 | ) -> Tuple[torch.Tensor, dict]:
122 | """Process one patch: Load from disk, apply transformation if needed. ToTensor is applied automatically
123 |
124 | Args:
125 | patch_name (Path): Name of patch to load, including patch suffix, e.g., wsi_1_1.png
126 | transform (Callable, optional): Optional Patch-Transformation
127 | Returns:
128 | Tuple[torch.Tensor, dict]:
129 |
130 | * torch.Tensor: patch as torch.tensor (:,:,3)
131 | * dict: patch metadata as dictionary
132 | """
133 | patch = Image.open(self.patched_slide_path / "patches" / patch_name)
134 | if transform:
135 | patch = transform(patch)
136 |
137 | metadata = self.load_patch_metadata(patch_name)
138 | return patch, metadata
139 |
140 | def get_number_patches(self) -> int:
141 | """Return the number of patches for this WSI
142 |
143 | Returns:
144 | int: number of patches
145 | """
146 | return int(len(self.patches_list))
147 |
148 | def get_patches(
149 | self, transform: Callable = None
150 | ) -> Tuple[torch.Tensor, list, list]:
151 | """Get all patches for one image
152 |
153 | Args:
154 | transform (Callable, optional): Optional Patch-Transformation
155 |
156 | Returns:
157 | Tuple[torch.Tensor, list]:
158 |
159 | * patched image: Shape of torch.Tensor(num_patches, 3, :, :)
160 | * coordinates as list metadata_dictionary
161 |
162 | """
163 | if self.logger is not None:
164 | self.logger.warning(f"Loading {self.get_number_patches()} patches!")
165 | patches = []
166 | metadata = []
167 | for patch in self.patches_list:
168 | transformed_patch, meta = self.process_patch_image(patch, transform)
169 | patches.append(transformed_patch)
170 | metadata.append(meta)
171 | patches = torch.stack(patches)
172 |
173 | return patches, metadata
174 |
175 | def load_embedding(self) -> torch.Tensor:
176 | """Load embedding from subfolder patched_slide_path/embedding/
177 |
178 | Raises:
179 | FileNotFoundError: If embedding is not given
180 |
181 | Returns:
182 | torch.Tensor: WSI embedding
183 | """
184 | embedding_path = (
185 | self.patched_slide_path / "embeddings" / f"{self.embedding_name}.pt"
186 | )
187 | if embedding_path.is_file():
188 | embedding = torch.load(embedding_path)
189 | return embedding
190 | else:
191 | raise FileNotFoundError(
192 | f"Embeddings for WSI {self.slide_path} cannot be found in path {embedding_path}"
193 | )
194 |
--------------------------------------------------------------------------------
/docs/datasets/PanNuke/dataset_config.yaml:
--------------------------------------------------------------------------------
1 | tissue_types:
2 | "Adrenal_gland": 0
3 | "Bile-duct": 1
4 | "Bladder": 2
5 | "Breast": 3
6 | "Cervix": 4
7 | "Colon": 5
8 | "Esophagus": 6
9 | "HeadNeck": 7
10 | "Kidney": 8
11 | "Liver": 9
12 | "Lung": 10
13 | "Ovarian": 11
14 | "Pancreatic": 12
15 | "Prostate": 13
16 | "Skin": 14
17 | "Stomach": 15
18 | "Testis": 16
19 | "Thyroid": 17
20 | "Uterus": 18
21 |
22 | nuclei_types:
23 | "Background": 0
24 | "Neoplastic": 1
25 | "Inflammatory": 2
26 | "Connective": 3
27 | "Dead": 4
28 | "Epithelial": 5
29 |
--------------------------------------------------------------------------------
/docs/datasets/PanNuke/weight_config.yaml:
--------------------------------------------------------------------------------
1 | tissue:
2 | "Adrenal_gland": 437
3 | "Bile-duct": 420
4 | "Bladder": 146
5 | "Breast": 2351
6 | "Cervix": 293
7 | "Colon": 1440
8 | "Esophagus": 424
9 | "HeadNeck": 384
10 | "Kidney": 134
11 | "Liver": 224
12 | "Lung": 184
13 | "Ovarian": 146
14 | "Pancreatic": 195
15 | "Prostate": 182
16 | "Skin": 187
17 | "Stomach": 146
18 | "Testis": 196
19 | "Thyroid": 226
20 | "Uterus": 186
21 |
--------------------------------------------------------------------------------
/docs/model.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hustvl/LKCell/ae2bf75994c9c93aecb661ec89256bf52d28f09f/docs/model.png
--------------------------------------------------------------------------------
/docs/readmes/cell_segmentation.md:
--------------------------------------------------------------------------------
1 | # Cell Segmentation
2 |
3 | ## Training
4 |
5 | The data structure used to train cell segmentation networks is different than to train classification networks on WSI/Patient level. Cureently, due to the massive amount of cells inside a WSI, all famous cell segmentation datasets (such like [PanNuke](https://warwick.ac.uk/fac/cross_fac/tia/data/pannuke), https://doi.org/10.48550/arXiv.2003.10778) provide just patches with cell annotations. Therefore, we use the following dataset structure (with k folds):
6 |
7 | ```bash
8 | dataset
9 | ├── dataset_config.yaml
10 | ├── fold0
11 | │ ├── images
12 | | | ├── 0_imgname0.png
13 | | | ├── 0_imgname1.png
14 | | | ├── 0_imgname2.png
15 | ...
16 | | | └── 0_imgnameN.png
17 | │ ├── labels
18 | | | ├── 0_imgname0.npy
19 | | | ├── 0_imgname1.npy
20 | | | ├── 0_imgname2.npy
21 | ...
22 | | | └── 0_imgnameN.npy
23 | | └── types.csv
24 | ├── fold1
25 | │ ├── images
26 | | | ├── 1_imgname0.png
27 | | | ├── 1_imgname1.png
28 | ...
29 | │ ├── labels
30 | | | ├── 1_imgname0.npy
31 | | | ├── 1_imgname1.npy
32 | ...
33 | | └── types.csv
34 | ...
35 | └── foldk
36 | │ ├── images
37 | | ├── k_imgname0.png
38 | | ├── k_imgname1.png
39 | ...
40 | ├── labels
41 | | ├── k_imgname0.npy
42 | | ├── k_imgname1.npy
43 | └── types.csv
44 | ```
45 |
46 | Each type csv should have the following header:
47 | ```csv
48 | img,type # Header
49 | foldnum_imgname0.png,SetTypeHeare # Each row is one patch with tissue type
50 | ```
51 |
52 | The labels are numpy masks with the following structure:
53 | TBD
54 |
55 | ## Add a new dataset
56 | add to dataset coordnator.
57 |
58 | All settings of the dataset must be performed in the correspondinng yaml file, under the data section
59 |
60 | dataset name is **not** case sensitive!
61 |
--------------------------------------------------------------------------------
/docs/readmes/monuseg.md:
--------------------------------------------------------------------------------
1 | ## MoNuSeg Preparation
2 | The original PanNuke dataset has the following style using .xml annotations and .tiff files with a size of $1000 \times 1000$ pixels:
3 |
4 | ```bash
5 | ├── testing
6 | │ ├── images
7 | │ │ ├── TCGA-2Z-A9J9-01A-01-TS1.tif
8 | │ │ ├── TCGA-44-2665-01B-06-BS6.tif
9 | ...
10 | │ └── labels
11 | │ ├── TCGA-2Z-A9J9-01A-01-TS1.xml
12 | │ ├── TCGA-44-2665-01B-06-BS6.xml
13 | ...
14 | └── training
15 | ├── images
16 | └── labels
17 | ```
18 | For our experiments, we resized the dataset images to $1024 \times 1024$ pixels and convert the .xml annotations to binary masks:
19 | ```bash
20 | ├── testing
21 | │ ├── images
22 | │ │ ├── TCGA-2Z-A9J9-01A-01-TS1.png
23 | │ │ ├── TCGA-44-2665-01B-06-BS6.png
24 | ...
25 | │ └── labels
26 | │ │ ├── TCGA-2Z-A9J9-01A-01-TS1.npy
27 | │ │ ├── TCGA-44-2665-01B-06-BS6.npy
28 | ...
29 | └── training
30 | ├── images
31 | └── labels
32 | ```
33 |
34 | Everythin can be extracted using the [`cell_segmentation/datasets/prepare_monuseg.py`](cell_segmentation/datasets/prepare_monuseg.py) script.
35 |
--------------------------------------------------------------------------------
/docs/readmes/pannuke.md:
--------------------------------------------------------------------------------
1 | ## PanNuke Preparation
2 | The original PanNuke dataset has the following style using just one big array for each dataset split:
3 |
4 | ```bash
5 | ├── fold0
6 | │ ├── images.npy
7 | │ ├── masks.npy
8 | │ └── types.npy
9 | ├── fold1
10 | │ ├── images.npy
11 | │ ├── masks.npy
12 | │ └── types.npy
13 | └── fold2
14 | ├── images.npy
15 | ├── masks.npy
16 | └── types.npy
17 | ```
18 |
19 | For memory efficieny and to make us of multi-threading dataloading with our augmentation pipeline, we reassemble the dataset to the following structure:
20 | ```bash
21 | ├── fold0
22 | │ ├── cell_count.csv # cell-count for each image to be used in sampling
23 | │ ├── images # H&E Image for each sample as .png files
24 | │ ├── images
25 | │ │ ├── 0_0.png
26 | │ │ ├── 0_1.png
27 | │ │ ├── 0_2.png
28 | ...
29 | │ ├── labels # label as .npy arrays for each sample
30 | │ │ ├── 0_0.npy
31 | │ │ ├── 0_1.npy
32 | │ │ ├── 0_2.npy
33 | ...
34 | │ └── types.csv # csv file with type for each image
35 | ├── fold1
36 | │ ├── cell_count.csv
37 | │ ├── images
38 | │ │ ├── 1_0.png
39 | ...
40 | │ ├── labels
41 | │ │ ├── 1_0.npy
42 | ...
43 | │ └── types.csv
44 | ├── fold2
45 | │ ├── cell_count.csv
46 | │ ├── images
47 | │ │ ├── 2_0.png
48 | ...
49 | │ ├── labels
50 | │ │ ├── 2_0.npy
51 | ...
52 | │ └── types.csv
53 | ├── dataset_config.yaml # dataset config with dataset information
54 | └── weight_config.yaml # config file for our sampling
55 | ```
56 |
57 | We provide all configuration files for the PanNuke dataset in the [`configs/datasets/PanNuke`](configs/datasets/PanNuke) folder. Please copy them in your dataset folder. Images and masks have to be extracted using the [`cell_segmentation/datasets/prepare_pannuke.py`](cell_segmentation/datasets/prepare_pannuke.py) script.
58 |
--------------------------------------------------------------------------------
/models/__init__.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # Model implementations and pretrained models
3 | #
4 | # @ Fabian Hörst, fabian.hoerst@uk-essen.de
5 | # Institute for Artifical Intelligence in Medicine,
6 | # University Medicine Essen
7 |
--------------------------------------------------------------------------------
/models/segmentation/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hustvl/LKCell/ae2bf75994c9c93aecb661ec89256bf52d28f09f/models/segmentation/__init__.py
--------------------------------------------------------------------------------
/models/segmentation/cell_segmentation/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hustvl/LKCell/ae2bf75994c9c93aecb661ec89256bf52d28f09f/models/segmentation/cell_segmentation/__init__.py
--------------------------------------------------------------------------------
/models/segmentation/cell_segmentation/__pycache__/UNet_v2.py:
--------------------------------------------------------------------------------
1 | import os.path
2 | import warnings
3 |
4 | import torch
5 | from torch import nn
6 | from unet_v2.pvtv2 import *
7 | import torch.nn.functional as F
8 |
9 |
10 | class ChannelAttention(nn.Module):
11 | def __init__(self, in_planes, ratio=16):
12 | super(ChannelAttention, self).__init__()
13 | self.avg_pool = nn.AdaptiveAvgPool2d(1)
14 | self.max_pool = nn.AdaptiveMaxPool2d(1)
15 |
16 | self.fc1 = nn.Conv2d(in_planes, in_planes // 16, 1, bias=False)
17 | self.relu1 = nn.ReLU()
18 | self.fc2 = nn.Conv2d(in_planes // 16, in_planes, 1, bias=False)
19 |
20 | self.sigmoid = nn.Sigmoid()
21 |
22 | def forward(self, x):
23 | avg_out = self.fc2(self.relu1(self.fc1(self.avg_pool(x))))
24 | max_out = self.fc2(self.relu1(self.fc1(self.max_pool(x))))
25 | out = avg_out + max_out
26 | return self.sigmoid(out)
27 |
28 |
29 | class SpatialAttention(nn.Module):
30 | def __init__(self, kernel_size=7):
31 | super(SpatialAttention, self).__init__()
32 |
33 | assert kernel_size in (3, 7), 'kernel size must be 3 or 7'
34 | padding = 3 if kernel_size == 7 else 1
35 |
36 | self.conv1 = nn.Conv2d(2, 1, kernel_size, padding=padding, bias=False)
37 | self.sigmoid = nn.Sigmoid()
38 |
39 | def forward(self, x):
40 | avg_out = torch.mean(x, dim=1, keepdim=True)
41 | max_out, _ = torch.max(x, dim=1, keepdim=True)
42 | x = torch.cat([avg_out, max_out], dim=1)
43 | x = self.conv1(x)
44 | return self.sigmoid(x)
45 |
46 |
47 | class BasicConv2d(nn.Module):
48 | def __init__(self, in_planes, out_planes, kernel_size, stride=1, padding=0, dilation=1):
49 | super(BasicConv2d, self).__init__()
50 |
51 | self.conv = nn.Conv2d(in_planes, out_planes,
52 | kernel_size=kernel_size, stride=stride,
53 | padding=padding, dilation=dilation, bias=False)
54 | self.bn = nn.BatchNorm2d(out_planes)
55 | self.relu = nn.ReLU(inplace=True)
56 |
57 | def forward(self, x):
58 | x = self.conv(x)
59 | x = self.bn(x)
60 | return x
61 |
62 |
63 | class Encoder(nn.Module):
64 | def __init__(self, pretrain_path):
65 | super().__init__()
66 | self.backbone = pvt_v2_b2()
67 |
68 | if pretrain_path is None:
69 | warnings.warn('please provide the pretrained pvt model. Not using pretrained model.')
70 | elif not os.path.isfile(pretrain_path):
71 | warnings.warn(f'path: {pretrain_path} does not exists. Not using pretrained model.')
72 | else:
73 | print(f"using pretrained file: {pretrain_path}")
74 | save_model = torch.load(pretrain_path)
75 | model_dict = self.backbone.state_dict()
76 | state_dict = {k: v for k, v in save_model.items() if k in model_dict.keys()}
77 | model_dict.update(state_dict)
78 |
79 | self.backbone.load_state_dict(model_dict)
80 |
81 | def forward(self, x):
82 | f1, f2, f3, f4 = self.backbone(x) # (x: 3, 352, 352)
83 | return f1, f2, f3, f4
84 |
85 |
86 | class SDI(nn.Module):
87 | def __init__(self, channel):
88 | super().__init__()
89 |
90 | self.convs = nn.ModuleList(
91 | [nn.Conv2d(channel, channel, kernel_size=3, stride=1, padding=1) for _ in range(4)])
92 |
93 | def forward(self, xs, anchor):
94 | ans = torch.ones_like(anchor)
95 | target_size = anchor.shape[-1]
96 |
97 | for i, x in enumerate(xs):
98 | if x.shape[-1] > target_size:
99 | x = F.adaptive_avg_pool2d(x, (target_size, target_size))
100 | elif x.shape[-1] < target_size:
101 | x = F.interpolate(x, size=(target_size, target_size),
102 | mode='bilinear', align_corners=True)
103 |
104 | ans = ans * self.convs[i](x)
105 |
106 | return ans
107 |
108 |
109 | class UNetV2(nn.Module):
110 | """
111 | use SpatialAtt + ChannelAtt
112 | """
113 | def __init__(self, channel=32, n_classes=1, deep_supervision=True, pretrained_path=None):
114 | super().__init__()
115 | self.deep_supervision = deep_supervision
116 |
117 | self.encoder = Encoder(pretrained_path)
118 |
119 | self.ca_1 = ChannelAttention(64)
120 | self.sa_1 = SpatialAttention()
121 |
122 | self.ca_2 = ChannelAttention(128)
123 | self.sa_2 = SpatialAttention()
124 |
125 | self.ca_3 = ChannelAttention(320)
126 | self.sa_3 = SpatialAttention()
127 |
128 | self.ca_4 = ChannelAttention(512)
129 | self.sa_4 = SpatialAttention()
130 |
131 | self.Translayer_1 = BasicConv2d(64, channel, 1)
132 | self.Translayer_2 = BasicConv2d(128, channel, 1)
133 | self.Translayer_3 = BasicConv2d(320, channel, 1)
134 | self.Translayer_4 = BasicConv2d(512, channel, 1)
135 |
136 | self.sdi_1 = SDI(channel)
137 | self.sdi_2 = SDI(channel)
138 | self.sdi_3 = SDI(channel)
139 | self.sdi_4 = SDI(channel)
140 |
141 | self.seg_outs = nn.ModuleList([
142 | nn.Conv2d(channel, n_classes, 1, 1) for _ in range(4)])
143 |
144 | self.deconv2 = nn.ConvTranspose2d(channel, channel, kernel_size=4, stride=2, padding=1,
145 | bias=False)
146 | self.deconv3 = nn.ConvTranspose2d(channel, channel, kernel_size=4, stride=2,
147 | padding=1, bias=False)
148 | self.deconv4 = nn.ConvTranspose2d(channel, channel, kernel_size=4, stride=2,
149 | padding=1, bias=False)
150 | self.deconv5 = nn.ConvTranspose2d(channel, channel, kernel_size=4, stride=2,
151 | padding=1, bias=False)
152 |
153 | def forward(self, x):
154 | seg_outs = []
155 | f1, f2, f3, f4 = self.encoder(x)
156 |
157 | f1 = self.ca_1(f1) * f1
158 | f1 = self.sa_1(f1) * f1
159 | f1 = self.Translayer_1(f1)
160 |
161 | f2 = self.ca_2(f2) * f2
162 | f2 = self.sa_2(f2) * f2
163 | f2 = self.Translayer_2(f2)
164 |
165 | f3 = self.ca_3(f3) * f3
166 | f3 = self.sa_3(f3) * f3
167 | f3 = self.Translayer_3(f3)
168 |
169 | f4 = self.ca_4(f4) * f4
170 | f4 = self.sa_4(f4) * f4
171 | f4 = self.Translayer_4(f4)
172 |
173 | f41 = self.sdi_4([f1, f2, f3, f4], f4)
174 | f31 = self.sdi_3([f1, f2, f3, f4], f3)
175 | f21 = self.sdi_2([f1, f2, f3, f4], f2)
176 | f11 = self.sdi_1([f1, f2, f3, f4], f1)
177 |
178 | seg_outs.append(self.seg_outs[0](f41))
179 |
180 | y = self.deconv2(f41) + f31
181 | seg_outs.append(self.seg_outs[1](y))
182 |
183 | y = self.deconv3(y) + f21
184 | seg_outs.append(self.seg_outs[2](y))
185 |
186 | y = self.deconv4(y) + f11
187 | seg_outs.append(self.seg_outs[3](y))
188 |
189 | for i, o in enumerate(seg_outs):
190 | seg_outs[i] = F.interpolate(o, scale_factor=4, mode='bilinear')
191 |
192 | if self.deep_supervision:
193 | return seg_outs[::-1]
194 | else:
195 | return seg_outs[-1]
196 |
197 |
198 | if __name__ == "__main__":
199 | pretrained_path = "/afs/crc.nd.edu/user/y/ypeng4/Polyp-PVT_2/pvt_pth/pvt_v2_b2.pth"
200 | model = UNetV2(n_classes=2, deep_supervision=True, pretrained_path=None)
201 | x = torch.rand((2, 3, 256, 256))
202 | ys = model(x)
203 | for y in ys:
204 | print(y.shape)
205 |
--------------------------------------------------------------------------------
/models/segmentation/cell_segmentation/__pycache__/__init__.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hustvl/LKCell/ae2bf75994c9c93aecb661ec89256bf52d28f09f/models/segmentation/cell_segmentation/__pycache__/__init__.cpython-310.pyc
--------------------------------------------------------------------------------
/models/segmentation/cell_segmentation/__pycache__/__init__.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hustvl/LKCell/ae2bf75994c9c93aecb661ec89256bf52d28f09f/models/segmentation/cell_segmentation/__pycache__/__init__.cpython-38.pyc
--------------------------------------------------------------------------------
/models/segmentation/cell_segmentation/__pycache__/cellvit.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hustvl/LKCell/ae2bf75994c9c93aecb661ec89256bf52d28f09f/models/segmentation/cell_segmentation/__pycache__/cellvit.cpython-310.pyc
--------------------------------------------------------------------------------
/models/segmentation/cell_segmentation/__pycache__/cellvit.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hustvl/LKCell/ae2bf75994c9c93aecb661ec89256bf52d28f09f/models/segmentation/cell_segmentation/__pycache__/cellvit.cpython-38.pyc
--------------------------------------------------------------------------------
/models/segmentation/cell_segmentation/__pycache__/cellvit_shared.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hustvl/LKCell/ae2bf75994c9c93aecb661ec89256bf52d28f09f/models/segmentation/cell_segmentation/__pycache__/cellvit_shared.cpython-38.pyc
--------------------------------------------------------------------------------
/models/segmentation/cell_segmentation/__pycache__/cellvit_unirepLKnet.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hustvl/LKCell/ae2bf75994c9c93aecb661ec89256bf52d28f09f/models/segmentation/cell_segmentation/__pycache__/cellvit_unirepLKnet.cpython-310.pyc
--------------------------------------------------------------------------------
/models/segmentation/cell_segmentation/__pycache__/cellvit_unirepLKnet.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hustvl/LKCell/ae2bf75994c9c93aecb661ec89256bf52d28f09f/models/segmentation/cell_segmentation/__pycache__/cellvit_unirepLKnet.cpython-38.pyc
--------------------------------------------------------------------------------
/models/segmentation/cell_segmentation/__pycache__/replknet.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hustvl/LKCell/ae2bf75994c9c93aecb661ec89256bf52d28f09f/models/segmentation/cell_segmentation/__pycache__/replknet.cpython-38.pyc
--------------------------------------------------------------------------------
/models/segmentation/cell_segmentation/__pycache__/utils.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hustvl/LKCell/ae2bf75994c9c93aecb661ec89256bf52d28f09f/models/segmentation/cell_segmentation/__pycache__/utils.cpython-38.pyc
--------------------------------------------------------------------------------
/models/segmentation/cell_segmentation/utils.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | from einops import rearrange
3 | from models.encoders.VIT.SAM.image_encoder import ImageEncoderViT
4 | from models.encoders.VIT.vits_histo import VisionTransformer
5 |
6 | import torch
7 | import torch.nn as nn
8 | from typing import Callable, Tuple, Type, List
9 |
10 |
11 | class Conv2DBlock(nn.Module):
12 | """Conv2DBlock with convolution followed by batch-normalisation, ReLU activation and dropout
13 |
14 | Args:
15 | in_channels (int): Number of input channels for convolution
16 | out_channels (int): Number of output channels for convolution
17 | kernel_size (int, optional): Kernel size for convolution. Defaults to 3.
18 | dropout (float, optional): Dropout. Defaults to 0.
19 | """
20 |
21 | def __init__(
22 | self,
23 | in_channels: int,
24 | out_channels: int,
25 | kernel_size: int = 3,
26 | dropout: float = 0,
27 | ) -> None:
28 | super().__init__()
29 | self.block = nn.Sequential(
30 | nn.Conv2d(
31 | in_channels=in_channels,
32 | out_channels=out_channels,
33 | kernel_size=kernel_size,
34 | stride=1,
35 | padding=((kernel_size - 1) // 2),
36 | ),
37 | nn.BatchNorm2d(out_channels),
38 | nn.ReLU(True),
39 | nn.Dropout(dropout),
40 | )
41 |
42 | def forward(self, x):
43 | return self.block(x)
44 |
45 |
46 | class Deconv2DBlock(nn.Module):
47 | """Deconvolution block with ConvTranspose2d followed by Conv2d, batch-normalisation, ReLU activation and dropout
48 |
49 | Args:
50 | in_channels (int): Number of input channels for deconv block
51 | out_channels (int): Number of output channels for deconv and convolution.
52 | kernel_size (int, optional): Kernel size for convolution. Defaults to 3.
53 | dropout (float, optional): Dropout. Defaults to 0.
54 | """
55 |
56 | def __init__(
57 | self,
58 | in_channels: int,
59 | out_channels: int,
60 | kernel_size: int = 3,
61 | dropout: float = 0,
62 | ) -> None:
63 | super().__init__()
64 | self.block = nn.Sequential(
65 | nn.ConvTranspose2d(
66 | in_channels=in_channels,
67 | out_channels=out_channels,
68 | kernel_size=2,
69 | stride=2,
70 | padding=0,
71 | output_padding=0,
72 | ),
73 | nn.Conv2d(
74 | in_channels=out_channels,
75 | out_channels=out_channels,
76 | kernel_size=kernel_size,
77 | stride=1,
78 | padding=((kernel_size - 1) // 2),
79 | ),
80 | nn.BatchNorm2d(out_channels),
81 | nn.ReLU(True),
82 | nn.Dropout(dropout),
83 | )
84 |
85 | def forward(self, x):
86 | return self.block(x)
87 |
88 |
89 | class ViTCellViT(VisionTransformer):
90 | def __init__(
91 | self,
92 | extract_layers: List[int],
93 | img_size: List[int] = [224],
94 | patch_size: int = 16,
95 | in_chans: int = 3,
96 | num_classes: int = 0,
97 | embed_dim: int = 768,
98 | depth: int = 12,
99 | num_heads: int = 12,
100 | mlp_ratio: float = 4,
101 | qkv_bias: bool = False,
102 | qk_scale: float = None,
103 | drop_rate: float = 0,
104 | attn_drop_rate: float = 0,
105 | drop_path_rate: float = 0,
106 | norm_layer: Callable = nn.LayerNorm,
107 | **kwargs
108 | ):
109 | """Vision Transformer with 1D positional embedding
110 |
111 | Args:
112 | extract_layers: (List[int]): List of Transformer Blocks whose outputs should be returned in addition to the tokens. First blocks starts with 1, and maximum is N=depth.
113 | img_size (int, optional): Input image size. Defaults to 224.
114 | patch_size (int, optional): Patch Token size (one dimension only, cause tokens are squared). Defaults to 16.
115 | in_chans (int, optional): Number of input channels. Defaults to 3.
116 | num_classes (int, optional): Number of output classes. if num classes = 0, raw tokens are returned (nn.Identity).
117 | Default to 0.
118 | embed_dim (int, optional): Embedding dimension. Defaults to 768.
119 | depth(int, optional): Number of Transformer Blocks. Defaults to 12.
120 | num_heads (int, optional): Number of attention heads per Transformer Block. Defaults to 12.
121 | mlp_ratio (float, optional): MLP ratio for hidden MLP dimension (Bottleneck = dim*mlp_ratio).
122 | Defaults to 4.0.
123 | qkv_bias (bool, optional): If bias should be used for query (q), key (k), and value (v). Defaults to False.
124 | qk_scale (float, optional): Scaling parameter. Defaults to None.
125 | drop_rate (float, optional): Dropout in MLP. Defaults to 0.0.
126 | attn_drop_rate (float, optional): Dropout for attention layer. Defaults to 0.0.
127 | drop_path_rate (float, optional): Dropout for skip connection. Defaults to 0.0.
128 | norm_layer (Callable, optional): Normalization layer. Defaults to nn.LayerNorm.
129 |
130 | """
131 | super().__init__(
132 | img_size=img_size,
133 | patch_size=patch_size,
134 | in_chans=in_chans,
135 | num_classes=num_classes,
136 | embed_dim=embed_dim,
137 | depth=depth,
138 | num_heads=num_heads,
139 | mlp_ratio=mlp_ratio,
140 | qkv_bias=qkv_bias,
141 | qk_scale=qk_scale,
142 | drop_rate=drop_rate,
143 | attn_drop_rate=attn_drop_rate,
144 | drop_path_rate=drop_path_rate,
145 | norm_layer=norm_layer,
146 | )
147 | self.extract_layers = extract_layers
148 |
149 | def forward(
150 | self, x: torch.Tensor
151 | ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
152 | """Forward pass with returning intermediate outputs for skip connections
153 |
154 | Args:
155 | x (torch.Tensor): Input batch
156 |
157 | Returns:
158 | Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
159 | torch.Tensor: Output of last layers (all tokens, without classification)
160 | torch.Tensor: Classification output
161 | torch.Tensor: Skip connection outputs from extract_layer selection
162 | """
163 | extracted_layers = []
164 | x = self.prepare_tokens(x)
165 |
166 | for depth, blk in enumerate(self.blocks):
167 | x = blk(x)
168 | if depth + 1 in self.extract_layers:
169 | extracted_layers.append(x)
170 |
171 | x = self.norm(x)
172 | output = self.head(x[:, 0])
173 |
174 | return output, x[:, 0], extracted_layers
175 |
176 |
177 | class ViTCellViTDeit(ImageEncoderViT):
178 | def __init__(
179 | self,
180 | extract_layers: List[int],
181 | img_size: int = 1024,
182 | patch_size: int = 16,
183 | in_chans: int = 3,
184 | embed_dim: int = 768,
185 | depth: int = 12,
186 | num_heads: int = 12,
187 | mlp_ratio: float = 4,
188 | out_chans: int = 256,
189 | qkv_bias: bool = True,
190 | norm_layer: Type[nn.Module] = nn.LayerNorm,
191 | act_layer: Type[nn.Module] = nn.GELU,
192 | use_abs_pos: bool = True,
193 | use_rel_pos: bool = False,
194 | rel_pos_zero_init: bool = True,
195 | window_size: int = 0,
196 | global_attn_indexes: Tuple[int, ...] = (),
197 | ) -> None:
198 | super().__init__(
199 | img_size,
200 | patch_size,
201 | in_chans,
202 | embed_dim,
203 | depth,
204 | num_heads,
205 | mlp_ratio,
206 | out_chans,
207 | qkv_bias,
208 | norm_layer,
209 | act_layer,
210 | use_abs_pos,
211 | use_rel_pos,
212 | rel_pos_zero_init,
213 | window_size,
214 | global_attn_indexes,
215 | )
216 | self.extract_layers = extract_layers
217 |
218 | def forward(self, x: torch.Tensor) -> torch.Tensor:
219 | extracted_layers = []
220 | x = self.patch_embed(x)
221 |
222 | if self.pos_embed is not None:
223 | token_size = x.shape[1]
224 | x = x + self.pos_embed[:, :token_size, :token_size, :]
225 |
226 | for depth, blk in enumerate(self.blocks):
227 | x = blk(x)
228 | if depth + 1 in self.extract_layers:
229 | extracted_layers.append(x)
230 | output = self.neck(x.permute(0, 3, 1, 2))
231 | _output = rearrange(output, "b c h w -> b c (h w)")
232 |
233 | return torch.mean(_output, axis=-1), output, extracted_layers
234 |
--------------------------------------------------------------------------------
/models/utils/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hustvl/LKCell/ae2bf75994c9c93aecb661ec89256bf52d28f09f/models/utils/__init__.py
--------------------------------------------------------------------------------
/models/utils/attention.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # PyTorch Implementation of Attention Modules
3 | #
4 | # Implementation based on: https://github.com/mahmoodlab/CLAM
5 | # @ Fabian Hörst, fabian.hoerst@uk-essen.de
6 | # Institute for Artifical Intelligence in Medicine,
7 | # University Medicine Essen
8 |
9 | from typing import Tuple
10 | import torch
11 | import torch.nn as nn
12 |
13 |
14 | class Attention(nn.Module):
15 | """Basic Attention module. Compare https://github.com/AMLab-Amsterdam/AttentionDeepMIL
16 |
17 | Args:
18 | in_features (int, optional): Input shape of attention module. Defaults to 1024.
19 | attention_features (int, optional): Number of attention features. Defaults to 128.
20 | num_classes (int, optional): Number of output classes. Defaults to 2.
21 | dropout (bool, optional): If True, dropout is used. Defaults to False.
22 | dropout_rate (float, optional): Dropout rate, just applies if dropout parameter is true.
23 | Needs to be between 0.0 and 1.0. Defaults to 0.25.
24 | """
25 |
26 | def __init__(
27 | self,
28 | in_features: int = 1024,
29 | attention_features: int = 128,
30 | num_classes: int = 2,
31 | dropout: bool = False,
32 | dropout_rate: float = 0.25,
33 | ):
34 | super(Attention, self).__init__()
35 | # naming
36 | self.model_name = "AttentionModule"
37 |
38 | # set parameter dimensions for attention
39 | self.attention_features = attention_features
40 | self.in_features = in_features
41 | self.num_classes = num_classes
42 | self.dropout = dropout
43 | self.d_rate = dropout_rate
44 |
45 | if self.dropout:
46 | assert self.d_rate < 1
47 | self.attention = nn.Sequential(
48 | nn.Linear(self.in_features, self.attention_features),
49 | nn.Tanh(),
50 | nn.Dropout(self.d_rate),
51 | nn.Linear(self.attention_features, self.num_classes),
52 | )
53 | else:
54 | self.attention = nn.Sequential(
55 | nn.Linear(self.in_features, self.attention_features),
56 | nn.Tanh(),
57 | nn.Linear(self.attention_features, self.num_classes),
58 | )
59 |
60 | def forward(self, H: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
61 | """Forward pass, calculating attention scores for given input vector
62 |
63 | Args:
64 | H (torch.Tensor): Bag of instances. Shape: (Number of instances, Feature-dimensions)
65 |
66 | Returns:
67 | Tuple[torch.Tensor, torch.Tensor]:
68 |
69 | * Attention-Scores
70 | * H. Shape: Bag of instances. Shape: (Number of instances, Feature-dimensions)
71 | """
72 | A = self.attention(H)
73 | return A, H
74 |
75 |
76 | class AttentionGated(nn.Module):
77 | """Gated Attention module. Compare https://github.com/AMLab-Amsterdam/AttentionDeepMIL
78 |
79 | Args:
80 | in_features (int, optional): Input shape of attention module. Defaults to 1024.
81 | attention_features (int, optional): Number of attention features. Defaults to 128.
82 | num_classes (int, optional): Number of output classes. Defaults to 2.
83 | dropout (bool, optional): If True, dropout is used. Defaults to False.
84 | dropout_rate (float, optional): Dropout rate, just applies if dropout parameter is true.
85 | needs to be between 0.0 and 1.0. Defaults to 0.25.
86 | """
87 |
88 | def __init__(
89 | self,
90 | in_features: int = 1024,
91 | attention_features: int = 128,
92 | num_classes: int = 2,
93 | dropout: bool = False,
94 | dropout_rate: float = 0.25,
95 | ):
96 | super(AttentionGated, self).__init__()
97 | # naming
98 | self.model_name = "AttentionModuleGated"
99 |
100 | # set Parameter dimensions for attention
101 | self.attention_features = attention_features
102 | self.in_features = in_features
103 | self.num_classes = num_classes
104 | self.dropout = dropout
105 | self.d_rate = dropout_rate
106 |
107 | if self.dropout:
108 | assert self.d_rate < 1
109 | self.attention_V = nn.Sequential(
110 | nn.Linear(self.in_features, self.attention_features),
111 | nn.Tanh(),
112 | nn.Dropout(self.d_rate),
113 | )
114 | self.attention_U = nn.Sequential(
115 | nn.Linear(self.in_features, self.attention_features),
116 | nn.Sigmoid(),
117 | nn.Dropout(self.d_rate),
118 | )
119 | self.attention_W = nn.Sequential(
120 | nn.Linear(self.attention_features, self.num_classes)
121 | )
122 |
123 | else:
124 | self.attention_V = nn.Sequential(
125 | nn.Linear(self.in_features, self.attention_features), nn.Tanh()
126 | )
127 | self.attention_U = nn.Sequential(
128 | nn.Linear(self.in_features, self.attention_features), nn.Sigmoid()
129 | )
130 | self.attention_W = nn.Sequential(
131 | nn.Linear(self.attention_features, self.num_classes)
132 | )
133 |
134 | def forward(self, H: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
135 | """Forward pass, calculating attention scores for given input vector
136 |
137 | Args:
138 | H (torch.Tensor): Bag of instances. Shape: (Number of instances, Feature-dimensions)
139 |
140 | Returns:
141 | Tuple[torch.Tensor, torch.Tensor]:
142 |
143 | * Attention-Scores. Shape: (Number of instances)
144 | * H. Shape: Bag of instances. Shape: (Number of instances, Feature-dimensions)
145 | """
146 | v = self.attention_V(H)
147 | u = self.attention_U(H)
148 | A = self.attention_W(v * u)
149 | return A, H
150 |
--------------------------------------------------------------------------------
/models/utils/dense.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # Dense Block as defined in:
3 | # Huang, Gao, Zhuang Liu, Laurens Van Der Maaten, and Kilian Q. Weinberger.
4 | # "Densely connected convolutional networks." In Proceedings of the IEEE conference
5 | # on computer vision and pattern recognition, pp. 4700-4708. 2017.
6 | #
7 | # Code Snippet adapted from HoverNet implementation (https://github.com/vqdang/hover_net)
8 | #
9 | # @ Fabian Hörst, fabian.hoerst@uk-essen.de
10 | # Institute for Artifical Intelligence in Medicine,
11 | # University Medicine Essen
12 |
13 |
14 | import torch
15 | import torch.nn as nn
16 |
17 | from collections import OrderedDict
18 |
19 |
20 | class DenseBlock(nn.Module):
21 | """Dense Block as defined in:
22 |
23 | Huang, Gao, Zhuang Liu, Laurens Van Der Maaten, and Kilian Q. Weinberger.
24 | "Densely connected convolutional networks." In Proceedings of the IEEE conference
25 | on computer vision and pattern recognition, pp. 4700-4708. 2017.
26 |
27 | Only performs `valid` convolution.
28 |
29 | """
30 |
31 | def __init__(self, in_ch, unit_ksize, unit_ch, unit_count, split=1):
32 | super(DenseBlock, self).__init__()
33 | assert len(unit_ksize) == len(unit_ch), "Unbalance Unit Info"
34 |
35 | self.nr_unit = unit_count
36 | self.in_ch = in_ch
37 | self.unit_ch = unit_ch
38 |
39 | # ! For inference only so init values for batchnorm may not match tensorflow
40 | unit_in_ch = in_ch
41 | self.units = nn.ModuleList()
42 | for idx in range(unit_count):
43 | self.units.append(
44 | nn.Sequential(
45 | OrderedDict(
46 | [
47 | ("preact_bna/bn", nn.BatchNorm2d(unit_in_ch, eps=1e-5)),
48 | ("preact_bna/relu", nn.ReLU(inplace=True)),
49 | (
50 | "conv1",
51 | nn.Conv2d(
52 | unit_in_ch,
53 | unit_ch[0],
54 | unit_ksize[0],
55 | stride=1,
56 | padding=0,
57 | bias=False,
58 | ),
59 | ),
60 | ("conv1/bn", nn.BatchNorm2d(unit_ch[0], eps=1e-5)),
61 | ("conv1/relu", nn.ReLU(inplace=True)),
62 | # ('conv2/pool', TFSamepaddingLayer(ksize=unit_ksize[1], stride=1)),
63 | (
64 | "conv2",
65 | nn.Conv2d(
66 | unit_ch[0],
67 | unit_ch[1],
68 | unit_ksize[1],
69 | groups=split,
70 | stride=1,
71 | padding=0,
72 | bias=False,
73 | ),
74 | ),
75 | ]
76 | )
77 | )
78 | )
79 | unit_in_ch += unit_ch[1]
80 |
81 | self.blk_bna = nn.Sequential(
82 | OrderedDict(
83 | [
84 | ("bn", nn.BatchNorm2d(unit_in_ch, eps=1e-5)),
85 | ("relu", nn.ReLU(inplace=True)),
86 | ]
87 | )
88 | )
89 |
90 | def out_ch(self):
91 | return self.in_ch + self.nr_unit * self.unit_ch[-1]
92 |
93 | def init_weights(self):
94 | """Kaiming (HE) initialization for convolutional layers and constant initialization for normalization and linear layers"""
95 | for m in self.modules():
96 | classname = m.__class__.__name__
97 |
98 | if isinstance(m, nn.Conv2d):
99 | nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
100 |
101 | if "norm" in classname.lower():
102 | nn.init.constant_(m.weight, 1)
103 | nn.init.constant_(m.bias, 0)
104 |
105 | if "linear" in classname.lower():
106 | if m.bias is not None:
107 | nn.init.constant_(m.bias, 0)
108 |
109 | def forward(self, prev_feat):
110 | for idx in range(self.nr_unit):
111 | new_feat = self.units[idx](prev_feat)
112 | prev_feat = crop_to_shape(prev_feat, new_feat)
113 | prev_feat = torch.cat([prev_feat, new_feat], dim=1)
114 | prev_feat = self.blk_bna(prev_feat)
115 |
116 | return prev_feat
117 |
118 |
119 | # helper functions for cropping
120 | def crop_op(x, cropping, data_format="NCHW"):
121 | """Center crop image.
122 |
123 | Args:
124 | x: input image
125 | cropping: the substracted amount
126 | data_format: choose either `NCHW` or `NHWC`
127 |
128 | """
129 | crop_t = cropping[0] // 2
130 | crop_b = cropping[0] - crop_t
131 | crop_l = cropping[1] // 2
132 | crop_r = cropping[1] - crop_l
133 | if data_format == "NCHW":
134 | x = x[:, :, crop_t:-crop_b, crop_l:-crop_r]
135 | else:
136 | x = x[:, crop_t:-crop_b, crop_l:-crop_r, :]
137 | return x
138 |
139 |
140 | def crop_to_shape(x, y, data_format="NCHW"):
141 | """Centre crop x so that x has shape of y. y dims must be smaller than x dims.
142 |
143 | Args:
144 | x: input array
145 | y: array with desired shape.
146 |
147 | """
148 | assert (
149 | y.shape[0] <= x.shape[0] and y.shape[1] <= x.shape[1]
150 | ), "Ensure that y dimensions are smaller than x dimensions!"
151 |
152 | x_shape = x.size()
153 | y_shape = y.size()
154 | if data_format == "NCHW":
155 | crop_shape = (x_shape[2] - y_shape[2], x_shape[3] - y_shape[3])
156 | else:
157 | crop_shape = (x_shape[1] - y_shape[1], x_shape[2] - y_shape[2])
158 | return crop_op(x, crop_shape, data_format)
159 |
--------------------------------------------------------------------------------
/models/utils/residual.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # Residual block as defined in:
3 | # He, Kaiming, Xiangyu Zhang, Shaoqing Ren, and Jian Sun. "Deep residual learning
4 | # for image recognition." In Proceedings of the IEEE conference on computer vision
5 | # and pattern recognition, pp. 770-778. 2016.
6 | #
7 | # Code Snippet adapted from HoverNet implementation (https://github.com/vqdang/hover_net)
8 | #
9 | # @ Fabian Hörst, fabian.hoerst@uk-essen.de
10 | # Institute for Artifical Intelligence in Medicine,
11 | # University Medicine Essen
12 |
13 |
14 | import torch
15 | import torch.nn as nn
16 |
17 | from collections import OrderedDict
18 |
19 | from models.utils.tf_utils import TFSamepaddingLayer
20 |
21 |
22 | class ResidualBlock(nn.Module):
23 | """Residual block as defined in:
24 |
25 | He, Kaiming, Xiangyu Zhang, Shaoqing Ren, and Jian Sun. "Deep residual learning
26 | for image recognition." In Proceedings of the IEEE conference on computer vision
27 | and pattern recognition, pp. 770-778. 2016.
28 |
29 | """
30 |
31 | def __init__(self, in_ch, unit_ksize, unit_ch, unit_count, stride=1):
32 | super(ResidualBlock, self).__init__()
33 | assert len(unit_ksize) == len(unit_ch), "Unbalance Unit Info"
34 |
35 | self.nr_unit = unit_count
36 | self.in_ch = in_ch
37 | self.unit_ch = unit_ch
38 |
39 | # ! For inference only so init values for batchnorm may not match tensorflow
40 | unit_in_ch = in_ch
41 | self.units = nn.ModuleList()
42 | for idx in range(unit_count):
43 | unit_layer = [
44 | ("preact/bn", nn.BatchNorm2d(unit_in_ch, eps=1e-5)),
45 | ("preact/relu", nn.ReLU(inplace=True)),
46 | (
47 | "conv1",
48 | nn.Conv2d(
49 | unit_in_ch,
50 | unit_ch[0],
51 | unit_ksize[0],
52 | stride=1,
53 | padding=0,
54 | bias=False,
55 | ),
56 | ),
57 | ("conv1/bn", nn.BatchNorm2d(unit_ch[0], eps=1e-5)),
58 | ("conv1/relu", nn.ReLU(inplace=True)),
59 | (
60 | "conv2/pad",
61 | TFSamepaddingLayer(
62 | ksize=unit_ksize[1], stride=stride if idx == 0 else 1
63 | ),
64 | ),
65 | (
66 | "conv2",
67 | nn.Conv2d(
68 | unit_ch[0],
69 | unit_ch[1],
70 | unit_ksize[1],
71 | stride=stride if idx == 0 else 1,
72 | padding=0,
73 | bias=False,
74 | ),
75 | ),
76 | ("conv2/bn", nn.BatchNorm2d(unit_ch[1], eps=1e-5)),
77 | ("conv2/relu", nn.ReLU(inplace=True)),
78 | (
79 | "conv3",
80 | nn.Conv2d(
81 | unit_ch[1],
82 | unit_ch[2],
83 | unit_ksize[2],
84 | stride=1,
85 | padding=0,
86 | bias=False,
87 | ),
88 | ),
89 | ]
90 | # * has bna to conclude each previous block so
91 | # * must not put preact for the first unit of this block
92 | unit_layer = unit_layer if idx != 0 else unit_layer[2:]
93 | self.units.append(nn.Sequential(OrderedDict(unit_layer)))
94 | unit_in_ch = unit_ch[-1]
95 |
96 | if in_ch != unit_ch[-1] or stride != 1:
97 | self.shortcut = nn.Conv2d(in_ch, unit_ch[-1], 1, stride=stride, bias=False)
98 | else:
99 | self.shortcut = None
100 |
101 | self.blk_bna = nn.Sequential(
102 | OrderedDict(
103 | [
104 | ("bn", nn.BatchNorm2d(unit_in_ch, eps=1e-5)),
105 | ("relu", nn.ReLU(inplace=True)),
106 | ]
107 | )
108 | )
109 |
110 | def out_ch(self):
111 | return self.unit_ch[-1]
112 |
113 | def init_weights(self):
114 | """Kaiming (HE) initialization for convolutional layers and constant initialization for normalization and linear layers"""
115 | for m in self.modules():
116 | classname = m.__class__.__name__
117 |
118 | if isinstance(m, nn.Conv2d):
119 | nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
120 |
121 | if "norm" in classname.lower():
122 | nn.init.constant_(m.weight, 1)
123 | nn.init.constant_(m.bias, 0)
124 |
125 | if "linear" in classname.lower():
126 | if m.bias is not None:
127 | nn.init.constant_(m.bias, 0)
128 |
129 | def forward(self, prev_feat, freeze=False):
130 | if self.shortcut is None:
131 | shortcut = prev_feat
132 | else:
133 | shortcut = self.shortcut(prev_feat)
134 |
135 | for idx in range(0, len(self.units)):
136 | new_feat = prev_feat
137 | if self.training:
138 | with torch.set_grad_enabled(not freeze):
139 | new_feat = self.units[idx](new_feat)
140 | else:
141 | new_feat = self.units[idx](new_feat)
142 | prev_feat = new_feat + shortcut
143 | shortcut = prev_feat
144 | feat = self.blk_bna(prev_feat)
145 | return feat
146 |
--------------------------------------------------------------------------------
/models/utils/tf_utils.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 |
6 | class TFSamepaddingLayer(nn.Module):
7 | """To align with tf `same` padding.
8 |
9 | Putting this before any conv layer that need padding
10 | Assuming kernel has Height == Width for simplicity
11 | """
12 |
13 | def __init__(self, ksize, stride):
14 | super(TFSamepaddingLayer, self).__init__()
15 | self.ksize = ksize
16 | self.stride = stride
17 |
18 | def forward(self, x):
19 | if x.shape[2] % self.stride == 0:
20 | pad = max(self.ksize - self.stride, 0)
21 | else:
22 | pad = max(self.ksize - (x.shape[2] % self.stride), 0)
23 |
24 | if pad % 2 == 0:
25 | pad_val = pad // 2
26 | padding = (pad_val, pad_val, pad_val, pad_val)
27 | else:
28 | pad_val_start = pad // 2
29 | pad_val_end = pad - pad_val_start
30 | padding = (pad_val_start, pad_val_end, pad_val_start, pad_val_end)
31 | # print(x.shape, padding)
32 | x = F.pad(x, padding, "constant", 0)
33 | # print(x.shape)
34 | return x
35 |
--------------------------------------------------------------------------------
/models/utils/tools.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # Helper functions for models
3 | #
4 | # @ Fabian Hörst, fabian.hoerst@uk-essen.de
5 | # Institute for Artifical Intelligence in Medicine,
6 | # University Medicine Essen
7 |
8 | from torch import nn
9 |
10 |
11 | def reset_weights(model: nn.Module) -> None:
12 | """Reset the parameters of the model to avaid weight leakage
13 |
14 | Args:
15 | model (nn.Module): PyTorch Model
16 | """
17 | for layer in model.children():
18 | if hasattr(layer, "reset_parameters"):
19 | layer.reset_parameters()
20 |
21 |
22 | def initialize_weights(module: nn.Module) -> None:
23 | """Initialize Module weights according to xavier
24 |
25 | Args:
26 | module (nn.Module): Model
27 | """
28 | for m in module.modules():
29 | if isinstance(m, nn.Linear):
30 | nn.init.xavier_normal_(m.weight)
31 | m.bias.data.zero_()
32 |
33 | elif isinstance(m, nn.BatchNorm1d):
34 | nn.init.constant_(m.weight, 1)
35 | nn.init.constant_(m.bias, 0)
36 |
--------------------------------------------------------------------------------
/preprocessing/encoding/datasets/__init__.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | import logging
3 |
4 | logger = logging.getLogger(__name__)
5 | logger.addHandler(logging.NullHandler())
6 |
--------------------------------------------------------------------------------
/preprocessing/encoding/datasets/patched_wsi_inference.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # Patched WSI Dataset used for inference, mainly for calculating embeddings
3 | #
4 | # @ Fabian Hörst, fabian.hoerst@uk-essen.de
5 | # Institute for Artifical Intelligence in Medicine,
6 | # University Medicine Essen
7 |
8 | from typing import Callable, Tuple, List
9 |
10 | import torch
11 | from torch.utils.data import Dataset
12 | from datamodel.wsi_datamodel import WSI
13 |
14 |
15 | class PatchedWSIInference(Dataset):
16 | """Inference Dataset, used for calculating embeddings of *one* WSI. Wrapped around a WSI object
17 |
18 | Args:
19 | wsi_object (
20 | filelist (list[str]): List with filenames as entries. Filenames should match the key pattern in wsi_objects dictionary
21 | transform (Callable): Inference Transformations
22 | """
23 |
24 | def __init__(
25 | self,
26 | wsi_object: WSI,
27 | transform: Callable,
28 | ) -> None:
29 | # set all configurations
30 | assert isinstance(wsi_object, WSI), "Must be a WSI-object"
31 | assert (
32 | wsi_object.patched_slide_path is not None
33 | ), "Please provide a WSI that already has been patched into slices"
34 |
35 | self.transform = transform
36 | self.wsi_object = wsi_object
37 |
38 | def __getitem__(
39 | self, idx: int
40 | ) -> Tuple[torch.Tensor, list[list[str, str]], list[str], int, str]:
41 | """Returns one WSI with patches, coords, filenames, labels and wsi name for given idx
42 |
43 | Args:
44 | idx (int): Index of WSI to retrieve
45 |
46 | Returns:
47 | Tuple[torch.Tensor, list[list[str,str]], list[str], int, str]:
48 |
49 | * torch.Tensor: Tensor with shape [num_patches, 3, height, width], includes all patches for one WSI
50 | * list[list[str,str]]: List with coordinates as list entries, e.g., [['1', '1'], ['2', '1'], ..., ['row', 'col']]
51 | * list[str]: List with patch filenames
52 | * int: Patient label as integer
53 | * str: String with WSI name
54 | """
55 | patch_name = self.wsi_object.patches_list[idx]
56 |
57 | patch, metadata = self.wsi_object.process_patch_image(
58 | patch_name=patch_name, transform=self.transform
59 | )
60 |
61 | return patch, metadata
62 |
63 | def __len__(self) -> int:
64 | """Return len of dataset
65 |
66 | Returns:
67 | int: Len of dataset
68 | """
69 | return int(self.wsi_object.get_number_patches())
70 |
71 | @staticmethod
72 | def collate_batch(batch: List[Tuple]) -> Tuple[torch.Tensor, list[dict]]:
73 | """Create a custom batch
74 |
75 | Needed to unpack List of tuples with dictionaries and array
76 |
77 | Args:
78 | batch (List[Tuple]): Input batch consisting of a list of tuples (patch, patch-metadata)
79 |
80 | Returns:
81 | Tuple[torch.Tensor, list[dict]]:
82 | New batch: patches with shape [batch_size, 3, patch_size, patch_size], list of metadata dicts
83 | """
84 | patches, metadata = zip(*batch)
85 | patches = torch.stack(patches)
86 | metadata = list(metadata)
87 | return patches, metadata
88 |
--------------------------------------------------------------------------------
/preprocessing/patch_extraction/main_extraction.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # Main entry point for patch-preprocessing
3 | #
4 | # @ Fabian Hörst, fabian.hoerst@uk-essen.de
5 | # Institute for Artifical Intelligence in Medicine,
6 | # University Medicine Essen
7 |
8 | import inspect
9 | import logging
10 | import os
11 | import sys
12 |
13 |
14 | logger = logging.getLogger()
15 | logger.addHandler(logging.NullHandler())
16 |
17 | currentdir = os.path.dirname(os.path.abspath(inspect.getfile(inspect.currentframe())))
18 | parentdir = os.path.dirname(currentdir)
19 | sys.path.insert(0, parentdir)
20 | parentdir = os.path.dirname(parentdir)
21 | sys.path.insert(0, parentdir)
22 |
23 | from preprocessing.patch_extraction.src.cli import PreProcessingParser
24 | from preprocessing.patch_extraction.src.patch_extraction import PreProcessor
25 | from utils.tools import close_logger
26 |
27 | if __name__ == "__main__":
28 | configuration_parser = PreProcessingParser()
29 | configuration, logger = configuration_parser.get_config()
30 | configuration_parser.store_config()
31 |
32 | slide_processor = PreProcessor(slide_processor_config=configuration)
33 | slide_processor.sample_patches_dataset()
34 |
35 | logger.info("Finished Preprocessing.")
36 | close_logger(logger)
37 |
--------------------------------------------------------------------------------
/preprocessing/patch_extraction/scripts/macenko.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | import inspect
3 | import logging
4 | import os
5 | import sys
6 |
7 | logger = logging.getLogger()
8 | logger.addHandler(logging.NullHandler())
9 |
10 | currentdir = os.path.dirname(os.path.abspath(inspect.getfile(inspect.currentframe())))
11 | parentdir = os.path.dirname(currentdir)
12 | sys.path.insert(0, parentdir)
13 | parentdir = os.path.dirname(parentdir)
14 | sys.path.insert(0, parentdir)
15 | parentdir = os.path.dirname(parentdir)
16 | sys.path.insert(0, parentdir)
17 |
18 | from preprocessing.patch_extraction.src.cli import MacenkoParser
19 | from preprocessing.patch_extraction.src.patch_extraction import PreProcessor
20 |
21 | if __name__ == "__main__":
22 | configuration_parser = MacenkoParser()
23 | configuration, logger = configuration_parser.get_config()
24 |
25 | slide_processor = PreProcessor(slide_processor_config=configuration)
26 | slide_processor.save_normalization_vector(
27 | wsi_file=configuration.wsi_paths, save_json_path=configuration.save_json_path
28 | )
29 |
30 | logger.info("Finished Macenko Vector Calculation!")
31 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | absl-py==2.1.0
2 | addict==2.4.0
3 | albumentations==1.4.1
4 | aliyun-python-sdk-core==2.15.0
5 | aliyun-python-sdk-kms==2.16.2
6 | annotated-types==0.6.0
7 | apex
8 | appdirs==1.4.4
9 | Brotli
10 | cached-property==1.5.2
11 | certifi==2022.12.7
12 | cffi==1.16.0
13 | charset-normalizer==2.1.1
14 | chex==0.1.7
15 | click==8.1.7
16 | colorama==0.4.6
17 | contextlib2==21.6.0
18 | contourpy==1.1.1
19 | crcmod==1.7
20 | cryptography==42.0.5
21 | cycler==0.12.1
22 | depthwise-conv2d-implicit-gemm==0.0.0
23 | dill==0.3.8
24 | dm-tree==0.1.8
25 | docker-pycreds==0.4.0
26 | einops==0.7.0
27 | etils==1.3.0
28 | filelock==3.9.0
29 | flax==0.7.2
30 | fonttools==4.49.0
31 | fsspec==2024.2.0
32 | gitdb==4.0.11
33 | GitPython==3.1.42
34 | huggingface-hub==0.21.4
35 | idna==3.4
36 | imageio==2.34.0
37 | importlib_metadata==7.0.2
38 | importlib_resources==6.1.3
39 | jax==0.4.13
40 | jaxlib==0.4.13
41 | Jinja2==3.1.2
42 | jmespath==0.10.0
43 | joblib==1.3.2
44 | kiwisolver==1.4.5
45 | lazy_loader==0.3
46 | lightning-utilities==0.10.1
47 | llvmlite==0.41.1
48 | Markdown==3.5.2
49 | markdown-it-py==3.0.0
50 | MarkupSafe==2.1.3
51 | matplotlib==3.7.5
52 | mdurl==0.1.2
53 | ml-dtypes==0.2.0
54 | mmcv==2.1.0
55 | mmcv-full==1.7.2
56 | mmdet==3.3.0
57 | mmengine==0.10.3
58 | mmsegmentation==1.2.2
59 | model-index==0.1.11
60 | monai==1.3.0
61 | mpmath==1.3.0
62 | msgpack==1.0.8
63 | munch==4.0.0
64 | natsort==8.4.0
65 | nest-asyncio==1.6.0
66 | networkx==3.0
67 | ninja==1.11.1.1
68 | numba==0.58.1
69 | numpy
70 | opencv-python==4.9.0.80
71 | opencv-python-headless==4.9.0.80
72 | opendatalab==0.0.10
73 | openmim==0.3.9
74 | openxlab==0.0.35
75 | opt-einsum==3.3.0
76 | optax==0.1.8
77 | orbax-checkpoint==0.2.3
78 | ordered-set==4.1.0
79 | oss2==2.17.0
80 | packaging
81 | pandarallel==1.6.5
82 | pandas==2.0.3
83 | pillow==10.2.0
84 | platformdirs
85 | pretrainedmodels==0.7.4
86 | prettytable==3.10.0
87 | protobuf==4.25.3
88 | psutil==5.9.8
89 | pycocotools==2.0.7
90 | pycparser==2.21
91 | pycryptodome==3.20.0
92 | pydantic==2.6.4
93 | pydantic_core==2.16.3
94 | Pygments==2.17.2
95 | pyparsing==3.1.2
96 | PySocks
97 | python-dateutil==2.9.0.post0
98 | pytz==2023.4
99 | PyWavelets==1.4.1
100 | PyYAML
101 | requests==2.28.2
102 | rich==13.4.2
103 | safetensors==0.4.2
104 | schema==0.7.5
105 | scikit-image==0.21.0
106 | scikit-learn==1.3.2
107 | scipy
108 | sentry-sdk==1.41.0
109 | setproctitle==1.3.3
110 | shapely==2.0.3
111 | six==1.16.0
112 | smmap==5.0.1
113 | sympy==1.12
114 | tabulate==0.9.0
115 | tensorstore==0.1.45
116 | termcolor
117 | terminaltables==3.1.10
118 | thop==0.1.1.post2209072238
119 | threadpoolctl==3.3.0
120 | tifffile==2023.7.10
121 | timm==0.9.16
122 | tomli==2.0.1
123 | toolz==0.12.1
124 | torch==2.1.2+cu118
125 | torchaudio==2.1.2+cu118
126 | torchinfo==1.8.0
127 | torchmetrics==1.3.1
128 | torchstat==0.0.7
129 | torchsummary==1.5.1
130 | torchvision==0.16.2+cu118
131 | tqdm==4.65.2
132 | triton==2.1.0
133 | typing_extensions==4.10.0
134 | tzdata==2024.1
135 | ujson==5.9.0
136 | urllib3==1.26.13
137 | wandb==0.16.4
138 | wcwidth==0.2.13
139 | yacs
140 | yapf==0.40.2
141 | zipp==3.17.0
--------------------------------------------------------------------------------
/utils/__init__.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | import logging
3 |
4 | logger = logging.getLogger("__main__")
5 | logger.addHandler(logging.NullHandler())
6 |
--------------------------------------------------------------------------------
/utils/file_handling.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | from pathlib import Path
3 | from typing import List, Union
4 | import pandas as pd
5 |
6 |
7 | def load_wsi_files_from_csv(csv_path: Union[Path, str], wsi_extension: str) -> List:
8 | """Load filenames from csv file with column name "Filename"
9 |
10 | Args:
11 | csv_path (Union[Path, str]): Path to csv file
12 | wsi_extension (str): WSI file ending (suffix)
13 |
14 | Returns:
15 | List: _description_
16 | """
17 | wsi_filelist = pd.read_csv(csv_path)
18 | wsi_filelist = wsi_filelist["Filename"].to_list()
19 | wsi_filelist = [f for f in wsi_filelist if Path(f).suffix == f".{wsi_extension}"]
20 |
21 | return wsi_filelist
22 |
--------------------------------------------------------------------------------
/utils/logger.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # Logging Class
3 | #
4 |
5 | import datetime
6 | from typing import Literal, Union
7 | from pathlib import Path
8 | import logging
9 | import logging.handlers
10 | import os
11 | import sys
12 |
13 |
14 | class Logger:
15 | """Initialize a Logger for sys-logging and RotatingFileHandler-logging by using python logging module.
16 | The logger can be used out of the box without any changes, but is also adaptable for specific use cases.
17 | In basic configuration, just the log level must be provided. If log_dir is provided, another handler object is created
18 | logging into a file into the log_dir directory. The filename can be changes by using comment, which basically is the filename.
19 | To create different log files with specific timestamp set 'use_timestamp' = True. This adds an additional timestamp to the filename.
20 |
21 | Args:
22 | level (Literal["CRITICAL", "ERROR", "WARNING", "INFO", "DEBUG"]): Logger.level
23 | log_dir (Union[Path, str], optional): Path to save logfile in. Defaults to None.
24 | comment (str, optional): additional comment for save file. Defaults to 'logs'.
25 | formatter (str, optional): Custom formatter. Defaults to None.
26 | use_timestamp (bool, optional): Using timestamp for time-logging. Defaults to False.
27 | file_level (Literal["CRITICAL", "ERROR", "WARNING", "INFO", "DEBUG"], optional): Set Logger.level. for output file.
28 | Can be useful if a different logging level should be used for terminal output and logging file.
29 | If no level is selected, file level logging is the same as for console. Defaults to None.
30 | """
31 |
32 | def __init__(
33 | self,
34 | level: Literal["CRITICAL", "ERROR", "WARNING", "INFO", "DEBUG"],
35 | log_dir: Union[Path, str] = None,
36 | comment: str = "logs",
37 | formatter: str = None,
38 | use_timestamp: bool = False,
39 | file_level: Literal["CRITICAL", "ERROR", "WARNING", "INFO", "DEBUG"] = None,
40 | ) -> None:
41 | self.level = level
42 | self.comment = comment
43 | self.log_parent_dir = log_dir
44 | self.use_timestamp = use_timestamp
45 | if formatter is None:
46 | self.formatter = "%(asctime)s [%(levelname)s] - %(message)s"
47 | else:
48 | self.formatter = formatter
49 | if file_level is None:
50 | self.file_level = level
51 | else:
52 | self.file_level = file_level
53 |
54 | def create_handler(self, logger: logging.Logger) -> None:
55 | """Create logging handler for sys output and rotating files in parent_dir.
56 |
57 | Args:
58 | logger (logging.Logger): The Logger
59 | """
60 | log_handlers = {"StreamHandler": logging.StreamHandler(stream=sys.stdout)}
61 | fh_formatter = logging.Formatter(f"{self.formatter}")
62 | log_handlers["StreamHandler"].setLevel(self.level)
63 |
64 | if self.log_parent_dir is not None:
65 | log_parent_dir = Path(self.log_parent_dir)
66 | if self.use_timestamp:
67 | log_name = f'{datetime.datetime.now().strftime("%Y-%m-%dT%H%M%S")}_{self.comment}.log'
68 | else:
69 | log_name = f"{self.comment}.log"
70 | log_parent_dir.mkdir(parents=True, exist_ok=True)
71 |
72 | should_roll_over = os.path.isfile(log_parent_dir / log_name)
73 |
74 | log_handlers["FileHandler"] = logging.handlers.RotatingFileHandler(
75 | log_parent_dir / log_name, backupCount=5
76 | )
77 |
78 | if should_roll_over: # log already exists, roll over!
79 | log_handlers["FileHandler"].doRollover()
80 | log_handlers["FileHandler"].setLevel(self.file_level)
81 |
82 | for handler in log_handlers.values():
83 | handler.setFormatter(fh_formatter)
84 | logger.addHandler(handler)
85 |
86 | def create_logger(self) -> logging.Logger:
87 | """Create the logger
88 |
89 | Returns:
90 | Logger: The logger to be used.
91 | """
92 | logger = logging.getLogger("__main__")
93 | logger.addHandler(logging.NullHandler())
94 |
95 | logger.setLevel(
96 | "DEBUG"
97 | ) # set to debug because each handler level must be equal or lower
98 | self.create_handler(logger)
99 |
100 | return logger
101 |
--------------------------------------------------------------------------------
/utils/tools.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # Utility functions
3 |
4 |
5 |
6 | import importlib
7 | import logging
8 | import sys
9 |
10 | import types
11 | from datetime import timedelta
12 | from timeit import default_timer as timer
13 | from typing import Dict, List, Optional, Tuple, Union
14 |
15 | from utils.__init__ import logger
16 |
17 |
18 | # Helper timing functions
19 | def start_timer() -> float:
20 | """Returns the number of seconds passed since epoch. The epoch is the point where the time starts,
21 | and is platform dependent.
22 |
23 | Returns:
24 | float: The number of seconds passed since epoch
25 | """
26 | return timer()
27 |
28 |
29 | def end_timer(start_time: float, timed_event: str = "Time usage") -> None:
30 | """Prints the time passed from start_time.
31 |
32 |
33 | Args:
34 | start_time (float): The number of seconds passed since epoch when the timer started
35 | timed_event (str, optional): A string describing the activity being monitored. Defaults to "Time usage".
36 | """
37 | logger.info(f"{timed_event}: {timedelta(seconds=timer() - start_time)}")
38 |
39 |
40 | def module_exists(
41 | *names: Union[List[str], str],
42 | error: str = "ignore",
43 | warn_every_time: bool = False,
44 | __INSTALLED_OPTIONAL_MODULES: Dict[str, bool] = {},
45 | ) -> Optional[Union[Tuple[types.ModuleType, ...], types.ModuleType]]:
46 | """Try to import optional dependencies.
47 | Ref: https://stackoverflow.com/a/73838546/4900327
48 |
49 | Args:
50 | names (Union(List(str), str)): The module name(s) to import. Str or list of strings.
51 | error (str, optional): What to do when a dependency is not found:
52 | * raise : Raise an ImportError.
53 | * warn: print a warning.
54 | * ignore: If any module is not installed, return None, otherwise, return the module(s).
55 | Defaults to "ignore".
56 | warn_every_time (bool, optional): Whether to warn every time an import is tried. Only applies when error="warn".
57 | Setting this to True will result in multiple warnings if you try to import the same library multiple times.
58 | Defaults to False.
59 | Raises:
60 | ImportError: ImportError of Module
61 |
62 | Returns:
63 | Optional[ModuleType, Tuple[ModuleType...]]: The imported module(s), if all are found.
64 | None is returned if any module is not found and `error!="raise"`.
65 | """
66 | assert error in {"raise", "warn", "ignore"}
67 | if isinstance(names, (list, tuple, set)):
68 | names: List[str] = list(names)
69 | else:
70 | assert isinstance(names, str)
71 | names: List[str] = [names]
72 | modules = []
73 | for name in names:
74 | try:
75 | module = importlib.import_module(name)
76 | modules.append(module)
77 | __INSTALLED_OPTIONAL_MODULES[name] = True
78 | except ImportError:
79 | modules.append(None)
80 |
81 | def error_msg(missing: Union[str, List[str]]):
82 | if not isinstance(missing, (list, tuple)):
83 | missing = [missing]
84 | missing_str: str = " ".join([f'"{name}"' for name in missing])
85 | dep_str = "dependencies"
86 | if len(missing) == 1:
87 | dep_str = "dependency"
88 | msg = f"Missing optional {dep_str} {missing_str}. Use pip or conda to install."
89 | return msg
90 |
91 | missing_modules: List[str] = [
92 | name for name, module in zip(names, modules) if module is None
93 | ]
94 | if len(missing_modules) > 0:
95 | if error == "raise":
96 | raise ImportError(error_msg(missing_modules))
97 | if error == "warn":
98 | for name in missing_modules:
99 | # Ensures warning is printed only once
100 | if warn_every_time is True or name not in __INSTALLED_OPTIONAL_MODULES:
101 | logger.warning(f"Warning: {error_msg(name)}")
102 | __INSTALLED_OPTIONAL_MODULES[name] = False
103 | return None
104 | if len(modules) == 1:
105 | return modules[0]
106 | return tuple(modules)
107 |
108 |
109 | def close_logger(logger: logging.Logger) -> None:
110 | """Closing a logger savely
111 |
112 | Args:
113 | logger (logging.Logger): Logger to close
114 | """
115 | handlers = logger.handlers[:]
116 | for handler in handlers:
117 | logger.removeHandler(handler)
118 | handler.close()
119 |
120 | logger.handlers.clear()
121 | logging.shutdown()
122 |
123 |
124 | class AverageMeter(object):
125 | """Computes and stores the average and current value
126 |
127 | Original-Code: https://github.com/facebookresearch/simsiam
128 | """
129 |
130 | def __init__(self, name, fmt=":f"):
131 | self.name = name
132 | self.fmt = fmt
133 | self.reset()
134 |
135 | def reset(self):
136 | self.val = 0
137 | self.avg = 0
138 | self.sum = 0
139 | self.count = 0
140 |
141 | def update(self, val, n=1):
142 | self.val = val
143 | self.sum += val * n
144 | self.count += n
145 | self.avg = self.sum / self.count
146 |
147 | def __str__(self):
148 | fmtstr = "{name} {val" + self.fmt + "} ({avg" + self.fmt + "})"
149 | return fmtstr.format(**self.__dict__)
150 |
151 |
152 | def flatten_dict(d: dict, parent_key: str = "", sep: str = ".") -> dict:
153 | """Flatten a nested dictionary and insert the sep to seperate keys
154 |
155 | Args:
156 | d (dict): dict to flatten
157 | parent_key (str, optional): parent key name. Defaults to ''.
158 | sep (str, optional): Seperator. Defaults to '.'.
159 |
160 | Returns:
161 | dict: Flattened dict
162 | """
163 | items = []
164 | for k, v in d.items():
165 | new_key = parent_key + sep + k if parent_key else k
166 | if isinstance(v, dict):
167 | items.extend(flatten_dict(v, new_key, sep=sep).items())
168 | else:
169 | items.append((new_key, v))
170 | return dict(items)
171 |
172 |
173 | def unflatten_dict(d: dict, sep: str = ".") -> dict:
174 | """Unflatten a flattened dictionary (created a nested dictionary)
175 |
176 | Args:
177 | d (dict): Dict to be nested
178 | sep (str, optional): Seperator of flattened keys. Defaults to '.'.
179 |
180 | Returns:
181 | dict: Nested dict
182 | """
183 | output_dict = {}
184 | for key, value in d.items():
185 | keys = key.split(sep)
186 | d = output_dict
187 | for k in keys[:-1]:
188 | d = d.setdefault(k, {})
189 | d[keys[-1]] = value
190 |
191 | return output_dict
192 |
193 |
194 | def remove_parameter_tag(d: dict, sep: str = ".") -> dict:
195 | """Remove all paramter tags from dictionary
196 |
197 | Args:
198 | d (dict): Dict must be flattened with defined seperator
199 | sep (str, optional): Seperator used during flattening. Defaults to ".".
200 |
201 | Returns:
202 | dict: Dict with parameter tag removed
203 | """
204 | param_dict = {}
205 | for k, _ in d.items():
206 | unflattened_keys = k.split(sep)
207 | new_keys = []
208 | max_num_insert = len(unflattened_keys) - 1
209 | for i, k in enumerate(unflattened_keys):
210 | if i < max_num_insert and k != "parameters":
211 | new_keys.append(k)
212 | joined_key = sep.join(new_keys)
213 | param_dict[joined_key] = {}
214 | print(param_dict)
215 | for k, v in d.items():
216 | unflattened_keys = k.split(sep)
217 | new_keys = []
218 | max_num_insert = len(unflattened_keys) - 1
219 | for i, k in enumerate(unflattened_keys):
220 | if i < max_num_insert and k != "parameters":
221 | new_keys.append(k)
222 | joined_key = sep.join(new_keys)
223 | param_dict[joined_key][unflattened_keys[-1]] = v
224 |
225 | return param_dict
226 |
227 | def get_size_of_dict(d: dict) -> int:
228 | size = sys.getsizeof(d)
229 | for key, value in d.items():
230 | size += sys.getsizeof(key)
231 | size += sys.getsizeof(value)
232 | return size
233 |
--------------------------------------------------------------------------------