├── LICENSE ├── README.md ├── deepmed ├── __init__.py ├── _deploy.py ├── _experiment.py ├── _load.py ├── _train.py ├── evaluators │ ├── __init__.py │ ├── adapters.py │ ├── aggregate_stats.py │ ├── gradcam.py │ ├── heatmap.py │ ├── metrics.py │ ├── roc.py │ ├── top_tiles.py │ └── types.py ├── experiment_imports.py ├── get │ ├── __init__.py │ ├── _crossval.py │ ├── _extract_features.py │ ├── _multi_target.py │ ├── _parameterize.py │ ├── _simple.py │ └── _subgroup.py ├── mil.py ├── multi_input.py ├── on_features.py ├── types.py └── utils.py ├── docs ├── Makefile ├── make.bat └── source │ ├── conf.py │ ├── deepmed.get.rst │ ├── deepmed.rst │ ├── index.rst │ ├── modules.rst │ ├── users_guide.rst │ └── users_guide │ ├── crossval.rst │ ├── multi_target_training.rst │ └── simple_training.rst ├── examples ├── continuous.py ├── crossval.py ├── extract-with-custom-model.py ├── extract.py ├── mil.py ├── multi_target_deploy.py ├── multi_target_train.py ├── parameterize.py └── subgroup.py ├── pyproject.toml ├── setup.cfg └── test ├── __init__.py ├── test_examples.py └── test_simple.py /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright 2021 Kather Lab 2 | 3 | Permission is hereby granted, free of charge, to any person obtaining a copy of 4 | this software and associated documentation files (the "Software"), to deal in 5 | the Software without restriction, including without limitation the rights to 6 | use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies 7 | of the Software, and to permit persons to whom the Software is furnished to do 8 | so, subject to the following conditions: 9 | 10 | The above copyright notice and this permission notice shall be included in all 11 | copies or substantial portions of the Software. 12 | 13 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 14 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 15 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 16 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 17 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 18 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 19 | SOFTWARE. 20 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Welcome to Direct End-to-End Pipeline for Medical Imaging 2 | 3 | ## What is this? 4 | 5 | This is an open source platform for end-to-end artificial intelligence (AI) in 6 | computational pathology. It will enable you to use AI for prediction of any 7 | "label" directly from digitized pathology slides. Common use cases which can be 8 | reproduced by this pipeline are: 9 | 10 | - prediction of microsatellite instability in colorectal cancer (Kather et 11 | al., Nat Med 2019) 12 | - prediction of mutations in lung cancer (Coudray et al., Nat Med 2018) 13 | - prediction of subtypes of renal cell carcinoma (Lu et al., Nat Biomed Eng 14 | 2021) 15 | - other possible use cases are summarized by Echle et al., Br J Cancer 2021: 16 | https://www.nature.com/articles/s41416-020-01122-x 17 | 18 | This pipeline is modular, which means that new methods for pre-/postprocessing 19 | or new AI methods can be easily integrated. For an extensive protocol including many 20 | example scripts, please see https://www.biorxiv.org/content/10.1101/2021.12.19.473344v1 21 | 22 | 23 | ## Installation 24 | 25 | Deepmed has been tested on both Windows Server 2019 and Ubuntu 20.04. It 26 | requires a CUDA-enabled NVIDIA GPU and a Python installation of at least version 27 | 3.8. In most cases, deepmed can then be installed by typing: 28 | 29 | ```bash 30 | pip install git+https://github.com/KatherLab/deepmed 31 | ``` 32 | 33 | In some cases it may be necessary to install pytorch manually in order for it to 34 | recognize the system's GPU. To do so, please refer to the [pytorch installation 35 | guide]. 36 | 37 | [pytorch installation guide]: https://pytorch.org/get-started/locally/ 38 | 39 | 40 | ## Documentation 41 | 42 | To build the project's documentation, we need to install a few more 43 | dependencies: 44 | 45 | ```bash 46 | pip install sphinx sphinx_rtd_theme 47 | ``` 48 | 49 | After that, we can build the documentation by invoking the `Makefile` or 50 | `make.bat` in the docs dictory, i.e.: 51 | 52 | ```bash 53 | make -C path/to/deepmed/docs html 54 | ``` 55 | 56 | on Linux systems or 57 | 58 | ```powershell 59 | path\to\deepmed\docs\make.bat html 60 | ``` 61 | 62 | on Windows. Afterwards, the documentation can be found in 63 | `docs/build/html/index.html`. 64 | 65 | 66 | ## Tests 67 | 68 | Deepmed comes with a set of integration tests. These can be invoked by running 69 | 70 | ```bash 71 | cd path/to/deepmed && python -m unittest 72 | ``` 73 | -------------------------------------------------------------------------------- /deepmed/__init__.py: -------------------------------------------------------------------------------- 1 | from ._experiment import * 2 | from ._train import * 3 | from ._load import * 4 | from ._deploy import * 5 | 6 | __author__ = 'Marko van Treeck' 7 | __copyright__ = 'Copyright 2021, Kather Lab' 8 | __credits__ = [ 9 | 'Amelie Echle', 10 | 'Céline Nicole Heinz', 11 | 'Chiara Loeffler', 12 | 'Didem Cifci', 13 | 'Jakob Nikolas Kather', 14 | 'Marko van Treeck', 15 | 'Oliver Lester Saldanha', 16 | 'Tobias Paul Seraphin', 17 | 'Hannah Sophie Muti', 18 | ] 19 | __license__ = 'MIT' 20 | __version__ = '0.10.0rc0' 21 | __maintainer__ = 'Marko van Treeck' 22 | __email__ = 'markovantreeck@gmail.com' -------------------------------------------------------------------------------- /deepmed/_deploy.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import re 3 | from typing import Optional 4 | from fastai.data.transforms import CategoryMap 5 | from fastai.learner import Learner 6 | 7 | import pandas as pd 8 | from typing import Iterable 9 | 10 | from .types import GPUTask 11 | from .utils import exists_and_has_size, log_defaults, is_continuous, factory 12 | 13 | __all__ = ['Deploy'] 14 | 15 | 16 | @log_defaults 17 | def _deploy(learn: Learner, task: GPUTask) -> Optional[pd.DataFrame]: 18 | logger = logging.getLogger(str(task.path)) 19 | 20 | if task.test_df is None: 21 | logger.warning('No testing set found! Skipping deployment...') 22 | return None 23 | elif exists_and_has_size(preds_path := task.path/'predictions.csv.zip'): 24 | logger.warning(f'{preds_path} already exists, skipping deployment...') 25 | return pd.read_csv(preds_path, low_memory=False) 26 | 27 | test_df, target_label = task.test_df, task.target_label 28 | 29 | if hasattr(learn.dls, 'vocab'): # not possible for continuous targets 30 | vocab = learn.dls.vocab 31 | if not isinstance(vocab, CategoryMap): 32 | vocab = vocab[-1] 33 | 34 | test_df = _discretize_if_necessary( 35 | test_df=test_df, target_label=target_label, vocab=vocab) 36 | 37 | # restrict testing classes to those known by the model 38 | if not (known_idx := test_df[target_label].isin(vocab)).all(): 39 | unknown_classes = test_df[target_label][~known_idx].unique() 40 | logger.warning( 41 | f'classes unknown to model in test data: {unknown_classes}! Dropping them...') 42 | test_df = test_df[known_idx] 43 | else: 44 | vocab = None 45 | 46 | test_dl = learn.dls.test_dl(test_df) 47 | # inner needed so we don't jump GPUs 48 | # FIXME What does `inner` actually _do_? Is this harmful? 49 | scores, _, class_preds = learn.get_preds( 50 | dl=test_dl, inner=True, with_decoded=True) 51 | 52 | test_df = test_df.copy() 53 | if vocab is not None: 54 | # class-wise scores 55 | for class_, i in vocab.o2i.items(): 56 | test_df[f'{target_label}_{class_}'] = scores[:, i].numpy() 57 | 58 | # class prediction (i.e. the class w/ the highest score for each tile) 59 | test_df[f'{target_label}_pred'] = vocab.map_ids(class_preds) 60 | else: 61 | test_df[f'{target_label}_score'] = scores[:, 0].numpy() 62 | 63 | test_df.to_csv(preds_path, index=False, compression='zip') 64 | 65 | return test_df 66 | 67 | 68 | def _discretize_if_necessary(test_df: pd.DataFrame, target_label: str, vocab: Iterable[str]) -> pd.DataFrame: 69 | # check if this target was discretized for training and discretize testing set if necessary 70 | interval = re.compile( 71 | r'^\[([+-]?\d+\.?\d*(?:e[+-]?\d+|)|-inf),([+-]?\d+\.?\d*(?:e[+-]?\d+|)|inf)\)$') 72 | if is_continuous(test_df[target_label]) and \ 73 | all(isinstance(class_, str) and interval.match(class_) is not None for class_ in vocab): 74 | 75 | # extract thresholds from vocab 76 | threshs = [*(interval.match(class_).groups()[0] # type: ignore 77 | for class_ in vocab), 'inf'] 78 | threshs = sorted(threshs, key=float) 79 | 80 | def interval_label(x): 81 | """Discretizes data into ``[lower,upper)`` classes.""" 82 | for l, h in zip(threshs, threshs[1:]): 83 | # we only transform the values here, because we want h, l to be 84 | # *exactly* as in the training set 85 | if float(l) <= x and x < float(h): 86 | return f'[{l},{h})' 87 | raise RuntimeError('unreachable!') 88 | 89 | test_df[target_label] = test_df[target_label].map(interval_label) 90 | 91 | return test_df 92 | 93 | 94 | Deploy = factory(_deploy) 95 | -------------------------------------------------------------------------------- /deepmed/_experiment.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | import logging 4 | import threading 5 | from typing import Mapping, Union, Optional 6 | from pathlib import Path 7 | from concurrent import futures 8 | from fastcore.parallel import ThreadPoolExecutor 9 | 10 | from .types import * 11 | 12 | 13 | __all__ = ['do_experiment'] 14 | 15 | 16 | logger = logging.getLogger(__name__) 17 | 18 | 19 | def do_experiment( 20 | project_dir: PathLike, 21 | get: TaskGetter, 22 | num_concurrent_tasks: Optional[int] = 0, 23 | devices: Mapping[Union[str, int], int] = {0: 4}, 24 | logfile: Optional[str] = 'logfile', 25 | keep_going: bool = False) -> None: 26 | """Runs an experiement. 27 | 28 | Args: 29 | project_dir: The directory to save project data in. 30 | get: A function which generates tasks. 31 | train: A function training a model for a specific task. 32 | deploy: A function deploying a trained model. 33 | num_concurrent_tasks: The maximum amount of tasks to do at the same 34 | time. If None, the number of tasks will grow with the number of 35 | available devices. If 0, all jobs will be task in the main process 36 | (useful for debugging). 37 | devices: The devices to use for training and the maximum number of 38 | models to be trained at once for each device. 39 | keep_going: Whether to stop all runs on an exception. 40 | """ 41 | project_dir = Path(project_dir) 42 | project_dir.mkdir(exist_ok=True, parents=True) 43 | 44 | # add logfile handler 45 | if logfile is not None: 46 | file_handler = logging.FileHandler(f'{project_dir/"logfile"}') 47 | file_handler.setLevel(logging.DEBUG) 48 | formatter = logging.Formatter( 49 | '%(asctime)s: %(levelname)s: %(name)s: %(message)s') 50 | file_handler.setFormatter(formatter) 51 | logging.getLogger().addHandler(file_handler) 52 | 53 | logger.info('Getting tasks') 54 | 55 | # semaphores which tell us which GPUs still have resources available 56 | capacities = { 57 | device: threading.Semaphore(capacity) # type: ignore 58 | for device, capacity in devices.items()} 59 | tasks = get(project_dir=project_dir, 60 | capacities=capacities) 61 | 62 | try: 63 | if num_concurrent_tasks == 0: 64 | for task in tasks: 65 | task.run() 66 | else: 67 | with ThreadPoolExecutor(num_concurrent_tasks) as e: 68 | running = [e.submit(Task.run, task) for task in tasks] 69 | for future in futures.as_completed(running): 70 | future.result() # consume results to trigger exceptions 71 | 72 | except Exception as e: 73 | if not keep_going: 74 | raise e -------------------------------------------------------------------------------- /deepmed/_load.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | from fastai.learner import Learner 3 | from fastai.vision.learner import load_learner 4 | 5 | from .types import GPUTask 6 | from .utils import factory 7 | 8 | __all__ = ['Load'] 9 | 10 | 11 | def _load( 12 | task: GPUTask, /, 13 | project_dir: Path, 14 | training_project_dir: Path) -> Learner: 15 | model_path = training_project_dir/task.path.relative_to(project_dir)/'export.pkl' 16 | return load_learner(model_path, cpu=False) 17 | 18 | Load = factory(_load) 19 | -------------------------------------------------------------------------------- /deepmed/_train.py: -------------------------------------------------------------------------------- 1 | import shutil 2 | import os 3 | import logging 4 | import random 5 | from typing import Callable, Iterable, Optional, List 6 | from pathlib import Path 7 | from functools import lru_cache 8 | from dataclasses import dataclass, field 9 | from fastai.callback.progress import CSVLogger 10 | from fastai.callback.tracker import EarlyStoppingCallback, SaveModelCallback, TrackerCallback 11 | from fastai.data.block import CategoryBlock, DataBlock, RegressionBlock, TransformBlock 12 | from fastai.data.transforms import ColReader, ColSplitter, IntToFloatTensor 13 | from fastai.learner import Learner, load_learner 14 | from fastai.losses import CrossEntropyLossFlat 15 | from fastai.vision.augment import aug_transforms 16 | from fastai.vision.core import PILImage 17 | from fastai.vision.learner import cnn_learner 18 | from torchvision.models import resnet18 19 | 20 | import torch 21 | import pandas as pd 22 | from torch import nn 23 | 24 | from .types import GPUTask 25 | from .utils import is_continuous 26 | 27 | __all__ = ['Train'] 28 | 29 | 30 | @lru_cache(10000) 31 | def get_tile_list(slide_dir: Path) -> List[Path]: 32 | return list(slide_dir.glob('*.jpg')) 33 | 34 | 35 | def get_tile(tile_path) -> PILImage: 36 | """Gets a tile. 37 | 38 | If tile_path points to a file, the file is loaded directly. If it's a 39 | directory, a random file will be sampled.""" 40 | # Don't specify arg types! Otherwise fastai will do some whack dispatching 41 | # and this function will not be called 42 | tile_path = Path(tile_path) 43 | if tile_path.is_dir(): 44 | tile_path = random.choice(get_tile_list(tile_path)) 45 | 46 | return PILImage.create(tile_path) 47 | 48 | 49 | TileBlock = TransformBlock(type_tfms=get_tile, batch_tfms=IntToFloatTensor) 50 | 51 | 52 | @dataclass 53 | class Train: 54 | """Trains a single model. 55 | 56 | Args: 57 | batch_size: The number of training samples used through the network during one forward and backward pass. 58 | task: The task to train a model for. 59 | arch: The architecture of the model to train. 60 | max_epochs: The absolute maximum number of epochs to train. 61 | lr: The initial learning rate. 62 | num_workers: The number of workers to use in the data loaders. Set to 63 | 0 on windows! 64 | tfms: Transforms to apply to the data. 65 | metrics: The metrics to calculate on the validation set each epoch. 66 | patience: The number of epochs without improvement before stopping the 67 | training. 68 | monitor: The metric to monitor for early stopping. 69 | 70 | Returns: 71 | The trained model. 72 | 73 | If the training is interrupted, it will be continued from the last model 74 | checkpoint. 75 | """ 76 | arch: Callable[[bool], nn.Module] = resnet18 77 | batch_size: int = 64 78 | max_epochs: int = 32 79 | lr: float = 2e-3 80 | num_workers: int = 16 81 | tfms: Optional[Callable] = field( 82 | default_factory=lambda: aug_transforms( 83 | flip_vert=True, max_rotate=360, max_zoom=1, max_warp=0, size=224)) 84 | metrics: Iterable[Callable] = field(default_factory=list) 85 | patience: int = 3 86 | monitor: str = 'valid_loss' 87 | 88 | def __call__(self, task: GPUTask) -> Optional[Learner]: 89 | logger = logging.getLogger(str(task.path)) 90 | 91 | if (model_path := task.path/'export.pkl').exists(): 92 | logger.warning(f'{model_path} already exists! using old model...') 93 | return load_learner(model_path) 94 | 95 | target_label, train_df, result_dir = task.target_label, task.train_df, task.path 96 | 97 | if train_df is None: 98 | logger.warning('Cannot train: no training set given!') 99 | return None 100 | 101 | y_block = RegressionBlock if is_continuous(train_df[target_label]) else CategoryBlock 102 | 103 | dblock = DataBlock(blocks=(TileBlock, y_block), 104 | get_x=ColReader('tile_path'), 105 | get_y=ColReader(target_label), 106 | splitter=ColSplitter('is_valid'), 107 | batch_tfms=self.tfms) 108 | dls = dblock.dataloaders( 109 | train_df, bs=self.batch_size, num_workers=self.num_workers) 110 | 111 | target_col_idx = train_df[~train_df.is_valid].columns.get_loc( 112 | target_label) 113 | 114 | logger.debug( 115 | f'Class counts in training set: {train_df[~train_df.is_valid].iloc[:, target_col_idx].value_counts()}') 116 | logger.debug( 117 | f'Class counts in validation set: {train_df[train_df.is_valid].iloc[:, target_col_idx].value_counts()}') 118 | 119 | if is_continuous(train_df[target_label]): 120 | loss_func = None 121 | else: 122 | counts = train_df[~train_df.is_valid].iloc[:,target_col_idx].value_counts() 123 | weight = counts.sum() / counts 124 | weight /= weight.sum() 125 | weight = torch.tensor(list(map(weight.get, dls.vocab)), dtype=torch.float32) # reorder according to vocab 126 | loss_func = CrossEntropyLossFlat(weight=weight.cuda()) 127 | logger.debug(f'{dls.vocab = }, {weight = }') 128 | 129 | learn = cnn_learner( 130 | dls, self.arch, 131 | path=result_dir, 132 | loss_func=loss_func, 133 | metrics=self.metrics) 134 | 135 | cbs = [ 136 | SaveModelCallback( 137 | monitor=self.monitor, fname=f'best_{self.monitor}', reset_on_fit=False), 138 | SaveModelCallback(every_epoch=True, with_opt=True, 139 | reset_on_fit=False), 140 | EarlyStoppingCallback( 141 | monitor=self.monitor, min_delta=0.001, patience=self.patience, reset_on_fit=False), 142 | CSVLogger(append=True)] 143 | 144 | if (result_dir/'models'/f'best_{self.monitor}.pth').exists(): 145 | _fit_from_checkpoint( 146 | learn=learn, result_dir=result_dir, lr=self.lr/2, max_epochs=self.max_epochs, cbs=cbs, 147 | monitor=self.monitor, logger=logger) 148 | else: 149 | learn.fine_tune(epochs=self.max_epochs, base_lr=self.lr, cbs=cbs) 150 | 151 | learn.export() 152 | shutil.rmtree(result_dir/'models') 153 | return learn 154 | 155 | 156 | def _fit_from_checkpoint( 157 | learn: Learner, result_dir: Path, lr: float, max_epochs: int, cbs: Iterable[Callable], 158 | monitor: str, logger) -> None: 159 | logger.info('Continuing from checkpoint...') 160 | 161 | # get best performance so far 162 | history_df = pd.read_csv(result_dir/'history.csv') 163 | scores = pd.to_numeric(history_df[monitor], errors='coerce') 164 | high_score = scores.min() if 'loss' in monitor or 'error' in monitor else scores.max() 165 | logger.info(f'Best {monitor} up to checkpoint: {high_score}') 166 | 167 | # update tracker callback's high scores 168 | for cb in cbs: 169 | if isinstance(cb, TrackerCallback): 170 | cb.best = high_score 171 | 172 | # load newest model 173 | name = max((result_dir/'models').glob('model_*.pth'), 174 | key=os.path.getctime).stem 175 | learn.load(name, with_opt=True, strict=True) 176 | 177 | remaining_epochs = max_epochs - int(name.split('_')[1]) 178 | logger.info(f'{remaining_epochs = }') 179 | learn.unfreeze() 180 | learn.fit_one_cycle(remaining_epochs, slice( 181 | lr/100, lr), pct_start=.3, div=5., cbs=cbs) 182 | -------------------------------------------------------------------------------- /deepmed/evaluators/__init__.py: -------------------------------------------------------------------------------- 1 | from .aggregate_stats import * 2 | from .heatmap import * 3 | from .roc import * 4 | from .adapters import * 5 | from .top_tiles import * 6 | from .metrics import * 7 | from .gradcam import * 8 | 9 | __all__ = ['Grouped', 'SubGrouped', 'AggregateStats', 'OnDiscretized', 'Roc', 'GroupMode', 10 | 'Heatmap', 'TopTiles', 'F1', 'auroc', 'count', 'p_value', 'ConfusionMatrix', 'r2', 11 | 'gradcam'] 12 | -------------------------------------------------------------------------------- /deepmed/evaluators/adapters.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from typing import Optional, Callable 3 | from pathlib import Path 4 | from enum import Enum, auto 5 | 6 | import pandas as pd 7 | 8 | import pandas as pd 9 | from pathlib import Path 10 | from typing import Optional 11 | 12 | from deepmed.utils import is_continuous 13 | 14 | from .types import Evaluator 15 | 16 | 17 | class GroupMode(Enum): 18 | """Describes how to calculate grouped predictions (see Grouped).""" 19 | prediction_rate = auto() 20 | """The group class scores are set to the ratio of the elements' predictions.""" 21 | mean = auto() 22 | """The group class scores are set to the mean of the elements' scores.""" 23 | 24 | 25 | @dataclass 26 | class Grouped: 27 | """Calculates a metric with the data grouped on an attribute. 28 | 29 | It's not always meaningful to calculate metrics on the sample level. This 30 | function first accumulates the predictions according to another property of 31 | the sample (as specified in the clinical table), grouping samples with the 32 | same value together. Furthermore, the result dir given to the result dir 33 | will be extended by a subdirectory named after the grouped-by property. 34 | """ 35 | evaluator: Evaluator 36 | """Metric to evaluate on the grouped predictions.""" 37 | mode: Optional[GroupMode] = None 38 | """Mode to group predictions.""" 39 | by: str = 'PATIENT' 40 | """Label to group the predictions by.""" 41 | 42 | def __call__(self, target_label: str, preds_df: pd.DataFrame, result_dir: Path) \ 43 | -> Optional[pd.DataFrame]: 44 | group_dir = result_dir/self.by 45 | group_dir.mkdir(exist_ok=True) 46 | grouped_df = _group_df(preds_df, target_label, self.by, self.mode) 47 | if (df := self.evaluator(target_label, grouped_df, group_dir)) is not None: # type: ignore 48 | columns = pd.MultiIndex.from_product([df.columns, [self.by]]) 49 | return pd.DataFrame(df.values, index=df.index, columns=columns) 50 | 51 | return None 52 | 53 | 54 | def _group_df(preds_df: pd.DataFrame, target_label: str, by: str, mode: Optional[GroupMode]) -> pd.DataFrame: 55 | grouped_df = preds_df.groupby(by).first() 56 | 57 | if mode is None: 58 | mode = (GroupMode.mean if is_continuous(preds_df[target_label]) 59 | else GroupMode.prediction_rate) 60 | 61 | for class_ in preds_df[target_label].unique(): 62 | if mode == GroupMode.prediction_rate: 63 | grouped_df[f'{target_label}_{class_}'] = ( 64 | preds_df.groupby(by)[f'{target_label}_pred'] 65 | .agg(lambda x: sum(x == class_) / len(x))) 66 | elif mode == GroupMode.mean: 67 | if is_continuous(preds_df[target_label]): 68 | grouped_df[f'{target_label}_score'] = \ 69 | preds_df.groupby(by)[f'{target_label}_score'].mean() 70 | else: 71 | raise NotImplementedError() #TODO 72 | else: 73 | raise ValueError(f'unexpected {mode=}') 74 | 75 | return grouped_df 76 | 77 | 78 | @dataclass 79 | class SubGrouped: 80 | """Calculates a metric for different subgroups.""" 81 | evaluator: Evaluator 82 | by: str 83 | """The property to group by. 84 | 85 | The metric will be calculated seperately for each distinct label of this 86 | property. 87 | """ 88 | 89 | def __call__(self, target_label: str, preds_df: pd.DataFrame, result_dir: Path) \ 90 | -> Optional[pd.DataFrame]: 91 | dfs = [] 92 | for group, group_df in preds_df.groupby(self.by): 93 | group_dir = result_dir/group 94 | group_dir.mkdir(parents=True, exist_ok=True) 95 | if (df := self.evaluator(target_label, group_df, group_dir)) is not None: # type: ignore 96 | columns = pd.MultiIndex.from_product([df.columns, [group]]) 97 | dfs.append(pd.DataFrame( 98 | df.values, index=df.index, columns=columns)) 99 | 100 | if dfs: 101 | return pd.concat(dfs) 102 | 103 | return None 104 | 105 | @dataclass 106 | class OnDiscretized: 107 | """Discretizes continuous values before passing it to an evaluator.""" 108 | #TODO implement for arbitrary bin number 109 | evaluator: Evaluator 110 | 111 | def __call__(self, target_label: str, preds_df: pd.DataFrame, result_dir: Path) -> Optional[pd.DataFrame]: 112 | median = preds_df[target_label].median() 113 | discretized_df = preds_df.copy() 114 | median = discretized_df[target_label].median() 115 | 116 | discretized_df[target_label] = preds_df[target_label] > median 117 | discretized_df[f'{target_label}_pred'] = preds_df[f'{target_label}_score'] > median 118 | 119 | centered = discretized_df[f'{target_label}_score'] - median 120 | 121 | scaled_positives = (centered / centered.max() / 2 + .5) 122 | scaled_negatives = (-centered / centered.min() / 2 + .5) 123 | pos_scores = scaled_positives.where(centered > 0, scaled_negatives) 124 | 125 | discretized_df[f'{target_label}_True'] = pos_scores 126 | discretized_df[f'{target_label}_False'] = 1 - pos_scores 127 | 128 | return self.evaluator(target_label, discretized_df, result_dir) 129 | -------------------------------------------------------------------------------- /deepmed/evaluators/aggregate_stats.py: -------------------------------------------------------------------------------- 1 | from typing import Iterable, Tuple, Optional, Union, Sequence 2 | import logging 3 | import re 4 | from pathlib import Path 5 | 6 | import pandas as pd 7 | 8 | import pandas as pd 9 | from pathlib import Path 10 | import scipy.stats as st 11 | 12 | from ..utils import factory 13 | 14 | 15 | def _aggregate_stats( 16 | _target_label, _preds_df, path: Path, /, label: Optional[str] = None, 17 | over: Optional[Iterable[Union[str, int]]] = None, conf: float = .95) -> pd.DataFrame: 18 | """Accumulates stats from subdirectories. 19 | 20 | Args: 21 | over: Index columns to aggregate over. 22 | conf: The confindence interval calculated during aggregation. 23 | 24 | Returns: 25 | An aggregation of the subdirectories' statistics. 26 | 27 | By default, this function simply concatenates the contents of all the 28 | ``stats.pkl`` files in ``path``'s immediate subdirectories. Each of 29 | the subdirectories' names will be added as to the index at its top level. 30 | 31 | The ``over`` argument can be used to aggregate over certain index columns of 32 | the resulting concatenated dataframe; let's assume the concatenated 33 | dataframe looks like this: 34 | 35 | #TODO update!!!!! (( 36 | ====== ====== ======= ======= === 37 | Metric auroc f1 38 | ------ ------ ------- ------- --- 39 | Group Patient nan 40 | ------ ------ ------- ------- --- 41 | target fold class 42 | ====== ====== ======= ======= === 43 | isMSIH fold_0 MSIH 0.7 0.4 44 | isMSIH fold_0 nonMSIH 0.6 0.1 45 | isMSIH fold_1 MSIH 0.8 0.2 46 | isMSIH fold_1 nonMSIH 0.2 0.4 47 | ====== ====== ======= ======= === 48 | 49 | Then ``aggregate_stats(over=['fold'])`` would calculate the means and 50 | confidence intervals for all (target, class) pairs, using the different 51 | folds as samples. Alternatively, numerical indices can be given (c.f. 52 | :func:`pandas.DataFrame.groupby` :obj:`level`). 53 | """ 54 | # collect all parent stats dfs 55 | dfs = [] 56 | stats_df_paths = list(path.glob('*/stats.pkl')) 57 | for df_path in stats_df_paths: 58 | dfs.append(pd.read_pickle(df_path)) 59 | 60 | assert dfs, 'could not find any stats.pkls to aggregate! ' \ 61 | 'Did you accidentally use AggregateStats on the bottommost evaluator level?' 62 | assert all(df.index.names == dfs[0].index.names for df in dfs[1:]), \ 63 | 'index labels differ between stats.pkls to aggregate over!' 64 | stats_df = pd.concat( 65 | dfs, 66 | keys=[path.parent.name for path in stats_df_paths], 67 | names=[label] + dfs[0].index.names) 68 | 69 | if over is not None: 70 | level = _get_groupby_levels(stats_df, over) 71 | # sum all labels which have 'count' in their topmost column level; calculate means, 72 | # confidence intervals for the rest 73 | count_labels = [col for col in stats_df.columns 74 | if 'count' in (col[0] if isinstance(col, tuple) else col)] 75 | extreme_labels = [col for col in stats_df.columns 76 | if (col[0] if isinstance(col, tuple) else col) == 'p value'] #TODO make configurable 77 | metric_labels = list(set(stats_df.columns) 78 | - set(count_labels)) 79 | 80 | # calculate count sums 81 | try: 82 | grouped = stats_df.groupby(level=level) 83 | except IndexError as e: 84 | logging.getLogger(str(path)).critical( 85 | 'Invalid group levels in aggregate_stats! ' 86 | 'Did you use it in the right evaluator group?' 87 | ) 88 | raise e 89 | counts = grouped[count_labels].sum(min_count=1) 90 | 91 | maxs = grouped[extreme_labels].max() 92 | mins = grouped[extreme_labels].min() 93 | 94 | # calculate means, confidence interval bounds 95 | grouped = stats_df[metric_labels].groupby(level=level) 96 | means, ns, sems = grouped.mean(), grouped.count(), grouped.sem() 97 | l, h = st.t.interval(alpha=conf, df=ns-1, loc=means, scale=sems) 98 | confs = pd.DataFrame( 99 | (h - l) / 2, index=means.index, columns=means.columns) 100 | 101 | # for some reason concat doesn't like it if one of the dfs is empty and we supply a key 102 | # nonetheless... so only generate the headers if needed 103 | keys = (([] if means.empty else ['mean', '95% conf']) 104 | + ([] if counts.empty else ['total']) 105 | + ([] if maxs.empty else ['max']) 106 | + ([] if mins.empty else ['min'])) 107 | stats_df = pd.concat( 108 | [means, confs, counts, maxs, mins], keys=keys, axis=1) 109 | 110 | # make mean, conf, total the lowest of the column levels 111 | stats_df = pd.DataFrame( 112 | stats_df.values, index=stats_df.index, 113 | columns=stats_df.columns.reorder_levels([*range(1, stats_df.columns.nlevels), 0])) 114 | 115 | # sort by every but the last (mean, 95%) columns so we get a nice hierarchical order 116 | stats_df = stats_df[sorted(stats_df.columns, 117 | key=lambda x: x[:stats_df.columns.nlevels-1])] 118 | 119 | return stats_df 120 | 121 | 122 | def _get_groupby_levels(df: pd.DataFrame, over: Iterable[Union[str, int]]) -> Sequence[int]: 123 | """Returns numeric levels to give to pd.DataFrame.group. 124 | 125 | If we have the index ``df.index.names=['target', 'subgroup', 'fold', 126 | 'class']`` and ``over=[1, 'fold']``, then this function will return [0, 3], 127 | i.e. the indices which are *not* index 1 and the index with the name 128 | ``fold``. 129 | """ 130 | assert not isinstance(over, str), f'`over` has to be a list of labels. Try `over=[{over}]`.' 131 | 132 | # check if any of the labels appears zero / more than one time 133 | assert (label := next((label 134 | for label in over 135 | if isinstance(label, str) and df.index.names.count(label) != 1), 136 | None)) is None, \ 137 | f'{label!r} appears {df.index.names.count(label)} times in stats.pkl! ' \ 138 | 'Use index numbers to disambiguate!' 139 | 140 | return [ 141 | i for i, name in enumerate(df.index.names) 142 | if name not in over and i not in over] 143 | 144 | 145 | AggregateStats = factory(_aggregate_stats) 146 | -------------------------------------------------------------------------------- /deepmed/evaluators/gradcam.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | from pathlib import Path 3 | from fastai.learner import load_learner 4 | from fastai.torch_core import TensorImage 5 | from fastcore.basics import first 6 | 7 | import pandas as pd 8 | from PIL import Image 9 | import matplotlib.pyplot as plt 10 | 11 | from deepmed.multi_input import MultiInputModel 12 | from .top_tiles import _generate_tiles_fn 13 | 14 | 15 | class Hook(): 16 | def __init__(self, m): 17 | self.hook = m.register_forward_hook(self.hook_func) 18 | 19 | def hook_func(self, m, i, o): self.stored = o.detach().clone() 20 | def __enter__(self, *args): return self 21 | def __exit__(self, *args): self.hook.remove() 22 | 23 | 24 | class HookBwd(): 25 | def __init__(self, m): 26 | self.hook = m.register_backward_hook(self.hook_func) 27 | 28 | def hook_func(self, m, gi, go): 29 | self.stored = go[0].detach().clone() 30 | 31 | def __enter__(self, *args): return self 32 | def __exit__(self, *args): self.hook.remove() 33 | 34 | 35 | def gradcam( 36 | target_label: str, preds_df: pd.DataFrame, result_dir: Path, 37 | n_patients: int = 4, n_tiles: int = 4, patient_label: str = 'PATIENT', 38 | best_patients: bool = True, best_tiles: Optional[bool] = None, 39 | save_images: bool = False) -> None: 40 | """Generates a grid of GRAD CAM images for the best scoring tiles for each class. 41 | 42 | The function outputs a `n_patients` × `n_tiles` grid of tiles, where each 43 | row contains the `n_tiles` highest scoring tiles for one of the `n_patients` 44 | best-classified patients. 45 | 46 | Args: 47 | best_patients: Wether to select the best or worst n patients. 48 | best_tiles: Whether to select the highest or lowest scoring tiles. If 49 | set to ``None``, then the same as ``best_patients``. 50 | save_images: Also save the tiles seperately. 51 | """ 52 | # set `best_tiles` to `best_patients` if undefined 53 | best_tiles = best_tiles if best_tiles is not None else best_patients 54 | 55 | for class_ in preds_df[f'{target_label}_pred'].unique(): 56 | outdir = result_dir/_generate_tiles_fn( 57 | target_label, class_, best_patients, best_tiles, n_patients, n_tiles) 58 | outfile = Path(str(outdir) + '_GradCAM.svg') 59 | 60 | if outfile.exists() and (outdir.exists() or not save_images): 61 | continue 62 | if save_images: 63 | outdir.mkdir(parents=True, exist_ok=True) 64 | 65 | tile_dict = dict() 66 | 67 | class_instance_df = preds_df[preds_df[target_label] == class_] 68 | patient_scores = \ 69 | class_instance_df.groupby(patient_label)[f'{target_label}_pred'].agg( 70 | lambda x: sum(x == class_) / len(x)) 71 | 72 | patients = (patient_scores.nlargest(n_patients) if best_patients 73 | else patient_scores.nsmallest(n_patients)) 74 | 75 | for i, patient in enumerate(patients.keys()): 76 | patient_tiles = preds_df[preds_df[patient_label] == patient] 77 | 78 | tiles = (patient_tiles.nlargest(n=n_tiles, columns=f'{target_label}_{class_}') 79 | if best_tiles 80 | else patient_tiles.nsmallest(n=n_tiles, columns=f'{target_label}_{class_}')) 81 | 82 | for j, (_, tile) in enumerate(tiles.iterrows()): 83 | # if hasattr(tile, "fold"): 84 | # n_fold = tile.fold 85 | # p_fold = result_dir/f"fold_{n_fold}"/"export.pkl" 86 | # else: 87 | p_fold = result_dir/"export.pkl" 88 | 89 | learn = load_learner(p_fold) # p/target , cpu=False 90 | 91 | dls = learn.dls 92 | 93 | if hasattr(dls.vocab, 'o2i'): 94 | cls_dec = dls.vocab.o2i[class_] 95 | else: 96 | try: 97 | dict_vocab = {w: i for i, w in enumerate(dls.vocab)} 98 | cls_dec = dict_vocab[str(class_)] 99 | except TypeError: # not hashable, categorymap 100 | dict_vocab = {w: i for i, 101 | w in enumerate(dls.vocab[-1])} 102 | cls_dec = dict_vocab[str(class_)] 103 | 104 | x = first(dls.test_dl(tile.to_frame().transpose())) 105 | 106 | # TODO: referencing MultiInputModel 107 | 108 | if isinstance(learn.model, MultiInputModel): 109 | feature_extractor = learn.model.cnn_feature_extractor[0] 110 | else: 111 | feature_extractor = learn.model[0] 112 | 113 | with HookBwd(feature_extractor) as hookg: 114 | with Hook(feature_extractor) as hook: 115 | output = learn.model.eval()(*x) # .cuda() 116 | act = hook.stored 117 | output[0, cls_dec].backward() 118 | grad = hookg.stored 119 | 120 | w = grad[0].mean(dim=[1, 2], keepdim=True) 121 | cam_map = (w * act[0]).sum(0) 122 | 123 | x_dec = TensorImage(dls.train.decode(x)[0][0]) 124 | _, ax = plt.subplots() 125 | x_dec.show(ctx=ax) 126 | 127 | ax.imshow(cam_map.detach().cpu(), alpha=0.6, extent=( 128 | 0, x[0].shape[2], x[0].shape[3], 0), interpolation='bilinear', cmap='magma') 129 | 130 | plt.axis('off') 131 | 132 | index = i*n_tiles + j+1 133 | tile_dict[index] = tile 134 | tile_name = Path(tile.tile_path).stem 135 | # if save_images == True: 136 | gradcam_dir = result_dir / 'Grad-CAM_images' 137 | gradcam_dir.mkdir(exist_ok=True) 138 | out_pic = gradcam_dir/f"{tile_name}_{class_}_Grad-CAM.png" 139 | 140 | tile_dict[index] = out_pic 141 | plt.savefig(out_pic) 142 | plt.close() 143 | 144 | if not outfile.exists(): 145 | plt.figure(figsize=(n_patients, n_tiles), dpi=600) 146 | for i, im in tile_dict.items(): 147 | plt.subplot(n_patients, n_tiles, i) 148 | plt.axis('off') 149 | # cannot read svg thus atm gradcams are saved as PNG imgs 150 | plt.imshow(Image.open(im)) 151 | plt.savefig(outfile, bbox_inches='tight') 152 | 153 | plt.close() 154 | -------------------------------------------------------------------------------- /deepmed/evaluators/heatmap.py: -------------------------------------------------------------------------------- 1 | from deepmed.utils import factory 2 | from typing import Optional, Iterable, Union, Tuple 3 | import pandas as pd 4 | import numpy as np 5 | from matplotlib.patches import Patch 6 | from scipy import interpolate 7 | import re 8 | import logging 9 | from pathlib import Path 10 | import matplotlib.pyplot as plt 11 | from PIL import Image 12 | 13 | 14 | def _heatmap( 15 | target_label: str, preds_df: pd.DataFrame, path: Path, 16 | colors=np.array([[1, 0, 0], [0, 0, 1], [0, 1, 1], [1, 1, 0]]), 17 | wsi_paths: Optional[Iterable[Union[Path, str]]] = None, 18 | wsi_suffixes: Iterable[str] = ['.svs', '.ndpi'], 19 | superimpose: bool = False, format: str = '.svg') -> None: 20 | 21 | # openslide is non-trivial to install, so let's make it optional by lazily loading it 22 | from openslide import OpenSlide 23 | 24 | logger = logging.getLogger(str(path)) 25 | outdir = path/'heatmaps' 26 | 27 | classes = sorted(preds_df[target_label].unique()) 28 | score_labels = [f'{target_label}_{class_}' for class_ in classes] 29 | legend_elements = [ 30 | Patch(facecolor=color, label=class_) for class_, color in zip(classes, colors)] 31 | 32 | for slide_name, tiles in preds_df.groupby('FILENAME'): 33 | true_label = tiles.iloc[0][target_label] 34 | 35 | outfile = (outdir/str(true_label)/slide_name).with_suffix(format) 36 | if outfile.exists(): 37 | continue 38 | 39 | try: 40 | plt.figure(dpi=600) 41 | slide_path = Path(tiles.tile_path.iloc[0]).parent 42 | map_coords = np.array([ 43 | _get_coords(tile_path.name) for tile_path in slide_path.glob('*.jpg')]) 44 | 45 | stride = _get_stride(map_coords) 46 | scaled_map_coords = map_coords // stride 47 | 48 | mask = np.zeros(scaled_map_coords.max(0) + 1) 49 | for coord in scaled_map_coords: 50 | mask[coord[0], coord[1]] = 1 51 | 52 | points = tiles.tile_path.map(lambda x: _get_coords(Path(x).name)) 53 | points = np.array(list(points)) 54 | 55 | points = points // stride 56 | 57 | values = tiles[score_labels].to_numpy() 58 | 59 | assert points.shape[1] == 2, "expected points to have shape (_, 2)" 60 | assert points.shape[0] == values.shape[0], \ 61 | "expected points and values to have the same number of elements" 62 | # grid which will form the basis for our output image 63 | grid_x, grid_y = np.mgrid[0:scaled_map_coords[:,0].max()+1, 64 | 0:scaled_map_coords[:,1].max()+1] 65 | 66 | # interpolate heatmap over grid 67 | activations = interpolate.griddata(points, values, (grid_x, grid_y)) 68 | activations = np.nan_to_num(activations) * np.expand_dims(mask, 2) 69 | 70 | if not wsi_paths: 71 | heatmap = _visualize_activation_map( 72 | activations.transpose(1, 0, 2), colors[:activations.shape[-1]]) 73 | heatmap = heatmap.resize(np.multiply(heatmap.size, 8), resample=Image.NEAREST) 74 | plt.imshow(heatmap) 75 | plt.axis('off') 76 | legend = plt.legend( 77 | title=target_label, handles=legend_elements, bbox_to_anchor=(1, 1), loc='upper left') 78 | else: 79 | # find a wsi file with the slide 80 | fn = next(filter(Path.exists, 81 | ((Path(wsi_path)/str(slide_name)).with_suffix(suffix) 82 | for wsi_path in wsi_paths 83 | for suffix in wsi_suffixes)), 84 | None) 85 | 86 | if fn is None: continue 87 | slide = OpenSlide(str(fn)) 88 | 89 | # get the first level smaller than max_size 90 | level = next((i for i, dims in enumerate(slide.level_dimensions) 91 | if max(dims) <= 2400*2), 92 | slide.level_count-1) 93 | thumb = slide.read_region((0, 0), level, slide.level_dimensions[level]) 94 | covered_area_size = ( 95 | (map_coords.max(0)+stride) / 96 | slide.level_downsamples[level]).astype(int) 97 | heatmap = _visualize_activation_map( 98 | activations.transpose(1, 0, 2), 99 | colors=colors[:activations.shape[-1]], 100 | alpha=.5 if superimpose else 1) 101 | 102 | scaled_heatmap = Image.new('RGBA', thumb.size) 103 | scaled_heatmap.paste( 104 | heatmap.resize(covered_area_size, resample=Image.NEAREST)) 105 | 106 | if superimpose: 107 | thumb.alpha_composite( 108 | scaled_heatmap) 109 | plt.imshow(thumb) 110 | plt.axis('off') 111 | legend = plt.legend( 112 | title=target_label, handles=legend_elements, bbox_to_anchor=(1, 1), loc='upper left') 113 | else: 114 | fig, axs = plt.subplots(1, 2, figsize=(12,6), dpi=300) 115 | axs[0].imshow(thumb) 116 | axs[0].axis('off') 117 | axs[1].imshow(scaled_heatmap) 118 | axs[1].axis('off') 119 | legend = axs[1].legend( 120 | title=target_label, handles=legend_elements, bbox_to_anchor=(1, 1), loc='upper left') 121 | 122 | outfile.parent.mkdir(exist_ok=True, parents=True) 123 | plt.savefig(outfile, bbox_extra_artists=[legend], bbox_inches='tight') 124 | plt.close('all') 125 | except Exception as exp: 126 | logger.exception(exp) 127 | 128 | 129 | def _get_coords(filename: str) -> Optional[Tuple[int, int]]: 130 | if matches := re.match(r'.*\((\d+),(\d+)\)\.jpg', filename): 131 | coords = tuple(map(int, matches.groups())) 132 | assert len(coords) == 2, 'Error extracting coordinates' 133 | return (coords[0], coords[1]) # weird return format so mypy doesn't complain 134 | else: return None 135 | 136 | 137 | def _get_stride(coordinates: np.ndarray) -> int: 138 | xs = sorted(set(coordinates[:, 0])) 139 | x_strides = np.subtract(xs[1:], xs[:-1]) 140 | 141 | ys = sorted(set(coordinates[:, 1])) 142 | y_strides = np.subtract(ys[1:], ys[:-1]) 143 | 144 | stride = min(*x_strides, *y_strides) 145 | return stride 146 | 147 | 148 | def _visualize_activation_map(activations: np.ndarray, colors: np.ndarray, alpha: float = 1.) -> Image: 149 | """Transforms an activation map into an RGBA image. 150 | Args: 151 | activations: An (h, w, D) array of activations. 152 | colors: A (D, 3) array mapping each of the target classes to a color. 153 | Returns: 154 | An interpolated activation map image. Regions which the algorithm assumes to be background 155 | will be transparent. 156 | """ 157 | assert colors.shape[1] == 3, "expected color map to have three color channels" 158 | assert colors.shape[0] == activations.shape[2], "one color map entry per class required" 159 | 160 | # transform activation map into RGB map 161 | rgbmap = activations.dot(colors) 162 | 163 | # create RGBA image with non-zero activations being the foreground 164 | mask = activations.any(axis=2) 165 | im_data = (np.concatenate([rgbmap, np.expand_dims(mask * alpha, -1)], axis=2) * 255.5).astype(np.uint8) 166 | 167 | return Image.fromarray(im_data) 168 | 169 | 170 | Heatmap = factory(_heatmap) 171 | -------------------------------------------------------------------------------- /deepmed/evaluators/metrics.py: -------------------------------------------------------------------------------- 1 | """Metrics to evaluate a model's performance with. 2 | 3 | During evaluation, each metric will be called with three arguments: 4 | 1. The target label for which the metric is to be evaluated. 5 | 2. A predictions data frame, containing the complete testing set data frame and additional columns 6 | titled ``{target_label}_{class_}``, which contain the class scores as well as a column 7 | ``{target_label}_pred`` which contains a hard predicion for that item. 8 | 3. A path the metric can store results to. 9 | 10 | In general, metrics are implemented in two ways: 11 | 1. As a function. Some of these functions may have additional arguments; these can be set using 12 | ``functools.partial``, e.g. ``partial(f1, min_tpr=.95)``. 13 | 2. As a function object. These metrics usually encode meta-metrics, i.e. metrics which modify other 14 | metrics. 15 | 16 | Metrics may return a ``DataFrame``, which will be written to the result directory inside the file 17 | ``stats.pkl``. 18 | """ 19 | 20 | from typing import Optional 21 | from pathlib import Path 22 | 23 | import pandas as pd 24 | 25 | import sklearn.metrics as skm 26 | import pandas as pd 27 | from pathlib import Path 28 | from typing import Optional 29 | import matplotlib.pyplot as plt 30 | import scipy.stats as st 31 | 32 | from ..utils import factory 33 | 34 | 35 | def r2(target_label: str, preds_df: pd.DataFrame, _: Path) -> pd.DataFrame: 36 | return pd.DataFrame.from_dict( 37 | {'score': [skm.r2_score(preds_df[target_label], preds_df[f'{target_label}_score'])]}, 38 | columns=['r2'], orient='index') 39 | 40 | 41 | def p_value(target_label: str, preds_df: pd.DataFrame, _result_dir: Path) -> pd.DataFrame: 42 | stats = {} 43 | for class_ in preds_df[target_label].unique(): 44 | pos_scores = preds_df[f'{target_label}_{class_}'][preds_df[target_label] == class_] 45 | neg_scores = preds_df[f'{target_label}_{class_}'][preds_df[target_label] != class_] 46 | stats[class_] = [st.ttest_ind(pos_scores, neg_scores).pvalue] 47 | return pd.DataFrame.from_dict(stats, orient='index', columns=['p value']) 48 | 49 | 50 | def _f1(target_label: str, preds_df: pd.DataFrame, _result_dir: Path, 51 | min_tpr: Optional[float] = None) \ 52 | -> pd.DataFrame: 53 | """Calculates the F1 score. 54 | 55 | Args: 56 | min_tpr: If min_tpr is not given, a threshold which maximizes the F1 57 | score is selected; otherwise, the threshold which guarantees a tpr of at 58 | least min_tpr is used. 59 | """ 60 | y_true = preds_df[target_label] 61 | 62 | stats = {} 63 | for class_ in y_true.unique(): 64 | thresh = _get_thresh(target_label, preds_df, class_, min_tpr=min_tpr) 65 | 66 | stats[class_] = \ 67 | skm.f1_score(y_true == class_, 68 | preds_df[f'{target_label}_{class_}'] >= thresh) 69 | 70 | return pd.DataFrame.from_dict( 71 | stats, columns=[f'f1 {min_tpr or "optimal"}'], orient='index') 72 | 73 | 74 | F1 = factory(_f1) 75 | 76 | 77 | def _confusion_matrix( 78 | target_label: str, preds_df: pd.DataFrame, result_dir: Path, 79 | min_tpr: Optional[float] = None) \ 80 | -> None: 81 | """Generates a confusion matrix for each class label. 82 | 83 | Args: 84 | min_tpr: The minimum true positive rate the confusion matrix shall have 85 | for each class. If None, the true positive rate maximizing the F1 86 | score will be calculated. 87 | """ 88 | classes = preds_df[target_label].unique() 89 | if len(classes) == 2: 90 | for class_ in classes: 91 | thresh = _get_thresh(target_label, preds_df, 92 | pos_label=class_, min_tpr=min_tpr) 93 | y_true = preds_df[target_label] == class_ 94 | y_pred = preds_df[f'{target_label}_{class_}'] >= thresh 95 | cm = skm.confusion_matrix(y_true, y_pred) 96 | disp = skm.ConfusionMatrixDisplay( 97 | confusion_matrix=cm, 98 | # FIXME this next line is horrible to read 99 | display_labels=(classes if class_ == classes[1] else list(reversed(classes)))) 100 | disp.plot() 101 | plt.title( 102 | f'{target_label} ' + 103 | (f"({class_} TPR ≥ {min_tpr})" if min_tpr 104 | else f"(Optimal {class_} F1 Score)")) 105 | plt.savefig(result_dir / 106 | f'conf_matrix_{target_label}_{class_}_{min_tpr or "opt"}.svg') 107 | plt.close() 108 | else: # TODO does this work? 109 | cm = skm.confusion_matrix( 110 | preds_df[target_label], preds_df[f'{target_label}_pred'], labels=classes) 111 | disp = skm.ConfusionMatrixDisplay( 112 | confusion_matrix=cm, display_labels=classes) 113 | disp.plot() 114 | plt.title(f'{target_label}') 115 | plt.savefig(result_dir/f'conf_matrix_{target_label}.svg') 116 | plt.close() 117 | 118 | 119 | ConfusionMatrix = factory(_confusion_matrix) 120 | 121 | 122 | def _get_thresh(target_label: str, preds_df: pd.DataFrame, pos_label: str, 123 | min_tpr: Optional[float] = None) -> float: 124 | """Calculates a classification threshold for a class. 125 | 126 | If `min_tpr` is given, the lowest threshold to guarantee the requested tpr 127 | is returned. Else, the threshold optimizing the F1 score will be returned. 128 | 129 | Args: 130 | pos_label: str: The class to optimize for. 131 | min_tpr: The minimum required true prositive rate, or the threshold 132 | which maximizes the F1 score if None. 133 | 134 | Returns: 135 | The optimal theshold. 136 | """ 137 | fprs, tprs, threshs = skm.roc_curve( 138 | (preds_df[target_label] == pos_label)*1., preds_df[f'{target_label}_{pos_label}']) 139 | 140 | if min_tpr: 141 | return threshs[next(i for i, tpr in enumerate(tprs) if tpr >= min_tpr)] 142 | else: 143 | return max( 144 | threshs, 145 | key=lambda t: skm.f1_score( 146 | preds_df[target_label] == pos_label, preds_df[f'{target_label}_{pos_label}'] > t)) 147 | 148 | 149 | def auroc(target_label: str, preds_df: pd.DataFrame, _result_dir) -> pd.DataFrame: 150 | """Calculates the one-vs-rest AUROC for each class of the target label.""" 151 | y_true = preds_df[target_label] 152 | df = pd.DataFrame.from_dict( 153 | {class_: [skm.roc_auc_score(y_true == class_, preds_df[f'{target_label}_{class_}'])] 154 | for class_ in y_true.unique()}, 155 | columns=['auroc'], orient='index') 156 | return df 157 | 158 | 159 | def count(target_label: str, preds_df: pd.DataFrame, _result_dir) -> pd.DataFrame: 160 | """Calculates the number of testing instances for each class.""" 161 | counts = preds_df[target_label].value_counts() 162 | return pd.DataFrame(counts.values, index=counts.index, columns=['count']) 163 | -------------------------------------------------------------------------------- /deepmed/evaluators/roc.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import matplotlib.pyplot as plt 3 | 4 | from sklearn.metrics import auc 5 | from sklearn.metrics import roc_curve, auc, RocCurveDisplay 6 | import scipy.stats as st 7 | import pandas as pd 8 | from pathlib import Path 9 | 10 | from ..utils import factory 11 | 12 | 13 | def _plot_roc(df: pd.DataFrame, target_label: str, pos_label: str, ax, conf: float = 0.95): 14 | # gracefully stolen from 15 | tprs = [] 16 | aucs = [] 17 | mean_fpr = np.linspace(0, 1, 100) 18 | 19 | for fold in sorted(df.fold.unique()): 20 | fold_df = df[df.fold == fold] 21 | fpr, tpr, _ = roc_curve((fold_df[target_label] == pos_label)*1., fold_df[f'{target_label}_{pos_label}']) 22 | 23 | roc_auc = auc(fpr, tpr) 24 | viz = RocCurveDisplay(fpr=fpr, 25 | tpr=tpr, 26 | estimator_name=f'Fold {int(fold)}', 27 | roc_auc=roc_auc) 28 | viz.plot(ax=ax) 29 | 30 | interp_tpr = np.interp(mean_fpr, viz.fpr, viz.tpr) 31 | interp_tpr[0] = 0.0 32 | tprs.append(interp_tpr) 33 | aucs.append(viz.roc_auc) 34 | 35 | ax.plot([0, 1], [0, 1], linestyle='--', lw=2, color='r', 36 | label='Chance', alpha=.8) 37 | 38 | mean_tpr = np.mean(tprs, axis=0) 39 | mean_tpr[-1] = 1.0 # type: ignore 40 | 41 | # calculate mean and conf intervals 42 | auc_mean = np.mean(aucs) 43 | auc_conf_limits = st.t.interval(alpha=conf, df=len(aucs)-1, loc=np.mean(aucs), scale=st.sem(aucs)) 44 | auc_conf = (auc_conf_limits[1]-auc_conf_limits[0])/2 45 | 46 | ax.plot(mean_fpr, mean_tpr, color='b', 47 | label=f'Mean ROC (AUC = {auc_mean:0.2f} $\\pm$ {auc_conf:0.2f})', 48 | lw=2, alpha=.8) 49 | 50 | ax.set(xlim=[-0.05, 1.05], ylim=[-0.05, 1.05], 51 | title=f'{target_label}: {pos_label} ROC') 52 | ax.legend(loc="lower right") 53 | 54 | 55 | def _plot_simple_roc(df: pd.DataFrame, target_label: str, pos_label: str, ax, conf: float = 0.95): 56 | # gracefully stolen from 57 | fpr, tpr, _ = roc_curve((df[target_label] == pos_label)*1., df[f'{target_label}_{pos_label}']) 58 | 59 | roc_auc = auc(fpr, tpr) 60 | viz = RocCurveDisplay(fpr=fpr, 61 | tpr=tpr, 62 | roc_auc=roc_auc) 63 | viz.plot(ax=ax) 64 | 65 | ax.plot([0, 1], [0, 1], linestyle='--', lw=2, color='r', 66 | label='Chance', alpha=.8) 67 | 68 | ax.set(xlim=[-0.05, 1.05], ylim=[-0.05, 1.05], 69 | title=f'{target_label}: {pos_label} ROC') 70 | ax.legend(loc="lower right") 71 | 72 | 73 | def _roc(target_label: str, preds_df: pd.DataFrame, result_dir: Path) -> None: 74 | """Creates a one-vs-all ROC curve plot for each class.""" 75 | y_true = preds_df[target_label] 76 | for class_ in y_true.unique(): 77 | outfile = result_dir/f'roc_{target_label}_{class_}.svg' 78 | if outfile.exists(): 79 | continue 80 | 81 | fig, ax = plt.subplots() 82 | if 'fold' in preds_df: 83 | _plot_roc(preds_df, target_label, class_, ax=ax, conf=.95) 84 | else: 85 | _plot_simple_roc(preds_df, target_label, class_, ax=ax, conf=.95) 86 | 87 | plt.savefig(outfile) 88 | plt.close() 89 | 90 | Roc = factory(_roc) -------------------------------------------------------------------------------- /deepmed/evaluators/top_tiles.py: -------------------------------------------------------------------------------- 1 | import shutil 2 | from typing import Optional 3 | from pathlib import Path 4 | 5 | import pandas as pd 6 | from PIL import Image 7 | 8 | import pandas as pd 9 | from pathlib import Path 10 | from typing import Optional 11 | import matplotlib.pyplot as plt 12 | from PIL import Image 13 | 14 | from ..utils import factory 15 | 16 | 17 | def _top_tiles( 18 | target_label: str, preds_df: pd.DataFrame, result_dir: Path, 19 | n_patients: int = 4, n_tiles: int = 4, patient_label: str = 'PATIENT', 20 | best_patients: bool = True, best_tiles: Optional[bool] = None, 21 | save_images: bool = False) -> None: 22 | """Generates a grid of the best scoring tiles for each class. 23 | 24 | The function outputs a `n_patients` × `n_tiles` grid of tiles, where each 25 | row contains the `n_tiles` highest scoring tiles for one of the `n_patients` 26 | best-classified patients. 27 | 28 | Args: 29 | best_patients: Whether to select the best or worst n patients. 30 | best_tiles: Whether to select the highest or lowest scoring tiles. If 31 | set to ``None``, then the same as ``best_patients``. 32 | save_images: Also save the tiles separately. 33 | """ 34 | # set `best_tiles` to `best_patients` if undefined 35 | best_tiles = best_tiles if best_tiles is not None else best_patients 36 | 37 | for class_ in preds_df[f'{target_label}_pred'].unique(): 38 | # class_ == MSIH 39 | outdir = result_dir/_generate_tiles_fn( 40 | target_label, class_, best_patients, best_tiles, n_patients, n_tiles) 41 | outfile = Path(str(outdir) + '.svg') 42 | if outfile.exists() and (outdir.exists() or not save_images): 43 | continue 44 | if save_images: 45 | outdir.mkdir(parents=True, exist_ok=True) 46 | 47 | plt.figure(figsize=(n_tiles, n_patients), dpi=600) 48 | # get patients with the best overall ratings for the label 49 | class_instance_df = preds_df[preds_df[target_label] == class_] 50 | patient_scores = \ 51 | class_instance_df.groupby(patient_label)[f'{target_label}_pred'].agg(lambda x: sum(x == class_) / len(x)) 52 | 53 | patients = (patient_scores.nlargest(n_patients) if best_patients 54 | else patient_scores.nsmallest(n_patients)) 55 | 56 | top_tile_list = [] 57 | for i, patient in enumerate(patients.keys()): 58 | # get the best tile for that patient 59 | patient_tiles = preds_df[preds_df[patient_label] == patient] 60 | 61 | tiles = (patient_tiles.nlargest(n=n_tiles, columns=f'{target_label}_{class_}') 62 | if best_tiles 63 | else patient_tiles.nsmallest(n=n_tiles, columns=f'{target_label}_{class_}')) 64 | top_tile_list.append(tiles) 65 | 66 | for j, tile in enumerate(tiles.tile_path): 67 | if save_images: 68 | shutil.copy(tile, outdir/Path(tile).name) 69 | if not outfile.exists(): 70 | plt.subplot(n_patients, n_tiles, i*n_tiles + j+1) 71 | plt.axis('off') 72 | plt.imshow(Image.open(tile), cmap='gray') 73 | 74 | pd.concat(top_tile_list).to_csv(outfile.with_suffix('.csv'), index=False) 75 | 76 | if not outfile.exists(): 77 | plt.savefig(outfile, bbox_inches='tight') 78 | plt.close() 79 | 80 | 81 | def _generate_tiles_fn( 82 | target_label: str, class_: str, best_patients: bool, best_tiles: bool, 83 | n_patients: int, n_tiles: int) -> str: 84 | patient_str = f'{"best" if best_patients else "worst"}-{n_patients}-patients' 85 | tile_str = f'{"best" if best_tiles else "worst"}-{n_tiles}-tiles' 86 | 87 | return f'{target_label}_{class_}_{patient_str}_{tile_str}' 88 | 89 | 90 | TopTiles = factory(_top_tiles) -------------------------------------------------------------------------------- /deepmed/evaluators/types.py: -------------------------------------------------------------------------------- 1 | from typing import Callable, Optional 2 | import pandas as pd 3 | from pathlib import Path 4 | 5 | Evaluator = Callable[[Optional[str], pd.DataFrame, Path], Optional[pd.DataFrame]] -------------------------------------------------------------------------------- /deepmed/experiment_imports.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import coloredlogs 3 | 4 | coloredlogs.install( 5 | fmt='%(asctime)s,%(msecs)03d %(name)s %(levelname)s %(message)s', level=logging.DEBUG) 6 | 7 | from fastai.vision.all import * 8 | from packaging.specifiers import SpecifierSet 9 | 10 | from pathlib import Path 11 | import pandas as pd 12 | import deepmed 13 | from deepmed import * 14 | from deepmed import get, multi_input, evaluators, on_features, mil 15 | from deepmed.evaluators import * 16 | from deepmed.get import cohort -------------------------------------------------------------------------------- /deepmed/get/__init__.py: -------------------------------------------------------------------------------- 1 | from ._simple import * 2 | from ._subgroup import * 3 | from ._crossval import * 4 | from ._multi_target import * 5 | from ._parameterize import * 6 | from ._extract_features import * 7 | 8 | __all__ = [ 9 | 'cohort', 'SimpleRun', 'GetTiles', 'DatasetType', 'Subgroup', 10 | 'MultiTarget', 'MultiTargetBaseTaskGetter', 11 | 'Parameterize', 'ParameterizeBaseTaskGetter', 12 | 'Crossval', 'CrossvalBaseTaskGetter', 'ExtractFeatures'] -------------------------------------------------------------------------------- /deepmed/get/_crossval.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from typing import Iterable, Iterator, Any, Optional 3 | from pathlib import Path 4 | from typing_extensions import Protocol 5 | 6 | import pandas as pd 7 | from sklearn.model_selection import StratifiedKFold, KFold 8 | 9 | from .._experiment import Task, EvalTask 10 | from ._simple import _prepare_cohorts 11 | from ..utils import exists_and_has_size, log_defaults, is_continuous, factory 12 | from ..get import Evaluator 13 | 14 | 15 | class CrossvalBaseTaskGetter(Protocol): 16 | """The signature of a task getter which can be modified by ``crossval``.""" 17 | 18 | def __call__( 19 | self, *args, 20 | project_dir: Path, target_label: str, 21 | train_cohorts_df: pd.DataFrame, test_cohorts_df: pd.DataFrame, min_support: int, 22 | **kwargs) \ 23 | -> Iterator[Task]: 24 | ... 25 | 26 | 27 | @log_defaults 28 | def _crossval( 29 | get: CrossvalBaseTaskGetter, 30 | *args, 31 | project_dir: Path, 32 | target_label: str, 33 | cohorts_df: pd.DataFrame, 34 | folds: int = 3, 35 | seed: int = 0, 36 | n_bins: Optional[int] = 2, 37 | na_values: Iterable[Any] = [], 38 | min_support: int = 10, 39 | patient_label: str = 'PATIENT', 40 | crossval_evaluators: Iterable[Evaluator] = [], 41 | **kwargs) \ 42 | -> Iterator[Task]: 43 | """Generates cross validation tasks for a single target. 44 | 45 | Args: 46 | get: Getter to perform cross-validation with. 47 | project_dir: Path to save project data to. 48 | train_cohorts_df: The cohorts to perform cross-validation on. 49 | valid_frac: The fraction of patients which will be reserved for 50 | validation during training. 51 | folds: Number of subsets to split the training data into. 52 | n_bins: The number of bins to discretize continuous values into. 53 | na_values: The class labels to consider as N/A values. 54 | min_support: The minimum amount of class samples required for the class 55 | per fold to be included in training. Classes with less support are 56 | dropped. 57 | *args: Arguments to pass to ``get``. 58 | *kwargs: Keyword arguments to pass to ``get``. 59 | 60 | Yields: 61 | A task for each fold of the cross-validation. 62 | 63 | For each of the folds a new subdirectory will be created. Each of the folds 64 | will be generated in a stratified fashion, meaning that the cohorts' class 65 | distribution will be maintained. 66 | """ 67 | logger = logging.getLogger(str(project_dir)) 68 | project_dir.mkdir(parents=True, exist_ok=True) 69 | 70 | if exists_and_has_size(folds_path := project_dir/'folds.csv.zip'): 71 | folded_df = pd.read_csv(folds_path) 72 | folded_df.slide_path = folded_df.slide_path.map(Path) 73 | else: 74 | cohorts_df = _prepare_cohorts( 75 | cohorts_df, target_label, na_values, n_bins, min_support*folds//(folds-1), logger=logger) 76 | 77 | if cohorts_df is None or cohorts_df.empty: 78 | logger.warning(f'No data left after preprocessing. Skipping...') 79 | return 80 | elif cohorts_df[target_label].nunique() < 2: 81 | logger.warning( 82 | f'Not enough classes for target {target_label}! Skipping...') 83 | return 84 | 85 | logger.info( 86 | f'Slide target counts: {dict(cohorts_df[target_label].value_counts())}') 87 | 88 | folded_df = _create_folds( 89 | cohorts_df=cohorts_df, target_label=target_label, folds=folds, seed=seed, 90 | patient_label=patient_label, n_bins=n_bins) 91 | folded_df.to_csv(folds_path, compression='zip') 92 | 93 | # accumulate first to ensure training / testing set data is saved 94 | fold_tasks = ( 95 | task 96 | for fold in sorted(folded_df.fold.unique()) 97 | for task in get( # type: ignore 98 | *args, 99 | project_dir=project_dir/f'fold_{fold}', 100 | target_label=target_label, 101 | train_cohorts_df=folded_df[folded_df.fold != fold], 102 | test_cohorts_df=folded_df[folded_df.fold == fold], 103 | n_bins=n_bins, 104 | min_support=0, 105 | **kwargs) 106 | ) 107 | requirements = [] 108 | for task in fold_tasks: 109 | yield task 110 | requirements.append(task) 111 | 112 | yield EvalTask( 113 | path=project_dir, 114 | target_label=target_label, 115 | requirements=requirements, 116 | evaluators=crossval_evaluators) 117 | 118 | 119 | def _create_folds( 120 | cohorts_df: pd.DataFrame, target_label: str, folds: int, seed: int, patient_label: str, n_bins: Optional[int] 121 | ) -> pd.DataFrame: 122 | """Adds a ``fold`` column.""" 123 | 124 | kf = (StratifiedKFold(n_splits=folds, random_state=seed, shuffle=True) 125 | if n_bins is not None or not is_continuous(cohorts_df[target_label]) 126 | else KFold(n_splits=folds, random_state=seed, shuffle=True)) 127 | 128 | # Pepare our dataframe 129 | # We enumerate each fold; this way, the training set for the `k`th iteration can be easily 130 | # obtained through `df[df.fold != k]`. Additionally, we sample a validation set for early 131 | # stopping. 132 | patients = cohorts_df.groupby(patient_label)[target_label].first() 133 | cohorts_df['fold'] = 0 134 | for fold, (_, test_idx) \ 135 | in enumerate(kf.split(patients.index, patients)): 136 | cohorts_df.loc[cohorts_df[patient_label].isin( 137 | patients.iloc[test_idx].index), 'fold'] = fold 138 | 139 | return cohorts_df 140 | 141 | 142 | Crossval = factory(_crossval) 143 | -------------------------------------------------------------------------------- /deepmed/get/_extract_features.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | import os 3 | from pathlib import Path 4 | import tempfile 5 | from PIL import Image 6 | from fastai.layers import AdaptiveConcatPool2d 7 | from fastai.learner import Learner 8 | from fastai.losses import CrossEntropyLossFlat 9 | from deepmed.utils import factory 10 | from typing import Callable, Generator, Iterable, Iterator, Optional, Sequence, TypeVar 11 | from fastai.data.block import DataBlock 12 | from fastai.data.transforms import ColReader 13 | from fastai.vision.data import ImageBlock 14 | from fastai.vision.learner import cnn_learner 15 | from torch import nn 16 | from fastai.vision.augment import RandomCrop, Resize 17 | from fastai.vision.learner import create_body 18 | from fastai.vision.models import resnet18 19 | from tqdm import tqdm 20 | import pandas as pd 21 | import numpy as np 22 | import h5py 23 | import re 24 | import torch 25 | import logging 26 | from fastdownload import FastDownload 27 | from fastai.data.external import fastai_cfg 28 | import fastai 29 | from ..types import PathLike, Task 30 | 31 | 32 | all = ['Extract', 'PretrainedModel', 'ExtractTask'] 33 | 34 | 35 | def _extract( 36 | project_dir: Path, 37 | tile_dir: PathLike, 38 | feat_dir: Optional[PathLike] = None, 39 | arch: Callable[[bool], nn.Module] = resnet18, 40 | num_workers: int = 32 if os.name == 'posix' else 0, 41 | **kwargs 42 | ) -> Iterator[Task]: 43 | tile_dir = Path(tile_dir) 44 | feat_dir = Path(feat_dir) if feat_dir is not None else project_dir 45 | 46 | feat_dir.mkdir(exist_ok=True) 47 | yield ExtractTask(path=feat_dir, 48 | requirements=[], 49 | slides=list(tile_dir.iterdir()), 50 | arch=arch, 51 | num_workers=num_workers) 52 | 53 | 54 | Extract = factory(_extract) 55 | 56 | 57 | def PretrainedModel(url, arch=resnet18) -> nn.Module: 58 | d = FastDownload(fastai_cfg(), module=fastai.data, base='~/.fastai') 59 | path = d.download(url) 60 | 61 | model = arch(pretrained=False) 62 | checkpoint = torch.load(path) 63 | missing = model.load_state_dict(checkpoint, strict=False) 64 | assert not set(missing.missing_keys) 65 | 66 | return lambda pretrained: model 67 | 68 | 69 | @dataclass 70 | class ExtractTask(Task): 71 | slides: Iterable[Path] 72 | arch: Callable[[bool], nn.Module] 73 | num_workers: int 74 | 75 | def do_work(self): 76 | for slides in (slide_pbar := tqdm(list(batch(self.slides, n=256)), leave=False)): 77 | learn = feature_extractor( 78 | arch=self.arch, num_workers=self.num_workers, item_tfms=RandomCrop(224))#Resize(224)) 79 | slide_pbar.set_description(slides[0].name) 80 | do_slides(slides, learn, self.path) 81 | 82 | 83 | def do_slides(slides: Iterable[Path], learn: Learner, feat_dir: Path): 84 | #checksum = model_checksum(learn.model) 85 | 86 | dfs = [] 87 | for slide in slides: 88 | # if (h5_file := feat_dir/f'{slide.name}.h5').exists(): 89 | # assert (h5_checksum := h5py.File(h5_file, 'r').attrs['extractor-checksum']) == checksum, \ 90 | # f'{h5_file} has been extracted with a different model than the current one. ' \ 91 | # f'(current: {checksum:08x}, {h5_file.name}: {h5_checksum:08x})' 92 | # continue 93 | 94 | slide_df = pd.DataFrame( 95 | list(slide.glob('*.jpg')), columns=['path']) 96 | slide_df['slide'] = slide 97 | if slide_df.empty: 98 | continue 99 | dfs.append(slide_df) 100 | 101 | if not dfs: 102 | return 103 | df = pd.concat(dfs).reset_index() 104 | 105 | test_dl = learn.dls.test_dl(df) 106 | preds, _ = learn.get_preds(dl=test_dl, act=nn.Identity()) 107 | 108 | for slide, data in df.groupby('slide'): 109 | coords = np.array(list(data.path.map(_get_coords))) 110 | outpath = feat_dir/f'{slide.name}.h5' 111 | with h5py.File(outpath, 'w') as f: 112 | f['feats'] = preds[data.index] 113 | f['coords'] = coords 114 | #f.attrs['extractor-checksum'] = checksum 115 | 116 | 117 | def model_checksum(m): 118 | checksum = torch.tensor(0, dtype=torch.int64) 119 | for p in m.parameters(): 120 | checksum += (p.cpu().abs()*(1<<24)).type(torch.int64).sum() 121 | checksum %= 1<<32 122 | return checksum 123 | 124 | 125 | T = TypeVar('T') 126 | 127 | 128 | def batch(sequence: Sequence[T], n: int) -> Iterable[Sequence[T]]: 129 | l = len(sequence) 130 | for ndx in range(0, l, n): 131 | yield sequence[ndx:min(ndx + n, l)] 132 | 133 | 134 | def _get_coords(filename: PathLike) -> Optional[np.ndarray]: 135 | if matches := re.match(r'.*\((-?\d+),(-?\d+)\)\.jpg', str(filename)): 136 | coords = tuple(map(int, matches.groups())) 137 | assert len(coords) == 2, 'Error extracting coordinates' 138 | return np.array(coords, dtype=int) 139 | else: 140 | return None 141 | 142 | 143 | def feature_extractor( 144 | arch: Callable[[bool], nn.Module], num_workers: int, **kwargs 145 | ) -> Learner: 146 | dblock = DataBlock( 147 | blocks=ImageBlock, 148 | get_x=ColReader('path'), 149 | **kwargs) 150 | 151 | with tempfile.TemporaryDirectory() as tempdir: 152 | tilepath = Path(tempdir)/'tile.jpg' 153 | Image.new('RGB', (224, 224)).save(tilepath) 154 | df = pd.DataFrame([tilepath], columns=['path']) 155 | dls = dblock.dataloaders(df, num_workers=num_workers) 156 | 157 | learn = cnn_learner(dls, arch, n_out=2, 158 | loss_func=CrossEntropyLossFlat(), 159 | custom_head=nn.Sequential(AdaptiveConcatPool2d(), 160 | nn.Flatten())) 161 | 162 | return learn 163 | -------------------------------------------------------------------------------- /deepmed/get/_multi_target.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | from typing import Iterable, Iterator 3 | 4 | from ..types import Task 5 | from ..evaluators.types import Evaluator 6 | from ..utils import factory 7 | from ._parameterize import _parameterize, ParameterizeBaseTaskGetter 8 | 9 | 10 | def _multi_target( 11 | get: ParameterizeBaseTaskGetter, 12 | *args, 13 | project_dir: Path, 14 | target_labels: Iterable[str], 15 | multi_target_evaluators: Iterable[Evaluator] = [], 16 | **kwargs) -> Iterator[Task]: 17 | """Adapts a `TaskGetter` into a multi-target one. 18 | 19 | Convenience wrapper around :func:``deepmed.Parameterize``. 20 | 21 | Args: 22 | get: The `TaskGetter` to adapt; it has to take at least one keyword 23 | argument `target_label`. 24 | project_dir: The directory to save the tasks' results to. 25 | target_label: The target labels to invoke ``get`` on. 26 | *args: Additional arguments give to ``get``. 27 | **kwargs: Additional keyword arguments to give to ``get``. 28 | 29 | Yields: 30 | The tasks which would be yielded by `get` for each of the target labels, 31 | in the order of the target labels. The task directories are prepended 32 | by a the name of the target label. 33 | """ 34 | return _parameterize( 35 | get, *args, project_dir=project_dir, 36 | parameterizations={ 37 | target_label: {'target_label': target_label} 38 | for target_label in target_labels}, 39 | parameterize_evaluators=multi_target_evaluators, 40 | **kwargs) 41 | 42 | 43 | MultiTarget = factory(_multi_target) 44 | -------------------------------------------------------------------------------- /deepmed/get/_parameterize.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | from typing import Iterable, Iterator, Mapping, Any 3 | from typing_extensions import Protocol 4 | from ..types import Task, EvalTask 5 | from ..evaluators.types import Evaluator 6 | from ..utils import factory 7 | 8 | 9 | class ParameterizeBaseTaskGetter(Protocol): 10 | """The signature of a task getter which can be modified by ``parameterize``.""" 11 | 12 | def __call__( 13 | self, *args, 14 | project_dir: Path, **kwargs) -> Iterator[Task]: 15 | ... 16 | 17 | 18 | def _parameterize( 19 | get: ParameterizeBaseTaskGetter, 20 | *args, 21 | project_dir: Path, 22 | parameterizations: Mapping[str, Mapping[str, Any]], 23 | parameterize_evaluators: Iterable[Evaluator] = [], 24 | **kwargs) -> Iterator[Task]: 25 | """Starts a family of runs with different parameterizations. 26 | 27 | Args: 28 | parameterizations: A mapping from parameterization descriptions (i.e. 29 | descriptive names) to kwargs mappings. For each element, ``get`` 30 | will be invoked with these kwargs. 31 | parameterize_evaluators: Evaluators to run at the end of all 32 | parameterized runs. 33 | kwargs: Additional arguments to pass to each parameterized run. If a 34 | keyword argument appears both in ``kwargs`` and in a 35 | parameterization, the parameterization's argument takes precedence. 36 | 37 | Yields: 38 | The tasks ``get`` would yield for each of the parameterizations. 39 | """ 40 | eval_reqirements = [] 41 | for name, parameterization in parameterizations.items(): 42 | for task in get( 43 | *args, project_dir=project_dir/name, 44 | # overwrite default ``kwargs``` w/ parameterization ones, if they were given 45 | **{**kwargs, **parameterization}): 46 | eval_reqirements.append(task) 47 | yield task 48 | 49 | yield EvalTask( 50 | path=project_dir, 51 | target_label=None, # TODO remove target label from eval task 52 | requirements=eval_reqirements, 53 | evaluators=parameterize_evaluators) 54 | 55 | 56 | Parameterize = factory(_parameterize) 57 | -------------------------------------------------------------------------------- /deepmed/get/_simple.py: -------------------------------------------------------------------------------- 1 | from enum import Enum, auto 2 | from functools import partial 3 | import random 4 | import logging 5 | from typing import Callable, Iterable, Sequence, Iterator, Optional, Any, Union, Mapping 6 | from pathlib import Path 7 | from numbers import Number 8 | from threading import Semaphore 9 | 10 | import torch 11 | import pandas as pd 12 | from sklearn.model_selection import train_test_split 13 | from sklearn import preprocessing 14 | from tqdm import tqdm 15 | import numpy as np 16 | 17 | from ..evaluators.types import Evaluator 18 | from ..utils import exists_and_has_size, is_continuous, log_defaults 19 | from .._experiment import Task, GPUTask, EvalTask 20 | 21 | from .._train import Train 22 | from .._deploy import Deploy 23 | from ..types import Trainer, Deployer 24 | from ..utils import factory 25 | 26 | 27 | logger = logging.getLogger(__name__) 28 | 29 | 30 | PathLike = Union[str, Path] 31 | 32 | 33 | def cohort( 34 | tiles_path: PathLike, clini_path: PathLike, slide_path: PathLike, 35 | patient_label: str = 'PATIENT', slidename_label: str = 'FILENAME' 36 | ) -> pd.DataFrame: 37 | """Creates a cohort df from a slide and a clini table. 38 | 39 | Args: 40 | tiles_path: The path in which the slides' tiles are stored. Each 41 | slides' tiles have to be stored in a directory in ``tiles_path`` 42 | named after the slide. 43 | clini_path: The path of the clinical table, either in csv or excel 44 | format. The clinical table contains information on each patient. 45 | It also needs to contain a column titled ``patient_label`` (default: 46 | 'PATIENT'). 47 | slide_path: A table in csv or excel format mapping slides (in a column 48 | ``slidename_label``) to patients (in a column ``patient_label``). 49 | patient_label: Label to merge the clinical and slide tables on. 50 | slidename_label: Column of the slide table containing the slide names. 51 | """ 52 | tiles_path, clini_path, slide_path = Path( 53 | tiles_path), Path(clini_path), Path(slide_path) 54 | 55 | dtype = {patient_label: str, slidename_label: str} 56 | clini_df = ( 57 | pd.read_csv(clini_path, dtype=dtype) if clini_path.suffix == '.csv' 58 | else pd.read_excel(clini_path, dtype=dtype)) 59 | slide_df = ( 60 | pd.read_csv(slide_path, dtype=dtype) if slide_path.suffix == '.csv' 61 | else pd.read_excel(slide_path, dtype=dtype)) 62 | 63 | cohort_df = clini_df.merge(slide_df, on=patient_label) 64 | cohort_df = cohort_df.copy() # for defragmentation 65 | cohort_df['slide_path'] = tiles_path/cohort_df[slidename_label].map(str) 66 | 67 | # TODO 68 | # assert cohort_df.slide_path.map(Path.exists).any(), \ 69 | # f'none of the slide paths for "{slide_path}" exist!' 70 | 71 | logger.debug(f'#slides in {slide_path}: {len(slide_df)}') 72 | logger.debug(f'#patients in {clini_path}: {len(clini_df)}') 73 | logger.debug(f'#patients with slides for {tiles_path}: {len(cohort_df)}') 74 | 75 | return cohort_df 76 | 77 | 78 | class DatasetType(Enum): 79 | TRAIN = auto() 80 | VALID = auto() 81 | TEST = auto() 82 | 83 | 84 | def get_tiles( 85 | dataset_type: DatasetType, cohorts_df: pd.DataFrame, 86 | max_tile_nums: Mapping[DatasetType, int] = {DatasetType.TRAIN: 128, 87 | DatasetType.VALID: 256, 88 | DatasetType.TEST: 512}, 89 | resample_each_epoch: bool = False, 90 | logger=logging, 91 | ) -> pd.DataFrame: 92 | """Create df containing patient, tiles, other data.""" 93 | tiles_dfs = [] 94 | for _, data in tqdm(cohorts_df.groupby('PATIENT')): 95 | tiles = [(tile_dir, file) 96 | for tile_dir in data.slide_path 97 | if tile_dir.exists() 98 | for file in tile_dir.iterdir()] 99 | if (tile_num := max_tile_nums.get(dataset_type)): 100 | tiles = random.sample(tiles, min(len(tiles), tile_num)) 101 | tiles_df = pd.DataFrame(tiles, columns=['slide_path', 'tile_path']) 102 | 103 | tiles_dfs.append(data.merge( 104 | tiles_df, on='slide_path').drop(columns='slide_path')) 105 | 106 | tiles_df = pd.concat(tiles_dfs).reset_index(drop=True) 107 | logger.info( 108 | f'Found {len(tiles_df)} tiles for {len(tiles_df["PATIENT"].unique())} patients') 109 | 110 | # if we want the training procedure to resample a slide's tiles every epoch, 111 | # we have to supply a slide path instead of the tile path 112 | if dataset_type == DatasetType.TRAIN and resample_each_epoch: 113 | tiles_df.tile_path = tiles_df.tile_path.map(lambda p: p.parent) 114 | 115 | return tiles_df 116 | 117 | 118 | GetTiles = factory(get_tiles) 119 | 120 | 121 | @log_defaults 122 | def _simple_run( 123 | project_dir: Path, 124 | target_label: str, 125 | capacities: Mapping[Union[int, str], Semaphore], 126 | train_cohorts_df: Optional[pd.DataFrame] = None, 127 | test_cohorts_df: Optional[pd.DataFrame] = None, 128 | patient_label: str = 'PATIENT', 129 | balance: bool = True, 130 | train: Trainer = Train(), 131 | deploy: Deployer = Deploy(), 132 | valid_frac: float = .2, 133 | n_bins: Optional[int] = 2, 134 | na_values: Iterable[Any] = [], 135 | min_support: int = 10, 136 | evaluators: Iterable[Evaluator] = [], 137 | max_class_count: Optional[Mapping[str, int]] = None, 138 | get_items: Callable = get_tiles, 139 | **kwargs, 140 | ) -> Iterator[Task]: 141 | """Creates tasks for a basic test-deploy procedure. 142 | 143 | This function will generate a single training and / or deployment task. 144 | Due to large in-patient similarities in slides it may be useful to only 145 | sample a limited number of tiles will from each patient. The task will 146 | have: 147 | 148 | - A training set, if ``train_cohorts`` is not empty. The training set 149 | will be balanced in such a way that each class is represented with a 150 | number of tiles equal to that of the smallest class if ``balanced`` 151 | is ``True``. 152 | - A testing set, if ``test_cohorts`` is not empty. The testing set 153 | may be unbalanced. 154 | 155 | If the target is continuous, it will be discretized. 156 | 157 | Args: 158 | project_dir: Path to save project data to. 159 | train_cohorts_df: The cohorts to use for training. 160 | test_cohorts_df: The cohorts to test on. 161 | resample_each_epoch: Whether to resample the training tiles used 162 | from each slide each epoch. 163 | max_train_tile_num: The maximum number of tiles per patient to use 164 | for training in each epoch or ``None`` for no subsampling. 165 | max_valid_tile_num: The maximum number of validation tiles used in 166 | each epoch or ``None`` for no subsampling. 167 | max_test_tile_num: The maximum number of testing tiles used in 168 | each epoch or ``None`` for no subsampling. 169 | balance: Whether the training set should be balanced. Applies to 170 | categorical targets only. 171 | valid_frac: The fraction of patients which will be reserved for 172 | validation during training. 173 | n_bins: The number of bins to discretize continuous values into. 174 | na_values: The class labels to consider as N/A values. 175 | min_support: The minimum amount of class samples required for the 176 | class to be included in training. Classes with less support are 177 | dropped. 178 | kwargs: Other arguments to be passed to train. 179 | 180 | Yields: 181 | A task to train and / or deploy a model on the given training and 182 | testing data as well as an evaluation task. 183 | """ 184 | logger = logging.getLogger(str(project_dir)) 185 | 186 | if exists_and_has_size(preds_df_path := project_dir/'predictions.csv.zip'): 187 | logger.warning( 188 | f'{preds_df_path} already exists, skipping training/deployment!') 189 | 190 | yield EvalTask( 191 | path=project_dir, 192 | target_label=target_label, 193 | requirements=[], 194 | evaluators=evaluators) 195 | else: 196 | # training set 197 | if exists_and_has_size(train_df_path := project_dir/'training_set.csv.zip'): 198 | logger.warning( 199 | f'{train_df_path} already exists, using old training set!') 200 | train_df = pd.read_csv(train_df_path, dtype={'is_valid': bool}) 201 | elif train_cohorts_df is not None: 202 | train_df = _generate_train_df( 203 | get_items=get_items, train_cohorts_df=train_cohorts_df, 204 | target_label=target_label, na_values=na_values, n_bins=n_bins, 205 | min_support=min_support, logger=logger, 206 | patient_label=patient_label, valid_frac=valid_frac, 207 | train_df_path=train_df_path, balance=balance, 208 | max_class_count=max_class_count) 209 | # unable to generate a train df (e.g. because of insufficient data) 210 | if train_df is None: 211 | return 212 | else: 213 | train_df = None 214 | 215 | # testing set 216 | if exists_and_has_size(test_df_path := project_dir/'testing_set.csv.zip'): 217 | # load old testing set if it exists 218 | logger.warning( 219 | f'{test_df_path} already exists, using old testing set!') 220 | test_df = pd.read_csv(test_df_path) 221 | elif test_cohorts_df is not None: 222 | logger.info(f'Searching for testing tiles') 223 | test_cohorts_df = _prepare_cohorts( 224 | test_cohorts_df, target_label, na_values, n_bins=None, min_support=0, logger=logger) 225 | 226 | logger.info(f'Testing slide counts: {len(test_cohorts_df)}') 227 | test_df = get_items( 228 | dataset_type=DatasetType.TEST, cohorts_df=test_cohorts_df, logger=logger) 229 | 230 | train_df_path.parent.mkdir(parents=True, exist_ok=True) 231 | test_df.to_csv(test_df_path, index=False, compression='zip') 232 | else: 233 | test_df = None 234 | 235 | assert train_df is None or train_df.is_valid.any(), f'no validation set!' 236 | 237 | gpu_task = GPUTask( 238 | path=project_dir, 239 | target_label=target_label, 240 | requirements=[], 241 | train=partial(train, **kwargs), 242 | deploy=deploy, 243 | train_df=train_df, 244 | test_df=test_df, 245 | capacities=capacities) 246 | 247 | yield gpu_task 248 | 249 | if test_df is not None: 250 | yield EvalTask( 251 | path=project_dir, 252 | target_label=target_label, 253 | requirements=[gpu_task], 254 | evaluators=evaluators) 255 | 256 | 257 | def _generate_train_df( 258 | train_cohorts_df: pd.DataFrame, 259 | target_label: str, 260 | get_items: Callable, 261 | na_values: Iterable[Any], 262 | n_bins: Optional[int], 263 | min_support: int, 264 | logger, 265 | patient_label: str, 266 | valid_frac: float, 267 | train_df_path: Path, 268 | balance: bool, 269 | max_class_count: Optional[Mapping[str, int]], 270 | ) -> Optional[pd.DataFrame]: 271 | train_cohorts_df = _prepare_cohorts( 272 | train_cohorts_df, target_label, na_values, n_bins, min_support, logger) 273 | 274 | if train_cohorts_df[target_label].nunique() < 2: 275 | logger.warning( 276 | f'Not enough classes for target {target_label}! skipping...') 277 | return None 278 | 279 | if is_continuous(train_cohorts_df[target_label]): 280 | targets = train_cohorts_df[target_label] 281 | logger.info( 282 | f'Training slide count: {len(targets)} (mean={targets.mean()}, std={targets.std()})') 283 | else: 284 | logger.info( 285 | f'Training slide counts: {dict(train_cohorts_df[target_label].value_counts())}') 286 | 287 | # only use a subset of patients 288 | # (can be useful to compare behavior when training on different cohorts) 289 | if max_class_count is not None: 290 | patients_to_use = [] 291 | for class_, count in max_class_count.items(): 292 | class_patients = \ 293 | train_cohorts_df[train_cohorts_df[target_label] 294 | == class_][patient_label].unique() 295 | patients_to_use.append(np.random.choice( 296 | class_patients, size=count, replace=False)) 297 | train_cohorts_df = train_cohorts_df[train_cohorts_df[patient_label].isin( 298 | np.concatenate(patients_to_use))] 299 | 300 | # split off validation set 301 | patients = train_cohorts_df.groupby(patient_label)[target_label].first() 302 | if is_continuous(train_cohorts_df[target_label]): 303 | _, valid_patients = train_test_split( 304 | patients.index, test_size=valid_frac, shuffle=True) 305 | else: 306 | _, valid_patients = train_test_split( 307 | patients.index, test_size=valid_frac, stratify=patients, shuffle=True) 308 | 309 | train_cohorts_df['is_valid'] = train_cohorts_df[patient_label].isin( 310 | valid_patients) 311 | 312 | logger.info(f'Searching for training tiles') 313 | train_df = get_items( 314 | dataset_type=DatasetType.TRAIN, 315 | cohorts_df=train_cohorts_df[~train_cohorts_df.is_valid], logger=logger) 316 | if train_df.empty: 317 | logger.warning('did not find any tiles. Skipping...') 318 | return None 319 | 320 | valid_df = get_items( 321 | dataset_type=DatasetType.VALID, 322 | cohorts_df=train_cohorts_df[train_cohorts_df.is_valid], 323 | logger=logger) 324 | 325 | # restrict to classes present in training set 326 | if not is_continuous(train_df[target_label]): 327 | train_classes = train_df[target_label].unique() 328 | valid_df = valid_df[valid_df[target_label].isin(train_classes)] 329 | 330 | logger.debug( 331 | f'Training tiles: {dict(train_df[target_label].value_counts())}') 332 | logger.debug( 333 | f'Validation tiles: {dict(valid_df[target_label].value_counts())}') 334 | 335 | if balance and not is_continuous(train_df[target_label]): 336 | train_df = _balance_classes( 337 | tiles_df=train_df, target=target_label) 338 | valid_df = _balance_classes(tiles_df=valid_df, target=target_label) 339 | logger.info(f'Training tiles after balancing: {len(train_df)}') 340 | logger.info(f'Validation tiles after balancing: {len(valid_df)}') 341 | 342 | train_df = pd.concat([train_df, valid_df]) 343 | 344 | train_df_path.parent.mkdir(parents=True, exist_ok=True) 345 | train_df.to_csv(train_df_path, index=False, compression='zip') 346 | 347 | return train_df 348 | 349 | 350 | def _prepare_cohorts( 351 | cohorts_df: pd.DataFrame, target_label: str, na_values: Iterable[str], 352 | n_bins: Optional[int], min_support: int, logger: logging.Logger 353 | ) -> pd.DataFrame: 354 | """Preprocesses the cohorts. 355 | 356 | Discretizes continuous targets and drops classes for which only few 357 | examples are present. 358 | """ 359 | assert not cohorts_df.empty 360 | cohorts_df = cohorts_df.copy() 361 | 362 | # remove N/As 363 | cohorts_df = cohorts_df[cohorts_df[target_label].notna()] 364 | for na_value in na_values: 365 | cohorts_df = cohorts_df[cohorts_df[target_label] != na_value] 366 | if cohorts_df.empty: 367 | logger.warning('no samples left after dropping NAs') 368 | return None 369 | 370 | if n_bins is not None and is_continuous(cohorts_df[target_label]): 371 | # discretize 372 | logger.info(f'Discretizing {target_label}') 373 | cohorts_df[target_label] = _discretize( 374 | cohorts_df[target_label].values, n_bins=n_bins) 375 | 376 | if not is_continuous(cohorts_df[target_label]): 377 | # drop classes with insufficient support 378 | class_counts = cohorts_df[target_label].value_counts() 379 | rare_classes = (class_counts[class_counts < min_support]).index 380 | cohorts_df = cohorts_df[~cohorts_df[target_label].isin(rare_classes)] 381 | if cohorts_df.empty: 382 | logger.warning('no samples left after excluding rare classes.') 383 | return 384 | 385 | return cohorts_df 386 | 387 | 388 | def _discretize(xs: Sequence[Number], n_bins: int) -> Sequence[str]: 389 | """Returns a discretized version of a Sequence of continuous values.""" 390 | unsqueezed = torch.tensor(xs).reshape(-1, 1) 391 | est = preprocessing.KBinsDiscretizer( 392 | n_bins=n_bins, encode='ordinal').fit(unsqueezed) 393 | labels = [f'[-inf,{est.bin_edges_[0][1]})', # label for smallest class 394 | # labels for intermediate classes 395 | *(f'[{lower},{upper})' 396 | for lower, upper in zip(est.bin_edges_[0][1:], est.bin_edges_[0][2:-1])), 397 | f'[{est.bin_edges_[0][-2]},inf)'] # label for largest class 398 | label_map = dict(enumerate(labels)) 399 | discretized = est.transform(unsqueezed).reshape(-1).astype(int) 400 | return list(map(label_map.get, discretized)) # type: ignore 401 | 402 | 403 | def _balance_classes(tiles_df: pd.DataFrame, target: str) -> pd.DataFrame: 404 | smallest_class_count = min(tiles_df[target].value_counts()) 405 | for label in tiles_df[target].unique(): 406 | tiles_with_label = tiles_df[tiles_df[target] == label] 407 | to_keep = tiles_with_label.sample(n=smallest_class_count).index 408 | tiles_df = tiles_df[(tiles_df[target] != label) | 409 | (tiles_df.index.isin(to_keep))] 410 | 411 | return tiles_df 412 | 413 | 414 | SimpleRun = factory(_simple_run) 415 | -------------------------------------------------------------------------------- /deepmed/get/_subgroup.py: -------------------------------------------------------------------------------- 1 | from typing import Iterable, Iterator, Callable, Union 2 | from pathlib import Path 3 | 4 | import pandas as pd 5 | 6 | from .._experiment import Task, EvalTask 7 | from ..evaluators.types import Evaluator 8 | from ..utils import factory 9 | 10 | 11 | # @log_defaults 12 | def _subgroup( 13 | get, 14 | *args, 15 | project_dir: Path, 16 | target_label: str, 17 | subgrouper: Callable[[pd.Series], Union[str, None]], 18 | subgroup_evaluators: Iterable[Evaluator] = [], 19 | cohort_df_arg_names: Iterable[str] = [ 20 | 'cohorts_df', 'train_cohorts_df', 'test_cohorts_df'], 21 | **kwargs) -> Iterator[Task]: 22 | """Splits a training data set into multiple subgroups. 23 | 24 | Args: 25 | train_cohorts_df: Base data set to be split into subgroups. 26 | subgrouper: A function mapping a sample of the training dataset onto a 27 | subgroup. The function is given a row from the training dataset and 28 | has to return either a string describing the group name, or None if 29 | it shall be excluded from training. 30 | subgroup_evaluators: A list of evaluators to be executed after all 31 | subgroup runs have been completed. 32 | cohort_df_arg_names: The keys of cohort_dfs passed as kwargs to adapted 33 | task getters. 34 | """ 35 | assert any(arg_name in kwargs for arg_name in cohort_df_arg_names), \ 36 | f'none of {cohort_df_arg_names} given to `Subgroup()`!' 37 | groups = { 38 | cohorts_df_name: kwargs[cohorts_df_name].apply(subgrouper, axis=1) 39 | for cohorts_df_name in cohort_df_arg_names 40 | if cohorts_df_name in kwargs 41 | } 42 | assert groups, 'no subgroup instances found!' 43 | group_names = { 44 | x 45 | for gs in groups.values() 46 | for x in gs.unique() 47 | if x is not None 48 | } 49 | 50 | tasks = ( 51 | task 52 | for group_name in group_names 53 | for task in get( # type: ignore 54 | *args, 55 | project_dir=project_dir/group_name, 56 | target_label=target_label, 57 | **{**kwargs, 58 | **{ 59 | cohorts_df_name: kwargs[cohorts_df_name][gs == group_name] 60 | for cohorts_df_name, gs in groups.items() 61 | }})) 62 | 63 | requirements = [] 64 | for task in tasks: 65 | yield task 66 | requirements.append(task) 67 | 68 | yield EvalTask( 69 | path=project_dir, 70 | target_label=target_label, 71 | requirements=requirements, 72 | evaluators=subgroup_evaluators) 73 | 74 | Subgroup = factory(_subgroup) -------------------------------------------------------------------------------- /deepmed/mil.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass, field 2 | import logging 3 | from pathlib import Path 4 | import shutil 5 | from fastai.callback.progress import CSVLogger 6 | from fastai.callback.tracker import EarlyStoppingCallback, SaveModelCallback 7 | from fastai.data.block import CategoryBlock, DataBlock, RegressionBlock, TransformBlock 8 | from fastai.data.transforms import ColReader, ColSplitter 9 | from fastai.learner import Learner, load_learner 10 | from fastai.losses import CrossEntropyLossFlat 11 | from fastai.vision.learner import create_head 12 | from fastcore.transform import Transform 13 | import h5py 14 | import os 15 | import torch 16 | import pandas as pd 17 | from torch import nn 18 | from tqdm import tqdm 19 | from typing import Callable, Iterable, Optional, Tuple, Union 20 | 21 | from deepmed.types import GPUTask 22 | from deepmed.utils import is_continuous 23 | 24 | 25 | __all__ = ['get_h5s', 'MILBagTransform', 'Attention', 'GatedAttention', 'MILModel', 'Train'] 26 | 27 | def get_h5s( 28 | dataset_type, cohorts_df: pd.DataFrame, 29 | resample_each_epoch: bool = True, 30 | logger=logging, 31 | ) -> pd.DataFrame: 32 | """Create df containing patient, tiles, other data.""" 33 | cohorts_df.slide_path = cohorts_df.slide_path.map(lambda p: p.parent/f'{p.name}.h5') 34 | cohorts_df = cohorts_df[cohorts_df.slide_path.map(lambda p: p.exists())] 35 | 36 | logger.info( 37 | f'Found {len(cohorts_df)} slides for {len(cohorts_df["PATIENT"].unique())} patients') 38 | 39 | return cohorts_df 40 | 41 | 42 | def _to_fixed_size_bag(bag: torch.Tensor, bag_size: int = 512) -> Tuple[torch.Tensor, int]: 43 | # get up to bag_size elements 44 | bag_idxs = torch.randperm(bag.shape[0])[:bag_size] 45 | bag_samples = bag[bag_idxs] 46 | 47 | # zero-pad if we don't have enough samples 48 | zero_padded = torch.cat((bag_samples, 49 | torch.zeros(bag_size-bag_samples.shape[0], bag_samples.shape[1]))) 50 | return zero_padded, min(bag_size, len(bag)) 51 | 52 | 53 | class MILBagTransform(Transform): 54 | def __init__(self, valid_files: Iterable[os.PathLike], max_bag_size: int = 512) -> None: 55 | self.max_train_bag_size = max_bag_size 56 | self.valid = {fn: self._draw(fn) for fn in tqdm(valid_files, leave=False)} 57 | 58 | def encodes(self, fn):# -> Tuple[torch.Tensor, int]: 59 | if not isinstance(fn, (Path, str)): 60 | return fn 61 | 62 | return self.valid.get(fn, self._draw(fn)) 63 | 64 | def _draw(self, fn: Union[str, Path]) -> Tuple[torch.Tensor, int]: 65 | with h5py.File(fn, 'r') as f: 66 | feats = torch.from_numpy(f['feats'][:]) 67 | return _to_fixed_size_bag(feats, bag_size=self.max_train_bag_size) 68 | 69 | 70 | def Attention(n_in: int, n_latent: Optional[int] = None) -> nn.Module: 71 | """A network calculating an embedding's importance weight. 72 | 73 | Taken from arXiv:1802.04712 74 | """ 75 | n_latent = n_latent or (n_in + 1) // 2 76 | 77 | return nn.Sequential( 78 | nn.Linear(n_in, n_latent), 79 | nn.Tanh(), 80 | nn.Linear(n_latent, 1)) 81 | 82 | 83 | class GatedAttention(nn.Module): 84 | """A network calculating an embedding's importance weight. 85 | 86 | Taken from arXiv:1802.04712 87 | """ 88 | 89 | def __init__(self, n_in: int, n_latent: Optional[int] = None) -> None: 90 | super().__init__() 91 | n_latent = n_latent or (n_in + 1) // 2 92 | 93 | self.fc1 = nn.Linear(n_in, n_latent) 94 | self.gate = nn.Linear(n_in, n_latent) 95 | self.fc2 = nn.Linear(n_latent, 1) 96 | 97 | def forward(self, h: torch.Tensor) -> torch.Tensor: 98 | return self.fc2(torch.tanh(self.fc1(h)) * torch.sigmoid(self.gate(h))) 99 | 100 | 101 | class MILModel(nn.Module): 102 | def __init__( 103 | self, n_feats: int, n_out: int, 104 | encoder: Optional[nn.Module] = None, 105 | attention: Optional[nn.Module] = None, 106 | head: Optional[nn.Module] = None, 107 | with_attention_scores: bool = False, 108 | ) -> None: 109 | """ 110 | 111 | Args: 112 | n_feats: The nuber of features each bag instance has. 113 | n_out: The number of output layers of the model. 114 | encoder: A network transforming bag instances into feature vectors. 115 | """ 116 | super().__init__() 117 | self.encoder = encoder or nn.Sequential( 118 | nn.Linear(n_feats, 256), nn.ReLU()) 119 | self.attention = attention or Attention(256)# GatedAttention(512) 120 | self.head = head or create_head( 121 | 256, n_out, concat_pool=False, lin_ftrs=[])[1:] 122 | 123 | self.with_attention_scores = with_attention_scores 124 | 125 | def forward(self, bags_and_lens): 126 | bags, lens = bags_and_lens 127 | assert bags.ndim == 3 128 | assert bags.shape[0] == lens.shape[0] 129 | 130 | embeddings = self.encoder(bags) 131 | 132 | masked_attention_scores = self._masked_attention_scores( 133 | embeddings, lens) 134 | weighted_embedding_sums = ( 135 | masked_attention_scores * embeddings).sum(-2) 136 | 137 | scores = self.head(weighted_embedding_sums) 138 | 139 | return scores 140 | 141 | def _masked_attention_scores(self, embeddings, lens): 142 | """Calculates attention scores for all bags. 143 | 144 | Returns: 145 | A tensor containing 146 | * The attention score of instance i of bag j if i < len[j] 147 | * 0 otherwise 148 | """ 149 | bs, bag_size = embeddings.shape[0], embeddings.shape[1] 150 | attention_scores = self.attention(embeddings) 151 | 152 | # a tensor containing a row [0, ..., bag_size-1] for each batch instance 153 | idx = (torch.arange(bag_size) 154 | .repeat(bs, 1) 155 | .to(attention_scores.device)) 156 | 157 | # False for every instance of bag i with index(instance) >= lens[i] 158 | attention_mask = (idx < lens.unsqueeze(-1)).unsqueeze(-1) 159 | 160 | masked_attention = torch.where( 161 | attention_mask, 162 | attention_scores, 163 | torch.full_like(attention_scores, -1e10)) 164 | return torch.softmax(masked_attention, dim=1) 165 | 166 | 167 | @dataclass 168 | class Train: 169 | """Trains a single model. 170 | 171 | Args: 172 | task: The task to train a model for. 173 | arch: The architecture of the model to train. 174 | max_epochs: The absolute maximum number of epochs to train. 175 | lr: The initial learning rate. 176 | num_workers: The number of workers to use in the data loaders. Set to 177 | 0 on windows! 178 | tfms: Transforms to apply to the data. 179 | metrics: The metrics to calculate on the validation set each epoch. 180 | patience: The number of epochs without improvement before stopping the 181 | training. 182 | monitor: The metric to monitor for early stopping. 183 | 184 | Returns: 185 | The trained model. 186 | 187 | If the training is interrupted, it will be continued from the last model 188 | checkpoint. 189 | """ 190 | max_bag_size: int = 512 191 | batch_size: int = 32 192 | max_epochs: int = 64 193 | lr: Optional[float] = 1e-3 194 | num_workers: int = 0 195 | metrics: Iterable[Callable] = field(default_factory=list) 196 | patience: int = 12 197 | monitor: str = 'valid_loss' 198 | 199 | def __call__(self, task: GPUTask) -> Optional[Learner]: 200 | logger = logging.getLogger(str(task.path)) 201 | 202 | if (model_path := task.path/'export.pkl').exists(): 203 | logger.warning(f'{model_path} already exists! using old model...') 204 | return load_learner(model_path) 205 | 206 | target_label, train_df = task.target_label, task.train_df 207 | 208 | if train_df is None: 209 | logger.warning('Cannot train: no training set given!') 210 | return None 211 | 212 | # create dataloader 213 | y_block = RegressionBlock if is_continuous( 214 | train_df[target_label]) else CategoryBlock 215 | 216 | train_df.slide_path = train_df.slide_path.map(Path) 217 | mil_tfm = MILBagTransform(train_df[~train_df.is_valid].slide_path, self.max_bag_size) 218 | dblock = DataBlock(blocks=(TransformBlock, y_block), 219 | get_x=ColReader('slide_path'), 220 | get_y=ColReader(target_label), 221 | splitter=ColSplitter('is_valid'), 222 | item_tfms=mil_tfm) 223 | dls = dblock.dataloaders( 224 | train_df, bs=self.batch_size, num_workers=self.num_workers) 225 | 226 | target_col_idx = train_df[~train_df.is_valid].columns.get_loc(target_label) 227 | 228 | logger.debug( 229 | f'Class counts in training set: {train_df[~train_df.is_valid].iloc[:, target_col_idx].value_counts()}') 230 | logger.debug( 231 | f'Class counts in validation set: {train_df[train_df.is_valid].iloc[:, target_col_idx].value_counts()}') 232 | 233 | # create weighted loss function in case of categorical data 234 | if is_continuous(train_df[target_label]): 235 | loss_func = None 236 | else: 237 | counts = train_df[~train_df.is_valid].iloc[:, target_col_idx].value_counts() 238 | weight = counts.sum() / counts 239 | weight /= weight.sum() 240 | # reorder according to vocab 241 | weight = torch.tensor(list(map(weight.get, dls.vocab)), dtype=torch.float32) 242 | loss_func = CrossEntropyLossFlat(weight=weight.cuda()) 243 | logger.info(f'{dls.vocab = }, {weight = }') 244 | 245 | feat_no = dls.one_batch()[0][0].shape[-1] 246 | learn = Learner(dls, MILModel(feat_no, dls.c), 247 | path=task.path, loss_func=loss_func, metrics=self.metrics) 248 | 249 | # save the features' extractor in the model so we can trace it back later 250 | with h5py.File(train_df.slide_path.iloc[0]) as f: 251 | learn.extractor_checksum = f.attrs['extractor-checksum'] 252 | 253 | # find learning rate if necessary 254 | if not self.lr: 255 | logger.info('searching learning rate...') 256 | suggested_lrs = learn.lr_find() 257 | logger.info(f'{suggested_lrs = }') 258 | self.lr = suggested_lrs.valley 259 | 260 | # finally: train! 261 | cbs = [ 262 | SaveModelCallback( 263 | monitor=self.monitor, fname=f'best_{self.monitor}', reset_on_fit=False), 264 | SaveModelCallback(every_epoch=True, with_opt=True, 265 | reset_on_fit=False), 266 | EarlyStoppingCallback( 267 | monitor=self.monitor, min_delta=0.001, patience=self.patience, reset_on_fit=False), 268 | CSVLogger(append=True)] 269 | 270 | learn.fit_one_cycle(n_epoch=self.max_epochs, lr_max=self.lr, cbs=cbs) 271 | 272 | # make bag size "realistically big" for deployment 273 | mil_tfm.max_bag_size = max(_bag_lens(train_df[train_df.is_valid].slide_path)) 274 | dls.valid.bs = 1 275 | 276 | learn.export() 277 | shutil.rmtree(task.path/'models') 278 | return learn 279 | 280 | 281 | def _bag_lens(h5_files: Iterable[os.PathLike]) -> Iterable[int]: 282 | lens = [] 283 | for fn in h5_files: 284 | with h5py.File(fn, 'r') as f: 285 | lens.append(len(f['feats'])) 286 | return lens -------------------------------------------------------------------------------- /deepmed/multi_input.py: -------------------------------------------------------------------------------- 1 | import shutil 2 | import os 3 | import math 4 | import logging 5 | from typing import Callable, Iterable, Optional, cast 6 | from pathlib import Path 7 | from dataclasses import dataclass 8 | from typing import Callable, Iterable, Union, Optional 9 | from fastai.callback.hook import num_features_model 10 | from fastai.callback.progress import CSVLogger 11 | from fastai.callback.tracker import EarlyStoppingCallback, SaveModelCallback, TrackerCallback 12 | from functools import partial 13 | from fastai.data.block import CategoryBlock, DataBlock, TransformBlock 14 | from fastai.data.transforms import CategoryMap, ColReader, ColSplitter, RegressionSetup, get_c 15 | from fastai.layers import AdaptiveConcatPool2d, Flatten 16 | from fastai.learner import Learner, load_learner 17 | from fastai.losses import CrossEntropyLossFlat 18 | from fastai.metrics import BalancedAccuracy 19 | from fastai.optimizer import Adam 20 | from fastai.torch_core import apply_init, params 21 | from fastai.vision.augment import aug_transforms 22 | from fastai.vision.data import ImageBlock 23 | from fastai.vision.learner import create_body, create_cnn_model, create_head, model_meta 24 | from fastcore.basics import ifnone, store_attr, defaults 25 | from fastcore.foundation import L 26 | from fastcore.meta import delegates 27 | 28 | import torch 29 | import pandas as pd 30 | from torch import nn 31 | 32 | from fastai.vision.learner import _add_norm, _default_meta 33 | from torchvision.models.resnet import resnet18 34 | 35 | from .utils import factory, log_defaults 36 | 37 | __all__ = ['Train'] 38 | 39 | 40 | class MultiInputModel(nn.Module): 41 | """A model which takes tabular information in addition to an image. 42 | 43 | In some cases, there may be additinal information available which may aid in 44 | classification. This model extends a CNN by feeding this information into 45 | the in addition to the image features calcuated by the convolutional layers. 46 | """ 47 | 48 | def __init__( 49 | self, arch, n_out: int, n_additional: int, n_in: int = 3, init=nn.init.kaiming_normal_, 50 | pretrained: bool = True, cut=None) -> None: 51 | super().__init__() 52 | 53 | meta = model_meta.get(arch, _default_meta) 54 | body = create_body(arch, n_in, pretrained, ifnone(cut, meta['cut'])) 55 | self.cnn_feature_extractor = nn.Sequential( 56 | body, AdaptiveConcatPool2d(), Flatten()) 57 | 58 | nf_body = num_features_model(nn.Sequential(*body.children())) 59 | # throw away pooling / flattenting layers 60 | self.head = create_head(nf_body*2 + n_additional, 61 | n_out, concat_pool=False)[2:] 62 | if init is not None: 63 | apply_init(self.head, init) 64 | 65 | def forward(self, img, *tab): 66 | img_feats = self.cnn_feature_extractor(img) 67 | 68 | if tab: 69 | stack_val = torch.stack((tab), axis=1) 70 | features = torch.cat([img_feats, stack_val], dim=1) 71 | else: 72 | features = img_feats 73 | return self.head(features) 74 | 75 | 76 | def multi_input_splitter(model, base_splitter): 77 | # TODO HIER HABE ICH AUFGEHOERT 78 | return [*base_splitter(model.cnn_feature_extractor)[:-1], params(model.head)] 79 | 80 | 81 | @dataclass 82 | class Normalize: 83 | mean: float 84 | std: float 85 | 86 | def __call__(self, x): 87 | x = float(x) 88 | return (x - self.mean)/self.std if not math.isnan(x) else 0 89 | 90 | 91 | @delegates(create_cnn_model) 92 | def multi_input_learner( 93 | dls, arch, normalize=True, n_out=None, n_additional=0, pretrained=True, 94 | # learner args 95 | loss_func=None, opt_func=Adam, lr=defaults.lr, splitter=None, cbs=None, metrics=None, 96 | path=None, model_dir='models', wd=None, wd_bn_bias=False, train_bn=True, 97 | moms=(0.95, 0.85, 0.95), 98 | # other model args 99 | **kwargs): 100 | # adapted from fastai.vision.learner.cnn_learner 101 | 102 | meta = model_meta.get(arch, _default_meta) 103 | if normalize: 104 | _add_norm(dls, meta, pretrained) 105 | 106 | if n_out is None: 107 | n_out = L(get_c(dls))[-1] 108 | assert n_out, \ 109 | "`n_out` is not defined, and could not be inferred from data, set `dls.c` or pass `n_out`" 110 | model = MultiInputModel( 111 | arch, n_out=n_out, n_additional=n_additional, pretrained=pretrained, **kwargs) 112 | 113 | splitter = ifnone(splitter, meta['split']) 114 | splitter = partial(multi_input_splitter, base_splitter=splitter) 115 | learn = Learner( 116 | dls=dls, model=model, loss_func=loss_func, opt_func=opt_func, lr=lr, splitter=splitter, 117 | cbs=cbs, metrics=metrics, path=path, model_dir=model_dir, wd=wd, wd_bn_bias=wd_bn_bias, 118 | train_bn=train_bn, moms=moms) 119 | 120 | if pretrained: 121 | learn.freeze() 122 | 123 | # keep track of args for loggers 124 | store_attr('arch, normalize, n_out, pretrained', self=learn, **kwargs) 125 | return learn 126 | 127 | 128 | @dataclass 129 | class Category: 130 | name: str 131 | vocab: Optional[Iterable[str]] = None 132 | 133 | @property 134 | def block(self) -> CategoryBlock: 135 | return CategoryBlock(vocab=self, sort=False, add_na=True) 136 | 137 | def __str__(self) -> str: 138 | return self.name 139 | 140 | 141 | @log_defaults 142 | def _train( 143 | task: 'GPUTask', /, #type: ignore 144 | arch: Callable[[bool], nn.Module] = resnet18, 145 | batch_size: int = 64, 146 | max_epochs: int = 10, 147 | lr: float = 2e-3, 148 | num_workers: int = 0, 149 | tfms: Callable = aug_transforms( 150 | flip_vert=True, max_rotate=360, max_zoom=1, max_warp=0, size=224), 151 | metrics: Iterable[Callable] = [BalancedAccuracy()], 152 | patience: int = 3, 153 | monitor: str = 'valid_loss', 154 | conts: Iterable[str] = [], 155 | cats: Iterable[Union[str, Category]] = []) -> Optional[Learner]: 156 | """Trains a single model. 157 | 158 | Args: 159 | batch_size: The number of training samples used through the network during one forward and backward pass. 160 | task: The task to train a model for. 161 | arch: The architecture of the model to train. 162 | max_epochs: The absolute maximum number of epochs to train. 163 | lr: The initial learning rate. 164 | num_workers: The number of workers to use in the data loaders. Set to 165 | 0 on windows! 166 | tfms: Transforms to apply to the data. 167 | metrics: The metrics to calculate on the validation set each epoch. 168 | patience: The number of epochs without improvement before stopping the 169 | training. 170 | monitor: The metric to monitor for early stopping. 171 | 172 | Returns: 173 | The trained model. 174 | 175 | If the training is interrupted, it will be continued from the last model 176 | checkpoint. 177 | """ 178 | logger = logging.getLogger(str(task.path)) 179 | 180 | if (model_path := task.path/'export.pkl').exists(): 181 | logger.warning(f'{model_path} already exists! using old model...') 182 | return load_learner(model_path) 183 | 184 | target_label, train_df, result_dir = task.target_label, task.train_df, task.path 185 | 186 | if train_df is None: 187 | logger.debug('Cannot train: no training set given!') 188 | return None 189 | 190 | for col in conts: 191 | train_df[col] = train_df[col].astype(float) 192 | 193 | conts = [cont for cont in conts if cont != target_label] 194 | cats = [cat for cat in cats if str(cat) != target_label] 195 | 196 | cont_blocks = [ 197 | TransformBlock(type_tfms=[Normalize( 198 | mean=mean, std=std), RegressionSetup()]) 199 | for label in conts 200 | for mean, std in [(train_df[label].mean(), train_df[label].std())]] 201 | cat_blocks = [ 202 | CategoryBlock(add_na=True) if isinstance(cat, str) 203 | else cat.block 204 | for cat in cats] 205 | 206 | dblock = DataBlock( 207 | blocks=( 208 | ImageBlock, 209 | *cont_blocks, 210 | *cat_blocks, 211 | CategoryBlock), 212 | getters=( 213 | ColReader('tile_path'), 214 | *(ColReader(name) for name in conts), 215 | *(ColReader(name) for name in cats), 216 | ColReader(target_label), 217 | ), 218 | splitter=ColSplitter('is_valid'), 219 | batch_tfms=tfms) 220 | 221 | dls = dblock.dataloaders(train_df, bs=batch_size, num_workers=num_workers) 222 | 223 | target_col_idx = train_df[~train_df.is_valid].columns.get_loc(target_label) 224 | 225 | logger.debug( 226 | 'Class counts in training set: ' 227 | f'{dict(train_df[~train_df.is_valid].iloc[:, target_col_idx].value_counts())}') 228 | logger.debug( 229 | 'Class counts in validation set: ' 230 | f'{dict(train_df[train_df.is_valid].iloc[:, target_col_idx].value_counts())}') 231 | 232 | counts = train_df[~train_df.is_valid].iloc[:, 233 | target_col_idx].value_counts() 234 | 235 | vocab = dls.vocab if isinstance(dls.vocab, CategoryMap) else dls.vocab[-1] 236 | counts = torch.tensor([counts[k] for k in vocab]) 237 | weights = 1 - (counts / sum(counts)) 238 | 239 | logger.debug(f'{dls.vocab = }, {weights = }') 240 | 241 | learn = multi_input_learner( 242 | dls, arch, 243 | n_additional=len(conts)+len(cats), 244 | path=result_dir, 245 | loss_func=CrossEntropyLossFlat(weight=weights.cuda()), 246 | metrics=metrics) 247 | 248 | cbs = [ 249 | SaveModelCallback( 250 | monitor=monitor, fname=f'best_{monitor}', reset_on_fit=False), 251 | SaveModelCallback(every_epoch=True, with_opt=True, reset_on_fit=False), 252 | EarlyStoppingCallback( 253 | monitor=monitor, min_delta=0.001, patience=patience, reset_on_fit=False), 254 | CSVLogger(append=True)] 255 | 256 | if (result_dir/'models'/f'best_{monitor}.pth').exists(): 257 | _fit_from_checkpoint( 258 | learn=learn, result_dir=result_dir, lr=lr/2, max_epochs=max_epochs, cbs=cbs, 259 | monitor=monitor, logger=logger) 260 | else: 261 | learn.fine_tune(epochs=max_epochs, base_lr=lr, cbs=cbs) 262 | 263 | learn.export() 264 | 265 | shutil.rmtree(result_dir/'models') 266 | 267 | return learn 268 | 269 | 270 | def _fit_from_checkpoint( 271 | learn: Learner, result_dir: Path, lr: float, max_epochs: int, cbs: Iterable[Callable], 272 | monitor: str, logger) \ 273 | -> None: 274 | logger.info('Continuing from checkpoint...') 275 | 276 | # get best performance so far 277 | history_df = pd.read_csv(result_dir/'history.csv') 278 | scores = pd.to_numeric(history_df[monitor], errors='coerce') 279 | high_score = scores.min() if 'loss' in monitor or 'error' in monitor else scores.max() 280 | logger.info(f'Best {monitor} up to checkpoint: {high_score}') 281 | 282 | # update tracker callback's high scores 283 | for cb in cbs: 284 | if isinstance(cb, TrackerCallback): 285 | cb.best = high_score 286 | 287 | # load newest model 288 | name = max((result_dir/'models').glob('model_*.pth'), 289 | key=os.path.getctime).stem 290 | learn.load(name, with_opt=True, strict=True) 291 | 292 | remaining_epochs = max_epochs - int(name.split('_')[1]) 293 | logger.info(f'{remaining_epochs = }') 294 | learn.unfreeze() 295 | learn.fit_one_cycle(remaining_epochs, slice( 296 | lr/100, lr), pct_start=.3, div=5., cbs=cbs) 297 | 298 | 299 | Train = factory(_train) 300 | -------------------------------------------------------------------------------- /deepmed/on_features.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | import os 3 | import random 4 | import shutil 5 | import logging 6 | from typing import Callable, Iterable, Mapping, Optional 7 | from dataclasses import dataclass, field 8 | from fastai.callback.progress import CSVLogger 9 | from fastai.callback.tracker import EarlyStoppingCallback, SaveModelCallback 10 | from fastai.data.block import CategoryBlock, DataBlock, RegressionBlock, TransformBlock 11 | from fastai.data.transforms import ColReader, ColSplitter 12 | from fastai.learner import Learner, load_learner 13 | from fastai.losses import CrossEntropyLossFlat 14 | from fastai.vision.learner import create_head 15 | from fastcore.foundation import L 16 | from tqdm import tqdm 17 | 18 | import torch 19 | import h5py 20 | import numpy as np 21 | import pandas as pd 22 | 23 | from .types import GPUTask 24 | from .utils import is_continuous 25 | from .get import DatasetType 26 | 27 | __all__ = ['Train', 'get_h5s'] 28 | 29 | 30 | def get_h5s( 31 | dataset_type: DatasetType, cohorts_df: pd.DataFrame, 32 | max_tile_nums: Mapping[DatasetType, int] = {DatasetType.TRAIN: 128, 33 | DatasetType.VALID: 256, 34 | DatasetType.TEST: 512}, 35 | resample_each_epoch: bool = False, 36 | logger=logging, 37 | ) -> pd.DataFrame: 38 | """Create df containing patient, tiles, other data.""" 39 | cohorts_df.slide_path = cohorts_df.slide_path.map(lambda p: p.parent/f'{p.name}.h5') 40 | cohorts_df = cohorts_df[cohorts_df.slide_path.map(lambda p: p.exists())] 41 | 42 | tiles_dfs = [] 43 | for slide_path in tqdm(cohorts_df.slide_path): 44 | with h5py.File(slide_path, 'r') as f: 45 | tiles = [(slide_path, i) 46 | for i in range(len(f['feats']))] 47 | if (tile_num := max_tile_nums.get(dataset_type)): 48 | tiles = random.sample(tiles, min(len(tiles), tile_num)) 49 | tiles_df = pd.DataFrame(tiles, columns=['slide_path', 'i']) 50 | 51 | tiles_dfs.append(tiles_df) 52 | 53 | tiles_df = pd.concat(tiles_dfs) 54 | tiles_df = cohorts_df.merge(tiles_df, on='slide_path').reset_index() 55 | 56 | logger.info( 57 | f'Found {len(cohorts_df)} tiles for {len(cohorts_df["PATIENT"].unique())} patients') 58 | 59 | # if we want the training procedure to resample a slide's tiles every epoch, 60 | # we have to supply a slide path instead of the tile path 61 | if dataset_type == DatasetType.TRAIN and resample_each_epoch: 62 | tiles_df.i = -1 63 | 64 | return tiles_df 65 | 66 | 67 | def load_feats(args: L): 68 | path, i = args 69 | with h5py.File(path, 'r') as f: 70 | # check if all features stem from the same extractor 71 | #h5_checksum = f.attrs['extractor-checksum'] 72 | #assert self.extractor_checksum == h5_checksum, \ 73 | # f'feature extractor mismatch for {path} ' \ 74 | # f'(expected {self.extractor_checksum:08x}, got {h5_checksum:08x})' 75 | if i == -1: 76 | return torch.from_numpy(f['feats'][np.random.randint(len(f['feats']))]) 77 | else: 78 | return torch.from_numpy(f['feats'][i]) 79 | 80 | 81 | @dataclass 82 | class Train: 83 | """Trains a single model. 84 | 85 | Args: 86 | batch_size: The number of training samples used through the network during one forward and backward pass. 87 | task: The task to train a model for. 88 | arch: The architecture of the model to train. 89 | max_epochs: The absolute maximum number of epochs to train. 90 | lr: The initial learning rate. 91 | num_workers: The number of workers to use in the data loaders. Set to 92 | 0 on windows! 93 | tfms: Transforms to apply to the data. 94 | metrics: The metrics to calculate on the validation set each epoch. 95 | patience: The number of epochs without improvement before stopping the 96 | training. 97 | monitor: The metric to monitor for early stopping. 98 | 99 | Returns: 100 | The trained model. 101 | 102 | If the training is interrupted, it will be continued from the last model 103 | checkpoint. 104 | """ 105 | batch_size: int = 64 106 | max_epochs: int = 32 107 | lr: float = 2e-3 108 | num_workers: int = (32 if os.name == 'posix' else 0) 109 | metrics: Iterable[Callable] = field(default_factory=list) 110 | patience: int = 3 111 | monitor: str = 'valid_loss' 112 | 113 | def __call__(self, task: GPUTask) -> Optional[Learner]: 114 | logger = logging.getLogger(str(task.path)) 115 | 116 | if (model_path := task.path/'export.pkl').exists(): 117 | logger.warning(f'{model_path} already exists! using old model...') 118 | return load_learner(model_path) 119 | 120 | target_label, train_df, result_dir = task.target_label, task.train_df, task.path 121 | 122 | if train_df is None: 123 | logger.warning('Cannot train: no training set given!') 124 | return None 125 | 126 | y_block = RegressionBlock if is_continuous(train_df[target_label]) else CategoryBlock 127 | dblock = DataBlock(blocks=(TransformBlock(item_tfms=load_feats), y_block), 128 | get_x=ColReader(['slide_path', 'i']), 129 | get_y=ColReader(target_label), 130 | splitter=ColSplitter('is_valid')) 131 | dls = dblock.dataloaders( 132 | train_df, bs=self.batch_size, num_workers=self.num_workers) 133 | 134 | logger.debug( 135 | f'Class counts in training set: {train_df[~train_df.is_valid][target_label].value_counts()}') 136 | logger.debug( 137 | f'Class counts in validation set: {train_df[train_df.is_valid][target_label].value_counts()}') 138 | 139 | if is_continuous(train_df[target_label]): 140 | loss_func = None 141 | else: 142 | counts = torch.tensor(train_df[~train_df.is_valid][target_label].value_counts()) 143 | weight = counts.sum() / counts 144 | weight /= weight.sum() 145 | loss_func = CrossEntropyLossFlat(weight=weight.cuda()) 146 | logger.debug(f'{dls.vocab = }, {weight = }') 147 | 148 | n_feats = dls.one_batch()[0].shape[-1] 149 | head = create_head(n_feats, dls.c, concat_pool=False)[2:] 150 | 151 | learn = Learner( 152 | dls, head, 153 | path=result_dir, 154 | loss_func=loss_func, 155 | metrics=self.metrics) 156 | 157 | # save the features' extractor in the model so we can trace it back later 158 | with h5py.File(train_df.slide_path.iloc[0]) as f: 159 | learn.extractor_checksum = f.attrs['extractor-checksum'] 160 | 161 | cbs = [ 162 | SaveModelCallback( 163 | monitor=self.monitor, fname=f'best_{self.monitor}', reset_on_fit=False), 164 | SaveModelCallback(every_epoch=True, with_opt=True, 165 | reset_on_fit=False), 166 | EarlyStoppingCallback( 167 | monitor=self.monitor, min_delta=0.001, patience=self.patience, reset_on_fit=False), 168 | CSVLogger(append=True)] 169 | 170 | learn.fit_one_cycle(n_epoch=self.max_epochs, lr_max=self.lr, cbs=cbs) 171 | 172 | learn.export() 173 | shutil.rmtree(result_dir/'models') 174 | return learn -------------------------------------------------------------------------------- /deepmed/types.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | from abc import ABC, abstractmethod 3 | from collections import abc 4 | import threading 5 | from deepmed.utils import exists_and_has_size 6 | import logging 7 | from itertools import cycle 8 | 9 | from typing import Any, Optional, Callable, Iterator, Union, Mapping, final 10 | from typing_extensions import Protocol 11 | from pathlib import Path 12 | from dataclasses import dataclass, field 13 | from threading import Event 14 | from fastai.learner import Learner 15 | 16 | import pandas as pd 17 | from threading import Semaphore 18 | 19 | from typing import Iterable 20 | from pathlib import Path 21 | 22 | import torch 23 | import pandas as pd 24 | 25 | from .evaluators.types import Evaluator 26 | 27 | __all__ = [ 28 | 'Task', 'GPUTask', 'EvalTask', 'TaskGetter', 'Trainer', 'Deployer', 'PathLike'] 29 | 30 | 31 | @dataclass # type: ignore 32 | class Task(ABC): 33 | path: Path 34 | """The directory to save data in for this task.""" 35 | 36 | requirements: Iterable[Task] 37 | """List of events which have to have occurred before this task can be 38 | started.""" 39 | 40 | done: Event = field(default_factory=threading.Event, init=False) 41 | """Whether this task has concluded.""" 42 | 43 | res: Any = field(default=None, init=False) 44 | 45 | def run(self) -> None: 46 | """Start this task.""" 47 | for reqirement in self.requirements: 48 | reqirement.done.wait() 49 | try: 50 | self.res = self.do_work() 51 | finally: 52 | self.done.set() 53 | 54 | @abstractmethod 55 | def do_work(self) -> Any: 56 | ... 57 | 58 | 59 | class TaskGetter(Protocol): 60 | def __call__( 61 | self, project_dir: Path, capacities: Mapping[Union[int, str], Semaphore] 62 | ) -> Iterator[Task]: 63 | """A function which creates a series of task. 64 | 65 | Args: 66 | project_dir: The directory to save the task's data in. 67 | 68 | Returns: 69 | An iterator over all tasks. 70 | """ 71 | raise NotImplementedError() 72 | 73 | 74 | Trainer = Callable[['GPUTask'], Optional[Learner]] 75 | """A function which trains a model. 76 | 77 | Args: 78 | task: The task to train. 79 | 80 | Returns: 81 | The trained model. 82 | """ 83 | 84 | Deployer = Callable[[Learner, Task], pd.DataFrame] 85 | """A function which deployes a model. 86 | 87 | Writes the results to a file ``predictions.csv.zip`` in the task directory. 88 | 89 | Args: 90 | model: The model to test on. 91 | target_label: The name to be given to the result column. 92 | test_df: A dataframe specifying which tiles to deploy the model on. 93 | result_dir: A folder to write intermediate results to. 94 | """ 95 | 96 | PathLike = Union[str, Path] 97 | 98 | 99 | @dataclass 100 | class GPUTask(Task): 101 | """A collection of data to train or test a model.""" 102 | 103 | target_label: str 104 | """The name of the target to train or deploy on.""" 105 | 106 | train: Trainer 107 | deploy: Deployer 108 | 109 | train_df: Optional[pd.DataFrame] 110 | """A dataframe mapping tiles to be used for training to their 111 | targets. 112 | 113 | It contains at least the following columns: 114 | - tile_path: Path 115 | - is_valid: bool: whether the tile should be used for validation (e.g. for 116 | early stopping). 117 | - At least one target column with the name saved in the task's `target`. 118 | """ 119 | test_df: Optional[pd.DataFrame] 120 | """A dataframe mapping tiles used for testing to their targets. 121 | 122 | It contains at least the following columns: 123 | - tile_path: Path 124 | """ 125 | 126 | capacities: Mapping[Union[int, str], Semaphore] 127 | """Mapping of pytorch device names to their current capacities.""" 128 | 129 | def do_work(self) -> None: 130 | logger = logging.getLogger(str(self.path)) 131 | logger.info(f'Starting GPU task') 132 | 133 | for device, capacity in cycle(self.capacities.items()): 134 | # search for a free gpu 135 | if not capacity.acquire(blocking=False): # type: ignore 136 | continue 137 | try: 138 | with torch.cuda.device(device): 139 | learn = self.train(self) 140 | if learn: 141 | self.deploy(learn, self) 142 | break 143 | finally: 144 | capacity.release() 145 | 146 | 147 | @dataclass 148 | class EvalTask(Task): 149 | target_label: Optional[str] 150 | """The name of the target to train or deploy on.""" 151 | 152 | evaluators: Iterable[Evaluator] 153 | 154 | def do_work(self) -> None: 155 | logger = logging.getLogger(str(self.path)) 156 | logger.info('Evaluating') 157 | 158 | preds_df = _generate_preds_df(self.path) 159 | stats_df = None 160 | for evaluate in self.evaluators: 161 | if (df := evaluate(self.target_label, preds_df, self.path)) is not None: 162 | assert isinstance(df, pd.DataFrame), \ 163 | f'{getattr(evaluate, __name__, evaluate)} did not return a DataFrame! ' \ 164 | 'Did you forget parentheses after its evaluator constructor ' \ 165 | f'(e.g. `{_camel_case_name(evaluate)}()`)?' 166 | 167 | if stats_df is None: 168 | stats_df = df 169 | stats_df.index.name = 'class' 170 | else: 171 | # make sure the two dfs have the same column level 172 | levels = max(stats_df.columns.nlevels, df.columns.nlevels) 173 | stats_df = _raise_df_column_level(stats_df, levels) 174 | df = _raise_df_column_level(df, levels) 175 | stats_df = stats_df.join(df) 176 | if stats_df is not None: 177 | stats_df.to_pickle(self.path/'stats.pkl') 178 | stats_df.to_excel(self.path/f'{self.path.name}_stats.xlsx') 179 | 180 | 181 | def _camel_case_name(obj) -> str: 182 | """Tries to construct the camel case name of an object.""" 183 | if hasattr(obj, '__name__'): 184 | return ''.join(word.title() for word in obj.__name__.split('_')) # make into CamelCase 185 | else: 186 | return repr(obj) # fallback: just return repr 187 | 188 | 189 | def _raise_df_column_level(df, level): 190 | if df.columns.empty: 191 | columns = pd.MultiIndex.from_product([[]] * level) 192 | elif isinstance(df.columns, pd.MultiIndex): 193 | columns = pd.MultiIndex.from_tuples([col + ('n/a',)*(level-df.columns.nlevels) 194 | for col in df.columns]) 195 | else: 196 | columns = pd.MultiIndex.from_tuples([(col,) + ('n/a',)*(level-df.columns.nlevels) 197 | for col in df.columns]) 198 | 199 | return pd.DataFrame(df.values, index=df.index, columns=columns) 200 | 201 | 202 | def _generate_preds_df(result_dir: Path) -> Optional[pd.DataFrame]: 203 | # load predictions 204 | if exists_and_has_size(preds_path := result_dir/'predictions.csv.zip'): 205 | preds_df = pd.read_csv(preds_path, low_memory=False) 206 | else: 207 | # create an accumulated predictions df if there isn't one already 208 | dfs = [] 209 | for df_path in result_dir.glob('**/predictions.csv.zip'): 210 | df = pd.read_csv(df_path, low_memory=False) 211 | # column which tells us which subset these predictions are from 212 | #TODO df[f'subset_{result_dir.name}'] = df_path.name 213 | dfs.append(df) 214 | 215 | if not dfs: 216 | return None 217 | 218 | preds_df = pd.concat(dfs) 219 | preds_df.to_csv(preds_path, index=False, compression='zip') 220 | 221 | return preds_df 222 | -------------------------------------------------------------------------------- /deepmed/utils.py: -------------------------------------------------------------------------------- 1 | import inspect 2 | import logging 3 | from typing import Callable, Any 4 | from functools import wraps, cached_property, partial 5 | import pandas as pd 6 | from pathlib import Path 7 | 8 | __all__ = ['log_defaults', 'Lazy', 'is_continuous', 'factory', 'exists_and_has_size'] 9 | 10 | 11 | def log_defaults(func): 12 | """Decorator which logs used default values of parameters of a function.""" 13 | @wraps(func) 14 | def default_logged(*args, **kwargs): 15 | # find unset kwargs with default values 16 | params = inspect.signature(func).parameters 17 | remaining_keys = list(params)[len(args):] 18 | params_with_defaults = [param 19 | for k in set(remaining_keys) - set(kwargs) 20 | if (param := params[k]).default != inspect.Parameter.empty] 21 | # log them 22 | for param in params_with_defaults: 23 | logging.getLogger(func.__module__).debug( 24 | f'using default value {param}') 25 | 26 | # call wrapped function 27 | return func(*args, **kwargs) 28 | 29 | return default_logged 30 | 31 | 32 | class Lazy: 33 | """A wrapper which constructs the underlying object only when it is needed.""" 34 | 35 | def __init__(self, factory: Callable[[], Any]) -> None: 36 | self._factory = factory 37 | 38 | @cached_property 39 | def _val(self): 40 | return self._factory() 41 | 42 | def __getattr__(self, k): 43 | return getattr(self._val, k) 44 | 45 | def __setattr__(self, k, v): 46 | if k == '_factory': 47 | super().__setattr__(k, v) 48 | else: 49 | setattr(self._val, k, v) 50 | 51 | def __getitem__(self, k): 52 | return self._val[k] 53 | 54 | def __setitem__(self, k, v): 55 | self._val[k] = v 56 | 57 | 58 | def is_continuous(series: pd.Series) -> bool: 59 | return series.dtype == float 60 | 61 | 62 | def factory(f: Callable) -> Callable[..., Callable]: 63 | @wraps(f) 64 | def g(*args, **kwargs) -> Callable: 65 | return partial(f, *args, **kwargs) 66 | return g 67 | 68 | 69 | def exists_and_has_size(zip_path: Path) -> bool: 70 | """Checks if a file exists and has non-zero size. 71 | 72 | This works as a heuristic to see if the writing of a large zip file was 73 | interrupted and thus is corrupted. 74 | """ 75 | return zip_path.exists() and zip_path.stat().st_size > 0 76 | -------------------------------------------------------------------------------- /docs/Makefile: -------------------------------------------------------------------------------- 1 | # Minimal makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line, and also 5 | # from the environment for the first two. 6 | SPHINXOPTS ?= 7 | SPHINXBUILD ?= sphinx-build 8 | SOURCEDIR = source 9 | BUILDDIR = build 10 | 11 | # Put it first so that "make" without argument is like "make help". 12 | help: 13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 14 | 15 | .PHONY: help Makefile 16 | 17 | # Catch-all target: route all unknown targets to Sphinx using the new 18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 19 | %: Makefile 20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 21 | -------------------------------------------------------------------------------- /docs/make.bat: -------------------------------------------------------------------------------- 1 | @ECHO OFF 2 | 3 | pushd %~dp0 4 | 5 | REM Command file for Sphinx documentation 6 | 7 | if "%SPHINXBUILD%" == "" ( 8 | set SPHINXBUILD=sphinx-build 9 | ) 10 | set SOURCEDIR=source 11 | set BUILDDIR=build 12 | 13 | if "%1" == "" goto help 14 | 15 | %SPHINXBUILD% >NUL 2>NUL 16 | if errorlevel 9009 ( 17 | echo. 18 | echo.The 'sphinx-build' command was not found. Make sure you have Sphinx 19 | echo.installed, then set the SPHINXBUILD environment variable to point 20 | echo.to the full path of the 'sphinx-build' executable. Alternatively you 21 | echo.may add the Sphinx directory to PATH. 22 | echo. 23 | echo.If you don't have Sphinx installed, grab it from 24 | echo.http://sphinx-doc.org/ 25 | exit /b 1 26 | ) 27 | 28 | %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 29 | goto end 30 | 31 | :help 32 | %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 33 | 34 | :end 35 | popd 36 | -------------------------------------------------------------------------------- /docs/source/conf.py: -------------------------------------------------------------------------------- 1 | # Configuration file for the Sphinx documentation builder. 2 | # 3 | # This file only contains a selection of the most common options. For a full 4 | # list see the documentation: 5 | # https://www.sphinx-doc.org/en/master/usage/configuration.html 6 | 7 | # -- Path setup -------------------------------------------------------------- 8 | 9 | # If extensions (or modules to document with autodoc) are in another directory, 10 | # add these directories to sys.path here. If the directory is relative to the 11 | # documentation root, use os.path.abspath to make it absolute, like shown here. 12 | 13 | import os 14 | import sys 15 | sys.path.insert(0, os.path.abspath('../../src')) 16 | 17 | 18 | # -- Project information ----------------------------------------------------- 19 | 20 | import deepmed 21 | 22 | project = 'deepmed' 23 | copyright = deepmed.__copyright__ 24 | author = deepmed.__author__ 25 | 26 | # The full version, including alpha/beta/rc tags 27 | release = deepmed.__version__ 28 | 29 | 30 | # -- General configuration --------------------------------------------------- 31 | 32 | # Add any Sphinx extension module names here, as strings. They can be 33 | # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom 34 | # ones. 35 | extensions = [ 36 | 'sphinx.ext.autodoc', 37 | 'sphinx.ext.napoleon', 38 | ] 39 | 40 | # Add any paths that contain templates here, relative to this directory. 41 | templates_path = ['_templates'] 42 | 43 | # List of patterns, relative to source directory, that match files and 44 | # directories to ignore when looking for source files. 45 | # This pattern also affects html_static_path and html_extra_path. 46 | exclude_patterns = [] 47 | 48 | 49 | # -- Options for HTML output ------------------------------------------------- 50 | 51 | # The theme to use for HTML and HTML Help pages. See the documentation for 52 | # a list of builtin themes. 53 | # 54 | html_theme = 'sphinx_rtd_theme' 55 | 56 | # Add any paths that contain custom static files (such as style sheets) here, 57 | # relative to this directory. They are copied after the builtin static files, 58 | # so a file named "default.css" will overwrite the builtin "default.css". 59 | html_static_path = ['_static'] 60 | -------------------------------------------------------------------------------- /docs/source/deepmed.get.rst: -------------------------------------------------------------------------------- 1 | deepmed.get package 2 | ============================== 3 | 4 | Module contents 5 | --------------- 6 | 7 | .. automodule:: deepmed.get 8 | :members: 9 | :undoc-members: 10 | :show-inheritance: 11 | -------------------------------------------------------------------------------- /docs/source/deepmed.rst: -------------------------------------------------------------------------------- 1 | deepmed package 2 | ========================== 3 | 4 | Subpackages 5 | ----------- 6 | 7 | .. toctree:: 8 | :maxdepth: 4 9 | 10 | deepmed.get 11 | 12 | Submodules 13 | ---------- 14 | 15 | deepmed.experiment\_imports module 16 | --------------------------------------------- 17 | 18 | .. automodule:: deepmed.experiment_imports 19 | :members: 20 | :undoc-members: 21 | :show-inheritance: 22 | 23 | deepmed.metrics module 24 | --------------------------------- 25 | 26 | .. automodule:: deepmed.metrics 27 | :members: 28 | :undoc-members: 29 | :show-inheritance: 30 | 31 | deepmed.types module 32 | ------------------------------- 33 | 34 | .. automodule:: deepmed.types 35 | :members: 36 | :undoc-members: 37 | :show-inheritance: 38 | 39 | deepmed.utils module 40 | ------------------------------- 41 | 42 | .. automodule:: deepmed.utils 43 | :members: 44 | :undoc-members: 45 | :show-inheritance: 46 | 47 | Module contents 48 | --------------- 49 | 50 | .. automodule:: deepmed 51 | :members: 52 | :undoc-members: 53 | :show-inheritance: 54 | -------------------------------------------------------------------------------- /docs/source/index.rst: -------------------------------------------------------------------------------- 1 | .. Deepest Histology documentation master file, created by 2 | sphinx-quickstart on Mon Jul 5 11:08:17 2021. 3 | You can adapt this file completely to your liking, but it should at least 4 | contain the root `toctree` directive. 5 | 6 | Welcome to Deepest Histology's documentation! 7 | ============================================= 8 | 9 | .. toctree:: 10 | :maxdepth: 2 11 | :caption: Contents: 12 | 13 | users_guide 14 | deepmed 15 | 16 | 17 | Indices and tables 18 | ================== 19 | 20 | * :ref:`genindex` 21 | * :ref:`modindex` 22 | * :ref:`search` 23 | -------------------------------------------------------------------------------- /docs/source/modules.rst: -------------------------------------------------------------------------------- 1 | deepmed 2 | ======= 3 | 4 | .. toctree:: 5 | :maxdepth: 4 6 | 7 | deepmed 8 | -------------------------------------------------------------------------------- /docs/source/users_guide.rst: -------------------------------------------------------------------------------- 1 | User's Guide 2 | ============ 3 | 4 | .. toctree:: 5 | :maxdepth: 2 6 | :caption: Contents: 7 | 8 | users_guide/simple_training 9 | users_guide/multi_target_training 10 | users_guide/crossval -------------------------------------------------------------------------------- /docs/source/users_guide/crossval.rst: -------------------------------------------------------------------------------- 1 | Cross-Validation 2 | ================ 3 | 4 | :: 5 | 6 | from deepmed.experiment_imports import * 7 | 8 | train_cohorts_df = pd.concat([ 9 | cohort(tile_path='E:/TCGA-BRCA-DX/BLOCKS_NORM', 10 | clini_path='G:/immunoproject/TCGA-IMMUNO_Clini_Slide/TCGA-BRCA-IMMUNO_CLINI.xlsx', 11 | slide_path='G:/immunoproject/TCGA-IMMUNO_Clini_Slide/TCGA-BRCA_SLIDE.csv'), 12 | cohort(tile_path='E:/TCGA-BRCA-DX/BLOCKS_NORM', 13 | clini_path='G:/immunoproject/TCGA-IMMUNO_Clini_Slide/TCGA-BRCA-IMMUNO_CLINI.xlsx', 14 | slide_path='G:/immunoproject/TCGA-IMMUNO_Clini_Slide/TCGA-BRCA_SLIDE.csv')]) 15 | 16 | crossval_get = partial( 17 | get.crossval, 18 | get.simple_run, 19 | cohorts_df=cohorts_df, 20 | target_label='isMSIH', 21 | max_tile_num=100, 22 | na_values=['inconclusive'], 23 | folds=3) -------------------------------------------------------------------------------- /docs/source/users_guide/multi_target_training.rst: -------------------------------------------------------------------------------- 1 | Multi-Target Training 2 | ===================== 3 | 4 | In the last tutorial we trained and deployed a model for a single target. While 5 | this in itself is of course interesting, the true power of deepmed only comes to 6 | light if we want to train a large amount of models at the same time. In this 7 | tutorial we will take a look at how to automatically train models for a range 8 | of different targets and harness the power of multiple GPUs in the progress. 9 | 10 | 11 | Defining a Multi-Target Run Getter 12 | ---------------------------------- 13 | 14 | Let's start off by defining a simple run getter:: 15 | 16 | simple_train_get = partial( 17 | get.simple_run, 18 | train_cohorts=train_cohorts, 19 | max_tile_num=100, 20 | na_values=['inconclusive']) 21 | 22 | Unlike the run getter in the previous tutorial, we did not specify the 23 | ``target_label`` this time around. This is because we actually don't want to 24 | manually specify the run's target label, but automatically repeat the training 25 | with different target labels. To achieve this, we use a *run adapter*: instead 26 | of generating runs by itself, a run adapter takes another run getter and 27 | transforms it; in our example, we take a single target run getter and adapt it 28 | into a multi-target one. It is constructed as such:: 29 | 30 | multi_train_get = partial( 31 | get.multi_target, 32 | simple_train_get, 33 | target_labels=['isMSIH', 'gender', ...] #TODO add some more target labels 34 | ) 35 | 36 | The rest of the file looks similar last time:: 37 | 38 | if __name__ == '__main__': 39 | do_experiment( 40 | project_dir='/path/to/training/project/dir', 41 | get=multi_train_get, 42 | devices=['cuda:0', 'cuda:1'], 43 | num_concurrent_runs=2, 44 | num_workers=4) # set to 0 on Windows! 45 | 46 | There are two new additions this time: First of all, we specify to train up to 47 | two models in parallel, using both GPUs in our system. Each of these runs will 48 | have four of the CPU's cores assigned to it for data preprocessing. This way, 49 | we can easily parallelize the training of multiple models and thus decrease the 50 | overall training time. On Windows, the ``num_workers`` option is currently 51 | broken due to a bug in the pytorch module internally used by deepmed. On such 52 | machines it may be helpful set ``num_workers`` to zero and further 53 | increase the number of concurrent runs. 54 | 55 | When running this script, we can the directory structure is as follows:: 56 | 57 | /path/to/training/project/dir 58 | ├── isMSIH 59 | │ └── export.pkl 60 | └── gender 61 | └── export.pkl 62 | 63 | When looking at the previous tutorial, the model (the ``export.pkl`` file) was 64 | saved directly in the project directory. The ``multi_target`` adapter added an 65 | additional subdirectory with the target's name to the project directory for each 66 | target. 67 | 68 | 69 | Evaluating Multiple Targets 70 | --------------------------- 71 | 72 | The deployment script is modified in almost the same way as the training 73 | script:: 74 | 75 | simple_deploy_get = partial( 76 | get.simple_run, 77 | test_cohorts=test_cohorts, 78 | max_tile_num=100, 79 | na_values=['inconclusive'], 80 | evaluators=[Grouped(auroc), Grouped(count)]) 81 | 82 | multi_deploy_get = partial( 83 | get.multi_target, 84 | simple_deploy_get, 85 | target_labels=['isMSIH', 'gender'] 86 | multi_target_evaluators=[aggregate_stats]) 87 | 88 | project_dir='/path/to/deployment/project/dir', 89 | 90 | if __name__ == '__main__': 91 | do_experiment( 92 | project_dir=project_dir, 93 | get=multi_deploy_get, 94 | train=partial( 95 | get.load, 96 | project_dir=project_dir, 97 | training_project_dir='/path/to/training/project/dir') ]) -------------------------------------------------------------------------------- /docs/source/users_guide/simple_training.rst: -------------------------------------------------------------------------------- 1 | Training and Deploying a Simple Model 2 | ===================================== 3 | 4 | TODO describe experiment file. 5 | 6 | 7 | Experiment Imports 8 | ------------------ 9 | 10 | To do anything, we first have to import all the necessary functionality. This 11 | is easily done by writing:: 12 | 13 | from deepmed.experiment_imports import * 14 | 15 | at the top of our file. 16 | 17 | 18 | Defining the Cohorts 19 | -------------------- 20 | 21 | In the deepmed pipeline, both training and deployment is performed on *cohorts* 22 | of patients. 23 | 24 | We will now a sets of cohorts one to train our data on:: 25 | 26 | #TODO make this different cohorts 27 | train_cohorts_df = pd.concat([ 28 | cohort(tile_path='E:/TCGA-BRCA-DX/BLOCKS_NORM', 29 | clini_path='G:/immunoproject/TCGA-IMMUNO_Clini_Slide/TCGA-BRCA-IMMUNO_CLINI.xlsx', 30 | slide_path='G:/immunoproject/TCGA-IMMUNO_Clini_Slide/TCGA-BRCA_SLIDE.csv'), 31 | cohort(tile_path='E:/TCGA-BRCA-DX/BLOCKS_NORM', 32 | clini_path='G:/immunoproject/TCGA-IMMUNO_Clini_Slide/TCGA-BRCA-IMMUNO_CLINI.xlsx', 33 | slide_path='G:/immunoproject/TCGA-IMMUNO_Clini_Slide/TCGA-BRCA_SLIDE.csv')]) 34 | 35 | When using Windows-like paths with backslashes, this string ought to be prefixed 36 | with an ``r`` to prevent the backslashes to be interpreted as character escapes: 37 | ``tile_path=r'C:\tile\path'``. 38 | 39 | TODO describe clini / slide table. 40 | 41 | 42 | Defining our Training Runs 43 | -------------------------- 44 | 45 | Next, we have to define how we want to use these cohorts. This is done using 46 | so-called ``RunGetters``. ``RunGetters`` allow us to define how we want to use 47 | our data to train models. We may for example want to perform cross-validation, 48 | train only on certain subgroups, train on many targets or even do a combination 49 | of the above. For this example, we will settle for a simple, single-target 50 | training. Let's construct our simple run getter:: 51 | 52 | simple_train_get = partial( 53 | get.simple_run, 54 | target_label='isMSIH', 55 | train_cohorts_df=train_cohorts_df, 56 | max_tile_num=100, 57 | na_values=['inconclusive']) 58 | 59 | That is quite a lot to take in! Let's break it down line by line. 60 | 61 | * ``get.simple_run`` describes how we want to use our data; in this case we 62 | want to train a simple, single-target model. All the following lines 63 | describe how we want this training to be performed. 64 | * ``target_label='isMSIH'`` is the label we want to predict with our model. 65 | The clinical table is expected to have a column with that name. 66 | * ``train_cohorts_df=train_cohorts_df`` are the cohorts we want to use for training. 67 | * ``max_tile_num=100`` states how many of a patient's tiles we want to sample. 68 | Often times, increasing the number of tiles for a patient has only a minor 69 | effect on the actual training result. Thus sampling from a patient's tiles 70 | can significantly speed up training without hugely influencing our results. 71 | * ``na_values=['inconclusive']`` allows us to define additional values which 72 | indicate a non-informational training sample. Patients with this label will 73 | be excluded from training. 74 | 75 | 76 | Training the Model 77 | ------------------ 78 | 79 | We can now finally train our model:: 80 | 81 | if __name__ == '__main__': # required on Windows 82 | do_experiment( 83 | project_dir='/path/to/training/project/dir', 84 | get=simple_get) 85 | 86 | * ``project_dir='/path/to/training/project/dir'`` defines where we want to 87 | save our training's results. 88 | 89 | And that's it! Our model should now be merrily training! 90 | 91 | 92 | Deploying the Model 93 | ------------------- 94 | 95 | After our model has finished training, we may want to deploy it on another 96 | dataset to ascertain its performance. This is done quite similarly to the 97 | training process. After defining our test cohorts, we can construct a run 98 | getter quite similarly to how we did before:: 99 | 100 | # file: simple_deploy.py 101 | 102 | from deepmed.experiment_imports import * 103 | 104 | test_cohorts_df = \ 105 | cohort(tile_path='E:/TCGA-BRCA-DX/BLOCKS_NORM', 106 | clini_path='G:/immunoproject/TCGA-IMMUNO_Clini_Slide/TCGA-BRCA-IMMUNO_CLINI.xlsx', 107 | slide_path='G:/immunoproject/TCGA-IMMUNO_Clini_Slide/TCGA-BRCA_SLIDE.csv') 108 | 109 | simple_deploy_get = partial( 110 | get.simple_run, 111 | target_label='isMSIH', 112 | test_cohorts_df=test_cohorts_df, 113 | max_tile_num=100, 114 | na_values=['inconclusive']) 115 | 116 | Observe how the getter has the exact same structure as the one above, with the 117 | only exception being that we specify ``test_cohorts_df`` instead of 118 | ``train_cohorts_df`` this time around. 119 | 120 | Next we have to specify how and where to load our models from:: 121 | 122 | project_dir='/path/to/deployment/project/dir', 123 | load = partial( 124 | get.load, 125 | project_dir=project_dir, 126 | training_project_dir='/path/to/training/project/dir') 127 | 128 | We can now deploy our model like this:: 129 | 130 | if __name__ == '__main__': 131 | do_experiment( 132 | project_dir=project_dir, 133 | get=simple_deploy_get, 134 | train=load) 135 | 136 | Usually, the train parameter is used to further define the modalities of a 137 | network's training. In this case, we say that instead of training a model we 138 | want to load a pretrained model. 139 | 140 | 141 | Defining Evaluation Metrics 142 | --------------------------- 143 | 144 | While our model has now been deployed on the testing cohort, we don't have any 145 | results yet: this is because we haven't defined any metrics with which to 146 | evaluate our testing data. Let's start off with some simple metrics:: 147 | 148 | evaluators = [auroc, count] 149 | 150 | These metrics will calculate the `area under the receiver operating 151 | characteristic curve`_ (AUROC) and the count of testing samples. These metrics 152 | are calculated on a *tile basis* though. It is often advantagous to instead 153 | calculate metrics on a per-patient basis instead. This can be done with the 154 | ``Grouped`` adapter:: 155 | 156 | evaluators += [Grouped(auroc, by='PATIENT'), Grouped(count, by='PATIENT')] 157 | 158 | This will modify the auroc and count metrics in such a way that they are 159 | calculated on a *per-patient* basis instead of a per-tile basis; instead of the 160 | overall tile count per class we for example get the number of patients per 161 | class. 162 | 163 | If we now extend our deployment script to make use of these evaluators, 164 | re-running the script should yield a file ``stats.pkl`` which contains the 165 | requested metrics:: 166 | 167 | project_dir = '/path/to/deployment/project/dir' 168 | simple_eval_get = partial( 169 | get.simple_run, 170 | target_label='isMSIH', 171 | test_cohorts_df=test_cohorts_df, 172 | max_tile_num=100, 173 | na_values=['inconclusive'], 174 | evaluators=evaluators) 175 | 176 | if __name__ == '__main__': 177 | do_experiment( 178 | project_dir=project_dir, 179 | get=simple_deploy_get, 180 | train=partial( 181 | get.load, 182 | project_dir=project_dir, 183 | training_project_dir='/path/to/training/project/dir')) 184 | 185 | .. _area under the receiver operating characteristic curve: https://en.wikipedia.org/wiki/Receiver_operating_characteristic 186 | 187 | 188 | Doing it All at Once 189 | -------------------- 190 | 191 | If we already know what data we want to train and deploy our model on 192 | beforehand, we can combine the two steps into one experiment:: 193 | 194 | from deepmed.experiment_imports import * 195 | 196 | if __name__ == '__main__': 197 | do_experiment( 198 | project_dir='/path/to/project/dir', 199 | get=partial( 200 | get.simple_run, 201 | target_label='isMSIH', 202 | train_cohorts_df=train_cohorts_df, 203 | max_tile_num=100, 204 | na_values=['inconclusive']) 205 | evaluator_groups=[evaluators]) 206 | 207 | Since we train our models in the same step as we deploy them, we don't need to 208 | specify where to load our models from this time. -------------------------------------------------------------------------------- /examples/continuous.py: -------------------------------------------------------------------------------- 1 | 2 | #!/usr/bin/env python3 3 | from deepmed.evaluators import Evaluator 4 | from dataclasses import dataclass 5 | from deepmed.experiment_imports import * 6 | 7 | # this is a tiny toy data set; do not expect any good results from this 8 | cohort_path = untar_data( 9 | 'https://katherlab-datasets.s3.eu-central-1.amazonaws.com/tiny-test-data.zip') 10 | 11 | cohorts_df = cohort( 12 | tiles_path=cohort_path/'tiles', 13 | clini_path=cohort_path/'clini.csv', 14 | slide_path=cohort_path/'slide.csv') 15 | 16 | 17 | def main(): 18 | do_experiment( 19 | project_dir='continuous', 20 | get=get.SimpleRun( 21 | train_cohorts_df=cohorts_df, 22 | test_cohorts_df=cohorts_df, 23 | target_label='TMB (nonsynonymous)', 24 | n_bins=None, # don't discretize 25 | evaluators=[Grouped(r2), r2, OnDiscretized(Grouped(auroc))], 26 | ), 27 | ) 28 | 29 | 30 | if __name__ == '__main__': 31 | main() 32 | -------------------------------------------------------------------------------- /examples/crossval.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | from deepmed.experiment_imports import * 3 | 4 | # this is a tiny toy data set; do not expect any good results from this 5 | cohort_path = untar_data( 6 | 'https://katherlab-datasets.s3.eu-central-1.amazonaws.com/tiny-test-data.zip') 7 | 8 | cohorts_df = cohort( 9 | tiles_path=cohort_path/'tiles', 10 | clini_path=cohort_path/'clini.csv', 11 | slide_path=cohort_path/'slide.csv') 12 | 13 | 14 | def main(): 15 | do_experiment( 16 | project_dir='crossval', 17 | get=get.Crossval( 18 | get.SimpleRun(), 19 | cohorts_df=cohorts_df, 20 | target_label='ER Status By IHC', 21 | valid_frac=.2, 22 | crossval_evaluators=[AggregateStats(label='fold', over=['fold'])], 23 | evaluators=[Grouped(auroc), Grouped(count), Grouped(p_value), gradcam], 24 | get_items=get.GetTiles(max_tile_nums={get.DatasetType.TRAIN: 128, 25 | get.DatasetType.VALID: 256, 26 | get.DatasetType.TEST: 512}), 27 | train=Train( 28 | batch_size=96, 29 | max_epochs=4), 30 | )) 31 | 32 | 33 | if __name__ == '__main__': 34 | main() 35 | -------------------------------------------------------------------------------- /examples/extract-with-custom-model.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | """Extracts features with a custom model.""" 3 | from deepmed.experiment_imports import * 4 | import torch 5 | import torchvision 6 | 7 | 8 | # The below code is taken from 9 | # 10 | def load_model_weights(model, weights): 11 | model_dict = model.state_dict() 12 | weights = {k: v for k, v in weights.items() if k in model_dict} 13 | if weights == {}: 14 | print('No weights could be loaded..') 15 | model_dict.update(weights) 16 | model.load_state_dict(model_dict) 17 | 18 | return model 19 | 20 | 21 | model = torchvision.models.__dict__['resnet18'](pretrained=False) 22 | state = torch.load('tenpercent_resnet18.ckpt', map_location='cuda:0') 23 | 24 | state_dict = state['state_dict'] 25 | for key in list(state_dict.keys()): 26 | state_dict[key.replace('model.', '').replace( 27 | 'resnet.', '')] = state_dict.pop(key) 28 | 29 | model = load_model_weights(model, state_dict) 30 | 31 | 32 | # assuming `model` contains a model of choice, calling `Extract` with 33 | # `arch=lambda pretrained: model` will extract features with that model. 34 | def main(): 35 | do_experiment( 36 | project_dir='features/ozanciga', 37 | get=get.Extract( 38 | tile_dir='tile/dir', 39 | arch=lambda pretrained: model, 40 | )) 41 | 42 | 43 | if __name__ == '__main__': 44 | main() 45 | -------------------------------------------------------------------------------- /examples/extract.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | from deepmed.experiment_imports import * 3 | 4 | 5 | def main(): 6 | do_experiment( 7 | project_dir='/feature/output/dir', 8 | get=get.Extract( 9 | tile_dir='/tile/dir', 10 | arch=resnet18)) 11 | 12 | 13 | if __name__ == '__main__': 14 | main() 15 | -------------------------------------------------------------------------------- /examples/mil.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | from deepmed.experiment_imports import * 3 | 4 | 5 | cohorts_df = cohort( 6 | tiles_path='/path/to/features', 7 | clini_path='/path/to/clini.xlsx', 8 | slide_path='/path/to/slide.csv') 9 | 10 | 11 | def main(): 12 | do_experiment( 13 | project_dir='crossval_mil', 14 | get=get.Crossval( 15 | get.SimpleRun(), 16 | cohorts_df=cohorts_df, 17 | target_label='ISHLT_2004_rej', 18 | evaluators=[auroc], 19 | crossval_evaluators=[AggregateStats()], 20 | # The next three lines are different from normal training 21 | get_items=mil.get_h5s, 22 | train=mil.Train(), 23 | balance=False, 24 | )) 25 | 26 | if __name__ == '__main__': 27 | main() 28 | -------------------------------------------------------------------------------- /examples/multi_target_deploy.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | """Deploys previously trained models on a new data set. 3 | 4 | This example assumes that you have previously trained some models using e.g. the 5 | ``multi_target_train`` example. 6 | """ 7 | 8 | from deepmed.experiment_imports import * 9 | 10 | # this is a tiny toy data set; do not expect any good results from this 11 | cohort_path = untar_data( 12 | 'https://katherlab-datasets.s3.eu-central-1.amazonaws.com/tiny-test-data.zip') 13 | 14 | test_cohorts_df = cohort( 15 | tiles_path=cohort_path/'tiles', 16 | clini_path=cohort_path/'clini.csv', 17 | slide_path=cohort_path/'slide.csv') 18 | 19 | 20 | def main(): 21 | project_dir = 'multi_target_deploy' 22 | 23 | do_experiment( 24 | project_dir=project_dir, 25 | get=get.MultiTarget( 26 | get.SimpleRun(), 27 | test_cohorts_df=test_cohorts_df, 28 | target_labels=['ER Status By IHC', 'TCGA Subtype', 'TMB (nonsynonymous)'], 29 | max_test_tile_num=512, 30 | evaluators=[auroc, Grouped(auroc), Grouped(F1()), Grouped(count)], 31 | multi_target_evaluators=[AggregateStats(label='target')], 32 | train=Load( 33 | project_dir=project_dir, 34 | training_project_dir='multi_target_train'), 35 | ), 36 | devices={'cuda:0': 4}, 37 | ) 38 | 39 | 40 | if __name__ == '__main__': 41 | main() 42 | -------------------------------------------------------------------------------- /examples/multi_target_train.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | from deepmed.experiment_imports import * 3 | 4 | 5 | # this is a tiny toy data set; do not expect any good results from this 6 | cohort_path = untar_data( 7 | 'https://katherlab-datasets.s3.eu-central-1.amazonaws.com/tiny-test-data.zip') 8 | 9 | train_cohorts_df = cohort( 10 | tiles_path=cohort_path/'tiles', 11 | clini_path=cohort_path/'clini.csv', 12 | slide_path=cohort_path/'slide.csv') 13 | 14 | 15 | def main(): 16 | do_experiment( 17 | project_dir='multi_target_train', 18 | get=get.MultiTarget( # train for multiple targets 19 | get.SimpleRun(), 20 | train_cohorts_df=train_cohorts_df, 21 | target_labels=['ER Status By IHC', 'TCGA Subtype', 22 | 'TMB (nonsynonymous)'], # target labels to train for 23 | get_items=get.GetTiles(max_tile_nums={ 24 | get.DatasetType.TRAIN: 128, # maximum number of tiles per patient to train with 25 | get.DatasetType.VALID: 256, # maximum number of tiles per patient to validate with 26 | get.DatasetType.TEST: 512 # maximum number of tiles per patient to test on 27 | }), 28 | # amount of data to use as validation set (for early stopping) 29 | valid_frac=.2, 30 | balance=True, # weather to balance the training set 31 | na_values=['inconclusive'], # labels to exclude in training 32 | n_bins=3, 33 | min_support=10, # minimal required patient-level class samples for a class to be considered 34 | train=Train( 35 | batch_size=96, 36 | # absolute maximum number of epochs to train for (usually preceeded by early stopping) 37 | max_epochs=32, 38 | metrics=[BalancedAccuracy()], # additional metrics 39 | # epochs to continue training without improvement (will still select best model in the end) 40 | patience=3, 41 | monitor='valid_loss', # metric to monitor for improvement 42 | # augmentations to apply to data 43 | tfms=aug_transforms(flip_vert=True, max_rotate=360, 44 | max_zoom=1, max_warp=0, size=224), 45 | ), 46 | ), 47 | ) 48 | 49 | 50 | if __name__ == '__main__': 51 | main() 52 | -------------------------------------------------------------------------------- /examples/parameterize.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | from deepmed.experiment_imports import * 3 | 4 | 5 | # this is a tiny toy data set; do not expect any good results from this 6 | cohort_path = untar_data( 7 | 'https://katherlab-datasets.s3.eu-central-1.amazonaws.com/tiny-test-data.zip') 8 | 9 | train_cohorts_df = cohort( 10 | tiles_path=cohort_path/'tiles', 11 | clini_path=cohort_path/'clini.csv', 12 | slide_path=cohort_path/'slide.csv') 13 | 14 | 15 | def main(): 16 | do_experiment( 17 | project_dir='parameterize', 18 | get=get.Parameterize( 19 | get.Crossval(), 20 | get.SimpleRun(), 21 | cohorts_df=train_cohorts_df, 22 | target_label='TMB (nonsynonymous)', 23 | parameterizations={ 24 | f'{patience=} {bs=} {folds=}': { 25 | 'folds': folds, 'train': Train(patience=patience, batch_size=bs, max_epochs=1)} 26 | for patience in [5, 8] 27 | for bs in [64, 128] 28 | for folds in [3, 5] 29 | }, 30 | evaluators=[Grouped(count)], 31 | crossval_evaluators=[AggregateStats(over=[0])], 32 | parameterize_evaluators=[AggregateStats()], 33 | ), 34 | ) 35 | 36 | 37 | if __name__ == '__main__': 38 | main() 39 | -------------------------------------------------------------------------------- /examples/subgroup.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | from deepmed.experiment_imports import * 3 | 4 | 5 | # this is a tiny toy data set; do not expect any good results from this 6 | cohort_path = untar_data( 7 | 'https://katherlab-datasets.s3.eu-central-1.amazonaws.com/tiny-test-data.zip') 8 | 9 | cohorts_df = cohort( 10 | tiles_path=cohort_path/'tiles', 11 | clini_path=cohort_path/'clini.csv', 12 | slide_path=cohort_path/'slide.csv') 13 | 14 | 15 | def subgrouper(x: pd.Series): 16 | if x['Diagnosis Age'] > 50: 17 | return 'old' 18 | elif x['Diagnosis Age'] <= 50: 19 | return 'young' 20 | else: 21 | return None 22 | 23 | 24 | def main(): 25 | do_experiment( 26 | project_dir='subgroup', 27 | get=get.Subgroup( 28 | get.SimpleRun(), 29 | train_cohorts_df=cohorts_df, 30 | test_cohorts_df=cohorts_df, 31 | target_label='ER Status By IHC', 32 | subgrouper=subgrouper, 33 | valid_frac=.2, 34 | evaluators=[Grouped(auroc), Grouped(count)], 35 | subgroup_evaluators=[AggregateStats()], 36 | train=Train(max_epochs=1), 37 | ), 38 | devices={'cuda:0': 4} 39 | ) 40 | 41 | 42 | if __name__ == '__main__': 43 | main() 44 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = [ 3 | "setuptools>=42", 4 | "wheel" 5 | ] 6 | build-backend = "setuptools.build_meta" 7 | -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [metadata] 2 | name = deepmed-katherlab 3 | version = 0.10.0rc0 4 | author = Marko van Treeck 5 | author_email = markovantreeck@gmail.com 6 | description = A pipeline for training networks for the analysis of histological images 7 | long_description = file: README.md 8 | long_description_content_type = text/markdown 9 | url = https://github.com/KatherLab/deepmed 10 | project_urls = 11 | Bug Tracker = https://github.com/KatherLab/deepmed/issues 12 | classifiers = 13 | Programming Language :: Python :: 3 14 | License :: OSI Approved :: MIT License 15 | Operating System :: OS Independent 16 | 17 | [options] 18 | packages = find: 19 | python_requires = >=3.8 20 | install_requires = 21 | torch 22 | pandas 23 | sklearn 24 | torchvision 25 | tqdm 26 | xlrd 27 | openpyxl 28 | matplotlib 29 | fastai 30 | pandas~=1.3 31 | coloredlogs 32 | openslide-python 33 | packaging 34 | h5py 35 | -------------------------------------------------------------------------------- /test/__init__.py: -------------------------------------------------------------------------------- 1 | from .test_simple import * 2 | from .test_examples import * -------------------------------------------------------------------------------- /test/test_examples.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | import subprocess 3 | import sys 4 | import os 5 | from tempfile import TemporaryDirectory 6 | from pathlib import Path 7 | 8 | 9 | class TestExamples(unittest.TestCase): 10 | def test_examples(self): 11 | cwd = Path.cwd() 12 | env = os.environ.copy() 13 | env['PYTHONPATH'] = str(cwd) 14 | 15 | examples = [ 16 | 'multi_target_train.py', 17 | 'multi_target_deploy.py', 18 | 'crossval.py', 19 | 'subgroup.py', 20 | 'parameterize.py', 21 | 'continuous.py', 22 | 'extract_features.py', 23 | ] 24 | example_path = Path('examples').absolute() 25 | 26 | with TemporaryDirectory(prefix='deepmed-example-test-') as project_dir: 27 | for example in examples: 28 | with self.subTest(example=example): 29 | example = example_path/example 30 | 31 | try: 32 | os.chdir(project_dir) 33 | subprocess.run([sys.executable, example], 34 | env=env, check=True) 35 | finally: 36 | os.chdir(cwd) 37 | -------------------------------------------------------------------------------- /test/test_simple.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | from tempfile import TemporaryDirectory 3 | from itertools import product 4 | 5 | from deepmed.evaluators.types import Evaluator 6 | from deepmed.experiment_imports import * 7 | 8 | 9 | class TestSeperateTrainAndDeploy(unittest.TestCase): 10 | @classmethod 11 | def setUpClass(cls) -> None: 12 | path = untar_data( 13 | 'https://katherlab-datasets.s3.eu-central-1.amazonaws.com/tiny-test-data.zip') 14 | cls.cohorts_df = cohort( 15 | tiles_path=path/'tiles', 16 | clini_path=path/'clini.csv', 17 | slide_path=path/'slide.csv') 18 | 19 | def test_simple(self): 20 | # train a model 21 | with TemporaryDirectory() as training_dir: 22 | do_experiment( 23 | project_dir=training_dir, 24 | get=get.SimpleRun( 25 | train_cohorts_df=self.cohorts_df, 26 | target_label='ER Status By IHC', 27 | max_train_tile_num=4, 28 | max_valid_tile_num=4, 29 | train=Train(max_epochs=1)), 30 | logfile=None) 31 | 32 | train_df = pd.read_csv(Path(training_dir)/'training_set.csv.zip') 33 | counts = train_df['ER Status By IHC'].value_counts() 34 | self.assertEqual( 35 | counts['Positive'], counts['Negative'], msg='Training set not balanced!') 36 | 37 | with TemporaryDirectory() as testing_dir: 38 | # deploy it 39 | max_test_tile_num = 2 40 | 41 | do_experiment( 42 | project_dir=testing_dir, 43 | get=get.SimpleRun( 44 | test_cohorts_df=self.cohorts_df, 45 | target_label='ER Status By IHC', 46 | max_test_tile_num=max_test_tile_num, 47 | train=Load( 48 | project_dir=Path(testing_dir), 49 | training_project_dir=Path(training_dir))), 50 | logfile=None) 51 | 52 | # add some evaluation 53 | do_experiment( 54 | project_dir=testing_dir, 55 | get=get.SimpleRun( 56 | test_cohorts_df=self.cohorts_df, 57 | target_label='ER Status By IHC', 58 | evaluators=[Grouped(auroc), count, Grouped(count)]), 59 | logfile=None) 60 | 61 | stats_df = pd.read_pickle( 62 | Path(testing_dir)/'stats.pkl') 63 | self.assertEqual( 64 | stats_df[('count', 'PATIENT')]['Positive'], 76) 65 | self.assertEqual( 66 | stats_df[('count', 'PATIENT')]['Negative'], 24) 67 | self.assertEqual( 68 | stats_df[('count', 'nan')]['Positive'], max_test_tile_num*76) 69 | self.assertEqual( 70 | stats_df[('count', 'nan')]['Negative'], max_test_tile_num*24) 71 | 72 | 73 | class TestDiscretization(unittest.TestCase): 74 | def test_class(self): 75 | path = untar_data( 76 | 'https://katherlab-datasets.s3.eu-central-1.amazonaws.com/tiny-test-data.zip') 77 | cohorts_df = cohort( 78 | tiles_path=path/'tiles', 79 | clini_path=path/'clini.csv', 80 | slide_path=path/'slide.csv') 81 | 82 | # train a model 83 | with TemporaryDirectory() as project_dir: 84 | do_experiment( 85 | project_dir=project_dir, 86 | get=get.SimpleRun( 87 | train_cohorts_df=cohorts_df, 88 | test_cohorts_df=cohorts_df, 89 | target_label='TMB (nonsynonymous)', 90 | evaluators=[Grouped(count)], 91 | n_bins=3, 92 | train=Train(max_epochs=1)), 93 | logfile=None) 94 | 95 | stats_df = pd.read_pickle(Path(project_dir)/'stats.pkl') 96 | self.assertEqual(stats_df[('count', 'PATIENT')]['[-inf,0.7)'], 31) 97 | self.assertEqual(stats_df[('count', 'PATIENT') 98 | ]['[0.7,1.43333333333)'], 34) 99 | self.assertEqual(stats_df[('count', 'PATIENT') 100 | ]['[1.43333333333,inf)'], 34) 101 | 102 | 103 | class TestEvaluators(unittest.TestCase): 104 | @classmethod 105 | def setUpClass(cls) -> None: 106 | path = untar_data( 107 | 'https://katherlab-datasets.s3.eu-central-1.amazonaws.com/tiny-test-data.zip') 108 | cls.cohorts_df = cohort( 109 | tiles_path=path/'tiles', 110 | clini_path=path/'clini.csv', 111 | slide_path=path/'slide.csv') 112 | 113 | cls.training_dir = TemporaryDirectory() 114 | cls.max_train_tile_num = 4 115 | 116 | # train and deploy 117 | do_experiment( 118 | project_dir=cls.training_dir.name, 119 | get=get.SimpleRun( 120 | train_cohorts_df=cls.cohorts_df, 121 | test_cohorts_df=cls.cohorts_df, 122 | target_label='ER Status By IHC', 123 | max_train_tile_num=cls.max_train_tile_num, 124 | train=Train(max_epochs=4)), 125 | logfile=None) 126 | 127 | @classmethod 128 | def tearDownClass(cls) -> None: 129 | cls.training_dir.cleanup() 130 | 131 | def test_auroc(self): 132 | """Test AUROC Metric.""" 133 | evaluate(self.training_dir.name, self.cohorts_df, 134 | [auroc, Grouped(auroc)]) 135 | stats_df = pd.read_pickle( 136 | Path(self.training_dir.name)/'stats.pkl') 137 | 138 | auroc_ = stats_df[('auroc', 'PATIENT')]['Positive'] 139 | self.assertTrue(auroc_ >= 0 and auroc_ <= 1, msg='AUROC not in [0,1]') 140 | self.assertAlmostEqual( 141 | stats_df[('auroc', 'PATIENT')]['Positive'], 142 | stats_df[('auroc', 'PATIENT')]['Negative'], 143 | msg='AUROC for binary target not symmetric!') 144 | 145 | def test_f1(self): 146 | evaluate(self.training_dir.name, 147 | self.cohorts_df, [F1(), Grouped(F1())]) 148 | stats_df = pd.read_pickle( 149 | Path(self.training_dir.name)/'stats.pkl') 150 | self.assertIn(('f1 optimal', 'nan'), stats_df.columns) 151 | self.assertIn(('f1 optimal', 'PATIENT'), stats_df.columns) 152 | 153 | def test_count(self): 154 | evaluate(self.training_dir.name, self.cohorts_df, 155 | [count, Grouped(count)]) 156 | stats_df = pd.read_picklel( 157 | Path(self.training_dir.name)/'stats.pkl') 158 | self.assertTrue( 159 | (stats_df[('count', 'nan')] == 160 | stats_df[('count', 'PATIENT')] * self.max_train_tile_num).all(), 161 | msg='Did not sample the correct number of tiles') 162 | 163 | def test_top_tiles(self): 164 | n_patients, n_tiles = 6, 3 165 | evaluate(self.training_dir.name, self.cohorts_df, [ 166 | TopTiles(n_patients=n_patients, n_tiles=n_tiles)]) 167 | for class_ in ['Positive', 'Negative']: 168 | self.assertTrue( 169 | (Path(self.training_dir.name) / 170 | f'ER Status By IHC_{class_}_best-{n_patients}-patients_best-{n_tiles}-tiles.svg').exists()) 171 | df = pd.read_csv( 172 | Path(self.training_dir.name) / 173 | f'ER Status By IHC_{class_}_best-{n_patients}-patients_best-{n_tiles}-tiles.csv') 174 | self.assertEqual(df.PATIENT.nunique(), n_patients) 175 | self.assertTrue( 176 | (df.groupby('PATIENT').tile_path.count() == n_tiles).all()) 177 | 178 | 179 | def evaluate(project_dir: Union[str, Path], cohorts_df: pd.DataFrame, evaluators: Iterable[Evaluator]): 180 | do_experiment( 181 | project_dir=project_dir, 182 | get=get.SimpleRun( 183 | test_cohorts_df=cohorts_df, 184 | target_label='ER Status By IHC', 185 | evaluators=evaluators), 186 | logfile=None) 187 | 188 | 189 | if __name__ == '__main__': 190 | unittest.main() 191 | --------------------------------------------------------------------------------