├── .github ├── FUNDING.yml ├── dependabot.yml ├── ISSUE_TEMPLATE │ ├── feature_request.md │ └── bug_report.md └── workflows │ └── python-publish.yml ├── .gitignore ├── home ├── assets │ ├── logo.png │ ├── cli-ui.png │ ├── notion.png │ ├── poster.png │ ├── recovery.png │ ├── logos │ │ ├── skule.png │ │ ├── uoft.png │ │ ├── uoftai.png │ │ ├── utmist.png │ │ └── vector.png │ └── visualization.png └── index.html ├── .gitmodules ├── mipcandy ├── __main__.py ├── common │ ├── numpy │ │ ├── __init__.py │ │ └── regressions.py │ ├── __init__.py │ ├── optim │ │ ├── __init__.py │ │ ├── lr_scheduler.py │ │ └── loss.py │ └── module │ │ ├── __init__.py │ │ ├── conv.py │ │ └── preprocess.py ├── presets │ ├── __init__.py │ └── segmentation.py ├── frontend │ ├── __init__.py │ ├── wandb_fe.py │ ├── prototype.py │ └── notion_fe.py ├── settings.yml ├── run.py ├── types.py ├── data │ ├── __init__.py │ ├── download.py │ ├── convertion.py │ ├── geometric.py │ ├── io.py │ ├── visualization.py │ ├── dataset.py │ └── inspection.py ├── __entry__.py ├── __init__.py ├── sanity_check.py ├── config.py ├── layer.py ├── evaluation.py ├── sliding_window.py ├── metrics.py ├── inference.py └── training.py ├── pyproject.toml ├── README.md └── LICENSE /.github/FUNDING.yml: -------------------------------------------------------------------------------- 1 | github: ProjectNeura 2 | patreon: ProjectNeura 3 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .idea 2 | __pycache__ 3 | main.py 4 | build 5 | *.egg-info 6 | dist 7 | secrets.yml 8 | -------------------------------------------------------------------------------- /home/assets/logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ProjectNeura/MIPCandy/HEAD/home/assets/logo.png -------------------------------------------------------------------------------- /home/assets/cli-ui.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ProjectNeura/MIPCandy/HEAD/home/assets/cli-ui.png -------------------------------------------------------------------------------- /home/assets/notion.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ProjectNeura/MIPCandy/HEAD/home/assets/notion.png -------------------------------------------------------------------------------- /home/assets/poster.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ProjectNeura/MIPCandy/HEAD/home/assets/poster.png -------------------------------------------------------------------------------- /home/assets/recovery.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ProjectNeura/MIPCandy/HEAD/home/assets/recovery.png -------------------------------------------------------------------------------- /home/assets/logos/skule.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ProjectNeura/MIPCandy/HEAD/home/assets/logos/skule.png -------------------------------------------------------------------------------- /home/assets/logos/uoft.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ProjectNeura/MIPCandy/HEAD/home/assets/logos/uoft.png -------------------------------------------------------------------------------- /home/assets/logos/uoftai.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ProjectNeura/MIPCandy/HEAD/home/assets/logos/uoftai.png -------------------------------------------------------------------------------- /home/assets/logos/utmist.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ProjectNeura/MIPCandy/HEAD/home/assets/logos/utmist.png -------------------------------------------------------------------------------- /home/assets/logos/vector.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ProjectNeura/MIPCandy/HEAD/home/assets/logos/vector.png -------------------------------------------------------------------------------- /home/assets/visualization.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ProjectNeura/MIPCandy/HEAD/home/assets/visualization.png -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "mipcandy-docs"] 2 | path = mipcandy-docs 3 | url = https://github.com/ProjectNeura/mipcandy-docs 4 | -------------------------------------------------------------------------------- /mipcandy/__main__.py: -------------------------------------------------------------------------------- 1 | from mipcandy.__entry__ import __entry__ 2 | 3 | 4 | if __name__ == "__main__": 5 | __entry__() 6 | -------------------------------------------------------------------------------- /mipcandy/common/numpy/__init__.py: -------------------------------------------------------------------------------- 1 | from mipcandy.common.numpy.regressions import quotient_regression, quotient_derivative, quotient_bounds 2 | -------------------------------------------------------------------------------- /mipcandy/common/__init__.py: -------------------------------------------------------------------------------- 1 | from mipcandy.common.numpy import * 2 | from mipcandy.common.module import * 3 | from mipcandy.common.optim import * 4 | -------------------------------------------------------------------------------- /mipcandy/presets/__init__.py: -------------------------------------------------------------------------------- 1 | from mipcandy.presets.segmentation import SegmentationTrainer, SlidingSegmentationTrainer, SlidingValidationTrainer 2 | -------------------------------------------------------------------------------- /mipcandy/frontend/__init__.py: -------------------------------------------------------------------------------- 1 | from mipcandy.frontend.notion_fe import NotionFrontend 2 | from mipcandy.frontend.prototype import Frontend, create_hybrid_frontend 3 | -------------------------------------------------------------------------------- /mipcandy/common/optim/__init__.py: -------------------------------------------------------------------------------- 1 | from mipcandy.common.optim.loss import FocalBCEWithLogits, DiceBCELossWithLogits 2 | from mipcandy.common.optim.lr_scheduler import AbsoluteLinearLR 3 | -------------------------------------------------------------------------------- /mipcandy/common/module/__init__.py: -------------------------------------------------------------------------------- 1 | from mipcandy.common.module.conv import ConvBlock2d, ConvBlock3d, WSConv2d, WSConv3d 2 | from mipcandy.common.module.preprocess import Pad2d, Pad3d, Restore2d, Restore3d, Normalize, ColorizeLabel 3 | -------------------------------------------------------------------------------- /mipcandy/settings.yml: -------------------------------------------------------------------------------- 1 | # fill in your settings here 2 | note: "" 3 | num_checkpoints: 5 4 | ema: true 5 | seed: ~ 6 | early_stop_tolerance: 5 7 | val_score_prediction: true 8 | val_score_prediction_degree: 5 9 | save_preview: true 10 | preview_quality: 0.75 11 | -------------------------------------------------------------------------------- /mipcandy/run.py: -------------------------------------------------------------------------------- 1 | from typing import Literal 2 | 3 | from mipcandy.config import load_settings, save_settings, load_secrets, save_secrets 4 | 5 | 6 | def config(target: Literal["setting", "secret"], key: str, value: str) -> None: 7 | match target: 8 | case "setting": 9 | settings = load_settings() 10 | settings[key] = value 11 | save_settings(settings) 12 | case "secret": 13 | secrets = load_secrets() 14 | secrets[key] = value 15 | save_secrets(secrets) 16 | -------------------------------------------------------------------------------- /.github/dependabot.yml: -------------------------------------------------------------------------------- 1 | # To get started with Dependabot version updates, you'll need to specify which 2 | # package ecosystems to update and where the package manifests are located. 3 | # Please see the documentation for all configuration options: 4 | # https://docs.github.com/github/administering-a-repository/configuration-options-for-dependency-updates 5 | 6 | version: 2 7 | updates: 8 | - package-ecosystem: "pip" # See documentation for possible values 9 | directory: "/" # Location of package manifests 10 | schedule: 11 | interval: "weekly" 12 | -------------------------------------------------------------------------------- /mipcandy/types.py: -------------------------------------------------------------------------------- 1 | from os import PathLike 2 | from typing import Any, Iterable, Sequence 3 | 4 | import torch 5 | from torch import nn 6 | from torchvision.transforms import Compose 7 | 8 | type Setting = str | int | float | bool | None | dict[str, Setting] | list[Setting] 9 | type Settings = dict[str, Setting] 10 | type Params = Iterable[torch.Tensor] | Iterable[dict[str, Any]] 11 | type Transform = nn.Module | Compose 12 | type SupportedPredictant = Sequence[torch.Tensor] | str | PathLike[str] | Sequence[str] | torch.Tensor 13 | type Colormap = Sequence[int | tuple[int, int, int]] 14 | type Device = torch.device | str 15 | type Shape2d = tuple[int, int] 16 | type Shape3d = tuple[int, int, int] 17 | type Shape = Shape2d | Shape3d 18 | type AmbiguousShape = tuple[int, ...] 19 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/feature_request.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Feature request 3 | about: Suggest an idea for this project 4 | title: '' 5 | labels: [ enhancement, question ] 6 | assignees: [ ATATC, qmascarenhas ] 7 | issue_type: Feature 8 | 9 | --- 10 | 11 | **Is your feature request related to a problem? Please describe.** 12 | A clear and concise description of what the problem is. Ex. I'm always frustrated when [...] 13 | 14 | **Describe the solution you'd like** 15 | A clear and concise description of what you want to happen. 16 | 17 | **Describe alternatives you've considered** 18 | A clear and concise description of any alternative solutions or features you've considered. 19 | 20 | **Additional context** 21 | Add any other context or screenshots about the feature request here. 22 | -------------------------------------------------------------------------------- /.github/workflows/python-publish.yml: -------------------------------------------------------------------------------- 1 | name: Upload Python Package 2 | 3 | on: 4 | release: 5 | types: 6 | - published 7 | 8 | permissions: 9 | contents: read 10 | 11 | jobs: 12 | build-and-publish: 13 | runs-on: ubuntu-latest 14 | permissions: 15 | id-token: write 16 | steps: 17 | - name: Checkout 18 | uses: actions/checkout@v4 19 | - name: Set up Python 20 | uses: actions/setup-python@v5 21 | with: 22 | python-version: "3.12" 23 | - name: Install dependencies 24 | run: python -m pip install -U setuptools wheel build 25 | - name: Build 26 | run: python -m build . 27 | - name: Publish 28 | uses: pypa/gh-action-pypi-publish@release/v1 29 | with: 30 | skip-existing: true 31 | -------------------------------------------------------------------------------- /mipcandy/data/__init__.py: -------------------------------------------------------------------------------- 1 | from mipcandy.data.convertion import convert_ids_to_logits, convert_logits_to_ids, auto_convert 2 | from mipcandy.data.dataset import Loader, UnsupervisedDataset, SupervisedDataset, DatasetFromMemory, MergedDataset, \ 3 | PathBasedUnsupervisedDataset, SimpleDataset, PathBasedSupervisedDataset, NNUNetDataset, BinarizedDataset 4 | from mipcandy.data.download import download_dataset 5 | from mipcandy.data.geometric import ensure_num_dimensions, orthographic_views, aggregate_orthographic_views, crop 6 | from mipcandy.data.inspection import InspectionAnnotation, InspectionAnnotations, load_inspection_annotations, \ 7 | inspect, ROIDataset, RandomROIDataset 8 | from mipcandy.data.io import resample_to_isotropic, load_image, save_image 9 | from mipcandy.data.visualization import visualize2d, visualize3d, overlay 10 | -------------------------------------------------------------------------------- /mipcandy/__entry__.py: -------------------------------------------------------------------------------- 1 | from argparse import ArgumentParser 2 | 3 | from mipcandy.run import config 4 | 5 | 6 | def __entry__() -> None: 7 | parser = ArgumentParser(prog="MIP Candy CLI", description="MIP Candy Command Line Interface", 8 | epilog="GitHub: https://github.com/ProjectNeura/MIPCandy") 9 | parser.add_argument("-c", "--config", choices=("setting", "secret"), default=None, 10 | help="set a configuration such that key=value") 11 | parser.add_argument("-kv", "--key-value", nargs=2, action="append", default=None, help="define a key-value pair") 12 | args = parser.parse_args() 13 | if args.config: 14 | if not args.key_value: 15 | raise ValueError("Expected at least one key-value pair") 16 | for key_value in args.key_value: 17 | config(args.config, key_value[0], key_value[1]) 18 | -------------------------------------------------------------------------------- /mipcandy/data/download.py: -------------------------------------------------------------------------------- 1 | from os import PathLike 2 | from zipfile import ZipFile 3 | 4 | from requests import get 5 | from rich.console import Console 6 | from rich.progress import track 7 | 8 | 9 | def download_dataset(name: str, to: str | PathLike[str], *, endpoint: str = "cds.projectneura.org", 10 | console: Console = Console()) -> None: 11 | to_zip = f"{to}.zip" 12 | with get(f"https://{endpoint}/{name}.zip", stream=True) as response: 13 | response.raise_for_status() 14 | with open(to_zip, "wb") as f: 15 | for chunk in track(response.iter_content(chunk_size=8192), description="Downloading...", console=console): 16 | f.write(chunk) 17 | console.print(f"Dataset successfully downloaded as {to_zip}") 18 | console.print("Unzipping...") 19 | with ZipFile(to_zip, "r") as zip_ref: 20 | zip_ref.extractall(to) 21 | console.print(f"Dataset successfully extracted to {to}") 22 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["hatchling"] 3 | build-backend = "hatchling.build" 4 | 5 | [project] 6 | name = "mipcandy" 7 | version = "1.1.0-beta.2" 8 | description = "A Candy for Medical Image Processing" 9 | license = "Apache-2.0" 10 | readme = "README.md" 11 | requires-python = ">=3.12" 12 | authors = [ 13 | { name = "Project Neura", email = "central@projectneura.org" } 14 | ] 15 | dependencies = ["torch", "ptflops", "numpy", "SimpleITK", "matplotlib", "rich", "pandas", "requests"] 16 | 17 | [project.optional-dependencies] 18 | standard = ["pyvista"] 19 | all = ["pyvista", "mipcandy-bundles"] 20 | 21 | [tool.hatch.build.targets.sdist] 22 | only-include = ["mipcandy"] 23 | 24 | [tool.hatch.build.targets.wheel] 25 | packages = ["mipcandy"] 26 | 27 | [project.urls] 28 | Homepage = "https://mipcandy.projectneura.org" 29 | Documentation = "https://mipcandy-docs.projectneura.org" 30 | Repository = "https://github.com/ProjectNeura/MIPCandy" 31 | 32 | [project.scripts] 33 | mipcandy = "mipcandy:__entry__" -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/bug_report.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Bug report 3 | about: Create a report to help us improve 4 | title: '' 5 | labels: [ bug, todo ] 6 | assignees: ATATC 7 | issue_type: Bug 8 | 9 | --- 10 | 11 | **Describe the bug** 12 | A clear and concise description of what the bug is. 13 | 14 | **To Reproduce** 15 | Steps to reproduce the behavior: 16 | 1. Go to '...' 17 | 2. Click on '....' 18 | 3. Scroll down to '....' 19 | 4. See error 20 | 21 | **Expected behavior** 22 | A clear and concise description of what you expected to happen. 23 | 24 | **Screenshots** 25 | If applicable, add screenshots to help explain your problem. 26 | 27 | **Desktop (please complete the following information):** 28 | - OS: [e.g. iOS] 29 | - Browser [e.g. chrome, safari] 30 | - Version [e.g. 22] 31 | 32 | **Smartphone (please complete the following information):** 33 | - Device: [e.g. iPhone6] 34 | - OS: [e.g. iOS8.1] 35 | - Browser [e.g. stock browser, safari] 36 | - Version [e.g. 22] 37 | 38 | **Additional context** 39 | Add any other context about the problem here. 40 | -------------------------------------------------------------------------------- /mipcandy/data/convertion.py: -------------------------------------------------------------------------------- 1 | from typing import Literal 2 | 3 | import torch 4 | from torch import nn 5 | 6 | from mipcandy.common import Normalize 7 | 8 | 9 | def convert_ids_to_logits(ids: torch.Tensor, d: Literal[1, 2, 3], num_classes: int) -> torch.Tensor: 10 | if ids.dtype != torch.int or ids.min() < 0: 11 | raise TypeError("`ids` should be positive integers") 12 | d += 1 13 | if ids.ndim != d: 14 | if ids.ndim == d + 1 and ids.shape[1] == 1: 15 | ids = ids.squeeze(1) 16 | else: 17 | raise ValueError(f"`ids` should be {d} dimensional or {d + 1} dimensional with single channel") 18 | return nn.functional.one_hot(ids.long(), num_classes).movedim(-1, 1).contiguous().float() 19 | 20 | 21 | def convert_logits_to_ids(logits: torch.Tensor, *, channel_dim: int = 1) -> torch.Tensor: 22 | return logits.max(channel_dim).indices.int() 23 | 24 | 25 | def auto_convert(image: torch.Tensor) -> torch.Tensor: 26 | return (image * 255 if 0 <= image.min() < image.max() <= 1 else Normalize(domain=(0, 255))(image)).int() 27 | -------------------------------------------------------------------------------- /mipcandy/data/geometric.py: -------------------------------------------------------------------------------- 1 | from typing import Literal, Sequence 2 | 3 | import torch 4 | 5 | 6 | def ensure_num_dimensions(x: torch.Tensor, num_dimensions: int) -> torch.Tensor: 7 | d = num_dimensions - x.ndim 8 | if d == 0: 9 | return x 10 | return x.reshape(*((1,) * d + x.shape)) if d > 0 else x.reshape(x.shape[-num_dimensions:]) 11 | 12 | 13 | def orthographic_views(x: torch.Tensor, reduction: Literal["mean", "sum"] = "mean") -> tuple[ 14 | torch.Tensor, torch.Tensor, torch.Tensor]: 15 | match reduction: 16 | case "mean": 17 | return x.mean(dim=-3), x.mean(dim=-2), x.mean(dim=-1) 18 | case "sum": 19 | return x.sum(dim=-3), x.sum(dim=-2), x.sum(dim=-1) 20 | 21 | 22 | def aggregate_orthographic_views(d: torch.Tensor, h: torch.Tensor, w: torch.Tensor) -> torch.Tensor: 23 | d, h, w = d.unsqueeze(-3), h.unsqueeze(-2), w.unsqueeze(-1) 24 | return d * h * w 25 | 26 | 27 | def crop(t: torch.Tensor, bbox: Sequence[int]) -> torch.Tensor: 28 | return t[:, :, bbox[0]:bbox[1], bbox[2]:bbox[3]] if len(bbox) == 4 else t[:, :, bbox[0]:bbox[1], bbox[2]:bbox[3], 29 | bbox[4]:bbox[5]] 30 | -------------------------------------------------------------------------------- /mipcandy/__init__.py: -------------------------------------------------------------------------------- 1 | from mipcandy.__entry__ import __entry__ 2 | from mipcandy.common import * 3 | from mipcandy.config import load_settings, save_settings, load_secrets, save_secrets 4 | from mipcandy.data import * 5 | from mipcandy.evaluation import EvalCase, EvalResult, Evaluator 6 | from mipcandy.frontend import * 7 | from mipcandy.inference import parse_predictant, Predictor 8 | from mipcandy.layer import batch_int_multiply, batch_int_divide, LayerT, HasDevice, auto_device, WithPaddingModule, \ 9 | WithNetwork 10 | from mipcandy.metrics import do_reduction, dice_similarity_coefficient_binary, \ 11 | dice_similarity_coefficient_multiclass, soft_dice_coefficient, accuracy_binary, accuracy_multiclass, \ 12 | precision_binary, precision_multiclass, recall_binary, recall_multiclass, iou_binary, iou_multiclass 13 | from mipcandy.presets import * 14 | from mipcandy.run import config 15 | from mipcandy.sanity_check import num_trainable_params, model_complexity_info, SanityCheckResult, sanity_check 16 | from mipcandy.training import TrainerToolbox, Trainer, SWMetadata, SlidingTrainer 17 | from mipcandy.types import Setting, Settings, Params, Transform, SupportedPredictant, Colormap, Device, Shape2d, \ 18 | Shape3d, Shape, AmbiguousShape 19 | -------------------------------------------------------------------------------- /mipcandy/common/optim/lr_scheduler.py: -------------------------------------------------------------------------------- 1 | from typing import override 2 | 3 | from torch import optim 4 | 5 | 6 | class AbsoluteLinearLR(optim.lr_scheduler.LRScheduler): 7 | """ 8 | lr = kx + b 9 | """ 10 | def __init__(self, optimizer: optim.Optimizer, k: float, b: float, *, min_lr: float = 1e-6, 11 | restart: bool = False, last_epoch: int = -1) -> None: 12 | self._k: float = k 13 | self._b: float = b 14 | if min_lr < 0: 15 | raise ValueError(f"`min_lr` must be positive, but got {min_lr}") 16 | self._min_lr: float = min_lr 17 | self._restart: bool = restart 18 | self._restart_step: int = 0 19 | super().__init__(optimizer, last_epoch) 20 | 21 | def _interp(self, step: int) -> float: 22 | step -= self._restart_step 23 | r = self._k * step + self._b 24 | if r < self._min_lr: 25 | if self._restart: 26 | self._restart_step = step 27 | return self._interp(step) 28 | return self._min_lr 29 | return r 30 | 31 | @override 32 | def get_lr(self) -> list[float]: 33 | target = self._interp(self.last_epoch) 34 | return [target for _ in self.optimizer.param_groups] 35 | -------------------------------------------------------------------------------- /mipcandy/frontend/wandb_fe.py: -------------------------------------------------------------------------------- 1 | from typing import override 2 | 3 | from wandb import init, Run 4 | 5 | from mipcandy.frontend.prototype import Frontend 6 | from mipcandy.types import Settings 7 | 8 | 9 | class WandBFrontend(Frontend): 10 | def __init__(self, secrets: Settings) -> None: 11 | super().__init__(secrets) 12 | self._entity: str = self.require_nonempty_secret("wandb_entity", require_type=str) 13 | self._project: str = self.require_nonempty_secret("wandb_project", require_type=str) 14 | self._run: Run | None = None 15 | 16 | @override 17 | def on_experiment_created(self, experiment_id: str, trainer: str, model: str, note: str, num_macs: float, 18 | num_params: float, num_epochs: int, early_stop_tolerance: int) -> None: 19 | self._run = init(entity=self._entity, project=self._project, config={ 20 | "experiment_id": experiment_id, "trainer": trainer, "model": model, "note": note, "num_macs": num_macs, 21 | "num_params": num_params, "num_epochs": num_epochs 22 | }) 23 | 24 | @override 25 | def on_experiment_updated(self, experiment_id: str, epoch: int, metrics: dict[str, list[float]], 26 | early_stop_tolerance: int) -> None: 27 | if self._run: 28 | self._run.log(metrics) 29 | 30 | @override 31 | def on_experiment_completed(self, experiment_id: str) -> None: 32 | if not self._run: 33 | raise RuntimeError("Experiment has not been created") 34 | self._run.finish() 35 | -------------------------------------------------------------------------------- /mipcandy/common/numpy/regressions.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | def quotient_regression(x: np.ndarray, y: np.ndarray, m: int, n: int) -> tuple[np.ndarray, np.ndarray]: 5 | matrix = [] 6 | for xi, yi in zip(x, y): 7 | row = [] 8 | for k in range(m + 1): 9 | row.append(xi ** k) 10 | for k in range(1, n + 1): 11 | row.append(-yi * xi ** k) 12 | matrix.append(row) 13 | matrix = np.array(matrix) 14 | coefficients, _, _, _ = np.linalg.lstsq(matrix, y, rcond=None) 15 | return coefficients[:m + 1][::-1], np.concatenate(([1.0], coefficients[m + 1:]))[::-1] 16 | 17 | 18 | def quotient_derivative(a: np.ndarray, b: np.ndarray) -> tuple[np.ndarray, np.ndarray]: 19 | da = np.polyder(a) 20 | db = np.polyder(b) 21 | return np.polysub(np.polymul(da, b), np.polymul(a, db)), np.polymul(b, b) 22 | 23 | 24 | def quotient_bounds(a: np.ndarray, b: np.ndarray, lower_bound: float | None, upper_bound: float | None, *, 25 | x_start: float = 0, x_stop: float = 1e4, x_step: float = .01) -> tuple[float, float] | None: 26 | x = np.arange(x_start, x_stop, x_step) 27 | y = np.polyval(a, x) / np.polyval(b, x) 28 | mask = np.array(True, like=y) 29 | if lower_bound is not None: 30 | mask = mask & (y > lower_bound) 31 | if upper_bound is not None: 32 | mask = mask & (y < upper_bound) 33 | if isinstance(mask, bool): 34 | raise ValueError("Bounds must be specified on at least one side") 35 | return (float(x[mask][0]), float(x[mask][-1])) if mask.any() else None 36 | -------------------------------------------------------------------------------- /mipcandy/sanity_check.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from io import StringIO 3 | from typing import Sequence, override 4 | 5 | import torch 6 | from ptflops import get_model_complexity_info 7 | from torch import nn 8 | 9 | from mipcandy.layer import auto_device 10 | from mipcandy.types import Device 11 | 12 | 13 | def num_trainable_params(model: nn.Module) -> int: 14 | return sum(p.numel() for p in model.parameters() if p.requires_grad) 15 | 16 | 17 | def model_complexity_info(model: nn.Module, example_shape: Sequence[int]) -> tuple[float | None, float | None, str]: 18 | layer_stats = StringIO() 19 | macs, params = get_model_complexity_info(model, tuple(example_shape), ost=layer_stats, as_strings=False) 20 | return macs, params, layer_stats.getvalue() 21 | 22 | 23 | @dataclass 24 | class SanityCheckResult(object): 25 | num_macs: float 26 | num_params: float 27 | layer_stats: str 28 | output: torch.Tensor 29 | 30 | @override 31 | def __str__(self) -> str: 32 | return f"MACs: {self.num_macs * 1e-9:.1f} G / Params: {self.num_params * 1e-6:.1f} M" 33 | 34 | 35 | def sanity_check(model: nn.Module, input_shape: Sequence[int], *, device: Device | None = None) -> SanityCheckResult: 36 | if device is None: 37 | device = auto_device() 38 | num_macs, num_params, layer_stats = model_complexity_info(model, input_shape) 39 | if num_macs is None or num_params is None: 40 | raise RuntimeError("Failed to validate model") 41 | outputs = model.to(device).eval()(torch.randn(1, *input_shape, device=device)) 42 | return SanityCheckResult(num_macs, num_params, layer_stats, ( 43 | outputs[0] if isinstance(outputs, tuple) else outputs).squeeze(0)) 44 | -------------------------------------------------------------------------------- /mipcandy/config.py: -------------------------------------------------------------------------------- 1 | from os import PathLike 2 | from os.path import abspath, exists 3 | 4 | from yaml import load, SafeLoader, dump, SafeDumper 5 | 6 | from mipcandy.types import Settings 7 | 8 | _DIR: str = abspath(__file__)[:-9] 9 | _DEFAULT_SETTINGS_PATH: str = f"{_DIR}settings.yml" 10 | _DEFAULT_SECRETS_PATH: str = f"{_DIR}secrets.yml" 11 | 12 | 13 | def _load(path: str | PathLike[str], *, hint: str = "fill in your settings here") -> Settings: 14 | if not exists(path): 15 | with open(path, "w") as f: 16 | f.write(f"# {hint}\n") 17 | with open(path) as f: 18 | settings = load(f.read(), SafeLoader) 19 | if settings is None: 20 | return {} 21 | if not isinstance(settings, dict): 22 | raise ValueError(f"Invalid settings file: {path}") 23 | return settings 24 | 25 | 26 | def _save(settings: Settings, path: str | PathLike[str], *, hint: str = "fill in your settings here") -> None: 27 | with open(path, "w") as f: 28 | f.write(f"# {hint}\n") 29 | dump(settings, f, SafeDumper) 30 | 31 | 32 | def load_settings(*, path: str | PathLike[str] = _DEFAULT_SETTINGS_PATH) -> Settings: 33 | return _load(path) 34 | 35 | 36 | def save_settings(settings: Settings, *, path: str | PathLike[str] = _DEFAULT_SETTINGS_PATH) -> None: 37 | _save(settings, path) 38 | 39 | 40 | def load_secrets(*, path: str | PathLike[str] = _DEFAULT_SECRETS_PATH) -> Settings: 41 | return _load(path, hint="fill in your secrets here, do not commit this file") 42 | 43 | 44 | def save_secrets(secrets: Settings, *, path: str | PathLike[str] = _DEFAULT_SECRETS_PATH) -> None: 45 | _save(secrets, path, hint="fill in your secrets here, do not commit this file") 46 | -------------------------------------------------------------------------------- /mipcandy/data/io.py: -------------------------------------------------------------------------------- 1 | from math import floor 2 | from os import PathLike 3 | 4 | import SimpleITK as SpITK 5 | import torch 6 | 7 | from mipcandy.data.convertion import auto_convert 8 | from mipcandy.data.geometric import ensure_num_dimensions 9 | from mipcandy.types import Device 10 | 11 | 12 | def resample_to_isotropic(image: SpITK.Image, *, target_iso: float | None = None, 13 | interpolator: int = SpITK.sitkBSpline) -> SpITK.Image: 14 | dim = image.GetDimension() 15 | old_spacing = image.GetSpacing() 16 | old_size = image.GetSize() 17 | origin = image.GetOrigin() 18 | direction = image.GetDirection() 19 | if target_iso is None: 20 | target_iso = min(old_spacing) 21 | new_spacing = (target_iso,) * dim 22 | new_size = (max(1, floor(old_spacing[i] * (old_size[i] - 1) / new_spacing[i] + 1)) for i in range(dim)) 23 | return SpITK.Resample( 24 | image, new_size, SpITK.Transform(), interpolator, origin, new_spacing, direction, 0, image.GetPixelID() 25 | ) 26 | 27 | 28 | def load_image(path: str | PathLike[str], *, is_label: bool = False, align_spacing: bool = False, 29 | device: Device = "cpu") -> torch.Tensor: 30 | file = SpITK.ReadImage(path) 31 | if align_spacing: 32 | file = resample_to_isotropic(file, interpolator=SpITK.sitkNearestNeighbor if is_label else SpITK.sitkBSpline) 33 | img = torch.tensor(SpITK.GetArrayFromImage(file), dtype=torch.float, device=device) 34 | if path.endswith(".nii.gz") or path.endswith(".nii") or path.endswith(".mha"): 35 | img = ensure_num_dimensions(img, 4) 36 | return img.squeeze(1) if img.shape[1] == 1 else img 37 | if path.endswith(".png") or path.endswith(".jpg") or path.endswith(".jpeg"): 38 | return ensure_num_dimensions(img, 3) 39 | raise NotImplementedError(f"Unsupported file type: {path}") 40 | 41 | 42 | def save_image(image: torch.Tensor, path: str | PathLike[str]) -> None: 43 | if path.endswith(".png"): 44 | image = auto_convert(image).to(torch.uint8) 45 | SpITK.WriteImage(SpITK.GetImageFromArray(image.detach().cpu().numpy()), path) 46 | -------------------------------------------------------------------------------- /mipcandy/common/optim/loss.py: -------------------------------------------------------------------------------- 1 | from typing import Literal 2 | 3 | import torch 4 | from torch import nn 5 | 6 | from mipcandy.data import convert_ids_to_logits 7 | from mipcandy.metrics import do_reduction, soft_dice_coefficient 8 | 9 | 10 | class FocalBCEWithLogits(nn.Module): 11 | def __init__(self, alpha: float, gamma: float, *, reduction: Literal["mean", "sum", "none"] = "mean") -> None: 12 | super().__init__() 13 | self.alpha: float = alpha 14 | self.gamma: float = gamma 15 | self.reduction: Literal["mean", "sum", "none"] = reduction 16 | 17 | def forward(self, logits: torch.Tensor, targets: torch.Tensor) -> torch.Tensor: 18 | bce = nn.functional.binary_cross_entropy_with_logits(logits, targets, reduction="none") 19 | p = torch.sigmoid(logits) 20 | p_t = torch.where(targets.bool(), p, 1 - p) 21 | alpha_t = torch.where(targets.bool(), torch.as_tensor(self.alpha, device=logits.device), torch.as_tensor( 22 | 1 - self.alpha, device=logits.device)) 23 | loss = alpha_t * (1 - p_t).pow(self.gamma) * bce 24 | return do_reduction(loss, self.reduction) 25 | 26 | 27 | class DiceBCELossWithLogits(nn.Module): 28 | def __init__(self, num_classes: int, *, lambda_bce: float = .5, lambda_soft_dice: float = 1, 29 | smooth: float = 1e-5, include_bg: bool = True) -> None: 30 | super().__init__() 31 | self.num_classes: int = num_classes 32 | self.lambda_bce: float = lambda_bce 33 | self.lambda_soft_dice: float = lambda_soft_dice 34 | self.smooth: float = smooth 35 | self.include_bg: bool = include_bg 36 | 37 | def forward(self, masks: torch.Tensor, labels: torch.Tensor) -> tuple[torch.Tensor, dict[str, float]]: 38 | if self.num_classes != 1 and labels.shape[1] == 1: 39 | d = labels.ndim - 2 40 | if d not in (1, 2, 3): 41 | raise ValueError(f"Expected labels to be 1D, 2D, or 3D, got {d} spatial dimensions") 42 | labels = convert_ids_to_logits(labels.int(), d, self.num_classes) 43 | labels = labels.float() 44 | bce = nn.functional.binary_cross_entropy_with_logits(masks, labels) 45 | masks = masks.sigmoid() 46 | soft_dice = soft_dice_coefficient(masks, labels, smooth=self.smooth, include_bg=self.include_bg) 47 | c = self.lambda_bce * bce + self.lambda_soft_dice * (1 - soft_dice) 48 | return c, {"soft dice": soft_dice.item(), "bce loss": bce.item()} 49 | -------------------------------------------------------------------------------- /mipcandy/common/module/conv.py: -------------------------------------------------------------------------------- 1 | from typing import override 2 | 3 | import torch 4 | from torch import nn 5 | 6 | from mipcandy.layer import LayerT 7 | 8 | 9 | class AbstractConvBlock(nn.Module): 10 | def __init__(self, in_ch: int, out_ch: int, kernel_size: int, *, stride: int = 1, padding: int = 0, 11 | dilation: int = 1, groups: int = 1, bias: bool = True, padding_mode: str = "zeros", 12 | conv: LayerT = ..., norm: LayerT = ..., act: LayerT = ...) -> None: 13 | super().__init__() 14 | self.conv: nn.Module = conv.assemble(in_ch, out_ch, kernel_size, stride, padding, dilation, groups, bias, 15 | padding_mode) 16 | self.norm: nn.Module = norm.assemble(in_ch=out_ch) 17 | self.act: nn.Module = act.assemble() 18 | 19 | def forward(self, x: torch.Tensor) -> torch.Tensor: 20 | return self.act(self.norm(self.conv(x))) 21 | 22 | 23 | def _conv_block(default_conv: LayerT, default_norm: LayerT, default_act: LayerT) -> type[AbstractConvBlock]: 24 | class ConvBlock(AbstractConvBlock): 25 | def __init__(self, *args, **kwargs) -> None: 26 | if "conv" not in kwargs: 27 | kwargs["conv"] = default_conv 28 | if "norm" not in kwargs: 29 | kwargs["norm"] = default_norm 30 | if "act" not in kwargs: 31 | kwargs["act"] = default_act 32 | super().__init__(*args, **kwargs) 33 | 34 | return ConvBlock 35 | 36 | 37 | ConvBlock2d: type[AbstractConvBlock] = _conv_block( 38 | LayerT(nn.Conv2d), LayerT(nn.BatchNorm2d, num_features="in_ch"), LayerT(nn.ReLU, inplace=True) 39 | ) 40 | 41 | ConvBlock3d: type[AbstractConvBlock] = _conv_block( 42 | LayerT(nn.Conv3d), LayerT(nn.BatchNorm3d, num_features="in_ch"), LayerT(nn.ReLU, inplace=True) 43 | ) 44 | 45 | 46 | class WSConv2d(nn.Conv2d): 47 | @override 48 | def forward(self, x: torch.Tensor) -> torch.Tensor: 49 | w = self.weight 50 | w = (w - w.mean(dim=(1, 2, 3), keepdim=True)) / (w.std(dim=(1, 2, 3), keepdim=True) + 1e-5) 51 | return nn.functional.conv2d(x, w, self.bias, self.stride, self.padding, self.dilation, self.groups) 52 | 53 | 54 | class WSConv3d(nn.Conv3d): 55 | @override 56 | def forward(self, x: torch.Tensor) -> torch.Tensor: 57 | w = self.weight 58 | w = (w - w.mean(dim=(1, 2, 3, 4), keepdim=True)) / (w.std(dim=(1, 2, 3, 4), keepdim=True) + 1e-5) 59 | return nn.functional.conv3d(x, w, self.bias, self.stride, self.padding, self.dilation, self.groups) 60 | -------------------------------------------------------------------------------- /mipcandy/frontend/prototype.py: -------------------------------------------------------------------------------- 1 | from typing import override 2 | 3 | 4 | from mipcandy.types import Setting, Settings 5 | 6 | 7 | class Frontend(object): 8 | def __init__(self, secrets: Settings) -> None: 9 | self._secrets: Settings = secrets 10 | 11 | def require_nonempty_secret(self, entry: str, *, require_type: type | None = None) -> Setting: 12 | if entry not in self._secrets: 13 | raise ValueError(f"Missing secret {entry}") 14 | secret = self._secrets[entry] 15 | if require_type is None or isinstance(secret, require_type): 16 | return secret 17 | raise ValueError(f"Invalid secret type {type(secret)}, {require_type} expected") 18 | 19 | def on_experiment_created(self, experiment_id: str, trainer: str, model: str, note: str, num_params: float, 20 | num_macs: float, num_epochs: int, early_stop_tolerance: int) -> None: 21 | ... 22 | 23 | def on_experiment_updated(self, experiment_id: str, epoch: int, metrics: dict[str, list[float]], 24 | early_stop_tolerance: int) -> None: 25 | ... 26 | 27 | def on_experiment_completed(self, experiment_id: str) -> None: 28 | ... 29 | 30 | def on_experiment_interrupted(self, experiment_id: str, error: Exception) -> None: 31 | ... 32 | 33 | 34 | def create_hybrid_frontend(*frontends: Frontend) -> type[Frontend]: 35 | class HybridFrontend(Frontend): 36 | def __init__(self, secrets: Settings) -> None: 37 | super().__init__(secrets) 38 | 39 | @override 40 | def on_experiment_created(self, experiment_id: str, trainer: str, model: str, note: str, num_macs: float, 41 | num_params: float, num_epochs: int, early_stop_tolerance: int) -> None: 42 | for frontend in frontends: 43 | frontend.on_experiment_created(experiment_id, trainer, model, note, num_macs, num_params, num_epochs, 44 | early_stop_tolerance) 45 | 46 | @override 47 | def on_experiment_updated(self, experiment_id: str, epoch: int, metrics: dict[str, list[float]], 48 | early_stop_tolerance: int) -> None: 49 | for frontend in frontends: 50 | frontend.on_experiment_updated(experiment_id, epoch, metrics, early_stop_tolerance) 51 | 52 | @override 53 | def on_experiment_completed(self, experiment_id: str) -> None: 54 | for frontend in frontends: 55 | frontend.on_experiment_completed(experiment_id) 56 | 57 | @override 58 | def on_experiment_interrupted(self, experiment_id: str, error: Exception) -> None: 59 | for frontend in frontends: 60 | frontend.on_experiment_interrupted(experiment_id, error) 61 | 62 | return HybridFrontend 63 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # MIP Candy: A Candy for Medical Image Processing 2 | 3 | ![PyPI](https://img.shields.io/pypi/v/mipcandy) 4 | ![GitHub Release](https://img.shields.io/github/v/release/ProjectNeura/MIPCandy) 5 | ![PyPI Downloads](https://img.shields.io/pypi/dm/mipcandy) 6 | ![GitHub Stars](https://img.shields.io/github/stars/ProjectNeura/MIPCandy) 7 | 8 | ![poster](home/assets/poster.png) 9 | 10 | MIP Candy is Project Neura's next-generation infrastructure framework for medical image processing. It defines a handful 11 | of common network architectures with their corresponding training, inference, and evaluation pipelines that are 12 | out-of-the-box ready to use. Additionally, it also provides integrations with popular frontend dashboards such as 13 | Notion, WandB, and TensorBoard. 14 | 15 | We provide a flexible and extensible framework for medical image processing researchers to quickly prototype their 16 | ideas. MIP Candy takes care of all the rest, so you can focus on only the key experiment designs. 17 | 18 | :link: [Home](https://mipcandy.projectneura.org) 19 | 20 | :link: [Docs](https://mipcandy-docs.projectneura.org) 21 | 22 | ## Key Features 23 | 24 | Why MIP Candy? :thinking: 25 | 26 |
27 | Easy adaptation to fit your needs 28 | We provide tons of easy-to-use techniques for training that seamlessly support your customized experiments. 29 | 30 | - Sliding window 31 | - ROI inspection 32 | - ROI cropping to align dataset shape (100% or 33% foreground) 33 | - Automatic padding 34 | - ... 35 | 36 | You only need to override one method to create a trainer for your network architecture. 37 | 38 | ```python 39 | from typing import override 40 | 41 | from torch import nn 42 | from mipcandy import SegmentationTrainer 43 | 44 | 45 | class MyTrainer(SegmentationTrainer): 46 | @override 47 | def build_network(self, example_shape: tuple[int, ...]) -> nn.Module: 48 | ... 49 | ``` 50 |
51 | 52 |
53 | Satisfying command-line UI design 54 | cmd-ui 55 |
56 | 57 |
58 | Built-in 2D and 3D visualization for intuitive understanding 59 | visualization 60 |
61 | 62 |
63 | High availability with interruption tolerance 64 | Interrupted experiments can be resumed with ease. 65 | recovery 66 |
67 | 68 |
69 | Support of various frontend platforms for remote monitoring 70 | 71 | MIP Candy Supports [Notion](https://mipcandy-projectneura.notion.site), WandB, and TensorBoard. 72 | 73 | notion 74 |
75 | 76 | ## Installation 77 | 78 | Note that MIP Candy requires **Python >= 3.12**. 79 | 80 | ```shell 81 | pip install "mipcandy[standard]" 82 | ``` 83 | 84 | ## Quick Start 85 | 86 | Below is a simple example of a nnU-Net style training. The batch size is set to 1 due to the varying shape of the 87 | dataset, although you can use a `ROIDataset` to align the shapes. 88 | 89 | ```python 90 | from typing import override 91 | 92 | import torch 93 | from mipcandy_bundles.unet import UNetTrainer 94 | from torch.utils.data import DataLoader 95 | 96 | from mipcandy import download_dataset, NNUNetDataset 97 | 98 | 99 | class PH2(NNUNetDataset): 100 | @override 101 | def load(self, idx: int) -> tuple[torch.Tensor, torch.Tensor]: 102 | image, label = super().load(idx) 103 | return image.squeeze(0).permute(2, 0, 1), label 104 | 105 | 106 | download_dataset("nnunet_datasets/PH2", "tutorial/datasets/PH2") 107 | dataset, val_dataset = PH2("tutorial/datasets/PH2", device="cuda").fold() 108 | dataloader = DataLoader(dataset, 1, shuffle=True) 109 | val_dataloader = DataLoader(val_dataset, 1, shuffle=False) 110 | trainer = UNetTrainer("tutorial", dataloader, val_dataloader, device="cuda") 111 | trainer.train(1000, note="a nnU-Net style example") 112 | ``` -------------------------------------------------------------------------------- /mipcandy/layer.py: -------------------------------------------------------------------------------- 1 | from abc import ABCMeta, abstractmethod 2 | from typing import Any, Generator, Self, Mapping 3 | 4 | import torch 5 | from torch import nn 6 | 7 | from mipcandy.types import Device, AmbiguousShape 8 | 9 | 10 | def batch_int_multiply(f: float, *n: int) -> Generator[int, None, None]: 11 | for i in n: 12 | r = i * f 13 | if not r.is_integer(): 14 | raise ValueError(f"Inequivalent conversion") 15 | yield int(r) 16 | 17 | 18 | def batch_int_divide(f: float, *n: int) -> Generator[int, None, None]: 19 | return batch_int_multiply(1 / f, *n) 20 | 21 | 22 | class LayerT(object): 23 | def __init__(self, m: type[nn.Module], **kwargs) -> None: 24 | self.m: type[nn.Module] = m 25 | self.kwargs: dict[str, Any] = kwargs 26 | 27 | def update(self, *, must_exist: bool = True, inplace: bool = False, **kwargs) -> Self: 28 | if not inplace: 29 | return self.copy().update(must_exist=must_exist, inplace=True, **kwargs) 30 | for k, v in kwargs.items(): 31 | if not must_exist or k in self.kwargs: 32 | self.kwargs[k] = v 33 | return self 34 | 35 | def assemble(self, *args, **kwargs) -> nn.Module: 36 | self_kwargs = self.kwargs.copy() 37 | for k, v in self_kwargs.items(): 38 | if isinstance(v, str) and v in kwargs: 39 | self_kwargs[k] = kwargs.pop(v) 40 | return self.m(*args, **self_kwargs, **kwargs) 41 | 42 | def copy(self) -> Self: 43 | return self.__class__(self.m, **self.kwargs) 44 | 45 | 46 | class HasDevice(object): 47 | def __init__(self, device: Device) -> None: 48 | self._device: Device = device 49 | 50 | def device(self, *, device: Device | None = None) -> None | Device: 51 | if device is None: 52 | return self._device 53 | else: 54 | self._device = device 55 | 56 | 57 | def auto_device() -> Device: 58 | if torch.cuda.is_available(): 59 | return f"cuda:{max(range(torch.cuda.device_count()), 60 | key=lambda i: torch.cuda.memory_reserved(i) - torch.cuda.memory_allocated(i))}" 61 | if torch.backends.mps.is_available(): 62 | return "mps" 63 | return "cpu" 64 | 65 | 66 | class WithPaddingModule(HasDevice): 67 | def __init__(self, device: Device) -> None: 68 | super().__init__(device) 69 | self._padding_module: nn.Module | None = None 70 | self._restoring_module: nn.Module | None = None 71 | self._padding_module_built: bool = False 72 | 73 | def build_padding_module(self) -> nn.Module | None: 74 | return None 75 | 76 | def build_restoring_module(self, padding_module: nn.Module | None) -> nn.Module | None: 77 | return None 78 | 79 | def _lazy_load_padding_module(self) -> None: 80 | if self._padding_module_built: 81 | return 82 | self._padding_module = self.build_padding_module() 83 | if self._padding_module: 84 | self._padding_module = self._padding_module.to(self._device) 85 | self._restoring_module = self.build_restoring_module(self._padding_module) 86 | if self._restoring_module: 87 | self._restoring_module = self._restoring_module.to(self._device) 88 | self._padding_module_built = True 89 | 90 | def get_padding_module(self) -> nn.Module | None: 91 | self._lazy_load_padding_module() 92 | return self._padding_module 93 | 94 | def get_restoring_module(self) -> nn.Module | None: 95 | self._lazy_load_padding_module() 96 | return self._restoring_module 97 | 98 | 99 | class WithNetwork(HasDevice, metaclass=ABCMeta): 100 | def __init__(self, device: Device) -> None: 101 | super().__init__(device) 102 | 103 | @abstractmethod 104 | def build_network(self, example_shape: AmbiguousShape) -> nn.Module: 105 | raise NotImplementedError 106 | 107 | def build_network_from_checkpoint(self, example_shape: AmbiguousShape, checkpoint: Mapping[str, Any]) -> nn.Module: 108 | """ 109 | Internally exposed interface for overriding. Use `load_model()` instead. 110 | """ 111 | network = self.build_network(example_shape) 112 | network.load_state_dict(checkpoint) 113 | return network 114 | 115 | def load_model(self, example_shape: AmbiguousShape, *, checkpoint: Mapping[str, Any] | None = None) -> nn.Module: 116 | if checkpoint: 117 | return self.build_network_from_checkpoint(example_shape, checkpoint).to(self._device) 118 | return self.build_network(example_shape).to(self._device) 119 | -------------------------------------------------------------------------------- /mipcandy/evaluation.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from typing import Callable, Sequence, Generator, override 3 | 4 | import torch 5 | 6 | from mipcandy.data import SupervisedDataset, MergedDataset, Loader, DatasetFromMemory 7 | from mipcandy.inference import parse_predictant, Predictor 8 | from mipcandy.types import SupportedPredictant 9 | 10 | 11 | @dataclass 12 | class EvalCase(object): 13 | metrics: dict[str, float] 14 | output: torch.Tensor 15 | label: torch.Tensor 16 | image: torch.Tensor | None = None 17 | filename: str | None = None 18 | 19 | 20 | class EvalResult(Sequence[EvalCase]): 21 | def __init__(self, metrics: dict[str, list[float]], outputs: list[torch.Tensor], labels: list[torch.Tensor], *, 22 | images: list[torch.Tensor] | None = None, filenames: list[str] | None = None) -> None: 23 | if len(outputs) != len(labels): 24 | raise ValueError(f"Unmatched number of outputs ({len(outputs)}) and labels ({len(labels)})") 25 | self.metrics: dict[str, list[float]] = metrics 26 | self.mean_metrics: dict[str, float] = {name: sum(values) / len(values) for name, values in metrics.items()} 27 | self.images: list[torch.Tensor] | None = images 28 | self.outputs: list[torch.Tensor] = outputs 29 | self.labels: list[torch.Tensor] = labels 30 | self.filenames: list[str] | None = filenames 31 | 32 | @override 33 | def __len__(self) -> int: 34 | return len(self.outputs) 35 | 36 | @override 37 | def __getitem__(self, item: int) -> EvalCase: 38 | return EvalCase({name: values[item] for name, values in self.metrics.items()}, self.outputs[item], 39 | self.labels[item], self.images[item] if self.images else None, 40 | self.filenames[item] if self.filenames else None) 41 | 42 | def _select(self, metric: str, n: int, descending: bool) -> Generator[EvalCase, None, None]: 43 | o_values = self.metrics[metric] 44 | values = o_values.copy() 45 | values.sort(reverse=descending) 46 | for value in values[:n]: 47 | yield self[o_values.index(value)] 48 | 49 | def min(self, metric: str) -> EvalCase: 50 | return self.min_n(metric, 1)[0] 51 | 52 | def min_n(self, metric: str, n: int) -> tuple[EvalCase, ...]: 53 | return tuple(self._select(metric, n, False)) 54 | 55 | def max(self, metric: str) -> EvalCase: 56 | return self.max_n(metric, 1)[0] 57 | 58 | def max_n(self, metric: str, n: int) -> tuple[EvalCase, ...]: 59 | return tuple(self._select(metric, n, True)) 60 | 61 | 62 | class Evaluator(object): 63 | def __init__(self, *metrics: Callable[[torch.Tensor, torch.Tensor], torch.Tensor]) -> None: 64 | self._metrics: tuple[Callable[[torch.Tensor, torch.Tensor], torch.Tensor], ...] = metrics 65 | 66 | def _evaluate_dataset(self, x: SupervisedDataset, *, prefilled_outputs: list[torch.Tensor] | None = None, 67 | prefilled_labels: list[torch.Tensor] | None = None) -> EvalResult: 68 | metrics = {} 69 | outputs = prefilled_outputs if prefilled_outputs else [] 70 | labels = prefilled_labels if prefilled_labels else [] 71 | for output, label in x: 72 | if not prefilled_outputs: 73 | outputs.append(output) 74 | if not prefilled_labels: 75 | labels.append(label) 76 | for m in self._metrics: 77 | if m.__name__ not in metrics: 78 | metrics[m.__name__] = [] 79 | metrics[m.__name__].append(m(output, label).item()) 80 | return EvalResult(metrics, outputs, labels) 81 | 82 | def evaluate_dataset(self, x: SupervisedDataset) -> EvalResult: 83 | return self._evaluate_dataset(x) 84 | 85 | def evaluate(self, outputs: SupportedPredictant, labels: SupportedPredictant) -> EvalResult: 86 | outputs, filenames = parse_predictant(outputs, Loader) 87 | labels, _ = parse_predictant(labels, Loader, as_label=True) 88 | r = self._evaluate_dataset(MergedDataset(DatasetFromMemory(outputs), DatasetFromMemory(labels)), 89 | prefilled_outputs=outputs, prefilled_labels=labels) 90 | r.filenames = filenames 91 | return r 92 | 93 | def predict_and_evaluate(self, x: SupportedPredictant, labels: SupportedPredictant, 94 | predictor: Predictor) -> EvalResult: 95 | x, filenames = parse_predictant(x, Loader) 96 | outputs = [e.cpu() for e in predictor.predict(x)] 97 | r = self.evaluate(outputs, labels) 98 | r.images = x 99 | r.filenames = filenames 100 | return r 101 | -------------------------------------------------------------------------------- /mipcandy/sliding_window.py: -------------------------------------------------------------------------------- 1 | from abc import ABCMeta, abstractmethod 2 | from dataclasses import dataclass 3 | from typing import Literal 4 | 5 | import torch 6 | from torch import nn 7 | 8 | from mipcandy.layer import HasDevice 9 | from mipcandy.types import Shape 10 | 11 | 12 | @dataclass 13 | class SWMetadata(object): 14 | kernel: Shape 15 | stride: tuple[int, int] | tuple[int, int, int] 16 | ndim: Literal[2, 3] 17 | batch_size: int 18 | out_size: Shape 19 | n: int 20 | 21 | 22 | class SlidingWindow(HasDevice, metaclass=ABCMeta): 23 | sliding_window_batch_size: int | None = None 24 | 25 | @abstractmethod 26 | def get_window_shape(self) -> Shape: 27 | raise NotImplementedError 28 | 29 | def get_batch_size(self) -> int | None: 30 | return self.sliding_window_batch_size 31 | 32 | def gaussian_1d(self, k: int, *, sigma_scale: float = 0.5) -> torch.Tensor: 33 | x = torch.linspace(-1.0, 1.0, steps=k, device=self._device) 34 | sigma = sigma_scale 35 | g = torch.exp(-0.5 * (x / sigma) ** 2) 36 | g /= g.max() 37 | return g 38 | 39 | def do_sliding_window(self, t: torch.Tensor) -> tuple[torch.Tensor, SWMetadata]: 40 | window_shape = self.get_window_shape() 41 | if not (len(window_shape) + 2 == t.ndim): 42 | raise RuntimeError("Unmatched number of dimensions") 43 | stride = window_shape 44 | if len(stride) == 2: 45 | kernel = stride[0] * 2, stride[1] * 2 46 | b, c, h, w = t.shape 47 | t = nn.functional.unfold(t, kernel, stride=stride) 48 | n = t.shape[-1] 49 | kh, kw = kernel 50 | return (t.transpose(1, 2).contiguous().view(b * n, c, kh, kw), 51 | SWMetadata(kernel, stride, 2, b, (h, w), n)) 52 | else: 53 | b, c, d, h, w = t.shape 54 | sd, sh, sw = stride 55 | kd, kh, kw = kernel = sd * 2, sh * 2, sw * 2 56 | image_windows = [] 57 | for z in range(0, d - kd + 1, sd): 58 | for y in range(0, h - kh + 1, sh): 59 | for x in range(0, w - kw + 1, sw): 60 | image_windows.append(t[:, :, z:z + kd, y:y + kh, x:x + kw]) 61 | t = torch.stack(image_windows, dim=0) 62 | n = t.shape[0] 63 | return (t.permute(0, 1, 2, 3, 4, 5).contiguous().view(b * n, c, kd, kh, kw), 64 | SWMetadata(kernel, stride, 3, b, (d, h, w), n)) 65 | 66 | def revert_sliding_window(self, t: torch.Tensor, metadata: SWMetadata, *, clamp_min: float = 1e-8) -> torch.Tensor: 67 | kernel = metadata.kernel 68 | stride = metadata.stride 69 | dims = metadata.ndim 70 | b = metadata.batch_size 71 | out_size = metadata.out_size 72 | n = metadata.n 73 | dtype = t.dtype 74 | if dims == 2: 75 | kh, kw = kernel 76 | gh = self.gaussian_1d(kh) 77 | gw = self.gaussian_1d(kw) 78 | w2d = (gh[:, None] * gw[None, :]).to(dtype) 79 | w2d /= w2d.max() 80 | w2d = w2d.view(1, 1, kh, kw) 81 | bn, c, _, _ = t.shape 82 | if bn != b * n: 83 | raise RuntimeError("Inconsistent number of windows for reverting sliding window") 84 | weighted = t * w2d 85 | patches = weighted.view(b, n, c, kh, kw) 86 | cols = patches.view(b, n, c * kh * kw).transpose(1, 2).contiguous() 87 | numerator = nn.functional.fold(cols, out_size, kernel, stride=stride) 88 | w_cols = w2d.expand(b, n, 1, kh, kw).contiguous().view(b, n, 1 * kh * kw).transpose(1, 2) 89 | denominator = nn.functional.fold(w_cols, out_size, kernel, stride=stride) 90 | denominator = denominator.clamp_min(clamp_min) 91 | return numerator / denominator 92 | else: 93 | kd, kh, kw = kernel 94 | sd, sh, sw = stride 95 | d, h, w = out_size 96 | gd = self.gaussian_1d(kd) 97 | gh = self.gaussian_1d(kh) 98 | gw = self.gaussian_1d(kw) 99 | w3d = (gd[:, None, None] * gh[None, :, None] * gw[None, None, :]).to(dtype) 100 | w3d /= w3d.max() 101 | w3d = w3d.view(1, 1, kd, kh, kw) 102 | bn, c, _, _, _ = t.shape 103 | if bn != b * n: 104 | raise RuntimeError("Inconsistent number of windows for reverting sliding window") 105 | canvas = torch.zeros((b, c, d, h, w), dtype=dtype, device=self._device) 106 | acc_w = torch.zeros((b, 1, d, h, w), dtype=dtype, device=self._device) 107 | idx = 0 108 | for z in range(0, d - kd + 1, sd): 109 | for y in range(0, h - kh + 1, sh): 110 | for x in range(0, w - kw + 1, sw): 111 | window = t[idx * b:(idx + 1) * b] 112 | window *= w3d 113 | canvas[:, :, z:z + kd, y:y + kh, x:x + kw] += window 114 | acc_w[:, :, z:z + kd, y:y + kh, x:x + kw] += w3d 115 | idx += 1 116 | acc_w = acc_w.clamp_min(clamp_min) 117 | return canvas / acc_w 118 | -------------------------------------------------------------------------------- /mipcandy/data/visualization.py: -------------------------------------------------------------------------------- 1 | from importlib.util import find_spec 2 | from math import ceil 3 | from multiprocessing import get_context 4 | from os import PathLike 5 | from typing import Literal 6 | from warnings import warn 7 | 8 | import numpy as np 9 | import torch 10 | from matplotlib import pyplot as plt 11 | from torch import nn 12 | 13 | from mipcandy.common import ColorizeLabel 14 | from mipcandy.data.convertion import auto_convert 15 | from mipcandy.data.geometric import ensure_num_dimensions 16 | 17 | 18 | def visualize2d(image: torch.Tensor, *, title: str | None = None, cmap: str = "gray", 19 | blocking: bool = False, screenshot_as: str | PathLike[str] | None = None) -> None: 20 | image = image.detach().cpu() 21 | if image.ndim < 2: 22 | raise ValueError(f"`image` must have at least 2 dimensions, got {image.shape}") 23 | if image.ndim > 3: 24 | image = ensure_num_dimensions(image, 3) 25 | if image.ndim == 3: 26 | if image.shape[0] == 1: 27 | image = image.squeeze(0) 28 | else: 29 | image = image.permute(1, 2, 0) 30 | image = auto_convert(image) 31 | plt.imshow(image.numpy(), cmap, vmin=0, vmax=255) 32 | plt.title(title) 33 | plt.axis("off") 34 | if screenshot_as: 35 | plt.savefig(screenshot_as) 36 | if blocking: 37 | plt.close() 38 | return 39 | plt.show(block=blocking) 40 | 41 | 42 | def _visualize3d_with_pyvista(image: np.ndarray, title: str | None, cmap: str, 43 | screenshot_as: str | PathLike[str] | None) -> None: 44 | from pyvista import Plotter 45 | p = Plotter(title=title, off_screen=bool(screenshot_as)) 46 | p.add_volume(image, cmap=cmap) 47 | if screenshot_as: 48 | p.screenshot(screenshot_as) 49 | else: 50 | p.show() 51 | 52 | 53 | def visualize3d(image: torch.Tensor, *, title: str | None = None, cmap: str = "gray", max_volume: int = 1e6, 54 | backend: Literal["auto", "matplotlib", "pyvista"] = "auto", blocking: bool = False, 55 | screenshot_as: str | PathLike[str] | None = None) -> None: 56 | image = image.detach().float().cpu() 57 | if image.ndim < 3: 58 | raise ValueError(f"`image` must have at least 3 dimensions, got {image.shape}") 59 | if image.ndim > 3: 60 | image = ensure_num_dimensions(image, 3) 61 | d, h, w = image.shape 62 | total = d * h * w 63 | ratio = int(ceil((total / max_volume) ** (1 / 3))) if total > max_volume else 1 64 | if ratio > 1: 65 | image = ensure_num_dimensions(nn.functional.avg_pool3d(ensure_num_dimensions(image, 5), kernel_size=ratio, 66 | stride=ratio, ceil_mode=True), 3) 67 | image /= image.max() 68 | image = image.numpy() 69 | if backend == "auto": 70 | backend = "pyvista" if find_spec("pyvista") else "matplotlib" 71 | match backend: 72 | case "matplotlib": 73 | warn("Using Matplotlib for 3D visualization is inefficient and inaccurate, consider using PyVista") 74 | face_colors = getattr(plt.cm, cmap)(image) 75 | face_colors[..., 3] = image * (image > 0) 76 | fig = plt.figure() 77 | ax = fig.add_subplot(111, projection="3d") 78 | ax.voxels(image, facecolors=face_colors) 79 | ax.set_title(title) 80 | if screenshot_as: 81 | fig.savefig(screenshot_as) 82 | if blocking: 83 | plt.close() 84 | return 85 | plt.show(block=blocking) 86 | case "pyvista": 87 | image = image.transpose(1, 2, 0) 88 | if blocking: 89 | return _visualize3d_with_pyvista(image, title, cmap, screenshot_as) 90 | ctx = get_context("spawn") 91 | return ctx.Process(target=_visualize3d_with_pyvista, args=(image, title, cmap, screenshot_as), 92 | daemon=False).start() 93 | 94 | 95 | def overlay(image: torch.Tensor, label: torch.Tensor, *, max_label_opacity: float = .5, 96 | label_colorizer: ColorizeLabel | None = ColorizeLabel()) -> torch.Tensor: 97 | if image.ndim < 2 or label.ndim < 2: 98 | raise ValueError("Only 2D images can be overlaid") 99 | image = ensure_num_dimensions(image, 3) 100 | label = ensure_num_dimensions(label, 2) 101 | image = auto_convert(image) 102 | if image.shape[0] == 1: 103 | image = image.repeat(3, 1, 1) 104 | image_c, image_shape = image.shape[0], image.shape[1:] 105 | label_shape = label.shape 106 | if image_shape != label_shape: 107 | raise ValueError(f"Unmatched shapes {image_shape} and {label_shape}") 108 | alpha = (label > 0).int() 109 | if label_colorizer: 110 | label = label_colorizer(label) 111 | if label.shape[0] == 4: 112 | alpha = label[-1] 113 | label = label[:-1] 114 | elif label.shape[0] == 1: 115 | label = label.repeat(3, 1, 1) 116 | if not (image_c == label.shape[0] == 3): 117 | raise ValueError("Unsupported number of channels") 118 | if alpha.max() > 0: 119 | alpha = alpha * max_label_opacity / alpha.max() 120 | return image * (1 - alpha) + label * alpha 121 | -------------------------------------------------------------------------------- /mipcandy/presets/segmentation.py: -------------------------------------------------------------------------------- 1 | from abc import ABCMeta 2 | from typing import override 3 | 4 | import torch 5 | from torch import nn, optim 6 | 7 | from mipcandy.common import AbsoluteLinearLR, DiceBCELossWithLogits 8 | from mipcandy.data import visualize2d, visualize3d, overlay, auto_convert 9 | from mipcandy.sliding_window import SWMetadata 10 | from mipcandy.training import Trainer, TrainerToolbox, SlidingTrainer 11 | from mipcandy.types import Params, Shape 12 | 13 | 14 | class SegmentationTrainer(Trainer, metaclass=ABCMeta): 15 | num_classes: int = 1 16 | include_bg: bool = True 17 | 18 | def _save_preview(self, x: torch.Tensor, title: str, quality: float) -> None: 19 | path = f"{self.experiment_folder()}/{title} (preview).png" 20 | if x.ndim == 3 and x.shape[0] in (1, 3, 4): 21 | visualize2d(auto_convert(x), title=title, blocking=True, screenshot_as=path) 22 | elif x.ndim == 4 and x.shape[0] == 1: 23 | visualize3d(x, title=title, max_volume=int(quality * 1e6), blocking=True, screenshot_as=path) 24 | 25 | @override 26 | def save_preview(self, image: torch.Tensor, label: torch.Tensor, output: torch.Tensor, *, 27 | quality: float = .75) -> None: 28 | output = output.sigmoid() 29 | self._save_preview(image, "input", quality) 30 | self._save_preview(label, "label", quality) 31 | self._save_preview(output, "prediction", quality) 32 | if image.ndim == label.ndim == output.ndim == 3 and label.shape[0] == output.shape[0] == 1: 33 | visualize2d(overlay(image, label), title="expected", blocking=True, 34 | screenshot_as=f"{self.experiment_folder()}/expected (preview).png") 35 | visualize2d(overlay(image, output), title="actual", blocking=True, 36 | screenshot_as=f"{self.experiment_folder()}/actual (preview).png") 37 | 38 | @override 39 | def build_criterion(self) -> nn.Module: 40 | return DiceBCELossWithLogits(self.num_classes, include_bg=self.include_bg) 41 | 42 | @override 43 | def build_optimizer(self, params: Params) -> optim.Optimizer: 44 | return optim.AdamW(params) 45 | 46 | @override 47 | def build_scheduler(self, optimizer: optim.Optimizer, num_epochs: int) -> optim.lr_scheduler.LRScheduler: 48 | return AbsoluteLinearLR(optimizer, -8e-6 / len(self._dataloader), 1e-2) 49 | 50 | @override 51 | def backward(self, images: torch.Tensor, labels: torch.Tensor, toolbox: TrainerToolbox) -> tuple[float, dict[ 52 | str, float]]: 53 | masks = toolbox.model(images) 54 | loss, metrics = toolbox.criterion(masks, labels) 55 | loss.backward() 56 | return loss.item(), metrics 57 | 58 | @override 59 | def validate_case(self, image: torch.Tensor, label: torch.Tensor, toolbox: TrainerToolbox) -> tuple[float, dict[ 60 | str, float], torch.Tensor]: 61 | image, label = image.unsqueeze(0), label.unsqueeze(0) 62 | mask = (toolbox.ema if toolbox.ema else toolbox.model)(image) 63 | loss, metrics = toolbox.criterion(mask, label) 64 | return -loss.item(), metrics, mask.squeeze(0) 65 | 66 | 67 | class SlidingSegmentationTrainer(SlidingTrainer, SegmentationTrainer, metaclass=ABCMeta): 68 | sliding_window_shape: Shape = (128, 128) 69 | 70 | @override 71 | def backward_windowed(self, images: torch.Tensor, labels: torch.Tensor, toolbox: TrainerToolbox, 72 | metadata: SWMetadata) -> tuple[float, dict[str, float]]: 73 | masks = toolbox.model(images) 74 | loss, metrics = toolbox.criterion(masks, labels) 75 | loss.backward() 76 | return loss.item(), metrics 77 | 78 | @override 79 | def validate_case_windowed(self, images: torch.Tensor, label: torch.Tensor, toolbox: TrainerToolbox, 80 | metadata: SWMetadata) -> tuple[float, dict[str, float], torch.Tensor]: 81 | batch_size = self.get_batch_size() 82 | model = toolbox.ema if toolbox.ema else toolbox.model 83 | if batch_size is None or batch_size >= images.shape[0]: 84 | outputs = model(images) 85 | else: 86 | output_list: list[torch.Tensor] = [] 87 | for i in range(0, images.shape[0], batch_size): 88 | batch = images[i:i + batch_size] 89 | output_list.append(model(batch)) 90 | outputs = torch.cat(output_list, dim=0) 91 | outputs = self.revert_sliding_window(outputs, metadata) 92 | loss, metrics = toolbox.criterion(outputs, label.unsqueeze(0)) 93 | return -loss.item(), metrics, outputs.squeeze(0) 94 | 95 | @override 96 | def get_window_shape(self) -> Shape: 97 | return self.sliding_window_shape 98 | 99 | 100 | class SlidingValidationTrainer(SlidingSegmentationTrainer, metaclass=ABCMeta): 101 | """ 102 | Use this when training data comes from RandomROIDataset (already patched), but validation data is full volumes 103 | requiring sliding window inference. 104 | """ 105 | @override 106 | def backward_windowed(self, images: torch.Tensor, labels: torch.Tensor, toolbox: TrainerToolbox, 107 | metadata: SWMetadata) -> tuple[float, dict[str, float]]: 108 | raise RuntimeError("`backward_windowed()` should not be called in `SlidingValidationTrainer`") 109 | 110 | @override 111 | def backward(self, images: torch.Tensor, labels: torch.Tensor, toolbox: TrainerToolbox) -> tuple[float, dict[ 112 | str, float]]: 113 | return SegmentationTrainer.backward(self, images, labels, toolbox) 114 | 115 | @override 116 | def validate_case_windowed(self, images: torch.Tensor, label: torch.Tensor, toolbox: TrainerToolbox, 117 | metadata: SWMetadata) -> tuple[float, dict[str, float], torch.Tensor]: 118 | return super().validate_case_windowed(images, label, toolbox, metadata) 119 | 120 | @override 121 | def get_window_shape(self) -> Shape: 122 | return self.sliding_window_shape 123 | -------------------------------------------------------------------------------- /mipcandy/metrics.py: -------------------------------------------------------------------------------- 1 | from typing import Protocol, Literal 2 | 3 | import torch 4 | 5 | from mipcandy.types import Device 6 | 7 | 8 | def _args_check(output: torch.Tensor, label: torch.Tensor, *, dtype: torch.dtype | None = None, 9 | device: Device | None = None) -> tuple[torch.dtype, Device]: 10 | if output.shape != label.shape: 11 | raise ValueError(f"Output ({output.shape}) and label ({label.shape}) must have the same shape") 12 | if (output_dtype := output.dtype) != label.dtype or dtype and output_dtype != dtype: 13 | raise TypeError(f"Output ({output_dtype}) and label ({label.dtype}) must both be {dtype}") 14 | if (output_device := output.device) != label.device: 15 | raise RuntimeError(f"Output ({output.device}) and label ({label.device}) must be on the same device") 16 | if device and output_device != device: 17 | raise RuntimeError(f"Tensors are expected to be on {device}, but instead they are on {output.device}") 18 | return output_dtype, output_device 19 | 20 | 21 | class Metric(Protocol): 22 | def __call__(self, output: torch.Tensor, label: torch.Tensor, *, if_empty: float = ...) -> torch.Tensor: ... 23 | 24 | 25 | def do_reduction(x: torch.Tensor, method: Literal["mean", "median", "sum", "none"] = "mean") -> torch.Tensor: 26 | match method: 27 | case "mean": 28 | return x.mean() 29 | case "median": 30 | return x.median() 31 | case "sum": 32 | return x.sum() 33 | case "none": 34 | return x 35 | 36 | 37 | def apply_multiclass_to_binary(metric: Metric, output: torch.Tensor, label: torch.Tensor, num_classes: int | None, 38 | if_empty: float, *, reduction: Literal["mean", "sum"] = "mean") -> torch.Tensor: 39 | _args_check(output, label, dtype=torch.int) 40 | if not num_classes: 41 | num_classes = max(output.max().item(), label.max().item()) 42 | if num_classes == 0: 43 | return torch.tensor(if_empty, dtype=torch.float) 44 | else: 45 | x = torch.tensor([metric(output == cls, label == cls, if_empty=if_empty) for cls in range(1, num_classes + 1)]) 46 | return do_reduction(x, reduction) 47 | 48 | 49 | def dice_similarity_coefficient_binary(output: torch.Tensor, label: torch.Tensor, *, 50 | if_empty: float = 1) -> torch.Tensor: 51 | _args_check(output, label, dtype=torch.bool) 52 | volume_sum = output.sum() + label.sum() 53 | if volume_sum == 0: 54 | return torch.tensor(if_empty, dtype=torch.float) 55 | return 2 * (output & label).sum() / volume_sum 56 | 57 | 58 | def dice_similarity_coefficient_multiclass(output: torch.Tensor, label: torch.Tensor, *, num_classes: int | None = None, 59 | if_empty: float = 1) -> torch.Tensor: 60 | return apply_multiclass_to_binary(dice_similarity_coefficient_binary, output, label, num_classes, if_empty) 61 | 62 | 63 | def soft_dice_coefficient(output: torch.Tensor, label: torch.Tensor, *, 64 | smooth: float = 1e-5, include_bg: bool = True) -> torch.Tensor: 65 | _args_check(output, label) 66 | axes = tuple(range(2, output.ndim)) 67 | intersection = (output * label).sum(dim=axes) 68 | dice = (2 * intersection + smooth) / (output.sum(dim=axes) + label.sum(dim=axes) + smooth) 69 | if not include_bg: 70 | dice = dice[:, 1:] 71 | return dice.mean() 72 | 73 | 74 | def accuracy_binary(output: torch.Tensor, label: torch.Tensor, *, if_empty: float = 1) -> torch.Tensor: 75 | _args_check(output, label, dtype=torch.bool) 76 | numerator = (output & label).sum() + (~output & ~label).sum() 77 | denominator = numerator + (output & ~label).sum() + (label & ~output).sum() 78 | return torch.tensor(if_empty, dtype=torch.float) if denominator == 0 else numerator / denominator 79 | 80 | 81 | def accuracy_multiclass(output: torch.Tensor, label: torch.Tensor, *, num_classes: int | None = None, 82 | if_empty: float = 1) -> torch.Tensor: 83 | return apply_multiclass_to_binary(accuracy_binary, output, label, num_classes, if_empty) 84 | 85 | 86 | def _precision_or_recall(output: torch.Tensor, label: torch.Tensor, if_empty: float, 87 | is_precision: bool) -> torch.Tensor: 88 | _args_check(output, label, dtype=torch.bool) 89 | tp = (output & label).sum() 90 | denominator = output.sum() if is_precision else label.sum() 91 | return torch.tensor(if_empty, dtype=torch.float) if denominator == 0 else tp / denominator 92 | 93 | 94 | def precision_binary(output: torch.Tensor, label: torch.Tensor, *, if_empty: float = 1) -> torch.Tensor: 95 | return _precision_or_recall(output, label, if_empty, True) 96 | 97 | 98 | def precision_multiclass(output: torch.Tensor, label: torch.Tensor, *, num_classes: int | None = None, 99 | if_empty: float = 1) -> torch.Tensor: 100 | return apply_multiclass_to_binary(precision_binary, output, label, num_classes, if_empty) 101 | 102 | 103 | def recall_binary(output: torch.Tensor, label: torch.Tensor, *, if_empty: float = 1) -> torch.Tensor: 104 | return _precision_or_recall(output, label, if_empty, False) 105 | 106 | 107 | def recall_multiclass(output: torch.Tensor, label: torch.Tensor, *, num_classes: int | None = None, 108 | if_empty: float = 1) -> torch.Tensor: 109 | return apply_multiclass_to_binary(recall_binary, output, label, num_classes, if_empty) 110 | 111 | 112 | def iou_binary(output: torch.Tensor, label: torch.Tensor, *, if_empty: float = 1) -> torch.Tensor: 113 | _args_check(output, label, dtype=torch.bool) 114 | denominator = (output | label).sum() 115 | return torch.tensor(if_empty, dtype=torch.float) if denominator == 0 else (output & label).sum() / denominator 116 | 117 | 118 | def iou_multiclass(output: torch.Tensor, label: torch.Tensor, *, num_classes: int | None = None, 119 | if_empty: float = 1) -> torch.Tensor: 120 | return apply_multiclass_to_binary(iou_binary, output, label, num_classes, if_empty) 121 | -------------------------------------------------------------------------------- /mipcandy/inference.py: -------------------------------------------------------------------------------- 1 | from abc import ABCMeta 2 | from math import log, ceil 3 | from os import PathLike, listdir 4 | from os.path import isdir, basename, exists 5 | from typing import Sequence, override 6 | 7 | import torch 8 | from torch import nn 9 | 10 | from mipcandy.common import Pad2d, Pad3d, Restore2d, Restore3d 11 | from mipcandy.data import save_image, Loader, UnsupervisedDataset, PathBasedUnsupervisedDataset 12 | from mipcandy.layer import WithPaddingModule, WithNetwork 13 | from mipcandy.sliding_window import SlidingWindow 14 | from mipcandy.types import SupportedPredictant, Device, AmbiguousShape 15 | 16 | 17 | def parse_predictant(x: SupportedPredictant, loader: type[Loader], *, as_label: bool = False) -> tuple[list[ 18 | torch.Tensor], list[str] | None]: 19 | if isinstance(x, str): 20 | if isdir(x): 21 | cases = listdir(x) 22 | return [loader.do_load(f"{x}/{case}", is_label=as_label) for case in cases], cases 23 | return [loader.do_load(x, is_label=as_label)], [basename(x)] 24 | if isinstance(x, torch.Tensor): 25 | return [x], None 26 | r, filenames = [], None 27 | for case in x: 28 | if isinstance(case, str): 29 | if not filenames: 30 | filenames = [] 31 | r.append(loader.do_load(case, is_label=as_label)) 32 | filenames.append(case[case.rfind("/") + 1:]) 33 | elif filenames: 34 | raise TypeError("`x` should be single-typed") 35 | elif isinstance(case, torch.Tensor): 36 | r.append(case) 37 | else: 38 | raise TypeError(f"Unexpected type of element {type(case)}") 39 | return r, filenames 40 | 41 | 42 | class Predictor(WithPaddingModule, WithNetwork, metaclass=ABCMeta): 43 | def __init__(self, experiment_folder: str | PathLike[str], example_shape: AmbiguousShape, *, 44 | checkpoint: str = "checkpoint_best.pth", device: Device = "cpu") -> None: 45 | WithPaddingModule.__init__(self, device) 46 | WithNetwork.__init__(self, device) 47 | self._experiment_folder: str = experiment_folder 48 | self._example_shape: AmbiguousShape = example_shape 49 | self._checkpoint: str = checkpoint 50 | self._model: nn.Module | None = None 51 | 52 | def lazy_load_model(self) -> None: 53 | if self._model: 54 | return 55 | self._model = self.load_model(self._example_shape, checkpoint=torch.load( 56 | f"{self._experiment_folder}/{self._checkpoint}" 57 | )) 58 | self._model.eval() 59 | 60 | def predict_image(self, image: torch.Tensor, *, batch: bool = False) -> torch.Tensor: 61 | self.lazy_load_model() 62 | image = image.to(self._device) 63 | if not batch: 64 | image = image.unsqueeze(0) 65 | padding_module = self.get_padding_module() 66 | if padding_module: 67 | image = padding_module(image) 68 | output = self._model(image) 69 | restoring_module = self.get_restoring_module() 70 | if restoring_module: 71 | output = restoring_module(output) 72 | return output if batch else output.squeeze(0) 73 | 74 | def _predict(self, x: SupportedPredictant | UnsupervisedDataset) -> tuple[list[torch.Tensor], list[str] | None]: 75 | if isinstance(x, PathBasedUnsupervisedDataset): 76 | return [self.predict_image(case) for case in x], x.paths() 77 | if isinstance(x, UnsupervisedDataset): 78 | return [self.predict_image(case) for case in x], None 79 | images, filenames = parse_predictant(x, Loader) 80 | return [self.predict_image(image) for image in images], filenames 81 | 82 | def predict(self, x: SupportedPredictant | UnsupervisedDataset) -> list[torch.Tensor]: 83 | return self._predict(x)[0] 84 | 85 | @staticmethod 86 | def save_prediction(output: torch.Tensor, path: str | PathLike[str]) -> None: 87 | save_image(output, path) 88 | 89 | def save_predictions(self, outputs: Sequence[torch.Tensor], folder: str | PathLike[str], *, 90 | filenames: Sequence[str | PathLike[str]] | None = None) -> None: 91 | if not exists(folder): 92 | raise FileNotFoundError(f"Folder {folder} does not exist") 93 | if not filenames: 94 | num_digits = ceil(log(len(outputs))) 95 | filenames = [f"prediction_{str(i).zfill(num_digits)}.{ 96 | "png" if output.ndim == 3 and output.shape[0] in (1, 3) else "mha"}" for i, output in enumerate(outputs)] 97 | for i, prediction in enumerate(outputs): 98 | self.save_prediction(prediction, f"{folder}/{filenames[i]}") 99 | 100 | def predict_to_files(self, x: SupportedPredictant | UnsupervisedDataset, 101 | folder: str | PathLike[str]) -> list[str] | None: 102 | outputs, filenames = self._predict(x) 103 | self.save_predictions(outputs, folder, filenames=filenames) 104 | return filenames 105 | 106 | def __call__(self, x: SupportedPredictant | UnsupervisedDataset) -> list[torch.Tensor]: 107 | return self.predict(x) 108 | 109 | 110 | class SlidingPredictor(Predictor, SlidingWindow, metaclass=ABCMeta): 111 | @override 112 | def build_padding_module(self) -> nn.Module | None: 113 | window_shape = self.get_window_shape() 114 | return (Pad2d if len(window_shape) == 2 else Pad3d)(window_shape) 115 | 116 | @override 117 | def build_restoring_module(self, padding_module: nn.Module | None) -> nn.Module | None: 118 | if not isinstance(padding_module, (Pad2d, Pad3d)): 119 | raise TypeError("`padding_module` should be either `Pad2d` or `Pad3d`") 120 | window_shape = self.get_window_shape() 121 | return (Restore2d if len(window_shape) == 2 else Restore3d)(padding_module) 122 | 123 | @override 124 | def predict_image(self, image: torch.Tensor, *, batch: bool = False) -> torch.Tensor: 125 | if not batch: 126 | image = image.unsqueeze(0) 127 | images, metadata = self.do_sliding_window(image) 128 | outputs = super().predict_image(images, batch=True) 129 | outputs = self.revert_sliding_window(outputs, metadata) 130 | return outputs if batch else outputs.squeeze(0) 131 | -------------------------------------------------------------------------------- /mipcandy/frontend/notion_fe.py: -------------------------------------------------------------------------------- 1 | from datetime import datetime 2 | from typing import override, Literal 3 | 4 | from requests import get, post, patch, Response 5 | 6 | from mipcandy.frontend.prototype import Frontend 7 | from mipcandy.types import Settings 8 | 9 | 10 | class NotionFrontend(Frontend): 11 | def __init__(self, secrets: Settings) -> None: 12 | super().__init__(secrets) 13 | self._api_key: str = self.require_nonempty_secret("notion_api_key", require_type=str) 14 | self._database_id: str = self.require_nonempty_secret("notion_database_id", require_type=str) 15 | self._headers: dict[str, str] = { 16 | "Authorization": f"Bearer {self._api_key}", 17 | "Content-Type": "application/json", 18 | "Notion-Version": "2022-06-28" 19 | } 20 | self._num_epochs: int = 1 21 | self._early_stop_tolerance: int = -1 22 | self._start_time: str = "" 23 | self._page_id: str = "" 24 | 25 | def retrieve_database(self) -> Response: 26 | return get(f"https://api.notion.com/v1/databases/{self._database_id}", headers=self._headers) 27 | 28 | def query_database(self, *, experiment_id: str | None = None) -> Response: 29 | json = {"filter": {"property": "Experiment ID", "title": {"equals": experiment_id}}} if experiment_id else None 30 | return post(f"https://api.notion.com/v1/databases/{self._database_id}/query", json=json, headers=self._headers) 31 | 32 | def select_experiment(self, experiment_id: str) -> str: 33 | experiments = self.query_database(experiment_id=experiment_id) 34 | if experiments.status_code != 200: 35 | raise RuntimeError(f"Failed to query database: {experiments.json()}") 36 | experiments = experiments.json()["results"] 37 | if len(experiments) == 1: 38 | return experiments[0]["id"] 39 | if len(experiments) > 1: 40 | raise RuntimeError(f"Found multiple experiments with the same ID {experiment_id}") 41 | return "" 42 | 43 | def new_experiment(self, experiment_id: str, trainer: str, model: str, note: str, num_macs: float, 44 | num_params: float) -> Response: 45 | self._start_time = datetime.now().astimezone().strftime("%Y-%m-%dT%H:%M:%S.000%z") 46 | properties = { 47 | "Experiment ID": {"title": [{"text": {"content": experiment_id}}]}, 48 | "Status": {"status": {"name": "In Progress"}}, 49 | "Progress": {"number": 0}, 50 | "Early Stop": {"number": 1}, 51 | "Trainer": {"select": {"name": trainer}}, 52 | "Model": {"select": {"name": model}}, 53 | "Time": {"date": {"start": self._start_time}}, 54 | "Note": {"rich_text": [{"text": {"content": note}}]}, 55 | "MACs (G)": {"number": round(num_macs, 1)}, 56 | "Params (M)": {"number": round(num_params, 1)}, 57 | "Epoch": {"number": 0}, 58 | "Score": {"number": 0}, 59 | } 60 | page_id = self.select_experiment(experiment_id) 61 | if page_id: 62 | self._page_id = page_id 63 | return patch(f"https://api.notion.com/v1/pages/{page_id}", json={"properties": properties}, 64 | headers=self._headers) 65 | res = post("https://api.notion.com/v1/pages", json={ 66 | "parent": {"database_id": self._database_id}, 67 | "icon": {"external": {"url": "https://www.notion.so/icons/science_gray.svg"}}, 68 | "properties": properties 69 | }, headers=self._headers) 70 | self._page_id = res.json()["id"] 71 | return res 72 | 73 | def update_experiment(self, experiment_id: str, status: Literal["In Progress", "Completed", "Interrupted"], 74 | *, epoch: int | None = None, score: float | None = None, 75 | early_stop_tolerance: int | None = None, observation: str | None = None) -> Response: 76 | if not self._page_id: 77 | raise RuntimeError(f"Experiment {experiment_id} has not been created") 78 | properties = {"Status": {"status": {"name": status}}} 79 | if epoch is not None: 80 | properties["Progress"] = {"number": epoch / self._num_epochs} 81 | properties["Epoch"] = {"number": epoch} 82 | if early_stop_tolerance is not None: 83 | properties["Early Stop"] = {"number": max(early_stop_tolerance, 0) / self._early_stop_tolerance} 84 | if score is not None: 85 | properties["Score"] = {"number": round(score, 4)} 86 | if observation is not None: 87 | properties["Observation"] = {"rich_text": [{"text": {"content": observation}}]} 88 | if status == "Completed": 89 | properties["Progress"] = {"number": 1} 90 | properties["Time"] = {"date": {"start": self._start_time, 91 | "end": datetime.now().astimezone().strftime("%Y-%m-%dT%H:%M:%S.000%z")}} 92 | return patch(f"https://api.notion.com/v1/pages/{self._page_id}", json={"properties": properties}, 93 | headers=self._headers) 94 | 95 | @override 96 | def on_experiment_created(self, experiment_id: str, trainer: str, model: str, note: str, num_macs: float, 97 | num_params: float, num_epochs: int, early_stop_tolerance: int) -> None: 98 | self._num_epochs = num_epochs 99 | self._early_stop_tolerance = early_stop_tolerance 100 | res = self.new_experiment(experiment_id, trainer, model, note, num_macs * 1e-9, num_params * 1e-6) 101 | if res.status_code != 200: 102 | raise RuntimeError(f"Failed to create experiment: {res.json()}") 103 | 104 | @override 105 | def on_experiment_updated(self, experiment_id: str, epoch: int, metrics: dict[str, list[float]], 106 | early_stop_tolerance: int) -> None: 107 | try: 108 | self.update_experiment(experiment_id, "In Progress", epoch=epoch, score=max(metrics["val score"]), 109 | early_stop_tolerance=early_stop_tolerance) 110 | except RuntimeError: 111 | pass 112 | 113 | @override 114 | def on_experiment_completed(self, experiment_id: str) -> None: 115 | res = self.update_experiment(experiment_id, "Completed") 116 | if res.status_code != 200: 117 | raise RuntimeError(f"Failed to update experiment: {res.json()}") 118 | 119 | @override 120 | def on_experiment_interrupted(self, experiment_id: str, error: Exception) -> None: 121 | res = self.update_experiment(experiment_id, "Interrupted", observation=repr(error)) 122 | if res.status_code != 200: 123 | raise RuntimeError(f"Failed to update experiment: {res.json()}") 124 | -------------------------------------------------------------------------------- /mipcandy/common/module/preprocess.py: -------------------------------------------------------------------------------- 1 | from math import ceil 2 | from typing import Literal 3 | 4 | import torch 5 | from torch import nn 6 | 7 | from mipcandy.types import Colormap, Shape2d, Shape3d 8 | 9 | 10 | class Pad(nn.Module): 11 | def __init__(self, *, value: int = 0, mode: str = "constant", batch: bool = True) -> None: 12 | super().__init__() 13 | self._value: int = value 14 | self._mode: str = mode 15 | self.batch: bool = batch 16 | self._paddings: tuple[int, int, int, int, int, int] | tuple[int, int, int, int] | None = None 17 | self.requires_grad_(False) 18 | 19 | @staticmethod 20 | def _c_t(size: int, min_factor: int) -> int: 21 | """ 22 | Compute target on a single dimension 23 | """ 24 | return ceil(size / min_factor) * min_factor 25 | 26 | @staticmethod 27 | def _c_p(size: int, min_factor: int) -> tuple[int, int]: 28 | """ 29 | Compute padding on a single dimension 30 | """ 31 | excess = Pad._c_t(size, min_factor) - size 32 | before = excess // 2 33 | return before, excess - before 34 | 35 | 36 | class Pad2d(Pad): 37 | def __init__(self, min_factor: int | Shape2d, *, value: int = 0, mode: str = "constant", 38 | batch: bool = True) -> None: 39 | super().__init__(value=value, mode=mode, batch=batch) 40 | self._min_factor: Shape2d = (min_factor,) * 2 if isinstance(min_factor, int) else min_factor 41 | 42 | def paddings(self) -> tuple[int, int, int, int] | None: 43 | return self._paddings 44 | 45 | def padded_shape(self, in_shape: tuple[int, int, ...]) -> tuple[int, int, ...]: 46 | return *in_shape[:-2], self._c_t(in_shape[-2], self._min_factor[0]), self._c_t( 47 | in_shape[-1], self._min_factor[1]) 48 | 49 | def forward(self, x: torch.Tensor) -> torch.Tensor: 50 | if self.batch: 51 | _, _, h, w = x.shape 52 | suffix = (0,) * 4 53 | else: 54 | _, h, w = x.shape 55 | suffix = (0,) * 2 56 | self._paddings = self._c_p(h, self._min_factor[0]) + self._c_p(w, self._min_factor[1]) 57 | return nn.functional.pad(x, self._paddings[::-1] + suffix, self._mode, self._value) 58 | 59 | 60 | class Pad3d(Pad): 61 | def __init__(self, min_factor: int | Shape3d, *, value: int = 0, mode: str = "constant", 62 | batch: bool = True) -> None: 63 | super().__init__(value=value, mode=mode, batch=batch) 64 | self._min_factor: Shape3d = (min_factor,) * 3 if isinstance(min_factor, int) else min_factor 65 | 66 | def paddings(self) -> tuple[int, int, int, int, int, int] | None: 67 | return self._paddings 68 | 69 | def padded_shape(self, in_shape: tuple[int, int, int, ...]) -> tuple[int, int, int, ...]: 70 | return (*in_shape[:-3], self._c_t(in_shape[-3], self._min_factor[0]), self._c_t( 71 | in_shape[-2], self._min_factor[1]), self._c_t(in_shape[-1], self._min_factor[2])) 72 | 73 | def forward(self, x: torch.Tensor) -> torch.Tensor: 74 | if self.batch: 75 | _, _, d, h, w = x.shape 76 | suffix = (0,) * 4 77 | else: 78 | _, d, h, w = x.shape 79 | suffix = (0,) * 2 80 | self._paddings = self._c_p(d, self._min_factor[0]) + self._c_p(h, self._min_factor[1]) + self._c_p( 81 | w, self._min_factor[2]) 82 | return nn.functional.pad(x, self._paddings[::-1] + suffix, self._mode, self._value) 83 | 84 | 85 | class Restore2d(nn.Module): 86 | def __init__(self, conjugate_padding: Pad2d) -> None: 87 | super().__init__() 88 | self.conjugate_padding: Pad2d = conjugate_padding 89 | self.requires_grad_(False) 90 | 91 | def forward(self, x: torch.Tensor) -> torch.Tensor: 92 | paddings = self.conjugate_padding.paddings() 93 | if not paddings: 94 | raise ValueError("Paddings are not set yet, did you forget to pad before restoring?") 95 | pad_h0, pad_h1, pad_w0, pad_w1 = paddings 96 | if self.conjugate_padding.batch: 97 | _, _, h, w = x.shape 98 | return x[:, :, pad_h0: h - pad_h1, pad_w0: w - pad_w1] 99 | _, h, w = x.shape 100 | return x[:, pad_h0: h - pad_h1, pad_w0: w - pad_w1] 101 | 102 | 103 | class Restore3d(nn.Module): 104 | def __init__(self, conjugate_padding: Pad3d) -> None: 105 | super().__init__() 106 | self.conjugate_padding: Pad3d = conjugate_padding 107 | self.requires_grad_(False) 108 | 109 | def forward(self, x: torch.Tensor) -> torch.Tensor: 110 | paddings = self.conjugate_padding.paddings() 111 | if not paddings: 112 | raise ValueError("Paddings are not set yet, did you forget to pad before restoring?") 113 | pad_d0, pad_d1, pad_h0, pad_h1, pad_w0, pad_w1 = paddings 114 | if self.conjugate_padding.batch: 115 | _, _, d, h, w = x.shape 116 | return x[:, :, pad_d0: d - pad_d1, pad_h0: h - pad_h1, pad_w0: w - pad_w1] 117 | _, d, h, w = x.shape 118 | return x[:, pad_d0: d - pad_d1, pad_h0: h - pad_h1, pad_w0: w - pad_w1] 119 | 120 | 121 | class Normalize(nn.Module): 122 | def __init__(self, *, domain: tuple[float | None, float | None] = (0, None), strict: bool = False, 123 | method: Literal["linear", "intercept", "cut"] = "linear") -> None: 124 | super().__init__() 125 | self._domain: tuple[float | None, float | None] = domain 126 | self._strict: bool = strict 127 | self._method: Literal["linear", "intercept", "cut"] = method 128 | 129 | def forward(self, x: torch.Tensor) -> torch.Tensor: 130 | left, right = self._domain 131 | if left is None and right is None: 132 | return x 133 | r_l, r_r = x.min(), x.max() 134 | match self._method: 135 | case "linear": 136 | if left is None or (left < r_l and not self._strict): 137 | left = r_l 138 | if right is None or (right > r_r and not self._strict): 139 | right = r_r 140 | numerator = right - left 141 | if numerator == 0: 142 | numerator = 1 143 | denominator = r_r - r_l 144 | if denominator == 0: 145 | denominator = 1 146 | return (x - r_l) * numerator / denominator + left 147 | case "intercept": 148 | if left is not None and right is None: 149 | return x - r_l + left if r_l < left or self._strict else x 150 | elif left is None and right is not None: 151 | return x - r_r + right if r_r > right or self._strict else x 152 | else: 153 | raise ValueError("Cannot use intercept normalization when both ends are fixed") 154 | case "cut": 155 | if self._strict: 156 | raise ValueError("Method \"cut\" cannot be strict") 157 | if left is not None: 158 | x[x < left] = left 159 | if right is not None: 160 | x[x > right] = right 161 | return x 162 | 163 | 164 | class ColorizeLabel(nn.Module): 165 | def __init__(self, *, colormap: Colormap | None = None) -> None: 166 | super().__init__() 167 | if not colormap: 168 | colormap = [] 169 | for r in range(8): 170 | for g in range(8): 171 | for b in range(32): 172 | colormap.append([r * 32, g * 32, 255 - b * 32]) 173 | self._colormap: torch.Tensor = torch.tensor(colormap) 174 | 175 | def forward(self, x: torch.Tensor) -> torch.Tensor: 176 | cmap = self._colormap.to(x.device) 177 | return ( 178 | torch.cat([cmap[(x > 0).int()].permute(2, 0, 1), x.unsqueeze(0)]) if 0 <= x.min() < x.max() <= 1 179 | else cmap[x.int()].permute(2, 0, 1) 180 | ) 181 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. -------------------------------------------------------------------------------- /mipcandy/data/dataset.py: -------------------------------------------------------------------------------- 1 | from abc import ABCMeta, abstractmethod 2 | from json import dump 3 | from os import PathLike, listdir, makedirs 4 | from os.path import exists 5 | from random import choices 6 | from shutil import copy2 7 | from typing import Literal, override, Self, Sequence, TypeVar, Generic, Any 8 | 9 | import torch 10 | from pandas import DataFrame 11 | from torch.utils.data import Dataset 12 | 13 | from mipcandy.data.io import load_image 14 | from mipcandy.layer import HasDevice 15 | from mipcandy.types import Transform, Device 16 | 17 | 18 | class KFPicker(object, metaclass=ABCMeta): 19 | @staticmethod 20 | @abstractmethod 21 | def pick(n: int, fold: Literal[0, 1, 2, 3, 4, "all"]) -> tuple[int, ...]: 22 | raise NotImplementedError 23 | 24 | 25 | class OrderedKFPicker(KFPicker): 26 | @staticmethod 27 | @override 28 | def pick(n: int, fold: Literal[0, 1, 2, 3, 4, "all"]) -> tuple[int, ...]: 29 | if fold == "all": 30 | return tuple(range(0, n, 4)) 31 | size = n // 5 32 | return tuple(range(size * fold, size * (fold + 1))) 33 | 34 | 35 | class RandomKFPicker(OrderedKFPicker): 36 | @staticmethod 37 | @override 38 | def pick(n: int, fold: Literal[0, 1, 2, 3, 4, "all"]) -> tuple[int, ...]: 39 | return tuple(choices(range(n), k=n // 5)) if fold == "all" else super().pick(n, fold) 40 | 41 | 42 | class Loader(object): 43 | @staticmethod 44 | def do_load(path: str | PathLike[str], *, is_label: bool = False, device: Device = "cpu", **kwargs) -> torch.Tensor: 45 | return load_image(path, is_label=is_label, device=device, **kwargs) 46 | 47 | 48 | T = TypeVar("T") 49 | 50 | 51 | class _AbstractDataset(Dataset, Loader, HasDevice, Generic[T], Sequence[T], metaclass=ABCMeta): 52 | @abstractmethod 53 | def load(self, idx: int) -> T: 54 | raise NotImplementedError 55 | 56 | @override 57 | def __getitem__(self, idx: int) -> T: 58 | return self.load(idx) 59 | 60 | 61 | D = TypeVar("D", bound=Sequence[Any]) 62 | 63 | 64 | class UnsupervisedDataset(_AbstractDataset[torch.Tensor], Generic[D], metaclass=ABCMeta): 65 | """ 66 | Do not use this as a generic class. Only parameterize it if you are inheriting from it. 67 | """ 68 | 69 | def __init__(self, images: D, *, device: Device = "cpu") -> None: 70 | super().__init__(device) 71 | self._images: D = images 72 | 73 | @override 74 | def __len__(self) -> int: 75 | return len(self._images) 76 | 77 | 78 | class SupervisedDataset(_AbstractDataset[tuple[torch.Tensor, torch.Tensor]], Generic[D], metaclass=ABCMeta): 79 | """ 80 | Do not use this as a generic class. Only parameterize it if you are inheriting from it. 81 | """ 82 | 83 | def __init__(self, images: D, labels: D, *, device: Device = "cpu") -> None: 84 | super().__init__(device) 85 | if len(images) != len(labels): 86 | raise ValueError(f"Unmatched number of images {len(images)} and labels {len(labels)}") 87 | self._images: D = images 88 | self._labels: D = labels 89 | 90 | @override 91 | def __len__(self) -> int: 92 | return len(self._images) 93 | 94 | @abstractmethod 95 | def construct_new(self, images: D, labels: D) -> Self: 96 | raise NotImplementedError 97 | 98 | def fold(self, *, fold: Literal[0, 1, 2, 3, 4, "all"] = "all", picker: type[KFPicker] = OrderedKFPicker) -> tuple[ 99 | Self, Self]: 100 | indexes = picker.pick(len(self), fold) 101 | images_train = [] 102 | labels_train = [] 103 | images_val = [] 104 | labels_val = [] 105 | for i in range(len(self)): 106 | if i in indexes: 107 | images_val.append(self._images[i]) 108 | labels_val.append(self._labels[i]) 109 | else: 110 | images_train.append(self._images[i]) 111 | labels_train.append(self._labels[i]) 112 | return self.construct_new(images_train, labels_train), self.construct_new(images_val, labels_val) 113 | 114 | 115 | class DatasetFromMemory(UnsupervisedDataset[Sequence[torch.Tensor]]): 116 | def __init__(self, images: Sequence[torch.Tensor], device: Device = "cpu") -> None: 117 | super().__init__(images, device=device) 118 | 119 | @override 120 | def load(self, idx: int) -> torch.Tensor: 121 | return self._images[idx].to(self._device) 122 | 123 | 124 | class MergedDataset(SupervisedDataset[UnsupervisedDataset]): 125 | def __init__(self, images: UnsupervisedDataset, labels: UnsupervisedDataset, *, device: Device = "cpu") -> None: 126 | super().__init__(images, labels, device=device) 127 | 128 | @override 129 | def load(self, idx: int) -> tuple[torch.Tensor, torch.Tensor]: 130 | return self._images[idx].to(self._device), self._labels[idx].to(self._device) 131 | 132 | @override 133 | def construct_new(self, images: UnsupervisedDataset, labels: UnsupervisedDataset) -> Self: 134 | return MergedDataset(DatasetFromMemory(images), DatasetFromMemory(labels), device=self._device) 135 | 136 | 137 | class ComposeDataset(_AbstractDataset[tuple[torch.Tensor, torch.Tensor] | torch.Tensor]): 138 | def __init__(self, bases: Sequence[SupervisedDataset] | Sequence[UnsupervisedDataset], *, 139 | device: Device = "cpu") -> None: 140 | super().__init__(device) 141 | self._bases: dict[tuple[int, int], SupervisedDataset | UnsupervisedDataset] = {} 142 | self._len = 0 143 | for dataset in bases: 144 | end = len(dataset) 145 | self._bases[(self._len, self._len + end)] = dataset 146 | self._len += end 147 | 148 | @override 149 | def load(self, idx: int) -> tuple[torch.Tensor, torch.Tensor] | torch.Tensor: 150 | for (start, end), base in self._bases.items(): 151 | if start <= idx < end: 152 | return base.load(idx - start) 153 | raise IndexError(f"Index {idx} out of range [0, {self._len})") 154 | 155 | @override 156 | def __len__(self) -> int: 157 | return self._len 158 | 159 | 160 | class PathBasedUnsupervisedDataset(UnsupervisedDataset[list[str]], metaclass=ABCMeta): 161 | def paths(self) -> list[str]: 162 | return self._images 163 | 164 | def save_paths(self, to: str | PathLike[str]) -> None: 165 | match (fmt := to.split(".")[-1]): 166 | case "csv": 167 | df = DataFrame([{"image": image_path} for image_path in self.paths()]) 168 | df.index = range(len(df)) 169 | df.index.name = "case" 170 | df.to_csv(to) 171 | case "json": 172 | with open(to, "w") as f: 173 | dump([{"image": image_path} for image_path in self.paths()], f) 174 | case "txt": 175 | with open(to, "w") as f: 176 | for image_path in self.paths(): 177 | f.write(f"{image_path}\n") 178 | case _: 179 | raise ValueError(f"Unsupported file extension: {fmt}") 180 | 181 | 182 | class SimpleDataset(PathBasedUnsupervisedDataset): 183 | def __init__(self, folder: str | PathLike[str], *, device: Device = "cpu") -> None: 184 | images = listdir(folder) 185 | images.sort() 186 | super().__init__(images, device=device) 187 | self._folder: str = folder 188 | 189 | @override 190 | def load(self, idx: int) -> torch.Tensor: 191 | return self.do_load(f"{self._folder}/{self._images[idx]}", device=self._device) 192 | 193 | 194 | class PathBasedSupervisedDataset(SupervisedDataset[list[str]], metaclass=ABCMeta): 195 | def paths(self) -> list[tuple[str, str]]: 196 | return [(self._images[i], self._labels[i]) for i in range(len(self))] 197 | 198 | def save_paths(self, to: str | PathLike[str]) -> None: 199 | match (fmt := to.split(".")[-1]): 200 | case "csv": 201 | df = DataFrame([{"image": image_path, "label": label_path} for image_path, label_path in self.paths()]) 202 | df.index = range(len(df)) 203 | df.index.name = "case" 204 | df.to_csv(to) 205 | case "json": 206 | with open(to, "w") as f: 207 | dump([{"image": image_path, "label": label_path} for image_path, label_path in self.paths()], f) 208 | case "txt": 209 | with open(to, "w") as f: 210 | for image_path, label_path in self.paths(): 211 | f.write(f"{image_path}\t{label_path}\n") 212 | case _: 213 | raise ValueError(f"Unsupported file extension: {fmt}") 214 | 215 | 216 | class NNUNetDataset(PathBasedSupervisedDataset): 217 | def __init__(self, folder: str | PathLike[str], *, split: str | Literal["Tr", "Ts"] = "Tr", prefix: str = "", 218 | align_spacing: bool = False, image_transform: Transform | None = None, 219 | label_transform: Transform | None = None, device: Device = "cpu") -> None: 220 | images: list[str] = [f for f in listdir(f"{folder}/images{split}") if f.startswith(prefix)] 221 | images.sort() 222 | labels: list[str] = [f for f in listdir(f"{folder}/labels{split}") if f.startswith(prefix)] 223 | labels.sort() 224 | self._multimodal_images: list[list[str]] = [] 225 | if len(images) == len(labels): 226 | super().__init__(images, labels, device=device) 227 | else: 228 | super().__init__([""] * len(labels), labels, device=device) 229 | current_case = "" 230 | for image in images: 231 | case = image[:image.rfind("_")] 232 | if case != current_case: 233 | self._multimodal_images.append([]) 234 | current_case = case 235 | self._multimodal_images[-1].append(image) 236 | if len(self._multimodal_images) != len(self._labels): 237 | raise ValueError("Unmatched number of images and labels") 238 | self._folder: str = folder 239 | self._split: str = split 240 | self._folded: bool = False 241 | self._prefix: str = prefix 242 | self._align_spacing: bool = align_spacing 243 | self._image_transform: Transform | None = image_transform 244 | self._label_transform: Transform | None = label_transform 245 | 246 | @staticmethod 247 | def _create_subset(folder: str) -> None: 248 | if exists(folder) and len(listdir(folder)) > 0: 249 | raise FileExistsError(f"{folder} already exists and is not empty") 250 | makedirs(folder, exist_ok=True) 251 | 252 | @override 253 | def load(self, idx: int) -> tuple[torch.Tensor, torch.Tensor]: 254 | image = torch.cat([self.do_load( 255 | f"{self._folder}/images{self._split}/{path}", align_spacing=self._align_spacing, device=self._device 256 | ) for path in self._multimodal_images[idx]]) if self._multimodal_images else self.do_load( 257 | f"{self._folder}/images{self._split}/{self._images[idx]}", align_spacing=self._align_spacing, 258 | device=self._device 259 | ) 260 | label = self.do_load( 261 | f"{self._folder}/labels{self._split}/{self._labels[idx]}", is_label=True, align_spacing=self._align_spacing, 262 | device=self._device 263 | ) 264 | if self._image_transform: 265 | image = self._image_transform(image) 266 | if self._label_transform: 267 | label = self._label_transform(label) 268 | return image, label 269 | 270 | def save(self, split: str | Literal["Tr", "Ts"], *, target_folder: str | PathLike[str] | None = None) -> None: 271 | target_base = target_folder if target_folder else self._folder 272 | images_target = f"{target_base}/images{split}" 273 | labels_target = f"{target_base}/labels{split}" 274 | self._create_subset(images_target) 275 | self._create_subset(labels_target) 276 | for image_path, label_path in self.paths(): 277 | copy2(f"{self._folder}/images{self._split}/{image_path}", f"{images_target}/{image_path}") 278 | copy2(f"{self._folder}/labels{self._split}/{label_path}", f"{labels_target}/{label_path}") 279 | self._split = split 280 | self._folded = False 281 | 282 | @override 283 | def construct_new(self, images: list[str], labels: list[str]) -> Self: 284 | if self._folded: 285 | raise ValueError("Cannot construct a new dataset from a fold") 286 | new = self.__class__(self._folder, split=self._split, prefix=self._prefix, align_spacing=self._align_spacing, 287 | image_transform=self._image_transform, label_transform=self._label_transform, 288 | device=self._device) 289 | new._images = images 290 | new._labels = labels 291 | new._folded = True 292 | return new 293 | 294 | 295 | class BinarizedDataset(SupervisedDataset[D]): 296 | def __init__(self, base: SupervisedDataset[D], positive_ids: tuple[int, ...]) -> None: 297 | super().__init__(base._images, base._labels) 298 | self._base: SupervisedDataset[D] = base 299 | self._positive_ids: tuple[int, ...] = positive_ids 300 | 301 | @override 302 | def construct_new(self, images: D, labels: D) -> Self: 303 | raise NotImplementedError 304 | 305 | @override 306 | def load(self, idx: int) -> tuple[torch.Tensor, torch.Tensor]: 307 | image, label = self._base.load(idx) 308 | for pid in self._positive_ids: 309 | label[label == pid] = -1 310 | label[label > 0] = 0 311 | label[label == -1] = 1 312 | return image, label 313 | -------------------------------------------------------------------------------- /mipcandy/data/inspection.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass, asdict 2 | from json import dump, load 3 | from os import PathLike 4 | from typing import Sequence, override, Callable, Self, Any 5 | 6 | import numpy as np 7 | import torch 8 | from rich.console import Console 9 | from rich.progress import Progress, SpinnerColumn 10 | from torch import nn 11 | 12 | from mipcandy.data.dataset import SupervisedDataset 13 | from mipcandy.data.geometric import crop 14 | from mipcandy.layer import HasDevice 15 | from mipcandy.types import Device, Shape, AmbiguousShape 16 | 17 | 18 | def format_bbox(bbox: Sequence[int]) -> tuple[int, int, int, int] | tuple[int, int, int, int, int, int]: 19 | if len(bbox) == 4: 20 | return bbox[0], bbox[1], bbox[2], bbox[3] 21 | elif len(bbox) == 6: 22 | return bbox[0], bbox[1], bbox[2], bbox[3], bbox[4], bbox[5] 23 | else: 24 | raise ValueError(f"Invalid bbox with {len(bbox)} elements") 25 | 26 | 27 | @dataclass 28 | class InspectionAnnotation(object): 29 | shape: AmbiguousShape 30 | foreground_bbox: tuple[int, int, int, int] | tuple[int, int, int, int, int, int] 31 | ids: tuple[int, ...] 32 | 33 | def foreground_shape(self) -> Shape: 34 | r = (self.foreground_bbox[1] - self.foreground_bbox[0], self.foreground_bbox[3] - self.foreground_bbox[2]) 35 | return r if len(self.foreground_bbox) == 4 else r + (self.foreground_bbox[5] - self.foreground_bbox[4],) 36 | 37 | def center_of_foreground(self) -> tuple[int, int] | tuple[int, int, int]: 38 | r = (round((self.foreground_bbox[1] + self.foreground_bbox[0]) * .5), 39 | round((self.foreground_bbox[3] + self.foreground_bbox[2]) * .5)) 40 | return r if len(self.shape) == 2 else r + (round((self.foreground_bbox[5] + self.foreground_bbox[4]) * .5),) 41 | 42 | def to_dict(self) -> dict[str, tuple[int, ...]]: 43 | return asdict(self) 44 | 45 | 46 | class InspectionAnnotations(HasDevice, Sequence[InspectionAnnotation]): 47 | def __init__(self, dataset: SupervisedDataset, background: int, *annotations: InspectionAnnotation, 48 | device: Device = "cpu") -> None: 49 | super().__init__(device) 50 | self._dataset: SupervisedDataset = dataset 51 | self._background: int = background 52 | self._annotations: tuple[InspectionAnnotation, ...] = annotations 53 | self._shapes: tuple[AmbiguousShape | None, AmbiguousShape, AmbiguousShape] | None = None 54 | self._foreground_shapes: tuple[AmbiguousShape | None, AmbiguousShape, AmbiguousShape] | None = None 55 | self._statistical_foreground_shape: Shape | None = None 56 | self._foreground_heatmap: torch.Tensor | None = None 57 | self._center_of_foregrounds: tuple[int, int] | tuple[int, int, int] | None = None 58 | self._foreground_offsets: tuple[int, int] | tuple[int, int, int] | None = None 59 | self._roi_shape: Shape | None = None 60 | 61 | def dataset(self) -> SupervisedDataset: 62 | return self._dataset 63 | 64 | def background(self) -> int: 65 | return self._background 66 | 67 | def annotations(self) -> tuple[InspectionAnnotation, ...]: 68 | return self._annotations 69 | 70 | @override 71 | def __getitem__(self, item: int) -> InspectionAnnotation: 72 | return self._annotations[item] 73 | 74 | @override 75 | def __len__(self) -> int: 76 | return len(self._annotations) 77 | 78 | def save(self, path: str | PathLike[str]) -> None: 79 | with open(path, "w") as f: 80 | dump({"background": self._background, "annotations": [a.to_dict() for a in self._annotations]}, f) 81 | 82 | def _get_shapes(self, get_shape: Callable[[InspectionAnnotation], AmbiguousShape]) -> tuple[ 83 | AmbiguousShape | None, AmbiguousShape, AmbiguousShape]: 84 | depths = [] 85 | widths = [] 86 | heights = [] 87 | for annotation in self._annotations: 88 | shape = get_shape(annotation) 89 | if len(shape) == 2: 90 | heights.append(shape[0]) 91 | widths.append(shape[1]) 92 | else: 93 | depths.append(shape[0]) 94 | heights.append(shape[1]) 95 | widths.append(shape[2]) 96 | return tuple(depths) if depths else None, tuple(heights), tuple(widths) 97 | 98 | def shapes(self) -> tuple[AmbiguousShape | None, AmbiguousShape, AmbiguousShape]: 99 | if self._shapes: 100 | return self._shapes 101 | self._shapes = self._get_shapes(lambda annotation: annotation.shape) 102 | return self._shapes 103 | 104 | def foreground_shapes(self) -> tuple[AmbiguousShape | None, AmbiguousShape, AmbiguousShape]: 105 | if self._foreground_shapes: 106 | return self._foreground_shapes 107 | self._foreground_shapes = self._get_shapes(lambda annotation: annotation.foreground_shape()) 108 | return self._foreground_shapes 109 | 110 | def statistical_foreground_shape(self, *, percentile: float = .95) -> Shape: 111 | if self._statistical_foreground_shape: 112 | return self._statistical_foreground_shape 113 | depths, heights, widths = self.foreground_shapes() 114 | percentile *= 100 115 | sfs = (round(np.percentile(heights, percentile)), round(np.percentile(widths, percentile))) 116 | self._statistical_foreground_shape = (round(np.percentile(depths, percentile)),) + sfs if depths else sfs 117 | return self._statistical_foreground_shape 118 | 119 | def crop_foreground(self, i: int, *, expand_ratio: float = 1) -> tuple[torch.Tensor, torch.Tensor]: 120 | image, label = self._dataset[i] 121 | annotation = self._annotations[i] 122 | bbox = list(annotation.foreground_bbox) 123 | shape = annotation.foreground_shape() 124 | for dim_idx, size in enumerate(shape): 125 | left = int((expand_ratio - 1) * size // 2) 126 | right = int((expand_ratio - 1) * size - left) 127 | bbox[dim_idx * 2] = max(0, bbox[dim_idx * 2] - left) 128 | bbox[dim_idx * 2 + 1] = min(bbox[dim_idx * 2 + 1] + right, annotation.shape[dim_idx]) 129 | return crop(image.unsqueeze(0), bbox).squeeze(0), crop(label.unsqueeze(0), bbox).squeeze(0) 130 | 131 | def foreground_heatmap(self) -> torch.Tensor: 132 | if self._foreground_heatmap: 133 | return self._foreground_heatmap 134 | depths, heights, widths = self.foreground_shapes() 135 | max_shape = (max(depths), max(heights), max(widths)) if depths else (max(heights), max(widths)) 136 | accumulated_label = torch.zeros((1, *max_shape), device=self._device) 137 | for i, (_, label) in enumerate(self._dataset): 138 | annotation = self._annotations[i] 139 | paddings = [0, 0, 0, 0] 140 | shape = annotation.foreground_shape() 141 | for j, size in enumerate(max_shape): 142 | left = (size - shape[j]) // 2 143 | right = size - shape[j] - left 144 | paddings.append(right) 145 | paddings.append(left) 146 | paddings.reverse() 147 | accumulated_label += nn.functional.pad( 148 | crop((label != self._background).unsqueeze(0), annotation.foreground_bbox), paddings 149 | ).squeeze(0) 150 | self._foreground_heatmap = accumulated_label.squeeze(0) 151 | return self._foreground_heatmap 152 | 153 | def center_of_foregrounds(self) -> tuple[int, int] | tuple[int, int, int]: 154 | if self._center_of_foregrounds: 155 | return self._center_of_foregrounds 156 | heatmap = self.foreground_heatmap() 157 | center = (heatmap.sum(dim=1).argmax().item(), heatmap.sum(dim=0).argmax().item()) if heatmap.ndim == 2 else ( 158 | heatmap.sum(dim=(1, 2)).argmax().item(), 159 | heatmap.sum(dim=(0, 2)).argmax().item(), 160 | heatmap.sum(dim=(0, 1)).argmax().item(), 161 | ) 162 | self._center_of_foregrounds = center 163 | return self._center_of_foregrounds 164 | 165 | def center_of_foregrounds_offsets(self) -> tuple[int, int] | tuple[int, int, int]: 166 | if self._foreground_offsets: 167 | return self._foreground_offsets 168 | center = self.center_of_foregrounds() 169 | depths, heights, widths = self.foreground_shapes() 170 | max_shape = (max(depths), max(heights), max(widths)) if depths else (max(heights), max(widths)) 171 | offsets = (round(center[0] - max_shape[0] * .5), round(center[1] - max_shape[1] * .5)) 172 | self._foreground_offsets = offsets + (round(center[2] - max_shape[2] * .5),) if depths else offsets 173 | return self._foreground_offsets 174 | 175 | def set_roi_shape(self, roi_shape: Shape | None) -> None: 176 | if roi_shape is not None: 177 | depths, heights, widths = self.shapes() 178 | if depths: 179 | if roi_shape[0] > min(depths) or roi_shape[1] > min(heights) or roi_shape[2] > min(widths): 180 | raise ValueError(f"ROI shape {roi_shape} exceeds minimum image shape ({min(depths)}, {min(heights)}, {min(widths)})") 181 | else: 182 | if roi_shape[0] > min(heights) or roi_shape[1] > min(widths): 183 | raise ValueError(f"ROI shape {roi_shape} exceeds minimum image shape ({min(heights)}, {min(widths)})") 184 | self._roi_shape = roi_shape 185 | 186 | def roi_shape(self, *, percentile: float = .95) -> Shape: 187 | if self._roi_shape: 188 | return self._roi_shape 189 | sfs = self.statistical_foreground_shape(percentile=percentile) 190 | if len(sfs) == 2: 191 | sfs = (None, *sfs) 192 | depths, heights, widths = self.shapes() 193 | roi_shape = (min(min(heights), sfs[1]), min(min(widths), sfs[2])) 194 | if depths: 195 | roi_shape = (min(min(depths), sfs[0]),) + roi_shape 196 | self._roi_shape = roi_shape 197 | return self._roi_shape 198 | 199 | def roi(self, i: int, *, percentile: float = .95) -> tuple[int, int, int, int] | tuple[ 200 | int, int, int, int, int, int]: 201 | annotation = self._annotations[i] 202 | roi_shape = self.roi_shape(percentile=percentile) 203 | offsets = self.center_of_foregrounds_offsets() 204 | center = annotation.center_of_foreground() 205 | roi = [] 206 | for i, position in enumerate(center): 207 | left = roi_shape[i] // 2 208 | right = roi_shape[i] - left 209 | offset = min(max(offsets[i], left - position), annotation.shape[i] - right - position) 210 | roi.append(position + offset - left) 211 | roi.append(position + offset + right) 212 | return tuple(roi) 213 | 214 | def crop_roi(self, i: int, *, percentile: float = .95) -> tuple[torch.Tensor, torch.Tensor]: 215 | image, label = self._dataset[i] 216 | roi = self.roi(i, percentile=percentile) 217 | return crop(image.unsqueeze(0), roi).squeeze(0), crop(label.unsqueeze(0), roi).squeeze(0) 218 | 219 | 220 | def _lists_to_tuples(pairs: Sequence[tuple[str, Any]]) -> dict[str, Any]: 221 | return {k: tuple(v) if isinstance(v, list) else v for k, v in pairs} 222 | 223 | 224 | def load_inspection_annotations(path: str | PathLike[str], dataset: SupervisedDataset) -> InspectionAnnotations: 225 | with open(path) as f: 226 | obj = load(f, object_pairs_hook=_lists_to_tuples) 227 | return InspectionAnnotations(dataset, obj["background"], *( 228 | InspectionAnnotation(**row) for row in obj["annotations"] 229 | )) 230 | 231 | 232 | def inspect(dataset: SupervisedDataset, *, background: int = 0, console: Console = Console()) -> InspectionAnnotations: 233 | r = [] 234 | with Progress(*Progress.get_default_columns(), SpinnerColumn(), console=console) as progress: 235 | task = progress.add_task("Inspecting dataset...", total=len(dataset)) 236 | for _, label in dataset: 237 | progress.update(task, advance=1, description=f"Inspecting dataset {tuple(label.shape)}") 238 | indices = (label != background).nonzero() 239 | mins = indices.min(dim=0)[0].tolist() 240 | maxs = indices.max(dim=0)[0].tolist() 241 | bbox = (mins[1], maxs[1] + 1, mins[2], maxs[2] + 1) 242 | r.append(InspectionAnnotation( 243 | label.shape[1:], bbox if label.ndim == 3 else bbox + (mins[3], maxs[3] + 1), tuple(label.unique()) 244 | )) 245 | return InspectionAnnotations(dataset, background, *r, device=dataset.device()) 246 | 247 | 248 | class ROIDataset(SupervisedDataset[list[torch.Tensor]]): 249 | def __init__(self, annotations: InspectionAnnotations, *, percentile: float = .95) -> None: 250 | super().__init__([], []) 251 | self._annotations: InspectionAnnotations = annotations 252 | self._percentile: float = percentile 253 | 254 | @override 255 | def __len__(self) -> int: 256 | return len(self._annotations) 257 | 258 | @override 259 | def construct_new(self, images: list[torch.Tensor], labels: list[torch.Tensor]) -> Self: 260 | return self.__class__(self._annotations, percentile=self._percentile) 261 | 262 | @override 263 | def load(self, idx: int) -> tuple[torch.Tensor, torch.Tensor]: 264 | return self._annotations.crop_roi(idx, percentile=self._percentile) 265 | 266 | 267 | class RandomROIDataset(ROIDataset): 268 | def __init__(self, annotations: InspectionAnnotations, *, percentile: float = .95, 269 | foreground_oversample_percentage: float = .33, min_foreground_samples: int = 500, 270 | max_foreground_samples: int = 10000, min_percent_coverage: float = .01) -> None: 271 | super().__init__(annotations, percentile=percentile) 272 | self._fg_oversample: float = foreground_oversample_percentage 273 | self._min_fg_samples: int = min_foreground_samples 274 | self._max_fg_samples: int = max_foreground_samples 275 | self._min_coverage: float = min_percent_coverage 276 | self._fg_locations_cache: dict[int, tuple[tuple[int, ...], ...] | None] = {} 277 | 278 | def _get_foreground_locations(self, idx: int) -> tuple[tuple[int, ...], ...] | None: 279 | if idx not in self._fg_locations_cache: 280 | _, label = self._annotations.dataset()[idx] 281 | indices = (label != self._annotations.background()).nonzero()[:, 1:] 282 | if len(indices) == 0: 283 | self._fg_locations_cache[idx] = None 284 | elif len(indices) <= self._min_fg_samples: 285 | self._fg_locations_cache[idx] = tuple(tuple(coord.tolist()) for coord in indices) 286 | else: 287 | target_samples = min( 288 | self._max_fg_samples, 289 | max(self._min_fg_samples, int(np.ceil(len(indices) * self._min_coverage))) 290 | ) 291 | sampled_idx = torch.randperm(len(indices))[:target_samples] 292 | sampled = indices[sampled_idx] 293 | self._fg_locations_cache[idx] = tuple(tuple(coord.tolist()) for coord in sampled) 294 | return self._fg_locations_cache[idx] 295 | 296 | def _random_roi(self, idx: int) -> tuple[int, int, int, int] | tuple[int, int, int, int, int, int]: 297 | annotation = self._annotations[idx] 298 | roi_shape = self._annotations.roi_shape(percentile=self._percentile) 299 | roi = [] 300 | for dim_size, patch_size in zip(annotation.shape, roi_shape): 301 | left = patch_size // 2 302 | right = patch_size - left 303 | min_center = left 304 | max_center = dim_size - right 305 | center = torch.randint(min_center, max_center + 1, (1,)).item() 306 | roi.append(center - left) 307 | roi.append(center + right) 308 | return tuple(roi) 309 | 310 | def _foreground_guided_random_roi(self, idx: int) -> tuple[int, int, int, int] | tuple[ 311 | int, int, int, int, int, int]: 312 | annotation = self._annotations[idx] 313 | roi_shape = self._annotations.roi_shape(percentile=self._percentile) 314 | foreground_locations = self._get_foreground_locations(idx) 315 | 316 | if foreground_locations is None or len(foreground_locations) == 0: 317 | return self._random_roi(idx) 318 | 319 | fg_idx = torch.randint(0, len(foreground_locations), (1,)).item() 320 | fg_position = foreground_locations[fg_idx] 321 | 322 | roi = [] 323 | for fg_pos, dim_size, patch_size in zip(fg_position, annotation.shape, roi_shape): 324 | left = patch_size // 2 325 | right = patch_size - left 326 | center = max(left, min(fg_pos, dim_size - right)) 327 | roi.append(center - left) 328 | roi.append(center + right) 329 | return tuple(roi) 330 | 331 | @override 332 | def construct_new(self, images: list[torch.Tensor], labels: list[torch.Tensor]) -> Self: 333 | return self.__class__(self._annotations, percentile=self._percentile, 334 | foreground_oversample_percentage=self._fg_oversample, 335 | min_foreground_samples=self._min_fg_samples, 336 | max_foreground_samples=self._max_fg_samples, 337 | min_percent_coverage=self._min_coverage) 338 | 339 | @override 340 | def load(self, idx: int) -> tuple[torch.Tensor, torch.Tensor]: 341 | image, label = self._annotations.dataset()[idx] 342 | force_fg = torch.rand(1).item() < self._fg_oversample 343 | if force_fg: 344 | roi = self._foreground_guided_random_roi(idx) 345 | else: 346 | roi = self._random_roi(idx) 347 | return crop(image.unsqueeze(0), roi).squeeze(0), crop(label.unsqueeze(0), roi).squeeze(0) 348 | -------------------------------------------------------------------------------- /mipcandy/training.py: -------------------------------------------------------------------------------- 1 | from abc import ABCMeta, abstractmethod 2 | from dataclasses import dataclass 3 | from datetime import datetime 4 | from hashlib import md5 5 | from json import load, dump 6 | from os import PathLike, urandom, makedirs, environ 7 | from os.path import exists 8 | from random import seed as random_seed, randint 9 | from shutil import copy 10 | from threading import Lock 11 | from time import time 12 | from typing import Sequence, override, Callable, Self 13 | 14 | import numpy as np 15 | import torch 16 | from matplotlib import pyplot as plt 17 | from pandas import DataFrame, read_csv 18 | from rich.console import Console 19 | from rich.progress import Progress, SpinnerColumn 20 | from rich.table import Table 21 | from torch import nn, optim 22 | from torch.utils.data import DataLoader 23 | 24 | from mipcandy.common import Pad2d, Pad3d, quotient_regression, quotient_derivative, quotient_bounds 25 | from mipcandy.config import load_settings, load_secrets 26 | from mipcandy.frontend import Frontend 27 | from mipcandy.layer import WithPaddingModule, WithNetwork 28 | from mipcandy.sanity_check import sanity_check 29 | from mipcandy.sliding_window import SWMetadata, SlidingWindow 30 | from mipcandy.types import Params, Setting, AmbiguousShape 31 | 32 | 33 | def try_append(new: float, to: dict[str, list[float]], key: str) -> None: 34 | if key in to: 35 | to[key].append(new) 36 | else: 37 | to[key] = [new] 38 | 39 | 40 | def try_append_all(new: dict[str, float], to: dict[str, list[float]]) -> None: 41 | for key, value in new.items(): 42 | try_append(value, to, key) 43 | 44 | 45 | @dataclass 46 | class TrainerToolbox(object): 47 | model: nn.Module 48 | optimizer: optim.Optimizer 49 | scheduler: optim.lr_scheduler.LRScheduler 50 | criterion: nn.Module 51 | ema: optim.swa_utils.AveragedModel | None = None 52 | 53 | 54 | @dataclass 55 | class TrainerTracker(object): 56 | epoch: int = 0 57 | best_score: float = float("-inf") 58 | worst_case: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None 59 | 60 | 61 | class Trainer(WithPaddingModule, WithNetwork, metaclass=ABCMeta): 62 | def __init__(self, trainer_folder: str | PathLike[str], dataloader: DataLoader[tuple[torch.Tensor, torch.Tensor]], 63 | validation_dataloader: DataLoader[tuple[torch.Tensor, torch.Tensor]], *, recoverable: bool = True, 64 | device: torch.device | str = "cpu", console: Console = Console()) -> None: 65 | WithPaddingModule.__init__(self, device) 66 | WithNetwork.__init__(self, device) 67 | self._trainer_folder: str = trainer_folder 68 | self._trainer_variant: str = self.__class__.__name__ 69 | self._experiment_id: str = "tbd" 70 | self._dataloader: DataLoader[tuple[torch.Tensor, torch.Tensor]] = dataloader 71 | self._validation_dataloader: DataLoader[tuple[torch.Tensor, torch.Tensor]] = validation_dataloader 72 | self._unrecoverable: bool | None = not recoverable # None if the trainer is recovered 73 | self._console: Console = console 74 | self._metrics: dict[str, list[float]] = {} 75 | self._epoch_metrics: dict[str, list[float]] = {} 76 | self._frontend: Frontend = Frontend({}) 77 | self._lock: Lock = Lock() 78 | self._tracker: TrainerTracker = TrainerTracker() 79 | 80 | # Recovery methods (PR #108 at https://github.com/ProjectNeura/MIPCandy/pull/108) 81 | 82 | def save_everything_for_recovery(self, toolbox: TrainerToolbox, tracker: TrainerTracker, 83 | **training_arguments) -> None: 84 | if self._unrecoverable: 85 | return 86 | torch.save(toolbox.optimizer.state_dict(), f"{self.experiment_folder()}/optimizer.pth") 87 | torch.save(toolbox.scheduler.state_dict(), f"{self.experiment_folder()}/scheduler.pth") 88 | torch.save(toolbox.criterion.state_dict(), f"{self.experiment_folder()}/criterion.pth") 89 | torch.save(tracker, f"{self.experiment_folder()}/tracker.pt") 90 | with open(f"{self.experiment_folder()}/training_arguments.json", "w") as f: 91 | dump(training_arguments, f) 92 | 93 | def load_tracker(self) -> TrainerTracker: 94 | return torch.load(f"{self.experiment_folder()}/tracker.pt", weights_only=False) 95 | 96 | def load_training_arguments(self) -> dict[str, Setting]: 97 | with open(f"{self.experiment_folder()}/training_arguments.json") as f: 98 | return load(f) 99 | 100 | def load_metrics(self) -> dict[str, list[float]]: 101 | df = read_csv(f"{self.experiment_folder()}/metrics.csv", index_col="epoch") 102 | return {column: df[column].astype(float).tolist() for column in df.columns} 103 | 104 | def load_toolbox(self, num_epochs: int, example_shape: AmbiguousShape) -> TrainerToolbox: 105 | toolbox = self._build_toolbox(num_epochs, example_shape, model=self.load_model( 106 | example_shape, checkpoint=torch.load(f"{self.experiment_folder()}/checkpoint_latest.pth") 107 | )) 108 | toolbox.optimizer.load_state_dict(torch.load(f"{self.experiment_folder()}/optimizer.pth")) 109 | toolbox.scheduler.load_state_dict(torch.load(f"{self.experiment_folder()}/scheduler.pth")) 110 | toolbox.criterion.load_state_dict(torch.load(f"{self.experiment_folder()}/criterion.pth")) 111 | return toolbox 112 | 113 | def recover_from(self, experiment_id: str) -> Self: 114 | self._experiment_id = experiment_id 115 | if not exists(self.experiment_folder()): 116 | raise FileNotFoundError(f"Experiment folder {self.experiment_folder()} not found") 117 | self._metrics = self.load_metrics() 118 | self._tracker = self.load_tracker() 119 | self._unrecoverable = None 120 | return self 121 | 122 | def continue_training(self, num_epochs: int) -> None: 123 | if not self.recovery(): 124 | raise RuntimeError("Must call `recover_from()` before continuing training") 125 | self.train(num_epochs, **self.load_training_arguments()) 126 | 127 | # Getters 128 | 129 | def trainer_folder(self) -> str: 130 | return self._trainer_folder 131 | 132 | def trainer_variant(self) -> str: 133 | return self._trainer_variant 134 | 135 | def experiment_id(self) -> str: 136 | return self._experiment_id 137 | 138 | def dataloader(self) -> DataLoader[tuple[torch.Tensor, torch.Tensor]]: 139 | return self._dataloader 140 | 141 | def validation_dataloader(self) -> DataLoader[tuple[torch.Tensor, torch.Tensor]]: 142 | return self._validation_dataloader 143 | 144 | def console(self) -> Console: 145 | return self._console 146 | 147 | def metrics(self) -> dict[str, list[float]]: 148 | return self._metrics.copy() 149 | 150 | def frontend(self) -> Frontend: 151 | return self._frontend 152 | 153 | def tracker(self) -> TrainerTracker: 154 | return self._tracker 155 | 156 | # Enhanced getters 157 | 158 | def initialized(self) -> bool: 159 | return self._experiment_id != "tbd" 160 | 161 | def recovery(self) -> bool: 162 | return self._unrecoverable is None 163 | 164 | def experiment_folder(self) -> str: 165 | return f"{self._trainer_folder}/{self._trainer_variant}/{self._experiment_id}" 166 | 167 | def predict_maximum_validation_score(self, num_epochs: int, *, degree: int = 5) -> tuple[int, float]: 168 | val_scores = np.array(self._metrics["val score"]) 169 | a, b = quotient_regression(np.arange(len(val_scores)), val_scores, degree, degree) 170 | da, db = quotient_derivative(a, b) 171 | max_roc = float(da[0] / db[0]) 172 | max_val_score = float(a[0] / b[0]) 173 | bounds = quotient_bounds(a, b, None, max_val_score * (1 - max_roc), x_start=0, x_stop=num_epochs, x_step=1) 174 | return (round(bounds[1]) + 1, max_val_score) if bounds else (0, 0) 175 | 176 | def etc(self, epoch: int, num_epochs: int, *, target_epoch: int | None = None, 177 | val_score_prediction_degree: int = 5) -> float: 178 | if not target_epoch: 179 | target_epoch, _ = self.predict_maximum_validation_score(num_epochs, degree=val_score_prediction_degree) 180 | epoch_durations = self._metrics["epoch duration"] 181 | return sum(epoch_durations) * (target_epoch - epoch) / len(epoch_durations) 182 | 183 | # Setters 184 | 185 | def set_frontend(self, frontend: type[Frontend], *, path_to_secrets: str | PathLike[str] | None = None) -> None: 186 | self._frontend = frontend(load_secrets(path=path_to_secrets) if path_to_secrets else load_secrets()) 187 | 188 | def set_seed(self, seed: int) -> None: 189 | np.random.seed(seed) 190 | torch.manual_seed(seed) 191 | torch.cuda.manual_seed(seed) 192 | torch.cuda.manual_seed_all(seed) 193 | torch.backends.cudnn.benchmark = False 194 | torch.backends.cudnn.deterministic = True 195 | random_seed(seed) 196 | np.random.seed(seed) 197 | environ['PYTHONHASHSEED'] = str(seed) 198 | if self.initialized(): 199 | self.log(f"Set to manual seed {seed}") 200 | 201 | # Initialization methods 202 | 203 | def _allocate_experiment_folder(self) -> str: 204 | self._experiment_id = datetime.now().strftime("%Y%m%d-%H-") + md5(urandom(8)).hexdigest()[:4] 205 | experiment_folder = self.experiment_folder() 206 | return self._allocate_experiment_folder() if exists(experiment_folder) else experiment_folder 207 | 208 | def allocate_experiment_folder(self) -> str: 209 | return self.experiment_folder() if self.initialized() else self._allocate_experiment_folder() 210 | 211 | def init_experiment(self) -> None: 212 | if self.recovery(): 213 | self.log(f"Training progress recovered from {self._experiment_id} from epoch {self._tracker.epoch}") 214 | return 215 | if self.initialized(): 216 | raise RuntimeError("Experiment already initialized") 217 | makedirs(self._trainer_folder, exist_ok=True) 218 | experiment_folder = self.allocate_experiment_folder() 219 | makedirs(experiment_folder) 220 | t = datetime.now() 221 | with open(f"{experiment_folder}/logs.txt", "w") as f: 222 | f.write(f"File created by FightTumor, copyright (C) {t.year} Project Neura. All rights reserved\n") 223 | self.log(f"Experiment (ID {self._experiment_id}) created at {t}") 224 | self.log(f"Trainer: {self.__class__.__name__}") 225 | 226 | # Logging utilities 227 | 228 | def log(self, msg: str, *, on_screen: bool = True) -> None: 229 | msg = f"[{datetime.now()}] {msg}" 230 | if self.initialized(): 231 | with open(f"{self.experiment_folder()}/logs.txt", "a") as f: 232 | f.write(f"{msg}\n") 233 | if on_screen: 234 | with self._lock: 235 | self._console.print(msg) 236 | 237 | def record(self, metric: str, value: float) -> None: 238 | try_append(value, self._epoch_metrics, metric) 239 | 240 | def _record(self, metric: str, value: float) -> None: 241 | try_append(value, self._metrics, metric) 242 | 243 | def record_all(self, metrics: dict[str, float]) -> None: 244 | try_append_all(metrics, self._epoch_metrics) 245 | 246 | def _bump_metrics(self) -> None: 247 | for metric, values in self._epoch_metrics.items(): 248 | epoch_overall = sum(values) / len(values) 249 | try_append(epoch_overall, self._metrics, metric) 250 | self._epoch_metrics.clear() 251 | 252 | def save_metrics(self) -> None: 253 | df = DataFrame(self._metrics) 254 | df.index = range(1, len(df) + 1) 255 | df.index.name = "epoch" 256 | df.to_csv(f"{self.experiment_folder()}/metrics.csv") 257 | 258 | def save_metric_curve(self, name: str, values: Sequence[float]) -> None: 259 | name = name.capitalize() 260 | plt.plot(values) 261 | plt.title(f"{name} over Epoch") 262 | plt.xlabel("Epoch") 263 | plt.ylabel(name) 264 | plt.grid() 265 | plt.savefig(f"{self.experiment_folder()}/{name.lower()}.png") 266 | plt.close() 267 | 268 | def save_metric_curve_combo(self, metrics: dict[str, Sequence[float]], *, title: str = "All Metrics") -> None: 269 | for name, values in metrics.items(): 270 | plt.plot(values, label=name.capitalize()) 271 | plt.title(title) 272 | plt.xlabel("Epoch") 273 | plt.legend() 274 | plt.grid() 275 | plt.savefig(f"{self.experiment_folder()}/{title.lower()}.png") 276 | plt.close() 277 | 278 | def save_metric_curves(self, *, names: Sequence[str] | None = None) -> None: 279 | if names is None: 280 | for name, values in self._metrics.items(): 281 | self.save_metric_curve(name, values) 282 | else: 283 | for name in names: 284 | self.save_metric_curve(name, self._metrics[name]) 285 | 286 | def save_progress(self, *, names: Sequence[str] = ("combined loss", "val score")) -> None: 287 | self.save_metric_curve_combo({name: self._metrics[name] for name in names}, title="Progress") 288 | 289 | def save_preview(self, image: torch.Tensor, label: torch.Tensor, output: torch.Tensor, *, 290 | quality: float = .75) -> None: 291 | ... 292 | 293 | def show_metrics(self, epoch: int, *, metrics: dict[str, list[float]] | None = None, prefix: str = "training", 294 | epochwise: bool = True, skip: Callable[[str, list[float]], bool] | None = None) -> None: 295 | if not metrics: 296 | metrics = self._metrics 297 | prefix = prefix.capitalize() 298 | table = Table(title=f"Epoch {epoch} {prefix}") 299 | table.add_column("Metric") 300 | table.add_column("Mean Value", style="green") 301 | table.add_column("Span", style="cyan") 302 | table.add_column("Diff", style="magenta") 303 | for metric, values in metrics.items(): 304 | if skip and skip(metric, values): 305 | continue 306 | span = f"[{min(values):.4f}, {max(values):.4f}]" 307 | if epochwise: 308 | value = f"{values[-1]:.4f}" 309 | diff = f"{values[-1] - values[-2]:+.4f}" if len(values) > 1 else "N/A" 310 | else: 311 | mean = sum(values) / len(values) 312 | value = f"{mean:.4f}" 313 | diff = f"{mean - self._metrics[metric][-1]:+.4f}" if metric in self._metrics else "N/A" 314 | table.add_row(metric, value, span, diff) 315 | self.log(f"{prefix} {metric}: {value} @{span} ({diff})") 316 | console = Console() 317 | console.print(table) 318 | 319 | # Builder interfaces 320 | 321 | @abstractmethod 322 | def build_optimizer(self, params: Params) -> optim.Optimizer: 323 | raise NotImplementedError 324 | 325 | @abstractmethod 326 | def build_scheduler(self, optimizer: optim.Optimizer, num_epochs: int) -> optim.lr_scheduler.LRScheduler: 327 | raise NotImplementedError 328 | 329 | @abstractmethod 330 | def build_criterion(self) -> nn.Module: 331 | raise NotImplementedError 332 | 333 | def _build_toolbox(self, num_epochs: int, example_shape: AmbiguousShape, *, 334 | model: nn.Module | None = None) -> TrainerToolbox: 335 | if not model: 336 | model = self.load_model(example_shape) 337 | optimizer = self.build_optimizer(model.parameters()) 338 | scheduler = self.build_scheduler(optimizer, num_epochs) 339 | criterion = self.build_criterion().to(self._device) 340 | return TrainerToolbox(model, optimizer, scheduler, criterion) 341 | 342 | def build_toolbox(self, num_epochs: int, example_shape: AmbiguousShape) -> TrainerToolbox: 343 | return self._build_toolbox(num_epochs, example_shape) 344 | 345 | # Training methods 346 | 347 | @abstractmethod 348 | def backward(self, images: torch.Tensor, labels: torch.Tensor, toolbox: TrainerToolbox) -> tuple[float, dict[ 349 | str, float]]: 350 | raise NotImplementedError 351 | 352 | def train_batch(self, images: torch.Tensor, labels: torch.Tensor, toolbox: TrainerToolbox) -> tuple[float, dict[ 353 | str, float]]: 354 | toolbox.optimizer.zero_grad() 355 | loss, metrics = self.backward(images, labels, toolbox) 356 | toolbox.optimizer.step() 357 | toolbox.scheduler.step() 358 | if toolbox.ema: 359 | toolbox.ema.update_parameters(toolbox.model) 360 | return loss, metrics 361 | 362 | def train_epoch(self, epoch: int, toolbox: TrainerToolbox) -> None: 363 | toolbox.model.train() 364 | if toolbox.ema: 365 | toolbox.ema.train() 366 | with Progress(*Progress.get_default_columns(), SpinnerColumn(), console=self._console) as progress: 367 | epoch_prog = progress.add_task(f"Epoch {epoch}", total=len(self._dataloader)) 368 | for images, labels in self._dataloader: 369 | images, labels = images.to(self._device), labels.to(self._device) 370 | padding_module = self.get_padding_module() 371 | if padding_module: 372 | images, labels = padding_module(images), padding_module(labels) 373 | progress.update(epoch_prog, description=f"Training epoch {epoch} {tuple(images.shape)}") 374 | loss, metrics = self.train_batch(images, labels, toolbox) 375 | self.record("combined loss", loss) 376 | self.record_all(metrics) 377 | progress.update(epoch_prog, advance=1, description=f"Training epoch {epoch} ({loss:.4f})") 378 | self._bump_metrics() 379 | 380 | def train(self, num_epochs: int, *, note: str = "", num_checkpoints: int = 5, ema: bool = True, 381 | seed: int | None = None, early_stop_tolerance: int = 5, val_score_prediction: bool = True, 382 | val_score_prediction_degree: int = 5, save_preview: bool = True, preview_quality: float = .75) -> None: 383 | training_arguments = self.filter_train_params(**locals()) 384 | self.init_experiment() 385 | if note: 386 | self.log(f"Note: {note}") 387 | if seed is None: 388 | seed = randint(0, 100) 389 | self.set_seed(seed) 390 | example_input = self._dataloader.dataset[0][0].to(self._device).unsqueeze(0) 391 | padding_module = self.get_padding_module() 392 | if padding_module: 393 | example_input = padding_module(example_input) 394 | example_shape = tuple(example_input.shape[1:]) 395 | self.log(f"Example input shape: {example_shape}") 396 | toolbox = self.load_toolbox(num_epochs, example_shape) if self.recovery() else self.build_toolbox( 397 | num_epochs, example_shape) 398 | model_name = toolbox.model.__class__.__name__ 399 | sanity_check_result = sanity_check(toolbox.model, example_shape, device=self._device) 400 | self.log(f"Model: {model_name}") 401 | self.log(str(sanity_check_result)) 402 | self.log(f"Example output shape: {tuple(sanity_check_result.output.shape)}") 403 | if ema: 404 | toolbox.ema = optim.swa_utils.AveragedModel(toolbox.model) 405 | checkpoint_path = lambda v: f"{self.experiment_folder()}/checkpoint_{v}.pth" 406 | es_tolerance = early_stop_tolerance 407 | self._frontend.on_experiment_created(self._experiment_id, self._trainer_variant, model_name, note, 408 | sanity_check_result.num_macs, sanity_check_result.num_params, num_epochs, 409 | early_stop_tolerance) 410 | try: 411 | for epoch in range(self._tracker.epoch, self._tracker.epoch + num_epochs): 412 | if early_stop_tolerance == -1: 413 | epoch -= 1 414 | self.log(f"Early stopping triggered because the validation score has not improved for { 415 | es_tolerance} epochs") 416 | break 417 | self._tracker.epoch = epoch 418 | # Training 419 | t0 = time() 420 | self.train_epoch(epoch, toolbox) 421 | lr = toolbox.scheduler.get_last_lr()[0] 422 | self._record("learning rate", lr) 423 | self.show_metrics(epoch, skip=lambda m, _: m.startswith("val ") or m == "epoch duration") 424 | torch.save(toolbox.model.state_dict(), checkpoint_path("latest")) 425 | if epoch % (num_epochs / num_checkpoints) == 0: 426 | copy(checkpoint_path("latest"), checkpoint_path(epoch)) 427 | self.log(f"Epoch {epoch} checkpoint saved") 428 | self.log(f"Epoch {epoch} training completed in {time() - t0:.1f} seconds") 429 | # Validation 430 | score, metrics = self.validate(toolbox) 431 | self._record("val score", score) 432 | msg = f"Validation score: {score:.4f}" 433 | if epoch > 1: 434 | msg += f" ({score - self._metrics["val score"][-2]:+.4f})" 435 | self.log(msg) 436 | if val_score_prediction and epoch > val_score_prediction_degree: 437 | target_epoch, max_score = self.predict_maximum_validation_score( 438 | num_epochs, degree=val_score_prediction_degree 439 | ) 440 | self.log(f"Maximum validation score {max_score:.4f} predicted at epoch {target_epoch}") 441 | etc = self.etc(epoch, num_epochs, target_epoch=target_epoch) 442 | self.log(f"Estimated time of completion in {etc:.1f} seconds at {datetime.fromtimestamp( 443 | time() + etc):%m-%d %H:%M:%S}") 444 | self.show_metrics(epoch, metrics=metrics, prefix="validation", epochwise=False) 445 | if score > self._tracker.best_score: 446 | copy(checkpoint_path("latest"), checkpoint_path("best")) 447 | self.log(f"======== Best checkpoint updated ({self._tracker.best_score:.4f} -> { 448 | score:.4f}) ========") 449 | self._tracker.best_score = score 450 | early_stop_tolerance = es_tolerance 451 | if save_preview: 452 | self.save_preview(*self._tracker.worst_case, quality=preview_quality) 453 | else: 454 | early_stop_tolerance -= 1 455 | epoch_duration = time() - t0 456 | self._record("epoch duration", epoch_duration) 457 | self.log(f"Epoch {epoch} completed in {epoch_duration:.1f} seconds") 458 | self.log(f"=============== Best Validation Score {self._tracker.best_score:.4f} ===============") 459 | self.save_metrics() 460 | self.save_progress() 461 | self.save_metric_curves() 462 | self.save_everything_for_recovery(toolbox, self._tracker, **training_arguments) 463 | self._frontend.on_experiment_updated(self._experiment_id, epoch, self._metrics, early_stop_tolerance) 464 | except Exception as e: 465 | self.log("Training interrupted") 466 | self.log(repr(e)) 467 | self._frontend.on_experiment_interrupted(self._experiment_id, e) 468 | raise e 469 | else: 470 | self.log("Training completed") 471 | self._frontend.on_experiment_completed(self._experiment_id) 472 | 473 | @staticmethod 474 | def filter_train_params(**kwargs) -> dict[str, Setting]: 475 | return {k: v for k, v in kwargs.items() if k in ( 476 | "note", "num_checkpoints", "ema", "seed", "early_stop_tolerance", "val_score_prediction", 477 | "val_score_prediction_degree", "save_preview", "preview_quality" 478 | )} 479 | 480 | def train_with_settings(self, num_epochs: int, **kwargs) -> None: 481 | settings = self.filter_train_params(**load_settings()) 482 | settings.update(kwargs) 483 | self.train(num_epochs, **settings) 484 | 485 | # Validation methods 486 | 487 | @abstractmethod 488 | def validate_case(self, image: torch.Tensor, label: torch.Tensor, toolbox: TrainerToolbox) -> tuple[float, dict[ 489 | str, float], torch.Tensor]: 490 | raise NotImplementedError 491 | 492 | def validate(self, toolbox: TrainerToolbox) -> tuple[float, dict[str, list[float]]]: 493 | if self._validation_dataloader.batch_size != 1: 494 | raise RuntimeError("Validation dataloader should have batch size 1") 495 | toolbox.model.eval() 496 | if toolbox.ema: 497 | toolbox.ema.eval() 498 | score = 0 499 | worst_score = float("+inf") 500 | metrics = {} 501 | num_cases = len(self._validation_dataloader) 502 | with torch.no_grad(), Progress( 503 | *Progress.get_default_columns(), SpinnerColumn(), console=self._console 504 | ) as progress: 505 | val_prog = progress.add_task(f"Validating", total=num_cases) 506 | for image, label in self._validation_dataloader: 507 | image, label = image.to(self._device), label.to(self._device) 508 | padding_module = self.get_padding_module() 509 | if padding_module: 510 | image, label = padding_module(image), padding_module(label) 511 | image, label = image.squeeze(0), label.squeeze(0) 512 | progress.update(val_prog, description=f"Validating {tuple(image.shape)}") 513 | case_score, case_metrics, output = self.validate_case(image, label, toolbox) 514 | score += case_score 515 | if case_score < worst_score: 516 | self._tracker.worst_case = (image, label, output) 517 | worst_score = case_score 518 | try_append_all(case_metrics, metrics) 519 | progress.update(val_prog, advance=1, description=f"Validating ({case_score:.4f})") 520 | return score / num_cases, metrics 521 | 522 | def __call__(self, *args, **kwargs) -> None: 523 | self.train(*args, **kwargs) 524 | 525 | @override 526 | def __str__(self) -> str: 527 | return f"{self.__class__.__name__} {self._experiment_id}" 528 | 529 | 530 | class SlidingTrainer(Trainer, SlidingWindow, metaclass=ABCMeta): 531 | @override 532 | def build_padding_module(self) -> nn.Module | None: 533 | window_shape = self.get_window_shape() 534 | return (Pad2d if len(window_shape) == 2 else Pad3d)(window_shape) 535 | 536 | @abstractmethod 537 | def validate_case_windowed(self, images: torch.Tensor, label: torch.Tensor, toolbox: TrainerToolbox, 538 | metadata: SWMetadata) -> tuple[float, dict[str, float], torch.Tensor]: 539 | raise NotImplementedError 540 | 541 | @override 542 | def validate_case(self, image: torch.Tensor, label: torch.Tensor, toolbox: TrainerToolbox) -> tuple[float, dict[ 543 | str, float], torch.Tensor]: 544 | images, metadata = self.do_sliding_window(image.unsqueeze(0)) 545 | return self.validate_case_windowed(images, label, toolbox, metadata) 546 | 547 | @abstractmethod 548 | def backward_windowed(self, images: torch.Tensor, labels: torch.Tensor, toolbox: TrainerToolbox, 549 | metadata: SWMetadata) -> tuple[float, dict[str, float]]: 550 | raise NotImplementedError 551 | 552 | @override 553 | def backward(self, images: torch.Tensor, labels: torch.Tensor, toolbox: TrainerToolbox) -> tuple[float, dict[ 554 | str, float]]: 555 | images, metadata = self.do_sliding_window(images) 556 | labels, _ = self.do_sliding_window(labels) 557 | return self.backward_windowed(images, labels, toolbox, metadata) 558 | -------------------------------------------------------------------------------- /home/index.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | MIP Candy 6 | 7 | 8 | 597 | 598 | 599 | 618 |
619 | 620 |
621 |
622 | 625 |
626 |
Project Neura
627 |
MIP Candy
628 |
629 |
630 | 637 |
638 | 639 | 640 |
641 |
642 |
A candy for medical image processing
643 |

Next-generation infrastructure for fast prototyping in machine learning.

644 |

645 | MIP Candy brings ready-to-use training, inference, and evaluation pipelines together with 646 | aesthetics, so you can focus on your experiments, not boilerplate. 647 |

648 | 649 | 659 | 660 |
661 | PyPI 662 | GitHub release 663 | PyPI downloads 664 | GitHub stars 665 |
666 |
667 | 668 |
669 | MIP Candy overview poster 670 |
671 |
672 | 673 | 674 |
675 |
676 |

Trusted by

677 |

MIP Candy powers research across top institutions.

678 |
679 | 680 | 718 |
719 | 720 | 721 |
722 |
723 |

Key features

724 |

Designed for modern medical image research pipelines.

725 |
726 | 727 |
728 |
729 |
Training
730 |
Easy adaptation to your workflow
731 |

732 | Override a single method to plug in your own network architecture. Grab a tool from the box and 733 | customize your experiments. 734 |

735 |
    736 |
  • Sliding window
  • 737 |
  • ROI inspection & cropping
  • 738 |
  • Automatic padding & shape alignment
  • 739 |
740 |
741 | 742 |
743 |
Interface
744 |
Thoughtful command-line UX
745 |

746 | A clean CLI layout makes it easy to configure experiments, track progress, and resume work without 747 | digging through scripts. 748 |

749 |
750 | CLI UI 751 |
752 |
753 | 754 |
755 |
Visualization
756 |
Built-in 2D & 3D views
757 |

758 | Inspect slices or volumes directly from the training pipeline for intuitive understanding of your 759 | data and predictions. 760 |

761 |
762 | 2D and 3D visualization 763 |
764 |
765 | 766 |
767 |
Reliability
768 |
Interruption-tolerant runs
769 |

770 | Experiments can be safely paused and resumed with built-in recovery mechanisms, so cluster hiccups 771 | don’t cost you progress. 772 |

773 |
774 | Recovery screenshots 775 |
776 |
777 | 778 |
779 |
Dashboards
780 |
Remote monitoring ready
781 |

782 | Connect to Notion, Weights & Biases, and TensorBoard for rich experiment tracking and sharing 783 | with your team. 784 |

785 |
786 | Notion dashboard 787 |
788 |
789 | 790 |
791 |
Python 3.12+
792 |
Modern Python, modern stack
793 |

794 | Built for Python 3.12 and above, MIP Candy takes advantage of modern typing and ecosystem 795 | improvements out of the box. 796 |

797 |
798 |
799 |
800 | 801 | 802 |
803 |
Custom trainers
804 |

Adapt a new network in one method.

805 |

806 | Start from SegmentationTrainer and only implement the network construction. 807 | MIP Candy handles data flow, loss computation, augmentation, checkpointing, and evaluation out of the box. 808 |

809 | 810 |
811 |
Example — Custom model integration
812 |
from typing import override
813 | from torch import nn
814 | from mipcandy import SegmentationTrainer
815 | 
816 | 
817 | class MyTrainer(SegmentationTrainer):
818 |     @override
819 |     def build_network(self, example_shape: tuple[int, ...]) -> nn.Module:
820 |         ...
821 | 
822 |
823 | Provide your architecture once. MIP Candy takes care of the entire training pipeline. 824 |
825 |
826 |
827 | 828 | 829 |
830 |
831 |

Live Notion dashboard

832 |

833 | Explore an interactive MIP Candy frontend demo directly in Notion. 834 |

835 |
836 |
837 | 843 |
844 |
845 | 846 | 847 |
848 |
849 |
850 |
Quick start
851 |

Train like a Pro in a few lines

852 |

853 | Download a dataset, create a dataset wrapper, and hand it to a bundled trainer. Below is an example 854 | using the PH2 dataset with batch size 1 for varying shapes, although you can use a 855 | ROIDataset to align them. 856 |

857 |
858 |
859 |
Example — nnU-Net style training
860 |
from typing import override
861 | 
862 | import torch
863 | from mipcandy_bundles.unet import UNetTrainer
864 | from torch.utils.data import DataLoader
865 | 
866 | from mipcandy import download_dataset, NNUNetDataset
867 | 
868 | 
869 | class PH2(NNUNetDataset):
870 |     @override
871 |     def load(self, idx: int) -> tuple[torch.Tensor, torch.Tensor]:
872 |         image, label = super().load(idx)
873 |         return image.squeeze(0).permute(2, 0, 1), label
874 | 
875 | 
876 | download_dataset("nnunet_datasets/PH2", "tutorial/datasets/PH2")
877 | dataset, val_dataset = PH2("tutorial/datasets/PH2", device="cuda").fold()
878 | dataloader = DataLoader(dataset, 1, shuffle=True)
879 | val_dataloader = DataLoader(val_dataset, 1, shuffle=False)
880 | trainer = UNetTrainer("tutorial", dataloader, val_dataloader, device="cuda")
881 | trainer.train(1000, note="a nnU-Net style example")
882 | 
883 |
884 |
885 |
886 | 887 |
888 |
889 |
890 |
Installation
891 |

Install MIP Candy

892 |

893 | MIP Candy requires Python ≥ 3.12. 894 | Install the standard bundle from PyPI: 895 |

896 | 897 |
898 |
899 |
900 | 901 | 902 | 903 |
904 | $ 905 | pip install "mipcandy[standard]" 906 |
907 |
908 |
909 |
910 |
Stand on the Giants
911 |

Install MIP Candy Bundles

912 |

913 | MIP Candy Bundles provide verified model architectures with corresponding trainers and predictors. 914 | You can install it along with MIP Candy. 915 |

916 | 917 |
918 |
919 |
920 | 921 | 922 | 923 |
924 | $ 925 | pip install "mipcandy[all]" 926 |
927 |
928 |
929 |
930 |
931 | 932 | 933 | 958 |
959 | 960 | 961 | --------------------------------------------------------------------------------