├── .github
├── FUNDING.yml
├── dependabot.yml
├── ISSUE_TEMPLATE
│ ├── feature_request.md
│ └── bug_report.md
└── workflows
│ └── python-publish.yml
├── .gitignore
├── home
├── assets
│ ├── logo.png
│ ├── cli-ui.png
│ ├── notion.png
│ ├── poster.png
│ ├── recovery.png
│ ├── logos
│ │ ├── skule.png
│ │ ├── uoft.png
│ │ ├── uoftai.png
│ │ ├── utmist.png
│ │ └── vector.png
│ └── visualization.png
└── index.html
├── .gitmodules
├── mipcandy
├── __main__.py
├── common
│ ├── numpy
│ │ ├── __init__.py
│ │ └── regressions.py
│ ├── __init__.py
│ ├── optim
│ │ ├── __init__.py
│ │ ├── lr_scheduler.py
│ │ └── loss.py
│ └── module
│ │ ├── __init__.py
│ │ ├── conv.py
│ │ └── preprocess.py
├── presets
│ ├── __init__.py
│ └── segmentation.py
├── frontend
│ ├── __init__.py
│ ├── wandb_fe.py
│ ├── prototype.py
│ └── notion_fe.py
├── settings.yml
├── run.py
├── types.py
├── data
│ ├── __init__.py
│ ├── download.py
│ ├── convertion.py
│ ├── geometric.py
│ ├── io.py
│ ├── visualization.py
│ ├── dataset.py
│ └── inspection.py
├── __entry__.py
├── __init__.py
├── sanity_check.py
├── config.py
├── layer.py
├── evaluation.py
├── sliding_window.py
├── metrics.py
├── inference.py
└── training.py
├── pyproject.toml
├── README.md
└── LICENSE
/.github/FUNDING.yml:
--------------------------------------------------------------------------------
1 | github: ProjectNeura
2 | patreon: ProjectNeura
3 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | .idea
2 | __pycache__
3 | main.py
4 | build
5 | *.egg-info
6 | dist
7 | secrets.yml
8 |
--------------------------------------------------------------------------------
/home/assets/logo.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ProjectNeura/MIPCandy/HEAD/home/assets/logo.png
--------------------------------------------------------------------------------
/home/assets/cli-ui.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ProjectNeura/MIPCandy/HEAD/home/assets/cli-ui.png
--------------------------------------------------------------------------------
/home/assets/notion.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ProjectNeura/MIPCandy/HEAD/home/assets/notion.png
--------------------------------------------------------------------------------
/home/assets/poster.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ProjectNeura/MIPCandy/HEAD/home/assets/poster.png
--------------------------------------------------------------------------------
/home/assets/recovery.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ProjectNeura/MIPCandy/HEAD/home/assets/recovery.png
--------------------------------------------------------------------------------
/home/assets/logos/skule.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ProjectNeura/MIPCandy/HEAD/home/assets/logos/skule.png
--------------------------------------------------------------------------------
/home/assets/logos/uoft.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ProjectNeura/MIPCandy/HEAD/home/assets/logos/uoft.png
--------------------------------------------------------------------------------
/home/assets/logos/uoftai.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ProjectNeura/MIPCandy/HEAD/home/assets/logos/uoftai.png
--------------------------------------------------------------------------------
/home/assets/logos/utmist.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ProjectNeura/MIPCandy/HEAD/home/assets/logos/utmist.png
--------------------------------------------------------------------------------
/home/assets/logos/vector.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ProjectNeura/MIPCandy/HEAD/home/assets/logos/vector.png
--------------------------------------------------------------------------------
/home/assets/visualization.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ProjectNeura/MIPCandy/HEAD/home/assets/visualization.png
--------------------------------------------------------------------------------
/.gitmodules:
--------------------------------------------------------------------------------
1 | [submodule "mipcandy-docs"]
2 | path = mipcandy-docs
3 | url = https://github.com/ProjectNeura/mipcandy-docs
4 |
--------------------------------------------------------------------------------
/mipcandy/__main__.py:
--------------------------------------------------------------------------------
1 | from mipcandy.__entry__ import __entry__
2 |
3 |
4 | if __name__ == "__main__":
5 | __entry__()
6 |
--------------------------------------------------------------------------------
/mipcandy/common/numpy/__init__.py:
--------------------------------------------------------------------------------
1 | from mipcandy.common.numpy.regressions import quotient_regression, quotient_derivative, quotient_bounds
2 |
--------------------------------------------------------------------------------
/mipcandy/common/__init__.py:
--------------------------------------------------------------------------------
1 | from mipcandy.common.numpy import *
2 | from mipcandy.common.module import *
3 | from mipcandy.common.optim import *
4 |
--------------------------------------------------------------------------------
/mipcandy/presets/__init__.py:
--------------------------------------------------------------------------------
1 | from mipcandy.presets.segmentation import SegmentationTrainer, SlidingSegmentationTrainer, SlidingValidationTrainer
2 |
--------------------------------------------------------------------------------
/mipcandy/frontend/__init__.py:
--------------------------------------------------------------------------------
1 | from mipcandy.frontend.notion_fe import NotionFrontend
2 | from mipcandy.frontend.prototype import Frontend, create_hybrid_frontend
3 |
--------------------------------------------------------------------------------
/mipcandy/common/optim/__init__.py:
--------------------------------------------------------------------------------
1 | from mipcandy.common.optim.loss import FocalBCEWithLogits, DiceBCELossWithLogits
2 | from mipcandy.common.optim.lr_scheduler import AbsoluteLinearLR
3 |
--------------------------------------------------------------------------------
/mipcandy/common/module/__init__.py:
--------------------------------------------------------------------------------
1 | from mipcandy.common.module.conv import ConvBlock2d, ConvBlock3d, WSConv2d, WSConv3d
2 | from mipcandy.common.module.preprocess import Pad2d, Pad3d, Restore2d, Restore3d, Normalize, ColorizeLabel
3 |
--------------------------------------------------------------------------------
/mipcandy/settings.yml:
--------------------------------------------------------------------------------
1 | # fill in your settings here
2 | note: ""
3 | num_checkpoints: 5
4 | ema: true
5 | seed: ~
6 | early_stop_tolerance: 5
7 | val_score_prediction: true
8 | val_score_prediction_degree: 5
9 | save_preview: true
10 | preview_quality: 0.75
11 |
--------------------------------------------------------------------------------
/mipcandy/run.py:
--------------------------------------------------------------------------------
1 | from typing import Literal
2 |
3 | from mipcandy.config import load_settings, save_settings, load_secrets, save_secrets
4 |
5 |
6 | def config(target: Literal["setting", "secret"], key: str, value: str) -> None:
7 | match target:
8 | case "setting":
9 | settings = load_settings()
10 | settings[key] = value
11 | save_settings(settings)
12 | case "secret":
13 | secrets = load_secrets()
14 | secrets[key] = value
15 | save_secrets(secrets)
16 |
--------------------------------------------------------------------------------
/.github/dependabot.yml:
--------------------------------------------------------------------------------
1 | # To get started with Dependabot version updates, you'll need to specify which
2 | # package ecosystems to update and where the package manifests are located.
3 | # Please see the documentation for all configuration options:
4 | # https://docs.github.com/github/administering-a-repository/configuration-options-for-dependency-updates
5 |
6 | version: 2
7 | updates:
8 | - package-ecosystem: "pip" # See documentation for possible values
9 | directory: "/" # Location of package manifests
10 | schedule:
11 | interval: "weekly"
12 |
--------------------------------------------------------------------------------
/mipcandy/types.py:
--------------------------------------------------------------------------------
1 | from os import PathLike
2 | from typing import Any, Iterable, Sequence
3 |
4 | import torch
5 | from torch import nn
6 | from torchvision.transforms import Compose
7 |
8 | type Setting = str | int | float | bool | None | dict[str, Setting] | list[Setting]
9 | type Settings = dict[str, Setting]
10 | type Params = Iterable[torch.Tensor] | Iterable[dict[str, Any]]
11 | type Transform = nn.Module | Compose
12 | type SupportedPredictant = Sequence[torch.Tensor] | str | PathLike[str] | Sequence[str] | torch.Tensor
13 | type Colormap = Sequence[int | tuple[int, int, int]]
14 | type Device = torch.device | str
15 | type Shape2d = tuple[int, int]
16 | type Shape3d = tuple[int, int, int]
17 | type Shape = Shape2d | Shape3d
18 | type AmbiguousShape = tuple[int, ...]
19 |
--------------------------------------------------------------------------------
/.github/ISSUE_TEMPLATE/feature_request.md:
--------------------------------------------------------------------------------
1 | ---
2 | name: Feature request
3 | about: Suggest an idea for this project
4 | title: ''
5 | labels: [ enhancement, question ]
6 | assignees: [ ATATC, qmascarenhas ]
7 | issue_type: Feature
8 |
9 | ---
10 |
11 | **Is your feature request related to a problem? Please describe.**
12 | A clear and concise description of what the problem is. Ex. I'm always frustrated when [...]
13 |
14 | **Describe the solution you'd like**
15 | A clear and concise description of what you want to happen.
16 |
17 | **Describe alternatives you've considered**
18 | A clear and concise description of any alternative solutions or features you've considered.
19 |
20 | **Additional context**
21 | Add any other context or screenshots about the feature request here.
22 |
--------------------------------------------------------------------------------
/.github/workflows/python-publish.yml:
--------------------------------------------------------------------------------
1 | name: Upload Python Package
2 |
3 | on:
4 | release:
5 | types:
6 | - published
7 |
8 | permissions:
9 | contents: read
10 |
11 | jobs:
12 | build-and-publish:
13 | runs-on: ubuntu-latest
14 | permissions:
15 | id-token: write
16 | steps:
17 | - name: Checkout
18 | uses: actions/checkout@v4
19 | - name: Set up Python
20 | uses: actions/setup-python@v5
21 | with:
22 | python-version: "3.12"
23 | - name: Install dependencies
24 | run: python -m pip install -U setuptools wheel build
25 | - name: Build
26 | run: python -m build .
27 | - name: Publish
28 | uses: pypa/gh-action-pypi-publish@release/v1
29 | with:
30 | skip-existing: true
31 |
--------------------------------------------------------------------------------
/mipcandy/data/__init__.py:
--------------------------------------------------------------------------------
1 | from mipcandy.data.convertion import convert_ids_to_logits, convert_logits_to_ids, auto_convert
2 | from mipcandy.data.dataset import Loader, UnsupervisedDataset, SupervisedDataset, DatasetFromMemory, MergedDataset, \
3 | PathBasedUnsupervisedDataset, SimpleDataset, PathBasedSupervisedDataset, NNUNetDataset, BinarizedDataset
4 | from mipcandy.data.download import download_dataset
5 | from mipcandy.data.geometric import ensure_num_dimensions, orthographic_views, aggregate_orthographic_views, crop
6 | from mipcandy.data.inspection import InspectionAnnotation, InspectionAnnotations, load_inspection_annotations, \
7 | inspect, ROIDataset, RandomROIDataset
8 | from mipcandy.data.io import resample_to_isotropic, load_image, save_image
9 | from mipcandy.data.visualization import visualize2d, visualize3d, overlay
10 |
--------------------------------------------------------------------------------
/mipcandy/__entry__.py:
--------------------------------------------------------------------------------
1 | from argparse import ArgumentParser
2 |
3 | from mipcandy.run import config
4 |
5 |
6 | def __entry__() -> None:
7 | parser = ArgumentParser(prog="MIP Candy CLI", description="MIP Candy Command Line Interface",
8 | epilog="GitHub: https://github.com/ProjectNeura/MIPCandy")
9 | parser.add_argument("-c", "--config", choices=("setting", "secret"), default=None,
10 | help="set a configuration such that key=value")
11 | parser.add_argument("-kv", "--key-value", nargs=2, action="append", default=None, help="define a key-value pair")
12 | args = parser.parse_args()
13 | if args.config:
14 | if not args.key_value:
15 | raise ValueError("Expected at least one key-value pair")
16 | for key_value in args.key_value:
17 | config(args.config, key_value[0], key_value[1])
18 |
--------------------------------------------------------------------------------
/mipcandy/data/download.py:
--------------------------------------------------------------------------------
1 | from os import PathLike
2 | from zipfile import ZipFile
3 |
4 | from requests import get
5 | from rich.console import Console
6 | from rich.progress import track
7 |
8 |
9 | def download_dataset(name: str, to: str | PathLike[str], *, endpoint: str = "cds.projectneura.org",
10 | console: Console = Console()) -> None:
11 | to_zip = f"{to}.zip"
12 | with get(f"https://{endpoint}/{name}.zip", stream=True) as response:
13 | response.raise_for_status()
14 | with open(to_zip, "wb") as f:
15 | for chunk in track(response.iter_content(chunk_size=8192), description="Downloading...", console=console):
16 | f.write(chunk)
17 | console.print(f"Dataset successfully downloaded as {to_zip}")
18 | console.print("Unzipping...")
19 | with ZipFile(to_zip, "r") as zip_ref:
20 | zip_ref.extractall(to)
21 | console.print(f"Dataset successfully extracted to {to}")
22 |
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [build-system]
2 | requires = ["hatchling"]
3 | build-backend = "hatchling.build"
4 |
5 | [project]
6 | name = "mipcandy"
7 | version = "1.1.0-beta.2"
8 | description = "A Candy for Medical Image Processing"
9 | license = "Apache-2.0"
10 | readme = "README.md"
11 | requires-python = ">=3.12"
12 | authors = [
13 | { name = "Project Neura", email = "central@projectneura.org" }
14 | ]
15 | dependencies = ["torch", "ptflops", "numpy", "SimpleITK", "matplotlib", "rich", "pandas", "requests"]
16 |
17 | [project.optional-dependencies]
18 | standard = ["pyvista"]
19 | all = ["pyvista", "mipcandy-bundles"]
20 |
21 | [tool.hatch.build.targets.sdist]
22 | only-include = ["mipcandy"]
23 |
24 | [tool.hatch.build.targets.wheel]
25 | packages = ["mipcandy"]
26 |
27 | [project.urls]
28 | Homepage = "https://mipcandy.projectneura.org"
29 | Documentation = "https://mipcandy-docs.projectneura.org"
30 | Repository = "https://github.com/ProjectNeura/MIPCandy"
31 |
32 | [project.scripts]
33 | mipcandy = "mipcandy:__entry__"
--------------------------------------------------------------------------------
/.github/ISSUE_TEMPLATE/bug_report.md:
--------------------------------------------------------------------------------
1 | ---
2 | name: Bug report
3 | about: Create a report to help us improve
4 | title: ''
5 | labels: [ bug, todo ]
6 | assignees: ATATC
7 | issue_type: Bug
8 |
9 | ---
10 |
11 | **Describe the bug**
12 | A clear and concise description of what the bug is.
13 |
14 | **To Reproduce**
15 | Steps to reproduce the behavior:
16 | 1. Go to '...'
17 | 2. Click on '....'
18 | 3. Scroll down to '....'
19 | 4. See error
20 |
21 | **Expected behavior**
22 | A clear and concise description of what you expected to happen.
23 |
24 | **Screenshots**
25 | If applicable, add screenshots to help explain your problem.
26 |
27 | **Desktop (please complete the following information):**
28 | - OS: [e.g. iOS]
29 | - Browser [e.g. chrome, safari]
30 | - Version [e.g. 22]
31 |
32 | **Smartphone (please complete the following information):**
33 | - Device: [e.g. iPhone6]
34 | - OS: [e.g. iOS8.1]
35 | - Browser [e.g. stock browser, safari]
36 | - Version [e.g. 22]
37 |
38 | **Additional context**
39 | Add any other context about the problem here.
40 |
--------------------------------------------------------------------------------
/mipcandy/data/convertion.py:
--------------------------------------------------------------------------------
1 | from typing import Literal
2 |
3 | import torch
4 | from torch import nn
5 |
6 | from mipcandy.common import Normalize
7 |
8 |
9 | def convert_ids_to_logits(ids: torch.Tensor, d: Literal[1, 2, 3], num_classes: int) -> torch.Tensor:
10 | if ids.dtype != torch.int or ids.min() < 0:
11 | raise TypeError("`ids` should be positive integers")
12 | d += 1
13 | if ids.ndim != d:
14 | if ids.ndim == d + 1 and ids.shape[1] == 1:
15 | ids = ids.squeeze(1)
16 | else:
17 | raise ValueError(f"`ids` should be {d} dimensional or {d + 1} dimensional with single channel")
18 | return nn.functional.one_hot(ids.long(), num_classes).movedim(-1, 1).contiguous().float()
19 |
20 |
21 | def convert_logits_to_ids(logits: torch.Tensor, *, channel_dim: int = 1) -> torch.Tensor:
22 | return logits.max(channel_dim).indices.int()
23 |
24 |
25 | def auto_convert(image: torch.Tensor) -> torch.Tensor:
26 | return (image * 255 if 0 <= image.min() < image.max() <= 1 else Normalize(domain=(0, 255))(image)).int()
27 |
--------------------------------------------------------------------------------
/mipcandy/data/geometric.py:
--------------------------------------------------------------------------------
1 | from typing import Literal, Sequence
2 |
3 | import torch
4 |
5 |
6 | def ensure_num_dimensions(x: torch.Tensor, num_dimensions: int) -> torch.Tensor:
7 | d = num_dimensions - x.ndim
8 | if d == 0:
9 | return x
10 | return x.reshape(*((1,) * d + x.shape)) if d > 0 else x.reshape(x.shape[-num_dimensions:])
11 |
12 |
13 | def orthographic_views(x: torch.Tensor, reduction: Literal["mean", "sum"] = "mean") -> tuple[
14 | torch.Tensor, torch.Tensor, torch.Tensor]:
15 | match reduction:
16 | case "mean":
17 | return x.mean(dim=-3), x.mean(dim=-2), x.mean(dim=-1)
18 | case "sum":
19 | return x.sum(dim=-3), x.sum(dim=-2), x.sum(dim=-1)
20 |
21 |
22 | def aggregate_orthographic_views(d: torch.Tensor, h: torch.Tensor, w: torch.Tensor) -> torch.Tensor:
23 | d, h, w = d.unsqueeze(-3), h.unsqueeze(-2), w.unsqueeze(-1)
24 | return d * h * w
25 |
26 |
27 | def crop(t: torch.Tensor, bbox: Sequence[int]) -> torch.Tensor:
28 | return t[:, :, bbox[0]:bbox[1], bbox[2]:bbox[3]] if len(bbox) == 4 else t[:, :, bbox[0]:bbox[1], bbox[2]:bbox[3],
29 | bbox[4]:bbox[5]]
30 |
--------------------------------------------------------------------------------
/mipcandy/__init__.py:
--------------------------------------------------------------------------------
1 | from mipcandy.__entry__ import __entry__
2 | from mipcandy.common import *
3 | from mipcandy.config import load_settings, save_settings, load_secrets, save_secrets
4 | from mipcandy.data import *
5 | from mipcandy.evaluation import EvalCase, EvalResult, Evaluator
6 | from mipcandy.frontend import *
7 | from mipcandy.inference import parse_predictant, Predictor
8 | from mipcandy.layer import batch_int_multiply, batch_int_divide, LayerT, HasDevice, auto_device, WithPaddingModule, \
9 | WithNetwork
10 | from mipcandy.metrics import do_reduction, dice_similarity_coefficient_binary, \
11 | dice_similarity_coefficient_multiclass, soft_dice_coefficient, accuracy_binary, accuracy_multiclass, \
12 | precision_binary, precision_multiclass, recall_binary, recall_multiclass, iou_binary, iou_multiclass
13 | from mipcandy.presets import *
14 | from mipcandy.run import config
15 | from mipcandy.sanity_check import num_trainable_params, model_complexity_info, SanityCheckResult, sanity_check
16 | from mipcandy.training import TrainerToolbox, Trainer, SWMetadata, SlidingTrainer
17 | from mipcandy.types import Setting, Settings, Params, Transform, SupportedPredictant, Colormap, Device, Shape2d, \
18 | Shape3d, Shape, AmbiguousShape
19 |
--------------------------------------------------------------------------------
/mipcandy/common/optim/lr_scheduler.py:
--------------------------------------------------------------------------------
1 | from typing import override
2 |
3 | from torch import optim
4 |
5 |
6 | class AbsoluteLinearLR(optim.lr_scheduler.LRScheduler):
7 | """
8 | lr = kx + b
9 | """
10 | def __init__(self, optimizer: optim.Optimizer, k: float, b: float, *, min_lr: float = 1e-6,
11 | restart: bool = False, last_epoch: int = -1) -> None:
12 | self._k: float = k
13 | self._b: float = b
14 | if min_lr < 0:
15 | raise ValueError(f"`min_lr` must be positive, but got {min_lr}")
16 | self._min_lr: float = min_lr
17 | self._restart: bool = restart
18 | self._restart_step: int = 0
19 | super().__init__(optimizer, last_epoch)
20 |
21 | def _interp(self, step: int) -> float:
22 | step -= self._restart_step
23 | r = self._k * step + self._b
24 | if r < self._min_lr:
25 | if self._restart:
26 | self._restart_step = step
27 | return self._interp(step)
28 | return self._min_lr
29 | return r
30 |
31 | @override
32 | def get_lr(self) -> list[float]:
33 | target = self._interp(self.last_epoch)
34 | return [target for _ in self.optimizer.param_groups]
35 |
--------------------------------------------------------------------------------
/mipcandy/frontend/wandb_fe.py:
--------------------------------------------------------------------------------
1 | from typing import override
2 |
3 | from wandb import init, Run
4 |
5 | from mipcandy.frontend.prototype import Frontend
6 | from mipcandy.types import Settings
7 |
8 |
9 | class WandBFrontend(Frontend):
10 | def __init__(self, secrets: Settings) -> None:
11 | super().__init__(secrets)
12 | self._entity: str = self.require_nonempty_secret("wandb_entity", require_type=str)
13 | self._project: str = self.require_nonempty_secret("wandb_project", require_type=str)
14 | self._run: Run | None = None
15 |
16 | @override
17 | def on_experiment_created(self, experiment_id: str, trainer: str, model: str, note: str, num_macs: float,
18 | num_params: float, num_epochs: int, early_stop_tolerance: int) -> None:
19 | self._run = init(entity=self._entity, project=self._project, config={
20 | "experiment_id": experiment_id, "trainer": trainer, "model": model, "note": note, "num_macs": num_macs,
21 | "num_params": num_params, "num_epochs": num_epochs
22 | })
23 |
24 | @override
25 | def on_experiment_updated(self, experiment_id: str, epoch: int, metrics: dict[str, list[float]],
26 | early_stop_tolerance: int) -> None:
27 | if self._run:
28 | self._run.log(metrics)
29 |
30 | @override
31 | def on_experiment_completed(self, experiment_id: str) -> None:
32 | if not self._run:
33 | raise RuntimeError("Experiment has not been created")
34 | self._run.finish()
35 |
--------------------------------------------------------------------------------
/mipcandy/common/numpy/regressions.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 |
4 | def quotient_regression(x: np.ndarray, y: np.ndarray, m: int, n: int) -> tuple[np.ndarray, np.ndarray]:
5 | matrix = []
6 | for xi, yi in zip(x, y):
7 | row = []
8 | for k in range(m + 1):
9 | row.append(xi ** k)
10 | for k in range(1, n + 1):
11 | row.append(-yi * xi ** k)
12 | matrix.append(row)
13 | matrix = np.array(matrix)
14 | coefficients, _, _, _ = np.linalg.lstsq(matrix, y, rcond=None)
15 | return coefficients[:m + 1][::-1], np.concatenate(([1.0], coefficients[m + 1:]))[::-1]
16 |
17 |
18 | def quotient_derivative(a: np.ndarray, b: np.ndarray) -> tuple[np.ndarray, np.ndarray]:
19 | da = np.polyder(a)
20 | db = np.polyder(b)
21 | return np.polysub(np.polymul(da, b), np.polymul(a, db)), np.polymul(b, b)
22 |
23 |
24 | def quotient_bounds(a: np.ndarray, b: np.ndarray, lower_bound: float | None, upper_bound: float | None, *,
25 | x_start: float = 0, x_stop: float = 1e4, x_step: float = .01) -> tuple[float, float] | None:
26 | x = np.arange(x_start, x_stop, x_step)
27 | y = np.polyval(a, x) / np.polyval(b, x)
28 | mask = np.array(True, like=y)
29 | if lower_bound is not None:
30 | mask = mask & (y > lower_bound)
31 | if upper_bound is not None:
32 | mask = mask & (y < upper_bound)
33 | if isinstance(mask, bool):
34 | raise ValueError("Bounds must be specified on at least one side")
35 | return (float(x[mask][0]), float(x[mask][-1])) if mask.any() else None
36 |
--------------------------------------------------------------------------------
/mipcandy/sanity_check.py:
--------------------------------------------------------------------------------
1 | from dataclasses import dataclass
2 | from io import StringIO
3 | from typing import Sequence, override
4 |
5 | import torch
6 | from ptflops import get_model_complexity_info
7 | from torch import nn
8 |
9 | from mipcandy.layer import auto_device
10 | from mipcandy.types import Device
11 |
12 |
13 | def num_trainable_params(model: nn.Module) -> int:
14 | return sum(p.numel() for p in model.parameters() if p.requires_grad)
15 |
16 |
17 | def model_complexity_info(model: nn.Module, example_shape: Sequence[int]) -> tuple[float | None, float | None, str]:
18 | layer_stats = StringIO()
19 | macs, params = get_model_complexity_info(model, tuple(example_shape), ost=layer_stats, as_strings=False)
20 | return macs, params, layer_stats.getvalue()
21 |
22 |
23 | @dataclass
24 | class SanityCheckResult(object):
25 | num_macs: float
26 | num_params: float
27 | layer_stats: str
28 | output: torch.Tensor
29 |
30 | @override
31 | def __str__(self) -> str:
32 | return f"MACs: {self.num_macs * 1e-9:.1f} G / Params: {self.num_params * 1e-6:.1f} M"
33 |
34 |
35 | def sanity_check(model: nn.Module, input_shape: Sequence[int], *, device: Device | None = None) -> SanityCheckResult:
36 | if device is None:
37 | device = auto_device()
38 | num_macs, num_params, layer_stats = model_complexity_info(model, input_shape)
39 | if num_macs is None or num_params is None:
40 | raise RuntimeError("Failed to validate model")
41 | outputs = model.to(device).eval()(torch.randn(1, *input_shape, device=device))
42 | return SanityCheckResult(num_macs, num_params, layer_stats, (
43 | outputs[0] if isinstance(outputs, tuple) else outputs).squeeze(0))
44 |
--------------------------------------------------------------------------------
/mipcandy/config.py:
--------------------------------------------------------------------------------
1 | from os import PathLike
2 | from os.path import abspath, exists
3 |
4 | from yaml import load, SafeLoader, dump, SafeDumper
5 |
6 | from mipcandy.types import Settings
7 |
8 | _DIR: str = abspath(__file__)[:-9]
9 | _DEFAULT_SETTINGS_PATH: str = f"{_DIR}settings.yml"
10 | _DEFAULT_SECRETS_PATH: str = f"{_DIR}secrets.yml"
11 |
12 |
13 | def _load(path: str | PathLike[str], *, hint: str = "fill in your settings here") -> Settings:
14 | if not exists(path):
15 | with open(path, "w") as f:
16 | f.write(f"# {hint}\n")
17 | with open(path) as f:
18 | settings = load(f.read(), SafeLoader)
19 | if settings is None:
20 | return {}
21 | if not isinstance(settings, dict):
22 | raise ValueError(f"Invalid settings file: {path}")
23 | return settings
24 |
25 |
26 | def _save(settings: Settings, path: str | PathLike[str], *, hint: str = "fill in your settings here") -> None:
27 | with open(path, "w") as f:
28 | f.write(f"# {hint}\n")
29 | dump(settings, f, SafeDumper)
30 |
31 |
32 | def load_settings(*, path: str | PathLike[str] = _DEFAULT_SETTINGS_PATH) -> Settings:
33 | return _load(path)
34 |
35 |
36 | def save_settings(settings: Settings, *, path: str | PathLike[str] = _DEFAULT_SETTINGS_PATH) -> None:
37 | _save(settings, path)
38 |
39 |
40 | def load_secrets(*, path: str | PathLike[str] = _DEFAULT_SECRETS_PATH) -> Settings:
41 | return _load(path, hint="fill in your secrets here, do not commit this file")
42 |
43 |
44 | def save_secrets(secrets: Settings, *, path: str | PathLike[str] = _DEFAULT_SECRETS_PATH) -> None:
45 | _save(secrets, path, hint="fill in your secrets here, do not commit this file")
46 |
--------------------------------------------------------------------------------
/mipcandy/data/io.py:
--------------------------------------------------------------------------------
1 | from math import floor
2 | from os import PathLike
3 |
4 | import SimpleITK as SpITK
5 | import torch
6 |
7 | from mipcandy.data.convertion import auto_convert
8 | from mipcandy.data.geometric import ensure_num_dimensions
9 | from mipcandy.types import Device
10 |
11 |
12 | def resample_to_isotropic(image: SpITK.Image, *, target_iso: float | None = None,
13 | interpolator: int = SpITK.sitkBSpline) -> SpITK.Image:
14 | dim = image.GetDimension()
15 | old_spacing = image.GetSpacing()
16 | old_size = image.GetSize()
17 | origin = image.GetOrigin()
18 | direction = image.GetDirection()
19 | if target_iso is None:
20 | target_iso = min(old_spacing)
21 | new_spacing = (target_iso,) * dim
22 | new_size = (max(1, floor(old_spacing[i] * (old_size[i] - 1) / new_spacing[i] + 1)) for i in range(dim))
23 | return SpITK.Resample(
24 | image, new_size, SpITK.Transform(), interpolator, origin, new_spacing, direction, 0, image.GetPixelID()
25 | )
26 |
27 |
28 | def load_image(path: str | PathLike[str], *, is_label: bool = False, align_spacing: bool = False,
29 | device: Device = "cpu") -> torch.Tensor:
30 | file = SpITK.ReadImage(path)
31 | if align_spacing:
32 | file = resample_to_isotropic(file, interpolator=SpITK.sitkNearestNeighbor if is_label else SpITK.sitkBSpline)
33 | img = torch.tensor(SpITK.GetArrayFromImage(file), dtype=torch.float, device=device)
34 | if path.endswith(".nii.gz") or path.endswith(".nii") or path.endswith(".mha"):
35 | img = ensure_num_dimensions(img, 4)
36 | return img.squeeze(1) if img.shape[1] == 1 else img
37 | if path.endswith(".png") or path.endswith(".jpg") or path.endswith(".jpeg"):
38 | return ensure_num_dimensions(img, 3)
39 | raise NotImplementedError(f"Unsupported file type: {path}")
40 |
41 |
42 | def save_image(image: torch.Tensor, path: str | PathLike[str]) -> None:
43 | if path.endswith(".png"):
44 | image = auto_convert(image).to(torch.uint8)
45 | SpITK.WriteImage(SpITK.GetImageFromArray(image.detach().cpu().numpy()), path)
46 |
--------------------------------------------------------------------------------
/mipcandy/common/optim/loss.py:
--------------------------------------------------------------------------------
1 | from typing import Literal
2 |
3 | import torch
4 | from torch import nn
5 |
6 | from mipcandy.data import convert_ids_to_logits
7 | from mipcandy.metrics import do_reduction, soft_dice_coefficient
8 |
9 |
10 | class FocalBCEWithLogits(nn.Module):
11 | def __init__(self, alpha: float, gamma: float, *, reduction: Literal["mean", "sum", "none"] = "mean") -> None:
12 | super().__init__()
13 | self.alpha: float = alpha
14 | self.gamma: float = gamma
15 | self.reduction: Literal["mean", "sum", "none"] = reduction
16 |
17 | def forward(self, logits: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
18 | bce = nn.functional.binary_cross_entropy_with_logits(logits, targets, reduction="none")
19 | p = torch.sigmoid(logits)
20 | p_t = torch.where(targets.bool(), p, 1 - p)
21 | alpha_t = torch.where(targets.bool(), torch.as_tensor(self.alpha, device=logits.device), torch.as_tensor(
22 | 1 - self.alpha, device=logits.device))
23 | loss = alpha_t * (1 - p_t).pow(self.gamma) * bce
24 | return do_reduction(loss, self.reduction)
25 |
26 |
27 | class DiceBCELossWithLogits(nn.Module):
28 | def __init__(self, num_classes: int, *, lambda_bce: float = .5, lambda_soft_dice: float = 1,
29 | smooth: float = 1e-5, include_bg: bool = True) -> None:
30 | super().__init__()
31 | self.num_classes: int = num_classes
32 | self.lambda_bce: float = lambda_bce
33 | self.lambda_soft_dice: float = lambda_soft_dice
34 | self.smooth: float = smooth
35 | self.include_bg: bool = include_bg
36 |
37 | def forward(self, masks: torch.Tensor, labels: torch.Tensor) -> tuple[torch.Tensor, dict[str, float]]:
38 | if self.num_classes != 1 and labels.shape[1] == 1:
39 | d = labels.ndim - 2
40 | if d not in (1, 2, 3):
41 | raise ValueError(f"Expected labels to be 1D, 2D, or 3D, got {d} spatial dimensions")
42 | labels = convert_ids_to_logits(labels.int(), d, self.num_classes)
43 | labels = labels.float()
44 | bce = nn.functional.binary_cross_entropy_with_logits(masks, labels)
45 | masks = masks.sigmoid()
46 | soft_dice = soft_dice_coefficient(masks, labels, smooth=self.smooth, include_bg=self.include_bg)
47 | c = self.lambda_bce * bce + self.lambda_soft_dice * (1 - soft_dice)
48 | return c, {"soft dice": soft_dice.item(), "bce loss": bce.item()}
49 |
--------------------------------------------------------------------------------
/mipcandy/common/module/conv.py:
--------------------------------------------------------------------------------
1 | from typing import override
2 |
3 | import torch
4 | from torch import nn
5 |
6 | from mipcandy.layer import LayerT
7 |
8 |
9 | class AbstractConvBlock(nn.Module):
10 | def __init__(self, in_ch: int, out_ch: int, kernel_size: int, *, stride: int = 1, padding: int = 0,
11 | dilation: int = 1, groups: int = 1, bias: bool = True, padding_mode: str = "zeros",
12 | conv: LayerT = ..., norm: LayerT = ..., act: LayerT = ...) -> None:
13 | super().__init__()
14 | self.conv: nn.Module = conv.assemble(in_ch, out_ch, kernel_size, stride, padding, dilation, groups, bias,
15 | padding_mode)
16 | self.norm: nn.Module = norm.assemble(in_ch=out_ch)
17 | self.act: nn.Module = act.assemble()
18 |
19 | def forward(self, x: torch.Tensor) -> torch.Tensor:
20 | return self.act(self.norm(self.conv(x)))
21 |
22 |
23 | def _conv_block(default_conv: LayerT, default_norm: LayerT, default_act: LayerT) -> type[AbstractConvBlock]:
24 | class ConvBlock(AbstractConvBlock):
25 | def __init__(self, *args, **kwargs) -> None:
26 | if "conv" not in kwargs:
27 | kwargs["conv"] = default_conv
28 | if "norm" not in kwargs:
29 | kwargs["norm"] = default_norm
30 | if "act" not in kwargs:
31 | kwargs["act"] = default_act
32 | super().__init__(*args, **kwargs)
33 |
34 | return ConvBlock
35 |
36 |
37 | ConvBlock2d: type[AbstractConvBlock] = _conv_block(
38 | LayerT(nn.Conv2d), LayerT(nn.BatchNorm2d, num_features="in_ch"), LayerT(nn.ReLU, inplace=True)
39 | )
40 |
41 | ConvBlock3d: type[AbstractConvBlock] = _conv_block(
42 | LayerT(nn.Conv3d), LayerT(nn.BatchNorm3d, num_features="in_ch"), LayerT(nn.ReLU, inplace=True)
43 | )
44 |
45 |
46 | class WSConv2d(nn.Conv2d):
47 | @override
48 | def forward(self, x: torch.Tensor) -> torch.Tensor:
49 | w = self.weight
50 | w = (w - w.mean(dim=(1, 2, 3), keepdim=True)) / (w.std(dim=(1, 2, 3), keepdim=True) + 1e-5)
51 | return nn.functional.conv2d(x, w, self.bias, self.stride, self.padding, self.dilation, self.groups)
52 |
53 |
54 | class WSConv3d(nn.Conv3d):
55 | @override
56 | def forward(self, x: torch.Tensor) -> torch.Tensor:
57 | w = self.weight
58 | w = (w - w.mean(dim=(1, 2, 3, 4), keepdim=True)) / (w.std(dim=(1, 2, 3, 4), keepdim=True) + 1e-5)
59 | return nn.functional.conv3d(x, w, self.bias, self.stride, self.padding, self.dilation, self.groups)
60 |
--------------------------------------------------------------------------------
/mipcandy/frontend/prototype.py:
--------------------------------------------------------------------------------
1 | from typing import override
2 |
3 |
4 | from mipcandy.types import Setting, Settings
5 |
6 |
7 | class Frontend(object):
8 | def __init__(self, secrets: Settings) -> None:
9 | self._secrets: Settings = secrets
10 |
11 | def require_nonempty_secret(self, entry: str, *, require_type: type | None = None) -> Setting:
12 | if entry not in self._secrets:
13 | raise ValueError(f"Missing secret {entry}")
14 | secret = self._secrets[entry]
15 | if require_type is None or isinstance(secret, require_type):
16 | return secret
17 | raise ValueError(f"Invalid secret type {type(secret)}, {require_type} expected")
18 |
19 | def on_experiment_created(self, experiment_id: str, trainer: str, model: str, note: str, num_params: float,
20 | num_macs: float, num_epochs: int, early_stop_tolerance: int) -> None:
21 | ...
22 |
23 | def on_experiment_updated(self, experiment_id: str, epoch: int, metrics: dict[str, list[float]],
24 | early_stop_tolerance: int) -> None:
25 | ...
26 |
27 | def on_experiment_completed(self, experiment_id: str) -> None:
28 | ...
29 |
30 | def on_experiment_interrupted(self, experiment_id: str, error: Exception) -> None:
31 | ...
32 |
33 |
34 | def create_hybrid_frontend(*frontends: Frontend) -> type[Frontend]:
35 | class HybridFrontend(Frontend):
36 | def __init__(self, secrets: Settings) -> None:
37 | super().__init__(secrets)
38 |
39 | @override
40 | def on_experiment_created(self, experiment_id: str, trainer: str, model: str, note: str, num_macs: float,
41 | num_params: float, num_epochs: int, early_stop_tolerance: int) -> None:
42 | for frontend in frontends:
43 | frontend.on_experiment_created(experiment_id, trainer, model, note, num_macs, num_params, num_epochs,
44 | early_stop_tolerance)
45 |
46 | @override
47 | def on_experiment_updated(self, experiment_id: str, epoch: int, metrics: dict[str, list[float]],
48 | early_stop_tolerance: int) -> None:
49 | for frontend in frontends:
50 | frontend.on_experiment_updated(experiment_id, epoch, metrics, early_stop_tolerance)
51 |
52 | @override
53 | def on_experiment_completed(self, experiment_id: str) -> None:
54 | for frontend in frontends:
55 | frontend.on_experiment_completed(experiment_id)
56 |
57 | @override
58 | def on_experiment_interrupted(self, experiment_id: str, error: Exception) -> None:
59 | for frontend in frontends:
60 | frontend.on_experiment_interrupted(experiment_id, error)
61 |
62 | return HybridFrontend
63 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # MIP Candy: A Candy for Medical Image Processing
2 |
3 | 
4 | 
5 | 
6 | 
7 |
8 | 
9 |
10 | MIP Candy is Project Neura's next-generation infrastructure framework for medical image processing. It defines a handful
11 | of common network architectures with their corresponding training, inference, and evaluation pipelines that are
12 | out-of-the-box ready to use. Additionally, it also provides integrations with popular frontend dashboards such as
13 | Notion, WandB, and TensorBoard.
14 |
15 | We provide a flexible and extensible framework for medical image processing researchers to quickly prototype their
16 | ideas. MIP Candy takes care of all the rest, so you can focus on only the key experiment designs.
17 |
18 | :link: [Home](https://mipcandy.projectneura.org)
19 |
20 | :link: [Docs](https://mipcandy-docs.projectneura.org)
21 |
22 | ## Key Features
23 |
24 | Why MIP Candy? :thinking:
25 |
26 |
27 | Easy adaptation to fit your needs
28 | We provide tons of easy-to-use techniques for training that seamlessly support your customized experiments.
29 |
30 | - Sliding window
31 | - ROI inspection
32 | - ROI cropping to align dataset shape (100% or 33% foreground)
33 | - Automatic padding
34 | - ...
35 |
36 | You only need to override one method to create a trainer for your network architecture.
37 |
38 | ```python
39 | from typing import override
40 |
41 | from torch import nn
42 | from mipcandy import SegmentationTrainer
43 |
44 |
45 | class MyTrainer(SegmentationTrainer):
46 | @override
47 | def build_network(self, example_shape: tuple[int, ...]) -> nn.Module:
48 | ...
49 | ```
50 |
51 |
52 |
53 | Satisfying command-line UI design
54 |
55 |
56 |
57 |
58 | Built-in 2D and 3D visualization for intuitive understanding
59 |
60 |
61 |
62 |
63 | High availability with interruption tolerance
64 | Interrupted experiments can be resumed with ease.
65 |
66 |
67 |
68 |
69 | Support of various frontend platforms for remote monitoring
70 |
71 | MIP Candy Supports [Notion](https://mipcandy-projectneura.notion.site), WandB, and TensorBoard.
72 |
73 |
74 |
75 |
76 | ## Installation
77 |
78 | Note that MIP Candy requires **Python >= 3.12**.
79 |
80 | ```shell
81 | pip install "mipcandy[standard]"
82 | ```
83 |
84 | ## Quick Start
85 |
86 | Below is a simple example of a nnU-Net style training. The batch size is set to 1 due to the varying shape of the
87 | dataset, although you can use a `ROIDataset` to align the shapes.
88 |
89 | ```python
90 | from typing import override
91 |
92 | import torch
93 | from mipcandy_bundles.unet import UNetTrainer
94 | from torch.utils.data import DataLoader
95 |
96 | from mipcandy import download_dataset, NNUNetDataset
97 |
98 |
99 | class PH2(NNUNetDataset):
100 | @override
101 | def load(self, idx: int) -> tuple[torch.Tensor, torch.Tensor]:
102 | image, label = super().load(idx)
103 | return image.squeeze(0).permute(2, 0, 1), label
104 |
105 |
106 | download_dataset("nnunet_datasets/PH2", "tutorial/datasets/PH2")
107 | dataset, val_dataset = PH2("tutorial/datasets/PH2", device="cuda").fold()
108 | dataloader = DataLoader(dataset, 1, shuffle=True)
109 | val_dataloader = DataLoader(val_dataset, 1, shuffle=False)
110 | trainer = UNetTrainer("tutorial", dataloader, val_dataloader, device="cuda")
111 | trainer.train(1000, note="a nnU-Net style example")
112 | ```
--------------------------------------------------------------------------------
/mipcandy/layer.py:
--------------------------------------------------------------------------------
1 | from abc import ABCMeta, abstractmethod
2 | from typing import Any, Generator, Self, Mapping
3 |
4 | import torch
5 | from torch import nn
6 |
7 | from mipcandy.types import Device, AmbiguousShape
8 |
9 |
10 | def batch_int_multiply(f: float, *n: int) -> Generator[int, None, None]:
11 | for i in n:
12 | r = i * f
13 | if not r.is_integer():
14 | raise ValueError(f"Inequivalent conversion")
15 | yield int(r)
16 |
17 |
18 | def batch_int_divide(f: float, *n: int) -> Generator[int, None, None]:
19 | return batch_int_multiply(1 / f, *n)
20 |
21 |
22 | class LayerT(object):
23 | def __init__(self, m: type[nn.Module], **kwargs) -> None:
24 | self.m: type[nn.Module] = m
25 | self.kwargs: dict[str, Any] = kwargs
26 |
27 | def update(self, *, must_exist: bool = True, inplace: bool = False, **kwargs) -> Self:
28 | if not inplace:
29 | return self.copy().update(must_exist=must_exist, inplace=True, **kwargs)
30 | for k, v in kwargs.items():
31 | if not must_exist or k in self.kwargs:
32 | self.kwargs[k] = v
33 | return self
34 |
35 | def assemble(self, *args, **kwargs) -> nn.Module:
36 | self_kwargs = self.kwargs.copy()
37 | for k, v in self_kwargs.items():
38 | if isinstance(v, str) and v in kwargs:
39 | self_kwargs[k] = kwargs.pop(v)
40 | return self.m(*args, **self_kwargs, **kwargs)
41 |
42 | def copy(self) -> Self:
43 | return self.__class__(self.m, **self.kwargs)
44 |
45 |
46 | class HasDevice(object):
47 | def __init__(self, device: Device) -> None:
48 | self._device: Device = device
49 |
50 | def device(self, *, device: Device | None = None) -> None | Device:
51 | if device is None:
52 | return self._device
53 | else:
54 | self._device = device
55 |
56 |
57 | def auto_device() -> Device:
58 | if torch.cuda.is_available():
59 | return f"cuda:{max(range(torch.cuda.device_count()),
60 | key=lambda i: torch.cuda.memory_reserved(i) - torch.cuda.memory_allocated(i))}"
61 | if torch.backends.mps.is_available():
62 | return "mps"
63 | return "cpu"
64 |
65 |
66 | class WithPaddingModule(HasDevice):
67 | def __init__(self, device: Device) -> None:
68 | super().__init__(device)
69 | self._padding_module: nn.Module | None = None
70 | self._restoring_module: nn.Module | None = None
71 | self._padding_module_built: bool = False
72 |
73 | def build_padding_module(self) -> nn.Module | None:
74 | return None
75 |
76 | def build_restoring_module(self, padding_module: nn.Module | None) -> nn.Module | None:
77 | return None
78 |
79 | def _lazy_load_padding_module(self) -> None:
80 | if self._padding_module_built:
81 | return
82 | self._padding_module = self.build_padding_module()
83 | if self._padding_module:
84 | self._padding_module = self._padding_module.to(self._device)
85 | self._restoring_module = self.build_restoring_module(self._padding_module)
86 | if self._restoring_module:
87 | self._restoring_module = self._restoring_module.to(self._device)
88 | self._padding_module_built = True
89 |
90 | def get_padding_module(self) -> nn.Module | None:
91 | self._lazy_load_padding_module()
92 | return self._padding_module
93 |
94 | def get_restoring_module(self) -> nn.Module | None:
95 | self._lazy_load_padding_module()
96 | return self._restoring_module
97 |
98 |
99 | class WithNetwork(HasDevice, metaclass=ABCMeta):
100 | def __init__(self, device: Device) -> None:
101 | super().__init__(device)
102 |
103 | @abstractmethod
104 | def build_network(self, example_shape: AmbiguousShape) -> nn.Module:
105 | raise NotImplementedError
106 |
107 | def build_network_from_checkpoint(self, example_shape: AmbiguousShape, checkpoint: Mapping[str, Any]) -> nn.Module:
108 | """
109 | Internally exposed interface for overriding. Use `load_model()` instead.
110 | """
111 | network = self.build_network(example_shape)
112 | network.load_state_dict(checkpoint)
113 | return network
114 |
115 | def load_model(self, example_shape: AmbiguousShape, *, checkpoint: Mapping[str, Any] | None = None) -> nn.Module:
116 | if checkpoint:
117 | return self.build_network_from_checkpoint(example_shape, checkpoint).to(self._device)
118 | return self.build_network(example_shape).to(self._device)
119 |
--------------------------------------------------------------------------------
/mipcandy/evaluation.py:
--------------------------------------------------------------------------------
1 | from dataclasses import dataclass
2 | from typing import Callable, Sequence, Generator, override
3 |
4 | import torch
5 |
6 | from mipcandy.data import SupervisedDataset, MergedDataset, Loader, DatasetFromMemory
7 | from mipcandy.inference import parse_predictant, Predictor
8 | from mipcandy.types import SupportedPredictant
9 |
10 |
11 | @dataclass
12 | class EvalCase(object):
13 | metrics: dict[str, float]
14 | output: torch.Tensor
15 | label: torch.Tensor
16 | image: torch.Tensor | None = None
17 | filename: str | None = None
18 |
19 |
20 | class EvalResult(Sequence[EvalCase]):
21 | def __init__(self, metrics: dict[str, list[float]], outputs: list[torch.Tensor], labels: list[torch.Tensor], *,
22 | images: list[torch.Tensor] | None = None, filenames: list[str] | None = None) -> None:
23 | if len(outputs) != len(labels):
24 | raise ValueError(f"Unmatched number of outputs ({len(outputs)}) and labels ({len(labels)})")
25 | self.metrics: dict[str, list[float]] = metrics
26 | self.mean_metrics: dict[str, float] = {name: sum(values) / len(values) for name, values in metrics.items()}
27 | self.images: list[torch.Tensor] | None = images
28 | self.outputs: list[torch.Tensor] = outputs
29 | self.labels: list[torch.Tensor] = labels
30 | self.filenames: list[str] | None = filenames
31 |
32 | @override
33 | def __len__(self) -> int:
34 | return len(self.outputs)
35 |
36 | @override
37 | def __getitem__(self, item: int) -> EvalCase:
38 | return EvalCase({name: values[item] for name, values in self.metrics.items()}, self.outputs[item],
39 | self.labels[item], self.images[item] if self.images else None,
40 | self.filenames[item] if self.filenames else None)
41 |
42 | def _select(self, metric: str, n: int, descending: bool) -> Generator[EvalCase, None, None]:
43 | o_values = self.metrics[metric]
44 | values = o_values.copy()
45 | values.sort(reverse=descending)
46 | for value in values[:n]:
47 | yield self[o_values.index(value)]
48 |
49 | def min(self, metric: str) -> EvalCase:
50 | return self.min_n(metric, 1)[0]
51 |
52 | def min_n(self, metric: str, n: int) -> tuple[EvalCase, ...]:
53 | return tuple(self._select(metric, n, False))
54 |
55 | def max(self, metric: str) -> EvalCase:
56 | return self.max_n(metric, 1)[0]
57 |
58 | def max_n(self, metric: str, n: int) -> tuple[EvalCase, ...]:
59 | return tuple(self._select(metric, n, True))
60 |
61 |
62 | class Evaluator(object):
63 | def __init__(self, *metrics: Callable[[torch.Tensor, torch.Tensor], torch.Tensor]) -> None:
64 | self._metrics: tuple[Callable[[torch.Tensor, torch.Tensor], torch.Tensor], ...] = metrics
65 |
66 | def _evaluate_dataset(self, x: SupervisedDataset, *, prefilled_outputs: list[torch.Tensor] | None = None,
67 | prefilled_labels: list[torch.Tensor] | None = None) -> EvalResult:
68 | metrics = {}
69 | outputs = prefilled_outputs if prefilled_outputs else []
70 | labels = prefilled_labels if prefilled_labels else []
71 | for output, label in x:
72 | if not prefilled_outputs:
73 | outputs.append(output)
74 | if not prefilled_labels:
75 | labels.append(label)
76 | for m in self._metrics:
77 | if m.__name__ not in metrics:
78 | metrics[m.__name__] = []
79 | metrics[m.__name__].append(m(output, label).item())
80 | return EvalResult(metrics, outputs, labels)
81 |
82 | def evaluate_dataset(self, x: SupervisedDataset) -> EvalResult:
83 | return self._evaluate_dataset(x)
84 |
85 | def evaluate(self, outputs: SupportedPredictant, labels: SupportedPredictant) -> EvalResult:
86 | outputs, filenames = parse_predictant(outputs, Loader)
87 | labels, _ = parse_predictant(labels, Loader, as_label=True)
88 | r = self._evaluate_dataset(MergedDataset(DatasetFromMemory(outputs), DatasetFromMemory(labels)),
89 | prefilled_outputs=outputs, prefilled_labels=labels)
90 | r.filenames = filenames
91 | return r
92 |
93 | def predict_and_evaluate(self, x: SupportedPredictant, labels: SupportedPredictant,
94 | predictor: Predictor) -> EvalResult:
95 | x, filenames = parse_predictant(x, Loader)
96 | outputs = [e.cpu() for e in predictor.predict(x)]
97 | r = self.evaluate(outputs, labels)
98 | r.images = x
99 | r.filenames = filenames
100 | return r
101 |
--------------------------------------------------------------------------------
/mipcandy/sliding_window.py:
--------------------------------------------------------------------------------
1 | from abc import ABCMeta, abstractmethod
2 | from dataclasses import dataclass
3 | from typing import Literal
4 |
5 | import torch
6 | from torch import nn
7 |
8 | from mipcandy.layer import HasDevice
9 | from mipcandy.types import Shape
10 |
11 |
12 | @dataclass
13 | class SWMetadata(object):
14 | kernel: Shape
15 | stride: tuple[int, int] | tuple[int, int, int]
16 | ndim: Literal[2, 3]
17 | batch_size: int
18 | out_size: Shape
19 | n: int
20 |
21 |
22 | class SlidingWindow(HasDevice, metaclass=ABCMeta):
23 | sliding_window_batch_size: int | None = None
24 |
25 | @abstractmethod
26 | def get_window_shape(self) -> Shape:
27 | raise NotImplementedError
28 |
29 | def get_batch_size(self) -> int | None:
30 | return self.sliding_window_batch_size
31 |
32 | def gaussian_1d(self, k: int, *, sigma_scale: float = 0.5) -> torch.Tensor:
33 | x = torch.linspace(-1.0, 1.0, steps=k, device=self._device)
34 | sigma = sigma_scale
35 | g = torch.exp(-0.5 * (x / sigma) ** 2)
36 | g /= g.max()
37 | return g
38 |
39 | def do_sliding_window(self, t: torch.Tensor) -> tuple[torch.Tensor, SWMetadata]:
40 | window_shape = self.get_window_shape()
41 | if not (len(window_shape) + 2 == t.ndim):
42 | raise RuntimeError("Unmatched number of dimensions")
43 | stride = window_shape
44 | if len(stride) == 2:
45 | kernel = stride[0] * 2, stride[1] * 2
46 | b, c, h, w = t.shape
47 | t = nn.functional.unfold(t, kernel, stride=stride)
48 | n = t.shape[-1]
49 | kh, kw = kernel
50 | return (t.transpose(1, 2).contiguous().view(b * n, c, kh, kw),
51 | SWMetadata(kernel, stride, 2, b, (h, w), n))
52 | else:
53 | b, c, d, h, w = t.shape
54 | sd, sh, sw = stride
55 | kd, kh, kw = kernel = sd * 2, sh * 2, sw * 2
56 | image_windows = []
57 | for z in range(0, d - kd + 1, sd):
58 | for y in range(0, h - kh + 1, sh):
59 | for x in range(0, w - kw + 1, sw):
60 | image_windows.append(t[:, :, z:z + kd, y:y + kh, x:x + kw])
61 | t = torch.stack(image_windows, dim=0)
62 | n = t.shape[0]
63 | return (t.permute(0, 1, 2, 3, 4, 5).contiguous().view(b * n, c, kd, kh, kw),
64 | SWMetadata(kernel, stride, 3, b, (d, h, w), n))
65 |
66 | def revert_sliding_window(self, t: torch.Tensor, metadata: SWMetadata, *, clamp_min: float = 1e-8) -> torch.Tensor:
67 | kernel = metadata.kernel
68 | stride = metadata.stride
69 | dims = metadata.ndim
70 | b = metadata.batch_size
71 | out_size = metadata.out_size
72 | n = metadata.n
73 | dtype = t.dtype
74 | if dims == 2:
75 | kh, kw = kernel
76 | gh = self.gaussian_1d(kh)
77 | gw = self.gaussian_1d(kw)
78 | w2d = (gh[:, None] * gw[None, :]).to(dtype)
79 | w2d /= w2d.max()
80 | w2d = w2d.view(1, 1, kh, kw)
81 | bn, c, _, _ = t.shape
82 | if bn != b * n:
83 | raise RuntimeError("Inconsistent number of windows for reverting sliding window")
84 | weighted = t * w2d
85 | patches = weighted.view(b, n, c, kh, kw)
86 | cols = patches.view(b, n, c * kh * kw).transpose(1, 2).contiguous()
87 | numerator = nn.functional.fold(cols, out_size, kernel, stride=stride)
88 | w_cols = w2d.expand(b, n, 1, kh, kw).contiguous().view(b, n, 1 * kh * kw).transpose(1, 2)
89 | denominator = nn.functional.fold(w_cols, out_size, kernel, stride=stride)
90 | denominator = denominator.clamp_min(clamp_min)
91 | return numerator / denominator
92 | else:
93 | kd, kh, kw = kernel
94 | sd, sh, sw = stride
95 | d, h, w = out_size
96 | gd = self.gaussian_1d(kd)
97 | gh = self.gaussian_1d(kh)
98 | gw = self.gaussian_1d(kw)
99 | w3d = (gd[:, None, None] * gh[None, :, None] * gw[None, None, :]).to(dtype)
100 | w3d /= w3d.max()
101 | w3d = w3d.view(1, 1, kd, kh, kw)
102 | bn, c, _, _, _ = t.shape
103 | if bn != b * n:
104 | raise RuntimeError("Inconsistent number of windows for reverting sliding window")
105 | canvas = torch.zeros((b, c, d, h, w), dtype=dtype, device=self._device)
106 | acc_w = torch.zeros((b, 1, d, h, w), dtype=dtype, device=self._device)
107 | idx = 0
108 | for z in range(0, d - kd + 1, sd):
109 | for y in range(0, h - kh + 1, sh):
110 | for x in range(0, w - kw + 1, sw):
111 | window = t[idx * b:(idx + 1) * b]
112 | window *= w3d
113 | canvas[:, :, z:z + kd, y:y + kh, x:x + kw] += window
114 | acc_w[:, :, z:z + kd, y:y + kh, x:x + kw] += w3d
115 | idx += 1
116 | acc_w = acc_w.clamp_min(clamp_min)
117 | return canvas / acc_w
118 |
--------------------------------------------------------------------------------
/mipcandy/data/visualization.py:
--------------------------------------------------------------------------------
1 | from importlib.util import find_spec
2 | from math import ceil
3 | from multiprocessing import get_context
4 | from os import PathLike
5 | from typing import Literal
6 | from warnings import warn
7 |
8 | import numpy as np
9 | import torch
10 | from matplotlib import pyplot as plt
11 | from torch import nn
12 |
13 | from mipcandy.common import ColorizeLabel
14 | from mipcandy.data.convertion import auto_convert
15 | from mipcandy.data.geometric import ensure_num_dimensions
16 |
17 |
18 | def visualize2d(image: torch.Tensor, *, title: str | None = None, cmap: str = "gray",
19 | blocking: bool = False, screenshot_as: str | PathLike[str] | None = None) -> None:
20 | image = image.detach().cpu()
21 | if image.ndim < 2:
22 | raise ValueError(f"`image` must have at least 2 dimensions, got {image.shape}")
23 | if image.ndim > 3:
24 | image = ensure_num_dimensions(image, 3)
25 | if image.ndim == 3:
26 | if image.shape[0] == 1:
27 | image = image.squeeze(0)
28 | else:
29 | image = image.permute(1, 2, 0)
30 | image = auto_convert(image)
31 | plt.imshow(image.numpy(), cmap, vmin=0, vmax=255)
32 | plt.title(title)
33 | plt.axis("off")
34 | if screenshot_as:
35 | plt.savefig(screenshot_as)
36 | if blocking:
37 | plt.close()
38 | return
39 | plt.show(block=blocking)
40 |
41 |
42 | def _visualize3d_with_pyvista(image: np.ndarray, title: str | None, cmap: str,
43 | screenshot_as: str | PathLike[str] | None) -> None:
44 | from pyvista import Plotter
45 | p = Plotter(title=title, off_screen=bool(screenshot_as))
46 | p.add_volume(image, cmap=cmap)
47 | if screenshot_as:
48 | p.screenshot(screenshot_as)
49 | else:
50 | p.show()
51 |
52 |
53 | def visualize3d(image: torch.Tensor, *, title: str | None = None, cmap: str = "gray", max_volume: int = 1e6,
54 | backend: Literal["auto", "matplotlib", "pyvista"] = "auto", blocking: bool = False,
55 | screenshot_as: str | PathLike[str] | None = None) -> None:
56 | image = image.detach().float().cpu()
57 | if image.ndim < 3:
58 | raise ValueError(f"`image` must have at least 3 dimensions, got {image.shape}")
59 | if image.ndim > 3:
60 | image = ensure_num_dimensions(image, 3)
61 | d, h, w = image.shape
62 | total = d * h * w
63 | ratio = int(ceil((total / max_volume) ** (1 / 3))) if total > max_volume else 1
64 | if ratio > 1:
65 | image = ensure_num_dimensions(nn.functional.avg_pool3d(ensure_num_dimensions(image, 5), kernel_size=ratio,
66 | stride=ratio, ceil_mode=True), 3)
67 | image /= image.max()
68 | image = image.numpy()
69 | if backend == "auto":
70 | backend = "pyvista" if find_spec("pyvista") else "matplotlib"
71 | match backend:
72 | case "matplotlib":
73 | warn("Using Matplotlib for 3D visualization is inefficient and inaccurate, consider using PyVista")
74 | face_colors = getattr(plt.cm, cmap)(image)
75 | face_colors[..., 3] = image * (image > 0)
76 | fig = plt.figure()
77 | ax = fig.add_subplot(111, projection="3d")
78 | ax.voxels(image, facecolors=face_colors)
79 | ax.set_title(title)
80 | if screenshot_as:
81 | fig.savefig(screenshot_as)
82 | if blocking:
83 | plt.close()
84 | return
85 | plt.show(block=blocking)
86 | case "pyvista":
87 | image = image.transpose(1, 2, 0)
88 | if blocking:
89 | return _visualize3d_with_pyvista(image, title, cmap, screenshot_as)
90 | ctx = get_context("spawn")
91 | return ctx.Process(target=_visualize3d_with_pyvista, args=(image, title, cmap, screenshot_as),
92 | daemon=False).start()
93 |
94 |
95 | def overlay(image: torch.Tensor, label: torch.Tensor, *, max_label_opacity: float = .5,
96 | label_colorizer: ColorizeLabel | None = ColorizeLabel()) -> torch.Tensor:
97 | if image.ndim < 2 or label.ndim < 2:
98 | raise ValueError("Only 2D images can be overlaid")
99 | image = ensure_num_dimensions(image, 3)
100 | label = ensure_num_dimensions(label, 2)
101 | image = auto_convert(image)
102 | if image.shape[0] == 1:
103 | image = image.repeat(3, 1, 1)
104 | image_c, image_shape = image.shape[0], image.shape[1:]
105 | label_shape = label.shape
106 | if image_shape != label_shape:
107 | raise ValueError(f"Unmatched shapes {image_shape} and {label_shape}")
108 | alpha = (label > 0).int()
109 | if label_colorizer:
110 | label = label_colorizer(label)
111 | if label.shape[0] == 4:
112 | alpha = label[-1]
113 | label = label[:-1]
114 | elif label.shape[0] == 1:
115 | label = label.repeat(3, 1, 1)
116 | if not (image_c == label.shape[0] == 3):
117 | raise ValueError("Unsupported number of channels")
118 | if alpha.max() > 0:
119 | alpha = alpha * max_label_opacity / alpha.max()
120 | return image * (1 - alpha) + label * alpha
121 |
--------------------------------------------------------------------------------
/mipcandy/presets/segmentation.py:
--------------------------------------------------------------------------------
1 | from abc import ABCMeta
2 | from typing import override
3 |
4 | import torch
5 | from torch import nn, optim
6 |
7 | from mipcandy.common import AbsoluteLinearLR, DiceBCELossWithLogits
8 | from mipcandy.data import visualize2d, visualize3d, overlay, auto_convert
9 | from mipcandy.sliding_window import SWMetadata
10 | from mipcandy.training import Trainer, TrainerToolbox, SlidingTrainer
11 | from mipcandy.types import Params, Shape
12 |
13 |
14 | class SegmentationTrainer(Trainer, metaclass=ABCMeta):
15 | num_classes: int = 1
16 | include_bg: bool = True
17 |
18 | def _save_preview(self, x: torch.Tensor, title: str, quality: float) -> None:
19 | path = f"{self.experiment_folder()}/{title} (preview).png"
20 | if x.ndim == 3 and x.shape[0] in (1, 3, 4):
21 | visualize2d(auto_convert(x), title=title, blocking=True, screenshot_as=path)
22 | elif x.ndim == 4 and x.shape[0] == 1:
23 | visualize3d(x, title=title, max_volume=int(quality * 1e6), blocking=True, screenshot_as=path)
24 |
25 | @override
26 | def save_preview(self, image: torch.Tensor, label: torch.Tensor, output: torch.Tensor, *,
27 | quality: float = .75) -> None:
28 | output = output.sigmoid()
29 | self._save_preview(image, "input", quality)
30 | self._save_preview(label, "label", quality)
31 | self._save_preview(output, "prediction", quality)
32 | if image.ndim == label.ndim == output.ndim == 3 and label.shape[0] == output.shape[0] == 1:
33 | visualize2d(overlay(image, label), title="expected", blocking=True,
34 | screenshot_as=f"{self.experiment_folder()}/expected (preview).png")
35 | visualize2d(overlay(image, output), title="actual", blocking=True,
36 | screenshot_as=f"{self.experiment_folder()}/actual (preview).png")
37 |
38 | @override
39 | def build_criterion(self) -> nn.Module:
40 | return DiceBCELossWithLogits(self.num_classes, include_bg=self.include_bg)
41 |
42 | @override
43 | def build_optimizer(self, params: Params) -> optim.Optimizer:
44 | return optim.AdamW(params)
45 |
46 | @override
47 | def build_scheduler(self, optimizer: optim.Optimizer, num_epochs: int) -> optim.lr_scheduler.LRScheduler:
48 | return AbsoluteLinearLR(optimizer, -8e-6 / len(self._dataloader), 1e-2)
49 |
50 | @override
51 | def backward(self, images: torch.Tensor, labels: torch.Tensor, toolbox: TrainerToolbox) -> tuple[float, dict[
52 | str, float]]:
53 | masks = toolbox.model(images)
54 | loss, metrics = toolbox.criterion(masks, labels)
55 | loss.backward()
56 | return loss.item(), metrics
57 |
58 | @override
59 | def validate_case(self, image: torch.Tensor, label: torch.Tensor, toolbox: TrainerToolbox) -> tuple[float, dict[
60 | str, float], torch.Tensor]:
61 | image, label = image.unsqueeze(0), label.unsqueeze(0)
62 | mask = (toolbox.ema if toolbox.ema else toolbox.model)(image)
63 | loss, metrics = toolbox.criterion(mask, label)
64 | return -loss.item(), metrics, mask.squeeze(0)
65 |
66 |
67 | class SlidingSegmentationTrainer(SlidingTrainer, SegmentationTrainer, metaclass=ABCMeta):
68 | sliding_window_shape: Shape = (128, 128)
69 |
70 | @override
71 | def backward_windowed(self, images: torch.Tensor, labels: torch.Tensor, toolbox: TrainerToolbox,
72 | metadata: SWMetadata) -> tuple[float, dict[str, float]]:
73 | masks = toolbox.model(images)
74 | loss, metrics = toolbox.criterion(masks, labels)
75 | loss.backward()
76 | return loss.item(), metrics
77 |
78 | @override
79 | def validate_case_windowed(self, images: torch.Tensor, label: torch.Tensor, toolbox: TrainerToolbox,
80 | metadata: SWMetadata) -> tuple[float, dict[str, float], torch.Tensor]:
81 | batch_size = self.get_batch_size()
82 | model = toolbox.ema if toolbox.ema else toolbox.model
83 | if batch_size is None or batch_size >= images.shape[0]:
84 | outputs = model(images)
85 | else:
86 | output_list: list[torch.Tensor] = []
87 | for i in range(0, images.shape[0], batch_size):
88 | batch = images[i:i + batch_size]
89 | output_list.append(model(batch))
90 | outputs = torch.cat(output_list, dim=0)
91 | outputs = self.revert_sliding_window(outputs, metadata)
92 | loss, metrics = toolbox.criterion(outputs, label.unsqueeze(0))
93 | return -loss.item(), metrics, outputs.squeeze(0)
94 |
95 | @override
96 | def get_window_shape(self) -> Shape:
97 | return self.sliding_window_shape
98 |
99 |
100 | class SlidingValidationTrainer(SlidingSegmentationTrainer, metaclass=ABCMeta):
101 | """
102 | Use this when training data comes from RandomROIDataset (already patched), but validation data is full volumes
103 | requiring sliding window inference.
104 | """
105 | @override
106 | def backward_windowed(self, images: torch.Tensor, labels: torch.Tensor, toolbox: TrainerToolbox,
107 | metadata: SWMetadata) -> tuple[float, dict[str, float]]:
108 | raise RuntimeError("`backward_windowed()` should not be called in `SlidingValidationTrainer`")
109 |
110 | @override
111 | def backward(self, images: torch.Tensor, labels: torch.Tensor, toolbox: TrainerToolbox) -> tuple[float, dict[
112 | str, float]]:
113 | return SegmentationTrainer.backward(self, images, labels, toolbox)
114 |
115 | @override
116 | def validate_case_windowed(self, images: torch.Tensor, label: torch.Tensor, toolbox: TrainerToolbox,
117 | metadata: SWMetadata) -> tuple[float, dict[str, float], torch.Tensor]:
118 | return super().validate_case_windowed(images, label, toolbox, metadata)
119 |
120 | @override
121 | def get_window_shape(self) -> Shape:
122 | return self.sliding_window_shape
123 |
--------------------------------------------------------------------------------
/mipcandy/metrics.py:
--------------------------------------------------------------------------------
1 | from typing import Protocol, Literal
2 |
3 | import torch
4 |
5 | from mipcandy.types import Device
6 |
7 |
8 | def _args_check(output: torch.Tensor, label: torch.Tensor, *, dtype: torch.dtype | None = None,
9 | device: Device | None = None) -> tuple[torch.dtype, Device]:
10 | if output.shape != label.shape:
11 | raise ValueError(f"Output ({output.shape}) and label ({label.shape}) must have the same shape")
12 | if (output_dtype := output.dtype) != label.dtype or dtype and output_dtype != dtype:
13 | raise TypeError(f"Output ({output_dtype}) and label ({label.dtype}) must both be {dtype}")
14 | if (output_device := output.device) != label.device:
15 | raise RuntimeError(f"Output ({output.device}) and label ({label.device}) must be on the same device")
16 | if device and output_device != device:
17 | raise RuntimeError(f"Tensors are expected to be on {device}, but instead they are on {output.device}")
18 | return output_dtype, output_device
19 |
20 |
21 | class Metric(Protocol):
22 | def __call__(self, output: torch.Tensor, label: torch.Tensor, *, if_empty: float = ...) -> torch.Tensor: ...
23 |
24 |
25 | def do_reduction(x: torch.Tensor, method: Literal["mean", "median", "sum", "none"] = "mean") -> torch.Tensor:
26 | match method:
27 | case "mean":
28 | return x.mean()
29 | case "median":
30 | return x.median()
31 | case "sum":
32 | return x.sum()
33 | case "none":
34 | return x
35 |
36 |
37 | def apply_multiclass_to_binary(metric: Metric, output: torch.Tensor, label: torch.Tensor, num_classes: int | None,
38 | if_empty: float, *, reduction: Literal["mean", "sum"] = "mean") -> torch.Tensor:
39 | _args_check(output, label, dtype=torch.int)
40 | if not num_classes:
41 | num_classes = max(output.max().item(), label.max().item())
42 | if num_classes == 0:
43 | return torch.tensor(if_empty, dtype=torch.float)
44 | else:
45 | x = torch.tensor([metric(output == cls, label == cls, if_empty=if_empty) for cls in range(1, num_classes + 1)])
46 | return do_reduction(x, reduction)
47 |
48 |
49 | def dice_similarity_coefficient_binary(output: torch.Tensor, label: torch.Tensor, *,
50 | if_empty: float = 1) -> torch.Tensor:
51 | _args_check(output, label, dtype=torch.bool)
52 | volume_sum = output.sum() + label.sum()
53 | if volume_sum == 0:
54 | return torch.tensor(if_empty, dtype=torch.float)
55 | return 2 * (output & label).sum() / volume_sum
56 |
57 |
58 | def dice_similarity_coefficient_multiclass(output: torch.Tensor, label: torch.Tensor, *, num_classes: int | None = None,
59 | if_empty: float = 1) -> torch.Tensor:
60 | return apply_multiclass_to_binary(dice_similarity_coefficient_binary, output, label, num_classes, if_empty)
61 |
62 |
63 | def soft_dice_coefficient(output: torch.Tensor, label: torch.Tensor, *,
64 | smooth: float = 1e-5, include_bg: bool = True) -> torch.Tensor:
65 | _args_check(output, label)
66 | axes = tuple(range(2, output.ndim))
67 | intersection = (output * label).sum(dim=axes)
68 | dice = (2 * intersection + smooth) / (output.sum(dim=axes) + label.sum(dim=axes) + smooth)
69 | if not include_bg:
70 | dice = dice[:, 1:]
71 | return dice.mean()
72 |
73 |
74 | def accuracy_binary(output: torch.Tensor, label: torch.Tensor, *, if_empty: float = 1) -> torch.Tensor:
75 | _args_check(output, label, dtype=torch.bool)
76 | numerator = (output & label).sum() + (~output & ~label).sum()
77 | denominator = numerator + (output & ~label).sum() + (label & ~output).sum()
78 | return torch.tensor(if_empty, dtype=torch.float) if denominator == 0 else numerator / denominator
79 |
80 |
81 | def accuracy_multiclass(output: torch.Tensor, label: torch.Tensor, *, num_classes: int | None = None,
82 | if_empty: float = 1) -> torch.Tensor:
83 | return apply_multiclass_to_binary(accuracy_binary, output, label, num_classes, if_empty)
84 |
85 |
86 | def _precision_or_recall(output: torch.Tensor, label: torch.Tensor, if_empty: float,
87 | is_precision: bool) -> torch.Tensor:
88 | _args_check(output, label, dtype=torch.bool)
89 | tp = (output & label).sum()
90 | denominator = output.sum() if is_precision else label.sum()
91 | return torch.tensor(if_empty, dtype=torch.float) if denominator == 0 else tp / denominator
92 |
93 |
94 | def precision_binary(output: torch.Tensor, label: torch.Tensor, *, if_empty: float = 1) -> torch.Tensor:
95 | return _precision_or_recall(output, label, if_empty, True)
96 |
97 |
98 | def precision_multiclass(output: torch.Tensor, label: torch.Tensor, *, num_classes: int | None = None,
99 | if_empty: float = 1) -> torch.Tensor:
100 | return apply_multiclass_to_binary(precision_binary, output, label, num_classes, if_empty)
101 |
102 |
103 | def recall_binary(output: torch.Tensor, label: torch.Tensor, *, if_empty: float = 1) -> torch.Tensor:
104 | return _precision_or_recall(output, label, if_empty, False)
105 |
106 |
107 | def recall_multiclass(output: torch.Tensor, label: torch.Tensor, *, num_classes: int | None = None,
108 | if_empty: float = 1) -> torch.Tensor:
109 | return apply_multiclass_to_binary(recall_binary, output, label, num_classes, if_empty)
110 |
111 |
112 | def iou_binary(output: torch.Tensor, label: torch.Tensor, *, if_empty: float = 1) -> torch.Tensor:
113 | _args_check(output, label, dtype=torch.bool)
114 | denominator = (output | label).sum()
115 | return torch.tensor(if_empty, dtype=torch.float) if denominator == 0 else (output & label).sum() / denominator
116 |
117 |
118 | def iou_multiclass(output: torch.Tensor, label: torch.Tensor, *, num_classes: int | None = None,
119 | if_empty: float = 1) -> torch.Tensor:
120 | return apply_multiclass_to_binary(iou_binary, output, label, num_classes, if_empty)
121 |
--------------------------------------------------------------------------------
/mipcandy/inference.py:
--------------------------------------------------------------------------------
1 | from abc import ABCMeta
2 | from math import log, ceil
3 | from os import PathLike, listdir
4 | from os.path import isdir, basename, exists
5 | from typing import Sequence, override
6 |
7 | import torch
8 | from torch import nn
9 |
10 | from mipcandy.common import Pad2d, Pad3d, Restore2d, Restore3d
11 | from mipcandy.data import save_image, Loader, UnsupervisedDataset, PathBasedUnsupervisedDataset
12 | from mipcandy.layer import WithPaddingModule, WithNetwork
13 | from mipcandy.sliding_window import SlidingWindow
14 | from mipcandy.types import SupportedPredictant, Device, AmbiguousShape
15 |
16 |
17 | def parse_predictant(x: SupportedPredictant, loader: type[Loader], *, as_label: bool = False) -> tuple[list[
18 | torch.Tensor], list[str] | None]:
19 | if isinstance(x, str):
20 | if isdir(x):
21 | cases = listdir(x)
22 | return [loader.do_load(f"{x}/{case}", is_label=as_label) for case in cases], cases
23 | return [loader.do_load(x, is_label=as_label)], [basename(x)]
24 | if isinstance(x, torch.Tensor):
25 | return [x], None
26 | r, filenames = [], None
27 | for case in x:
28 | if isinstance(case, str):
29 | if not filenames:
30 | filenames = []
31 | r.append(loader.do_load(case, is_label=as_label))
32 | filenames.append(case[case.rfind("/") + 1:])
33 | elif filenames:
34 | raise TypeError("`x` should be single-typed")
35 | elif isinstance(case, torch.Tensor):
36 | r.append(case)
37 | else:
38 | raise TypeError(f"Unexpected type of element {type(case)}")
39 | return r, filenames
40 |
41 |
42 | class Predictor(WithPaddingModule, WithNetwork, metaclass=ABCMeta):
43 | def __init__(self, experiment_folder: str | PathLike[str], example_shape: AmbiguousShape, *,
44 | checkpoint: str = "checkpoint_best.pth", device: Device = "cpu") -> None:
45 | WithPaddingModule.__init__(self, device)
46 | WithNetwork.__init__(self, device)
47 | self._experiment_folder: str = experiment_folder
48 | self._example_shape: AmbiguousShape = example_shape
49 | self._checkpoint: str = checkpoint
50 | self._model: nn.Module | None = None
51 |
52 | def lazy_load_model(self) -> None:
53 | if self._model:
54 | return
55 | self._model = self.load_model(self._example_shape, checkpoint=torch.load(
56 | f"{self._experiment_folder}/{self._checkpoint}"
57 | ))
58 | self._model.eval()
59 |
60 | def predict_image(self, image: torch.Tensor, *, batch: bool = False) -> torch.Tensor:
61 | self.lazy_load_model()
62 | image = image.to(self._device)
63 | if not batch:
64 | image = image.unsqueeze(0)
65 | padding_module = self.get_padding_module()
66 | if padding_module:
67 | image = padding_module(image)
68 | output = self._model(image)
69 | restoring_module = self.get_restoring_module()
70 | if restoring_module:
71 | output = restoring_module(output)
72 | return output if batch else output.squeeze(0)
73 |
74 | def _predict(self, x: SupportedPredictant | UnsupervisedDataset) -> tuple[list[torch.Tensor], list[str] | None]:
75 | if isinstance(x, PathBasedUnsupervisedDataset):
76 | return [self.predict_image(case) for case in x], x.paths()
77 | if isinstance(x, UnsupervisedDataset):
78 | return [self.predict_image(case) for case in x], None
79 | images, filenames = parse_predictant(x, Loader)
80 | return [self.predict_image(image) for image in images], filenames
81 |
82 | def predict(self, x: SupportedPredictant | UnsupervisedDataset) -> list[torch.Tensor]:
83 | return self._predict(x)[0]
84 |
85 | @staticmethod
86 | def save_prediction(output: torch.Tensor, path: str | PathLike[str]) -> None:
87 | save_image(output, path)
88 |
89 | def save_predictions(self, outputs: Sequence[torch.Tensor], folder: str | PathLike[str], *,
90 | filenames: Sequence[str | PathLike[str]] | None = None) -> None:
91 | if not exists(folder):
92 | raise FileNotFoundError(f"Folder {folder} does not exist")
93 | if not filenames:
94 | num_digits = ceil(log(len(outputs)))
95 | filenames = [f"prediction_{str(i).zfill(num_digits)}.{
96 | "png" if output.ndim == 3 and output.shape[0] in (1, 3) else "mha"}" for i, output in enumerate(outputs)]
97 | for i, prediction in enumerate(outputs):
98 | self.save_prediction(prediction, f"{folder}/{filenames[i]}")
99 |
100 | def predict_to_files(self, x: SupportedPredictant | UnsupervisedDataset,
101 | folder: str | PathLike[str]) -> list[str] | None:
102 | outputs, filenames = self._predict(x)
103 | self.save_predictions(outputs, folder, filenames=filenames)
104 | return filenames
105 |
106 | def __call__(self, x: SupportedPredictant | UnsupervisedDataset) -> list[torch.Tensor]:
107 | return self.predict(x)
108 |
109 |
110 | class SlidingPredictor(Predictor, SlidingWindow, metaclass=ABCMeta):
111 | @override
112 | def build_padding_module(self) -> nn.Module | None:
113 | window_shape = self.get_window_shape()
114 | return (Pad2d if len(window_shape) == 2 else Pad3d)(window_shape)
115 |
116 | @override
117 | def build_restoring_module(self, padding_module: nn.Module | None) -> nn.Module | None:
118 | if not isinstance(padding_module, (Pad2d, Pad3d)):
119 | raise TypeError("`padding_module` should be either `Pad2d` or `Pad3d`")
120 | window_shape = self.get_window_shape()
121 | return (Restore2d if len(window_shape) == 2 else Restore3d)(padding_module)
122 |
123 | @override
124 | def predict_image(self, image: torch.Tensor, *, batch: bool = False) -> torch.Tensor:
125 | if not batch:
126 | image = image.unsqueeze(0)
127 | images, metadata = self.do_sliding_window(image)
128 | outputs = super().predict_image(images, batch=True)
129 | outputs = self.revert_sliding_window(outputs, metadata)
130 | return outputs if batch else outputs.squeeze(0)
131 |
--------------------------------------------------------------------------------
/mipcandy/frontend/notion_fe.py:
--------------------------------------------------------------------------------
1 | from datetime import datetime
2 | from typing import override, Literal
3 |
4 | from requests import get, post, patch, Response
5 |
6 | from mipcandy.frontend.prototype import Frontend
7 | from mipcandy.types import Settings
8 |
9 |
10 | class NotionFrontend(Frontend):
11 | def __init__(self, secrets: Settings) -> None:
12 | super().__init__(secrets)
13 | self._api_key: str = self.require_nonempty_secret("notion_api_key", require_type=str)
14 | self._database_id: str = self.require_nonempty_secret("notion_database_id", require_type=str)
15 | self._headers: dict[str, str] = {
16 | "Authorization": f"Bearer {self._api_key}",
17 | "Content-Type": "application/json",
18 | "Notion-Version": "2022-06-28"
19 | }
20 | self._num_epochs: int = 1
21 | self._early_stop_tolerance: int = -1
22 | self._start_time: str = ""
23 | self._page_id: str = ""
24 |
25 | def retrieve_database(self) -> Response:
26 | return get(f"https://api.notion.com/v1/databases/{self._database_id}", headers=self._headers)
27 |
28 | def query_database(self, *, experiment_id: str | None = None) -> Response:
29 | json = {"filter": {"property": "Experiment ID", "title": {"equals": experiment_id}}} if experiment_id else None
30 | return post(f"https://api.notion.com/v1/databases/{self._database_id}/query", json=json, headers=self._headers)
31 |
32 | def select_experiment(self, experiment_id: str) -> str:
33 | experiments = self.query_database(experiment_id=experiment_id)
34 | if experiments.status_code != 200:
35 | raise RuntimeError(f"Failed to query database: {experiments.json()}")
36 | experiments = experiments.json()["results"]
37 | if len(experiments) == 1:
38 | return experiments[0]["id"]
39 | if len(experiments) > 1:
40 | raise RuntimeError(f"Found multiple experiments with the same ID {experiment_id}")
41 | return ""
42 |
43 | def new_experiment(self, experiment_id: str, trainer: str, model: str, note: str, num_macs: float,
44 | num_params: float) -> Response:
45 | self._start_time = datetime.now().astimezone().strftime("%Y-%m-%dT%H:%M:%S.000%z")
46 | properties = {
47 | "Experiment ID": {"title": [{"text": {"content": experiment_id}}]},
48 | "Status": {"status": {"name": "In Progress"}},
49 | "Progress": {"number": 0},
50 | "Early Stop": {"number": 1},
51 | "Trainer": {"select": {"name": trainer}},
52 | "Model": {"select": {"name": model}},
53 | "Time": {"date": {"start": self._start_time}},
54 | "Note": {"rich_text": [{"text": {"content": note}}]},
55 | "MACs (G)": {"number": round(num_macs, 1)},
56 | "Params (M)": {"number": round(num_params, 1)},
57 | "Epoch": {"number": 0},
58 | "Score": {"number": 0},
59 | }
60 | page_id = self.select_experiment(experiment_id)
61 | if page_id:
62 | self._page_id = page_id
63 | return patch(f"https://api.notion.com/v1/pages/{page_id}", json={"properties": properties},
64 | headers=self._headers)
65 | res = post("https://api.notion.com/v1/pages", json={
66 | "parent": {"database_id": self._database_id},
67 | "icon": {"external": {"url": "https://www.notion.so/icons/science_gray.svg"}},
68 | "properties": properties
69 | }, headers=self._headers)
70 | self._page_id = res.json()["id"]
71 | return res
72 |
73 | def update_experiment(self, experiment_id: str, status: Literal["In Progress", "Completed", "Interrupted"],
74 | *, epoch: int | None = None, score: float | None = None,
75 | early_stop_tolerance: int | None = None, observation: str | None = None) -> Response:
76 | if not self._page_id:
77 | raise RuntimeError(f"Experiment {experiment_id} has not been created")
78 | properties = {"Status": {"status": {"name": status}}}
79 | if epoch is not None:
80 | properties["Progress"] = {"number": epoch / self._num_epochs}
81 | properties["Epoch"] = {"number": epoch}
82 | if early_stop_tolerance is not None:
83 | properties["Early Stop"] = {"number": max(early_stop_tolerance, 0) / self._early_stop_tolerance}
84 | if score is not None:
85 | properties["Score"] = {"number": round(score, 4)}
86 | if observation is not None:
87 | properties["Observation"] = {"rich_text": [{"text": {"content": observation}}]}
88 | if status == "Completed":
89 | properties["Progress"] = {"number": 1}
90 | properties["Time"] = {"date": {"start": self._start_time,
91 | "end": datetime.now().astimezone().strftime("%Y-%m-%dT%H:%M:%S.000%z")}}
92 | return patch(f"https://api.notion.com/v1/pages/{self._page_id}", json={"properties": properties},
93 | headers=self._headers)
94 |
95 | @override
96 | def on_experiment_created(self, experiment_id: str, trainer: str, model: str, note: str, num_macs: float,
97 | num_params: float, num_epochs: int, early_stop_tolerance: int) -> None:
98 | self._num_epochs = num_epochs
99 | self._early_stop_tolerance = early_stop_tolerance
100 | res = self.new_experiment(experiment_id, trainer, model, note, num_macs * 1e-9, num_params * 1e-6)
101 | if res.status_code != 200:
102 | raise RuntimeError(f"Failed to create experiment: {res.json()}")
103 |
104 | @override
105 | def on_experiment_updated(self, experiment_id: str, epoch: int, metrics: dict[str, list[float]],
106 | early_stop_tolerance: int) -> None:
107 | try:
108 | self.update_experiment(experiment_id, "In Progress", epoch=epoch, score=max(metrics["val score"]),
109 | early_stop_tolerance=early_stop_tolerance)
110 | except RuntimeError:
111 | pass
112 |
113 | @override
114 | def on_experiment_completed(self, experiment_id: str) -> None:
115 | res = self.update_experiment(experiment_id, "Completed")
116 | if res.status_code != 200:
117 | raise RuntimeError(f"Failed to update experiment: {res.json()}")
118 |
119 | @override
120 | def on_experiment_interrupted(self, experiment_id: str, error: Exception) -> None:
121 | res = self.update_experiment(experiment_id, "Interrupted", observation=repr(error))
122 | if res.status_code != 200:
123 | raise RuntimeError(f"Failed to update experiment: {res.json()}")
124 |
--------------------------------------------------------------------------------
/mipcandy/common/module/preprocess.py:
--------------------------------------------------------------------------------
1 | from math import ceil
2 | from typing import Literal
3 |
4 | import torch
5 | from torch import nn
6 |
7 | from mipcandy.types import Colormap, Shape2d, Shape3d
8 |
9 |
10 | class Pad(nn.Module):
11 | def __init__(self, *, value: int = 0, mode: str = "constant", batch: bool = True) -> None:
12 | super().__init__()
13 | self._value: int = value
14 | self._mode: str = mode
15 | self.batch: bool = batch
16 | self._paddings: tuple[int, int, int, int, int, int] | tuple[int, int, int, int] | None = None
17 | self.requires_grad_(False)
18 |
19 | @staticmethod
20 | def _c_t(size: int, min_factor: int) -> int:
21 | """
22 | Compute target on a single dimension
23 | """
24 | return ceil(size / min_factor) * min_factor
25 |
26 | @staticmethod
27 | def _c_p(size: int, min_factor: int) -> tuple[int, int]:
28 | """
29 | Compute padding on a single dimension
30 | """
31 | excess = Pad._c_t(size, min_factor) - size
32 | before = excess // 2
33 | return before, excess - before
34 |
35 |
36 | class Pad2d(Pad):
37 | def __init__(self, min_factor: int | Shape2d, *, value: int = 0, mode: str = "constant",
38 | batch: bool = True) -> None:
39 | super().__init__(value=value, mode=mode, batch=batch)
40 | self._min_factor: Shape2d = (min_factor,) * 2 if isinstance(min_factor, int) else min_factor
41 |
42 | def paddings(self) -> tuple[int, int, int, int] | None:
43 | return self._paddings
44 |
45 | def padded_shape(self, in_shape: tuple[int, int, ...]) -> tuple[int, int, ...]:
46 | return *in_shape[:-2], self._c_t(in_shape[-2], self._min_factor[0]), self._c_t(
47 | in_shape[-1], self._min_factor[1])
48 |
49 | def forward(self, x: torch.Tensor) -> torch.Tensor:
50 | if self.batch:
51 | _, _, h, w = x.shape
52 | suffix = (0,) * 4
53 | else:
54 | _, h, w = x.shape
55 | suffix = (0,) * 2
56 | self._paddings = self._c_p(h, self._min_factor[0]) + self._c_p(w, self._min_factor[1])
57 | return nn.functional.pad(x, self._paddings[::-1] + suffix, self._mode, self._value)
58 |
59 |
60 | class Pad3d(Pad):
61 | def __init__(self, min_factor: int | Shape3d, *, value: int = 0, mode: str = "constant",
62 | batch: bool = True) -> None:
63 | super().__init__(value=value, mode=mode, batch=batch)
64 | self._min_factor: Shape3d = (min_factor,) * 3 if isinstance(min_factor, int) else min_factor
65 |
66 | def paddings(self) -> tuple[int, int, int, int, int, int] | None:
67 | return self._paddings
68 |
69 | def padded_shape(self, in_shape: tuple[int, int, int, ...]) -> tuple[int, int, int, ...]:
70 | return (*in_shape[:-3], self._c_t(in_shape[-3], self._min_factor[0]), self._c_t(
71 | in_shape[-2], self._min_factor[1]), self._c_t(in_shape[-1], self._min_factor[2]))
72 |
73 | def forward(self, x: torch.Tensor) -> torch.Tensor:
74 | if self.batch:
75 | _, _, d, h, w = x.shape
76 | suffix = (0,) * 4
77 | else:
78 | _, d, h, w = x.shape
79 | suffix = (0,) * 2
80 | self._paddings = self._c_p(d, self._min_factor[0]) + self._c_p(h, self._min_factor[1]) + self._c_p(
81 | w, self._min_factor[2])
82 | return nn.functional.pad(x, self._paddings[::-1] + suffix, self._mode, self._value)
83 |
84 |
85 | class Restore2d(nn.Module):
86 | def __init__(self, conjugate_padding: Pad2d) -> None:
87 | super().__init__()
88 | self.conjugate_padding: Pad2d = conjugate_padding
89 | self.requires_grad_(False)
90 |
91 | def forward(self, x: torch.Tensor) -> torch.Tensor:
92 | paddings = self.conjugate_padding.paddings()
93 | if not paddings:
94 | raise ValueError("Paddings are not set yet, did you forget to pad before restoring?")
95 | pad_h0, pad_h1, pad_w0, pad_w1 = paddings
96 | if self.conjugate_padding.batch:
97 | _, _, h, w = x.shape
98 | return x[:, :, pad_h0: h - pad_h1, pad_w0: w - pad_w1]
99 | _, h, w = x.shape
100 | return x[:, pad_h0: h - pad_h1, pad_w0: w - pad_w1]
101 |
102 |
103 | class Restore3d(nn.Module):
104 | def __init__(self, conjugate_padding: Pad3d) -> None:
105 | super().__init__()
106 | self.conjugate_padding: Pad3d = conjugate_padding
107 | self.requires_grad_(False)
108 |
109 | def forward(self, x: torch.Tensor) -> torch.Tensor:
110 | paddings = self.conjugate_padding.paddings()
111 | if not paddings:
112 | raise ValueError("Paddings are not set yet, did you forget to pad before restoring?")
113 | pad_d0, pad_d1, pad_h0, pad_h1, pad_w0, pad_w1 = paddings
114 | if self.conjugate_padding.batch:
115 | _, _, d, h, w = x.shape
116 | return x[:, :, pad_d0: d - pad_d1, pad_h0: h - pad_h1, pad_w0: w - pad_w1]
117 | _, d, h, w = x.shape
118 | return x[:, pad_d0: d - pad_d1, pad_h0: h - pad_h1, pad_w0: w - pad_w1]
119 |
120 |
121 | class Normalize(nn.Module):
122 | def __init__(self, *, domain: tuple[float | None, float | None] = (0, None), strict: bool = False,
123 | method: Literal["linear", "intercept", "cut"] = "linear") -> None:
124 | super().__init__()
125 | self._domain: tuple[float | None, float | None] = domain
126 | self._strict: bool = strict
127 | self._method: Literal["linear", "intercept", "cut"] = method
128 |
129 | def forward(self, x: torch.Tensor) -> torch.Tensor:
130 | left, right = self._domain
131 | if left is None and right is None:
132 | return x
133 | r_l, r_r = x.min(), x.max()
134 | match self._method:
135 | case "linear":
136 | if left is None or (left < r_l and not self._strict):
137 | left = r_l
138 | if right is None or (right > r_r and not self._strict):
139 | right = r_r
140 | numerator = right - left
141 | if numerator == 0:
142 | numerator = 1
143 | denominator = r_r - r_l
144 | if denominator == 0:
145 | denominator = 1
146 | return (x - r_l) * numerator / denominator + left
147 | case "intercept":
148 | if left is not None and right is None:
149 | return x - r_l + left if r_l < left or self._strict else x
150 | elif left is None and right is not None:
151 | return x - r_r + right if r_r > right or self._strict else x
152 | else:
153 | raise ValueError("Cannot use intercept normalization when both ends are fixed")
154 | case "cut":
155 | if self._strict:
156 | raise ValueError("Method \"cut\" cannot be strict")
157 | if left is not None:
158 | x[x < left] = left
159 | if right is not None:
160 | x[x > right] = right
161 | return x
162 |
163 |
164 | class ColorizeLabel(nn.Module):
165 | def __init__(self, *, colormap: Colormap | None = None) -> None:
166 | super().__init__()
167 | if not colormap:
168 | colormap = []
169 | for r in range(8):
170 | for g in range(8):
171 | for b in range(32):
172 | colormap.append([r * 32, g * 32, 255 - b * 32])
173 | self._colormap: torch.Tensor = torch.tensor(colormap)
174 |
175 | def forward(self, x: torch.Tensor) -> torch.Tensor:
176 | cmap = self._colormap.to(x.device)
177 | return (
178 | torch.cat([cmap[(x > 0).int()].permute(2, 0, 1), x.unsqueeze(0)]) if 0 <= x.min() < x.max() <= 1
179 | else cmap[x.int()].permute(2, 0, 1)
180 | )
181 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright [yyyy] [name of copyright owner]
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
--------------------------------------------------------------------------------
/mipcandy/data/dataset.py:
--------------------------------------------------------------------------------
1 | from abc import ABCMeta, abstractmethod
2 | from json import dump
3 | from os import PathLike, listdir, makedirs
4 | from os.path import exists
5 | from random import choices
6 | from shutil import copy2
7 | from typing import Literal, override, Self, Sequence, TypeVar, Generic, Any
8 |
9 | import torch
10 | from pandas import DataFrame
11 | from torch.utils.data import Dataset
12 |
13 | from mipcandy.data.io import load_image
14 | from mipcandy.layer import HasDevice
15 | from mipcandy.types import Transform, Device
16 |
17 |
18 | class KFPicker(object, metaclass=ABCMeta):
19 | @staticmethod
20 | @abstractmethod
21 | def pick(n: int, fold: Literal[0, 1, 2, 3, 4, "all"]) -> tuple[int, ...]:
22 | raise NotImplementedError
23 |
24 |
25 | class OrderedKFPicker(KFPicker):
26 | @staticmethod
27 | @override
28 | def pick(n: int, fold: Literal[0, 1, 2, 3, 4, "all"]) -> tuple[int, ...]:
29 | if fold == "all":
30 | return tuple(range(0, n, 4))
31 | size = n // 5
32 | return tuple(range(size * fold, size * (fold + 1)))
33 |
34 |
35 | class RandomKFPicker(OrderedKFPicker):
36 | @staticmethod
37 | @override
38 | def pick(n: int, fold: Literal[0, 1, 2, 3, 4, "all"]) -> tuple[int, ...]:
39 | return tuple(choices(range(n), k=n // 5)) if fold == "all" else super().pick(n, fold)
40 |
41 |
42 | class Loader(object):
43 | @staticmethod
44 | def do_load(path: str | PathLike[str], *, is_label: bool = False, device: Device = "cpu", **kwargs) -> torch.Tensor:
45 | return load_image(path, is_label=is_label, device=device, **kwargs)
46 |
47 |
48 | T = TypeVar("T")
49 |
50 |
51 | class _AbstractDataset(Dataset, Loader, HasDevice, Generic[T], Sequence[T], metaclass=ABCMeta):
52 | @abstractmethod
53 | def load(self, idx: int) -> T:
54 | raise NotImplementedError
55 |
56 | @override
57 | def __getitem__(self, idx: int) -> T:
58 | return self.load(idx)
59 |
60 |
61 | D = TypeVar("D", bound=Sequence[Any])
62 |
63 |
64 | class UnsupervisedDataset(_AbstractDataset[torch.Tensor], Generic[D], metaclass=ABCMeta):
65 | """
66 | Do not use this as a generic class. Only parameterize it if you are inheriting from it.
67 | """
68 |
69 | def __init__(self, images: D, *, device: Device = "cpu") -> None:
70 | super().__init__(device)
71 | self._images: D = images
72 |
73 | @override
74 | def __len__(self) -> int:
75 | return len(self._images)
76 |
77 |
78 | class SupervisedDataset(_AbstractDataset[tuple[torch.Tensor, torch.Tensor]], Generic[D], metaclass=ABCMeta):
79 | """
80 | Do not use this as a generic class. Only parameterize it if you are inheriting from it.
81 | """
82 |
83 | def __init__(self, images: D, labels: D, *, device: Device = "cpu") -> None:
84 | super().__init__(device)
85 | if len(images) != len(labels):
86 | raise ValueError(f"Unmatched number of images {len(images)} and labels {len(labels)}")
87 | self._images: D = images
88 | self._labels: D = labels
89 |
90 | @override
91 | def __len__(self) -> int:
92 | return len(self._images)
93 |
94 | @abstractmethod
95 | def construct_new(self, images: D, labels: D) -> Self:
96 | raise NotImplementedError
97 |
98 | def fold(self, *, fold: Literal[0, 1, 2, 3, 4, "all"] = "all", picker: type[KFPicker] = OrderedKFPicker) -> tuple[
99 | Self, Self]:
100 | indexes = picker.pick(len(self), fold)
101 | images_train = []
102 | labels_train = []
103 | images_val = []
104 | labels_val = []
105 | for i in range(len(self)):
106 | if i in indexes:
107 | images_val.append(self._images[i])
108 | labels_val.append(self._labels[i])
109 | else:
110 | images_train.append(self._images[i])
111 | labels_train.append(self._labels[i])
112 | return self.construct_new(images_train, labels_train), self.construct_new(images_val, labels_val)
113 |
114 |
115 | class DatasetFromMemory(UnsupervisedDataset[Sequence[torch.Tensor]]):
116 | def __init__(self, images: Sequence[torch.Tensor], device: Device = "cpu") -> None:
117 | super().__init__(images, device=device)
118 |
119 | @override
120 | def load(self, idx: int) -> torch.Tensor:
121 | return self._images[idx].to(self._device)
122 |
123 |
124 | class MergedDataset(SupervisedDataset[UnsupervisedDataset]):
125 | def __init__(self, images: UnsupervisedDataset, labels: UnsupervisedDataset, *, device: Device = "cpu") -> None:
126 | super().__init__(images, labels, device=device)
127 |
128 | @override
129 | def load(self, idx: int) -> tuple[torch.Tensor, torch.Tensor]:
130 | return self._images[idx].to(self._device), self._labels[idx].to(self._device)
131 |
132 | @override
133 | def construct_new(self, images: UnsupervisedDataset, labels: UnsupervisedDataset) -> Self:
134 | return MergedDataset(DatasetFromMemory(images), DatasetFromMemory(labels), device=self._device)
135 |
136 |
137 | class ComposeDataset(_AbstractDataset[tuple[torch.Tensor, torch.Tensor] | torch.Tensor]):
138 | def __init__(self, bases: Sequence[SupervisedDataset] | Sequence[UnsupervisedDataset], *,
139 | device: Device = "cpu") -> None:
140 | super().__init__(device)
141 | self._bases: dict[tuple[int, int], SupervisedDataset | UnsupervisedDataset] = {}
142 | self._len = 0
143 | for dataset in bases:
144 | end = len(dataset)
145 | self._bases[(self._len, self._len + end)] = dataset
146 | self._len += end
147 |
148 | @override
149 | def load(self, idx: int) -> tuple[torch.Tensor, torch.Tensor] | torch.Tensor:
150 | for (start, end), base in self._bases.items():
151 | if start <= idx < end:
152 | return base.load(idx - start)
153 | raise IndexError(f"Index {idx} out of range [0, {self._len})")
154 |
155 | @override
156 | def __len__(self) -> int:
157 | return self._len
158 |
159 |
160 | class PathBasedUnsupervisedDataset(UnsupervisedDataset[list[str]], metaclass=ABCMeta):
161 | def paths(self) -> list[str]:
162 | return self._images
163 |
164 | def save_paths(self, to: str | PathLike[str]) -> None:
165 | match (fmt := to.split(".")[-1]):
166 | case "csv":
167 | df = DataFrame([{"image": image_path} for image_path in self.paths()])
168 | df.index = range(len(df))
169 | df.index.name = "case"
170 | df.to_csv(to)
171 | case "json":
172 | with open(to, "w") as f:
173 | dump([{"image": image_path} for image_path in self.paths()], f)
174 | case "txt":
175 | with open(to, "w") as f:
176 | for image_path in self.paths():
177 | f.write(f"{image_path}\n")
178 | case _:
179 | raise ValueError(f"Unsupported file extension: {fmt}")
180 |
181 |
182 | class SimpleDataset(PathBasedUnsupervisedDataset):
183 | def __init__(self, folder: str | PathLike[str], *, device: Device = "cpu") -> None:
184 | images = listdir(folder)
185 | images.sort()
186 | super().__init__(images, device=device)
187 | self._folder: str = folder
188 |
189 | @override
190 | def load(self, idx: int) -> torch.Tensor:
191 | return self.do_load(f"{self._folder}/{self._images[idx]}", device=self._device)
192 |
193 |
194 | class PathBasedSupervisedDataset(SupervisedDataset[list[str]], metaclass=ABCMeta):
195 | def paths(self) -> list[tuple[str, str]]:
196 | return [(self._images[i], self._labels[i]) for i in range(len(self))]
197 |
198 | def save_paths(self, to: str | PathLike[str]) -> None:
199 | match (fmt := to.split(".")[-1]):
200 | case "csv":
201 | df = DataFrame([{"image": image_path, "label": label_path} for image_path, label_path in self.paths()])
202 | df.index = range(len(df))
203 | df.index.name = "case"
204 | df.to_csv(to)
205 | case "json":
206 | with open(to, "w") as f:
207 | dump([{"image": image_path, "label": label_path} for image_path, label_path in self.paths()], f)
208 | case "txt":
209 | with open(to, "w") as f:
210 | for image_path, label_path in self.paths():
211 | f.write(f"{image_path}\t{label_path}\n")
212 | case _:
213 | raise ValueError(f"Unsupported file extension: {fmt}")
214 |
215 |
216 | class NNUNetDataset(PathBasedSupervisedDataset):
217 | def __init__(self, folder: str | PathLike[str], *, split: str | Literal["Tr", "Ts"] = "Tr", prefix: str = "",
218 | align_spacing: bool = False, image_transform: Transform | None = None,
219 | label_transform: Transform | None = None, device: Device = "cpu") -> None:
220 | images: list[str] = [f for f in listdir(f"{folder}/images{split}") if f.startswith(prefix)]
221 | images.sort()
222 | labels: list[str] = [f for f in listdir(f"{folder}/labels{split}") if f.startswith(prefix)]
223 | labels.sort()
224 | self._multimodal_images: list[list[str]] = []
225 | if len(images) == len(labels):
226 | super().__init__(images, labels, device=device)
227 | else:
228 | super().__init__([""] * len(labels), labels, device=device)
229 | current_case = ""
230 | for image in images:
231 | case = image[:image.rfind("_")]
232 | if case != current_case:
233 | self._multimodal_images.append([])
234 | current_case = case
235 | self._multimodal_images[-1].append(image)
236 | if len(self._multimodal_images) != len(self._labels):
237 | raise ValueError("Unmatched number of images and labels")
238 | self._folder: str = folder
239 | self._split: str = split
240 | self._folded: bool = False
241 | self._prefix: str = prefix
242 | self._align_spacing: bool = align_spacing
243 | self._image_transform: Transform | None = image_transform
244 | self._label_transform: Transform | None = label_transform
245 |
246 | @staticmethod
247 | def _create_subset(folder: str) -> None:
248 | if exists(folder) and len(listdir(folder)) > 0:
249 | raise FileExistsError(f"{folder} already exists and is not empty")
250 | makedirs(folder, exist_ok=True)
251 |
252 | @override
253 | def load(self, idx: int) -> tuple[torch.Tensor, torch.Tensor]:
254 | image = torch.cat([self.do_load(
255 | f"{self._folder}/images{self._split}/{path}", align_spacing=self._align_spacing, device=self._device
256 | ) for path in self._multimodal_images[idx]]) if self._multimodal_images else self.do_load(
257 | f"{self._folder}/images{self._split}/{self._images[idx]}", align_spacing=self._align_spacing,
258 | device=self._device
259 | )
260 | label = self.do_load(
261 | f"{self._folder}/labels{self._split}/{self._labels[idx]}", is_label=True, align_spacing=self._align_spacing,
262 | device=self._device
263 | )
264 | if self._image_transform:
265 | image = self._image_transform(image)
266 | if self._label_transform:
267 | label = self._label_transform(label)
268 | return image, label
269 |
270 | def save(self, split: str | Literal["Tr", "Ts"], *, target_folder: str | PathLike[str] | None = None) -> None:
271 | target_base = target_folder if target_folder else self._folder
272 | images_target = f"{target_base}/images{split}"
273 | labels_target = f"{target_base}/labels{split}"
274 | self._create_subset(images_target)
275 | self._create_subset(labels_target)
276 | for image_path, label_path in self.paths():
277 | copy2(f"{self._folder}/images{self._split}/{image_path}", f"{images_target}/{image_path}")
278 | copy2(f"{self._folder}/labels{self._split}/{label_path}", f"{labels_target}/{label_path}")
279 | self._split = split
280 | self._folded = False
281 |
282 | @override
283 | def construct_new(self, images: list[str], labels: list[str]) -> Self:
284 | if self._folded:
285 | raise ValueError("Cannot construct a new dataset from a fold")
286 | new = self.__class__(self._folder, split=self._split, prefix=self._prefix, align_spacing=self._align_spacing,
287 | image_transform=self._image_transform, label_transform=self._label_transform,
288 | device=self._device)
289 | new._images = images
290 | new._labels = labels
291 | new._folded = True
292 | return new
293 |
294 |
295 | class BinarizedDataset(SupervisedDataset[D]):
296 | def __init__(self, base: SupervisedDataset[D], positive_ids: tuple[int, ...]) -> None:
297 | super().__init__(base._images, base._labels)
298 | self._base: SupervisedDataset[D] = base
299 | self._positive_ids: tuple[int, ...] = positive_ids
300 |
301 | @override
302 | def construct_new(self, images: D, labels: D) -> Self:
303 | raise NotImplementedError
304 |
305 | @override
306 | def load(self, idx: int) -> tuple[torch.Tensor, torch.Tensor]:
307 | image, label = self._base.load(idx)
308 | for pid in self._positive_ids:
309 | label[label == pid] = -1
310 | label[label > 0] = 0
311 | label[label == -1] = 1
312 | return image, label
313 |
--------------------------------------------------------------------------------
/mipcandy/data/inspection.py:
--------------------------------------------------------------------------------
1 | from dataclasses import dataclass, asdict
2 | from json import dump, load
3 | from os import PathLike
4 | from typing import Sequence, override, Callable, Self, Any
5 |
6 | import numpy as np
7 | import torch
8 | from rich.console import Console
9 | from rich.progress import Progress, SpinnerColumn
10 | from torch import nn
11 |
12 | from mipcandy.data.dataset import SupervisedDataset
13 | from mipcandy.data.geometric import crop
14 | from mipcandy.layer import HasDevice
15 | from mipcandy.types import Device, Shape, AmbiguousShape
16 |
17 |
18 | def format_bbox(bbox: Sequence[int]) -> tuple[int, int, int, int] | tuple[int, int, int, int, int, int]:
19 | if len(bbox) == 4:
20 | return bbox[0], bbox[1], bbox[2], bbox[3]
21 | elif len(bbox) == 6:
22 | return bbox[0], bbox[1], bbox[2], bbox[3], bbox[4], bbox[5]
23 | else:
24 | raise ValueError(f"Invalid bbox with {len(bbox)} elements")
25 |
26 |
27 | @dataclass
28 | class InspectionAnnotation(object):
29 | shape: AmbiguousShape
30 | foreground_bbox: tuple[int, int, int, int] | tuple[int, int, int, int, int, int]
31 | ids: tuple[int, ...]
32 |
33 | def foreground_shape(self) -> Shape:
34 | r = (self.foreground_bbox[1] - self.foreground_bbox[0], self.foreground_bbox[3] - self.foreground_bbox[2])
35 | return r if len(self.foreground_bbox) == 4 else r + (self.foreground_bbox[5] - self.foreground_bbox[4],)
36 |
37 | def center_of_foreground(self) -> tuple[int, int] | tuple[int, int, int]:
38 | r = (round((self.foreground_bbox[1] + self.foreground_bbox[0]) * .5),
39 | round((self.foreground_bbox[3] + self.foreground_bbox[2]) * .5))
40 | return r if len(self.shape) == 2 else r + (round((self.foreground_bbox[5] + self.foreground_bbox[4]) * .5),)
41 |
42 | def to_dict(self) -> dict[str, tuple[int, ...]]:
43 | return asdict(self)
44 |
45 |
46 | class InspectionAnnotations(HasDevice, Sequence[InspectionAnnotation]):
47 | def __init__(self, dataset: SupervisedDataset, background: int, *annotations: InspectionAnnotation,
48 | device: Device = "cpu") -> None:
49 | super().__init__(device)
50 | self._dataset: SupervisedDataset = dataset
51 | self._background: int = background
52 | self._annotations: tuple[InspectionAnnotation, ...] = annotations
53 | self._shapes: tuple[AmbiguousShape | None, AmbiguousShape, AmbiguousShape] | None = None
54 | self._foreground_shapes: tuple[AmbiguousShape | None, AmbiguousShape, AmbiguousShape] | None = None
55 | self._statistical_foreground_shape: Shape | None = None
56 | self._foreground_heatmap: torch.Tensor | None = None
57 | self._center_of_foregrounds: tuple[int, int] | tuple[int, int, int] | None = None
58 | self._foreground_offsets: tuple[int, int] | tuple[int, int, int] | None = None
59 | self._roi_shape: Shape | None = None
60 |
61 | def dataset(self) -> SupervisedDataset:
62 | return self._dataset
63 |
64 | def background(self) -> int:
65 | return self._background
66 |
67 | def annotations(self) -> tuple[InspectionAnnotation, ...]:
68 | return self._annotations
69 |
70 | @override
71 | def __getitem__(self, item: int) -> InspectionAnnotation:
72 | return self._annotations[item]
73 |
74 | @override
75 | def __len__(self) -> int:
76 | return len(self._annotations)
77 |
78 | def save(self, path: str | PathLike[str]) -> None:
79 | with open(path, "w") as f:
80 | dump({"background": self._background, "annotations": [a.to_dict() for a in self._annotations]}, f)
81 |
82 | def _get_shapes(self, get_shape: Callable[[InspectionAnnotation], AmbiguousShape]) -> tuple[
83 | AmbiguousShape | None, AmbiguousShape, AmbiguousShape]:
84 | depths = []
85 | widths = []
86 | heights = []
87 | for annotation in self._annotations:
88 | shape = get_shape(annotation)
89 | if len(shape) == 2:
90 | heights.append(shape[0])
91 | widths.append(shape[1])
92 | else:
93 | depths.append(shape[0])
94 | heights.append(shape[1])
95 | widths.append(shape[2])
96 | return tuple(depths) if depths else None, tuple(heights), tuple(widths)
97 |
98 | def shapes(self) -> tuple[AmbiguousShape | None, AmbiguousShape, AmbiguousShape]:
99 | if self._shapes:
100 | return self._shapes
101 | self._shapes = self._get_shapes(lambda annotation: annotation.shape)
102 | return self._shapes
103 |
104 | def foreground_shapes(self) -> tuple[AmbiguousShape | None, AmbiguousShape, AmbiguousShape]:
105 | if self._foreground_shapes:
106 | return self._foreground_shapes
107 | self._foreground_shapes = self._get_shapes(lambda annotation: annotation.foreground_shape())
108 | return self._foreground_shapes
109 |
110 | def statistical_foreground_shape(self, *, percentile: float = .95) -> Shape:
111 | if self._statistical_foreground_shape:
112 | return self._statistical_foreground_shape
113 | depths, heights, widths = self.foreground_shapes()
114 | percentile *= 100
115 | sfs = (round(np.percentile(heights, percentile)), round(np.percentile(widths, percentile)))
116 | self._statistical_foreground_shape = (round(np.percentile(depths, percentile)),) + sfs if depths else sfs
117 | return self._statistical_foreground_shape
118 |
119 | def crop_foreground(self, i: int, *, expand_ratio: float = 1) -> tuple[torch.Tensor, torch.Tensor]:
120 | image, label = self._dataset[i]
121 | annotation = self._annotations[i]
122 | bbox = list(annotation.foreground_bbox)
123 | shape = annotation.foreground_shape()
124 | for dim_idx, size in enumerate(shape):
125 | left = int((expand_ratio - 1) * size // 2)
126 | right = int((expand_ratio - 1) * size - left)
127 | bbox[dim_idx * 2] = max(0, bbox[dim_idx * 2] - left)
128 | bbox[dim_idx * 2 + 1] = min(bbox[dim_idx * 2 + 1] + right, annotation.shape[dim_idx])
129 | return crop(image.unsqueeze(0), bbox).squeeze(0), crop(label.unsqueeze(0), bbox).squeeze(0)
130 |
131 | def foreground_heatmap(self) -> torch.Tensor:
132 | if self._foreground_heatmap:
133 | return self._foreground_heatmap
134 | depths, heights, widths = self.foreground_shapes()
135 | max_shape = (max(depths), max(heights), max(widths)) if depths else (max(heights), max(widths))
136 | accumulated_label = torch.zeros((1, *max_shape), device=self._device)
137 | for i, (_, label) in enumerate(self._dataset):
138 | annotation = self._annotations[i]
139 | paddings = [0, 0, 0, 0]
140 | shape = annotation.foreground_shape()
141 | for j, size in enumerate(max_shape):
142 | left = (size - shape[j]) // 2
143 | right = size - shape[j] - left
144 | paddings.append(right)
145 | paddings.append(left)
146 | paddings.reverse()
147 | accumulated_label += nn.functional.pad(
148 | crop((label != self._background).unsqueeze(0), annotation.foreground_bbox), paddings
149 | ).squeeze(0)
150 | self._foreground_heatmap = accumulated_label.squeeze(0)
151 | return self._foreground_heatmap
152 |
153 | def center_of_foregrounds(self) -> tuple[int, int] | tuple[int, int, int]:
154 | if self._center_of_foregrounds:
155 | return self._center_of_foregrounds
156 | heatmap = self.foreground_heatmap()
157 | center = (heatmap.sum(dim=1).argmax().item(), heatmap.sum(dim=0).argmax().item()) if heatmap.ndim == 2 else (
158 | heatmap.sum(dim=(1, 2)).argmax().item(),
159 | heatmap.sum(dim=(0, 2)).argmax().item(),
160 | heatmap.sum(dim=(0, 1)).argmax().item(),
161 | )
162 | self._center_of_foregrounds = center
163 | return self._center_of_foregrounds
164 |
165 | def center_of_foregrounds_offsets(self) -> tuple[int, int] | tuple[int, int, int]:
166 | if self._foreground_offsets:
167 | return self._foreground_offsets
168 | center = self.center_of_foregrounds()
169 | depths, heights, widths = self.foreground_shapes()
170 | max_shape = (max(depths), max(heights), max(widths)) if depths else (max(heights), max(widths))
171 | offsets = (round(center[0] - max_shape[0] * .5), round(center[1] - max_shape[1] * .5))
172 | self._foreground_offsets = offsets + (round(center[2] - max_shape[2] * .5),) if depths else offsets
173 | return self._foreground_offsets
174 |
175 | def set_roi_shape(self, roi_shape: Shape | None) -> None:
176 | if roi_shape is not None:
177 | depths, heights, widths = self.shapes()
178 | if depths:
179 | if roi_shape[0] > min(depths) or roi_shape[1] > min(heights) or roi_shape[2] > min(widths):
180 | raise ValueError(f"ROI shape {roi_shape} exceeds minimum image shape ({min(depths)}, {min(heights)}, {min(widths)})")
181 | else:
182 | if roi_shape[0] > min(heights) or roi_shape[1] > min(widths):
183 | raise ValueError(f"ROI shape {roi_shape} exceeds minimum image shape ({min(heights)}, {min(widths)})")
184 | self._roi_shape = roi_shape
185 |
186 | def roi_shape(self, *, percentile: float = .95) -> Shape:
187 | if self._roi_shape:
188 | return self._roi_shape
189 | sfs = self.statistical_foreground_shape(percentile=percentile)
190 | if len(sfs) == 2:
191 | sfs = (None, *sfs)
192 | depths, heights, widths = self.shapes()
193 | roi_shape = (min(min(heights), sfs[1]), min(min(widths), sfs[2]))
194 | if depths:
195 | roi_shape = (min(min(depths), sfs[0]),) + roi_shape
196 | self._roi_shape = roi_shape
197 | return self._roi_shape
198 |
199 | def roi(self, i: int, *, percentile: float = .95) -> tuple[int, int, int, int] | tuple[
200 | int, int, int, int, int, int]:
201 | annotation = self._annotations[i]
202 | roi_shape = self.roi_shape(percentile=percentile)
203 | offsets = self.center_of_foregrounds_offsets()
204 | center = annotation.center_of_foreground()
205 | roi = []
206 | for i, position in enumerate(center):
207 | left = roi_shape[i] // 2
208 | right = roi_shape[i] - left
209 | offset = min(max(offsets[i], left - position), annotation.shape[i] - right - position)
210 | roi.append(position + offset - left)
211 | roi.append(position + offset + right)
212 | return tuple(roi)
213 |
214 | def crop_roi(self, i: int, *, percentile: float = .95) -> tuple[torch.Tensor, torch.Tensor]:
215 | image, label = self._dataset[i]
216 | roi = self.roi(i, percentile=percentile)
217 | return crop(image.unsqueeze(0), roi).squeeze(0), crop(label.unsqueeze(0), roi).squeeze(0)
218 |
219 |
220 | def _lists_to_tuples(pairs: Sequence[tuple[str, Any]]) -> dict[str, Any]:
221 | return {k: tuple(v) if isinstance(v, list) else v for k, v in pairs}
222 |
223 |
224 | def load_inspection_annotations(path: str | PathLike[str], dataset: SupervisedDataset) -> InspectionAnnotations:
225 | with open(path) as f:
226 | obj = load(f, object_pairs_hook=_lists_to_tuples)
227 | return InspectionAnnotations(dataset, obj["background"], *(
228 | InspectionAnnotation(**row) for row in obj["annotations"]
229 | ))
230 |
231 |
232 | def inspect(dataset: SupervisedDataset, *, background: int = 0, console: Console = Console()) -> InspectionAnnotations:
233 | r = []
234 | with Progress(*Progress.get_default_columns(), SpinnerColumn(), console=console) as progress:
235 | task = progress.add_task("Inspecting dataset...", total=len(dataset))
236 | for _, label in dataset:
237 | progress.update(task, advance=1, description=f"Inspecting dataset {tuple(label.shape)}")
238 | indices = (label != background).nonzero()
239 | mins = indices.min(dim=0)[0].tolist()
240 | maxs = indices.max(dim=0)[0].tolist()
241 | bbox = (mins[1], maxs[1] + 1, mins[2], maxs[2] + 1)
242 | r.append(InspectionAnnotation(
243 | label.shape[1:], bbox if label.ndim == 3 else bbox + (mins[3], maxs[3] + 1), tuple(label.unique())
244 | ))
245 | return InspectionAnnotations(dataset, background, *r, device=dataset.device())
246 |
247 |
248 | class ROIDataset(SupervisedDataset[list[torch.Tensor]]):
249 | def __init__(self, annotations: InspectionAnnotations, *, percentile: float = .95) -> None:
250 | super().__init__([], [])
251 | self._annotations: InspectionAnnotations = annotations
252 | self._percentile: float = percentile
253 |
254 | @override
255 | def __len__(self) -> int:
256 | return len(self._annotations)
257 |
258 | @override
259 | def construct_new(self, images: list[torch.Tensor], labels: list[torch.Tensor]) -> Self:
260 | return self.__class__(self._annotations, percentile=self._percentile)
261 |
262 | @override
263 | def load(self, idx: int) -> tuple[torch.Tensor, torch.Tensor]:
264 | return self._annotations.crop_roi(idx, percentile=self._percentile)
265 |
266 |
267 | class RandomROIDataset(ROIDataset):
268 | def __init__(self, annotations: InspectionAnnotations, *, percentile: float = .95,
269 | foreground_oversample_percentage: float = .33, min_foreground_samples: int = 500,
270 | max_foreground_samples: int = 10000, min_percent_coverage: float = .01) -> None:
271 | super().__init__(annotations, percentile=percentile)
272 | self._fg_oversample: float = foreground_oversample_percentage
273 | self._min_fg_samples: int = min_foreground_samples
274 | self._max_fg_samples: int = max_foreground_samples
275 | self._min_coverage: float = min_percent_coverage
276 | self._fg_locations_cache: dict[int, tuple[tuple[int, ...], ...] | None] = {}
277 |
278 | def _get_foreground_locations(self, idx: int) -> tuple[tuple[int, ...], ...] | None:
279 | if idx not in self._fg_locations_cache:
280 | _, label = self._annotations.dataset()[idx]
281 | indices = (label != self._annotations.background()).nonzero()[:, 1:]
282 | if len(indices) == 0:
283 | self._fg_locations_cache[idx] = None
284 | elif len(indices) <= self._min_fg_samples:
285 | self._fg_locations_cache[idx] = tuple(tuple(coord.tolist()) for coord in indices)
286 | else:
287 | target_samples = min(
288 | self._max_fg_samples,
289 | max(self._min_fg_samples, int(np.ceil(len(indices) * self._min_coverage)))
290 | )
291 | sampled_idx = torch.randperm(len(indices))[:target_samples]
292 | sampled = indices[sampled_idx]
293 | self._fg_locations_cache[idx] = tuple(tuple(coord.tolist()) for coord in sampled)
294 | return self._fg_locations_cache[idx]
295 |
296 | def _random_roi(self, idx: int) -> tuple[int, int, int, int] | tuple[int, int, int, int, int, int]:
297 | annotation = self._annotations[idx]
298 | roi_shape = self._annotations.roi_shape(percentile=self._percentile)
299 | roi = []
300 | for dim_size, patch_size in zip(annotation.shape, roi_shape):
301 | left = patch_size // 2
302 | right = patch_size - left
303 | min_center = left
304 | max_center = dim_size - right
305 | center = torch.randint(min_center, max_center + 1, (1,)).item()
306 | roi.append(center - left)
307 | roi.append(center + right)
308 | return tuple(roi)
309 |
310 | def _foreground_guided_random_roi(self, idx: int) -> tuple[int, int, int, int] | tuple[
311 | int, int, int, int, int, int]:
312 | annotation = self._annotations[idx]
313 | roi_shape = self._annotations.roi_shape(percentile=self._percentile)
314 | foreground_locations = self._get_foreground_locations(idx)
315 |
316 | if foreground_locations is None or len(foreground_locations) == 0:
317 | return self._random_roi(idx)
318 |
319 | fg_idx = torch.randint(0, len(foreground_locations), (1,)).item()
320 | fg_position = foreground_locations[fg_idx]
321 |
322 | roi = []
323 | for fg_pos, dim_size, patch_size in zip(fg_position, annotation.shape, roi_shape):
324 | left = patch_size // 2
325 | right = patch_size - left
326 | center = max(left, min(fg_pos, dim_size - right))
327 | roi.append(center - left)
328 | roi.append(center + right)
329 | return tuple(roi)
330 |
331 | @override
332 | def construct_new(self, images: list[torch.Tensor], labels: list[torch.Tensor]) -> Self:
333 | return self.__class__(self._annotations, percentile=self._percentile,
334 | foreground_oversample_percentage=self._fg_oversample,
335 | min_foreground_samples=self._min_fg_samples,
336 | max_foreground_samples=self._max_fg_samples,
337 | min_percent_coverage=self._min_coverage)
338 |
339 | @override
340 | def load(self, idx: int) -> tuple[torch.Tensor, torch.Tensor]:
341 | image, label = self._annotations.dataset()[idx]
342 | force_fg = torch.rand(1).item() < self._fg_oversample
343 | if force_fg:
344 | roi = self._foreground_guided_random_roi(idx)
345 | else:
346 | roi = self._random_roi(idx)
347 | return crop(image.unsqueeze(0), roi).squeeze(0), crop(label.unsqueeze(0), roi).squeeze(0)
348 |
--------------------------------------------------------------------------------
/mipcandy/training.py:
--------------------------------------------------------------------------------
1 | from abc import ABCMeta, abstractmethod
2 | from dataclasses import dataclass
3 | from datetime import datetime
4 | from hashlib import md5
5 | from json import load, dump
6 | from os import PathLike, urandom, makedirs, environ
7 | from os.path import exists
8 | from random import seed as random_seed, randint
9 | from shutil import copy
10 | from threading import Lock
11 | from time import time
12 | from typing import Sequence, override, Callable, Self
13 |
14 | import numpy as np
15 | import torch
16 | from matplotlib import pyplot as plt
17 | from pandas import DataFrame, read_csv
18 | from rich.console import Console
19 | from rich.progress import Progress, SpinnerColumn
20 | from rich.table import Table
21 | from torch import nn, optim
22 | from torch.utils.data import DataLoader
23 |
24 | from mipcandy.common import Pad2d, Pad3d, quotient_regression, quotient_derivative, quotient_bounds
25 | from mipcandy.config import load_settings, load_secrets
26 | from mipcandy.frontend import Frontend
27 | from mipcandy.layer import WithPaddingModule, WithNetwork
28 | from mipcandy.sanity_check import sanity_check
29 | from mipcandy.sliding_window import SWMetadata, SlidingWindow
30 | from mipcandy.types import Params, Setting, AmbiguousShape
31 |
32 |
33 | def try_append(new: float, to: dict[str, list[float]], key: str) -> None:
34 | if key in to:
35 | to[key].append(new)
36 | else:
37 | to[key] = [new]
38 |
39 |
40 | def try_append_all(new: dict[str, float], to: dict[str, list[float]]) -> None:
41 | for key, value in new.items():
42 | try_append(value, to, key)
43 |
44 |
45 | @dataclass
46 | class TrainerToolbox(object):
47 | model: nn.Module
48 | optimizer: optim.Optimizer
49 | scheduler: optim.lr_scheduler.LRScheduler
50 | criterion: nn.Module
51 | ema: optim.swa_utils.AveragedModel | None = None
52 |
53 |
54 | @dataclass
55 | class TrainerTracker(object):
56 | epoch: int = 0
57 | best_score: float = float("-inf")
58 | worst_case: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None
59 |
60 |
61 | class Trainer(WithPaddingModule, WithNetwork, metaclass=ABCMeta):
62 | def __init__(self, trainer_folder: str | PathLike[str], dataloader: DataLoader[tuple[torch.Tensor, torch.Tensor]],
63 | validation_dataloader: DataLoader[tuple[torch.Tensor, torch.Tensor]], *, recoverable: bool = True,
64 | device: torch.device | str = "cpu", console: Console = Console()) -> None:
65 | WithPaddingModule.__init__(self, device)
66 | WithNetwork.__init__(self, device)
67 | self._trainer_folder: str = trainer_folder
68 | self._trainer_variant: str = self.__class__.__name__
69 | self._experiment_id: str = "tbd"
70 | self._dataloader: DataLoader[tuple[torch.Tensor, torch.Tensor]] = dataloader
71 | self._validation_dataloader: DataLoader[tuple[torch.Tensor, torch.Tensor]] = validation_dataloader
72 | self._unrecoverable: bool | None = not recoverable # None if the trainer is recovered
73 | self._console: Console = console
74 | self._metrics: dict[str, list[float]] = {}
75 | self._epoch_metrics: dict[str, list[float]] = {}
76 | self._frontend: Frontend = Frontend({})
77 | self._lock: Lock = Lock()
78 | self._tracker: TrainerTracker = TrainerTracker()
79 |
80 | # Recovery methods (PR #108 at https://github.com/ProjectNeura/MIPCandy/pull/108)
81 |
82 | def save_everything_for_recovery(self, toolbox: TrainerToolbox, tracker: TrainerTracker,
83 | **training_arguments) -> None:
84 | if self._unrecoverable:
85 | return
86 | torch.save(toolbox.optimizer.state_dict(), f"{self.experiment_folder()}/optimizer.pth")
87 | torch.save(toolbox.scheduler.state_dict(), f"{self.experiment_folder()}/scheduler.pth")
88 | torch.save(toolbox.criterion.state_dict(), f"{self.experiment_folder()}/criterion.pth")
89 | torch.save(tracker, f"{self.experiment_folder()}/tracker.pt")
90 | with open(f"{self.experiment_folder()}/training_arguments.json", "w") as f:
91 | dump(training_arguments, f)
92 |
93 | def load_tracker(self) -> TrainerTracker:
94 | return torch.load(f"{self.experiment_folder()}/tracker.pt", weights_only=False)
95 |
96 | def load_training_arguments(self) -> dict[str, Setting]:
97 | with open(f"{self.experiment_folder()}/training_arguments.json") as f:
98 | return load(f)
99 |
100 | def load_metrics(self) -> dict[str, list[float]]:
101 | df = read_csv(f"{self.experiment_folder()}/metrics.csv", index_col="epoch")
102 | return {column: df[column].astype(float).tolist() for column in df.columns}
103 |
104 | def load_toolbox(self, num_epochs: int, example_shape: AmbiguousShape) -> TrainerToolbox:
105 | toolbox = self._build_toolbox(num_epochs, example_shape, model=self.load_model(
106 | example_shape, checkpoint=torch.load(f"{self.experiment_folder()}/checkpoint_latest.pth")
107 | ))
108 | toolbox.optimizer.load_state_dict(torch.load(f"{self.experiment_folder()}/optimizer.pth"))
109 | toolbox.scheduler.load_state_dict(torch.load(f"{self.experiment_folder()}/scheduler.pth"))
110 | toolbox.criterion.load_state_dict(torch.load(f"{self.experiment_folder()}/criterion.pth"))
111 | return toolbox
112 |
113 | def recover_from(self, experiment_id: str) -> Self:
114 | self._experiment_id = experiment_id
115 | if not exists(self.experiment_folder()):
116 | raise FileNotFoundError(f"Experiment folder {self.experiment_folder()} not found")
117 | self._metrics = self.load_metrics()
118 | self._tracker = self.load_tracker()
119 | self._unrecoverable = None
120 | return self
121 |
122 | def continue_training(self, num_epochs: int) -> None:
123 | if not self.recovery():
124 | raise RuntimeError("Must call `recover_from()` before continuing training")
125 | self.train(num_epochs, **self.load_training_arguments())
126 |
127 | # Getters
128 |
129 | def trainer_folder(self) -> str:
130 | return self._trainer_folder
131 |
132 | def trainer_variant(self) -> str:
133 | return self._trainer_variant
134 |
135 | def experiment_id(self) -> str:
136 | return self._experiment_id
137 |
138 | def dataloader(self) -> DataLoader[tuple[torch.Tensor, torch.Tensor]]:
139 | return self._dataloader
140 |
141 | def validation_dataloader(self) -> DataLoader[tuple[torch.Tensor, torch.Tensor]]:
142 | return self._validation_dataloader
143 |
144 | def console(self) -> Console:
145 | return self._console
146 |
147 | def metrics(self) -> dict[str, list[float]]:
148 | return self._metrics.copy()
149 |
150 | def frontend(self) -> Frontend:
151 | return self._frontend
152 |
153 | def tracker(self) -> TrainerTracker:
154 | return self._tracker
155 |
156 | # Enhanced getters
157 |
158 | def initialized(self) -> bool:
159 | return self._experiment_id != "tbd"
160 |
161 | def recovery(self) -> bool:
162 | return self._unrecoverable is None
163 |
164 | def experiment_folder(self) -> str:
165 | return f"{self._trainer_folder}/{self._trainer_variant}/{self._experiment_id}"
166 |
167 | def predict_maximum_validation_score(self, num_epochs: int, *, degree: int = 5) -> tuple[int, float]:
168 | val_scores = np.array(self._metrics["val score"])
169 | a, b = quotient_regression(np.arange(len(val_scores)), val_scores, degree, degree)
170 | da, db = quotient_derivative(a, b)
171 | max_roc = float(da[0] / db[0])
172 | max_val_score = float(a[0] / b[0])
173 | bounds = quotient_bounds(a, b, None, max_val_score * (1 - max_roc), x_start=0, x_stop=num_epochs, x_step=1)
174 | return (round(bounds[1]) + 1, max_val_score) if bounds else (0, 0)
175 |
176 | def etc(self, epoch: int, num_epochs: int, *, target_epoch: int | None = None,
177 | val_score_prediction_degree: int = 5) -> float:
178 | if not target_epoch:
179 | target_epoch, _ = self.predict_maximum_validation_score(num_epochs, degree=val_score_prediction_degree)
180 | epoch_durations = self._metrics["epoch duration"]
181 | return sum(epoch_durations) * (target_epoch - epoch) / len(epoch_durations)
182 |
183 | # Setters
184 |
185 | def set_frontend(self, frontend: type[Frontend], *, path_to_secrets: str | PathLike[str] | None = None) -> None:
186 | self._frontend = frontend(load_secrets(path=path_to_secrets) if path_to_secrets else load_secrets())
187 |
188 | def set_seed(self, seed: int) -> None:
189 | np.random.seed(seed)
190 | torch.manual_seed(seed)
191 | torch.cuda.manual_seed(seed)
192 | torch.cuda.manual_seed_all(seed)
193 | torch.backends.cudnn.benchmark = False
194 | torch.backends.cudnn.deterministic = True
195 | random_seed(seed)
196 | np.random.seed(seed)
197 | environ['PYTHONHASHSEED'] = str(seed)
198 | if self.initialized():
199 | self.log(f"Set to manual seed {seed}")
200 |
201 | # Initialization methods
202 |
203 | def _allocate_experiment_folder(self) -> str:
204 | self._experiment_id = datetime.now().strftime("%Y%m%d-%H-") + md5(urandom(8)).hexdigest()[:4]
205 | experiment_folder = self.experiment_folder()
206 | return self._allocate_experiment_folder() if exists(experiment_folder) else experiment_folder
207 |
208 | def allocate_experiment_folder(self) -> str:
209 | return self.experiment_folder() if self.initialized() else self._allocate_experiment_folder()
210 |
211 | def init_experiment(self) -> None:
212 | if self.recovery():
213 | self.log(f"Training progress recovered from {self._experiment_id} from epoch {self._tracker.epoch}")
214 | return
215 | if self.initialized():
216 | raise RuntimeError("Experiment already initialized")
217 | makedirs(self._trainer_folder, exist_ok=True)
218 | experiment_folder = self.allocate_experiment_folder()
219 | makedirs(experiment_folder)
220 | t = datetime.now()
221 | with open(f"{experiment_folder}/logs.txt", "w") as f:
222 | f.write(f"File created by FightTumor, copyright (C) {t.year} Project Neura. All rights reserved\n")
223 | self.log(f"Experiment (ID {self._experiment_id}) created at {t}")
224 | self.log(f"Trainer: {self.__class__.__name__}")
225 |
226 | # Logging utilities
227 |
228 | def log(self, msg: str, *, on_screen: bool = True) -> None:
229 | msg = f"[{datetime.now()}] {msg}"
230 | if self.initialized():
231 | with open(f"{self.experiment_folder()}/logs.txt", "a") as f:
232 | f.write(f"{msg}\n")
233 | if on_screen:
234 | with self._lock:
235 | self._console.print(msg)
236 |
237 | def record(self, metric: str, value: float) -> None:
238 | try_append(value, self._epoch_metrics, metric)
239 |
240 | def _record(self, metric: str, value: float) -> None:
241 | try_append(value, self._metrics, metric)
242 |
243 | def record_all(self, metrics: dict[str, float]) -> None:
244 | try_append_all(metrics, self._epoch_metrics)
245 |
246 | def _bump_metrics(self) -> None:
247 | for metric, values in self._epoch_metrics.items():
248 | epoch_overall = sum(values) / len(values)
249 | try_append(epoch_overall, self._metrics, metric)
250 | self._epoch_metrics.clear()
251 |
252 | def save_metrics(self) -> None:
253 | df = DataFrame(self._metrics)
254 | df.index = range(1, len(df) + 1)
255 | df.index.name = "epoch"
256 | df.to_csv(f"{self.experiment_folder()}/metrics.csv")
257 |
258 | def save_metric_curve(self, name: str, values: Sequence[float]) -> None:
259 | name = name.capitalize()
260 | plt.plot(values)
261 | plt.title(f"{name} over Epoch")
262 | plt.xlabel("Epoch")
263 | plt.ylabel(name)
264 | plt.grid()
265 | plt.savefig(f"{self.experiment_folder()}/{name.lower()}.png")
266 | plt.close()
267 |
268 | def save_metric_curve_combo(self, metrics: dict[str, Sequence[float]], *, title: str = "All Metrics") -> None:
269 | for name, values in metrics.items():
270 | plt.plot(values, label=name.capitalize())
271 | plt.title(title)
272 | plt.xlabel("Epoch")
273 | plt.legend()
274 | plt.grid()
275 | plt.savefig(f"{self.experiment_folder()}/{title.lower()}.png")
276 | plt.close()
277 |
278 | def save_metric_curves(self, *, names: Sequence[str] | None = None) -> None:
279 | if names is None:
280 | for name, values in self._metrics.items():
281 | self.save_metric_curve(name, values)
282 | else:
283 | for name in names:
284 | self.save_metric_curve(name, self._metrics[name])
285 |
286 | def save_progress(self, *, names: Sequence[str] = ("combined loss", "val score")) -> None:
287 | self.save_metric_curve_combo({name: self._metrics[name] for name in names}, title="Progress")
288 |
289 | def save_preview(self, image: torch.Tensor, label: torch.Tensor, output: torch.Tensor, *,
290 | quality: float = .75) -> None:
291 | ...
292 |
293 | def show_metrics(self, epoch: int, *, metrics: dict[str, list[float]] | None = None, prefix: str = "training",
294 | epochwise: bool = True, skip: Callable[[str, list[float]], bool] | None = None) -> None:
295 | if not metrics:
296 | metrics = self._metrics
297 | prefix = prefix.capitalize()
298 | table = Table(title=f"Epoch {epoch} {prefix}")
299 | table.add_column("Metric")
300 | table.add_column("Mean Value", style="green")
301 | table.add_column("Span", style="cyan")
302 | table.add_column("Diff", style="magenta")
303 | for metric, values in metrics.items():
304 | if skip and skip(metric, values):
305 | continue
306 | span = f"[{min(values):.4f}, {max(values):.4f}]"
307 | if epochwise:
308 | value = f"{values[-1]:.4f}"
309 | diff = f"{values[-1] - values[-2]:+.4f}" if len(values) > 1 else "N/A"
310 | else:
311 | mean = sum(values) / len(values)
312 | value = f"{mean:.4f}"
313 | diff = f"{mean - self._metrics[metric][-1]:+.4f}" if metric in self._metrics else "N/A"
314 | table.add_row(metric, value, span, diff)
315 | self.log(f"{prefix} {metric}: {value} @{span} ({diff})")
316 | console = Console()
317 | console.print(table)
318 |
319 | # Builder interfaces
320 |
321 | @abstractmethod
322 | def build_optimizer(self, params: Params) -> optim.Optimizer:
323 | raise NotImplementedError
324 |
325 | @abstractmethod
326 | def build_scheduler(self, optimizer: optim.Optimizer, num_epochs: int) -> optim.lr_scheduler.LRScheduler:
327 | raise NotImplementedError
328 |
329 | @abstractmethod
330 | def build_criterion(self) -> nn.Module:
331 | raise NotImplementedError
332 |
333 | def _build_toolbox(self, num_epochs: int, example_shape: AmbiguousShape, *,
334 | model: nn.Module | None = None) -> TrainerToolbox:
335 | if not model:
336 | model = self.load_model(example_shape)
337 | optimizer = self.build_optimizer(model.parameters())
338 | scheduler = self.build_scheduler(optimizer, num_epochs)
339 | criterion = self.build_criterion().to(self._device)
340 | return TrainerToolbox(model, optimizer, scheduler, criterion)
341 |
342 | def build_toolbox(self, num_epochs: int, example_shape: AmbiguousShape) -> TrainerToolbox:
343 | return self._build_toolbox(num_epochs, example_shape)
344 |
345 | # Training methods
346 |
347 | @abstractmethod
348 | def backward(self, images: torch.Tensor, labels: torch.Tensor, toolbox: TrainerToolbox) -> tuple[float, dict[
349 | str, float]]:
350 | raise NotImplementedError
351 |
352 | def train_batch(self, images: torch.Tensor, labels: torch.Tensor, toolbox: TrainerToolbox) -> tuple[float, dict[
353 | str, float]]:
354 | toolbox.optimizer.zero_grad()
355 | loss, metrics = self.backward(images, labels, toolbox)
356 | toolbox.optimizer.step()
357 | toolbox.scheduler.step()
358 | if toolbox.ema:
359 | toolbox.ema.update_parameters(toolbox.model)
360 | return loss, metrics
361 |
362 | def train_epoch(self, epoch: int, toolbox: TrainerToolbox) -> None:
363 | toolbox.model.train()
364 | if toolbox.ema:
365 | toolbox.ema.train()
366 | with Progress(*Progress.get_default_columns(), SpinnerColumn(), console=self._console) as progress:
367 | epoch_prog = progress.add_task(f"Epoch {epoch}", total=len(self._dataloader))
368 | for images, labels in self._dataloader:
369 | images, labels = images.to(self._device), labels.to(self._device)
370 | padding_module = self.get_padding_module()
371 | if padding_module:
372 | images, labels = padding_module(images), padding_module(labels)
373 | progress.update(epoch_prog, description=f"Training epoch {epoch} {tuple(images.shape)}")
374 | loss, metrics = self.train_batch(images, labels, toolbox)
375 | self.record("combined loss", loss)
376 | self.record_all(metrics)
377 | progress.update(epoch_prog, advance=1, description=f"Training epoch {epoch} ({loss:.4f})")
378 | self._bump_metrics()
379 |
380 | def train(self, num_epochs: int, *, note: str = "", num_checkpoints: int = 5, ema: bool = True,
381 | seed: int | None = None, early_stop_tolerance: int = 5, val_score_prediction: bool = True,
382 | val_score_prediction_degree: int = 5, save_preview: bool = True, preview_quality: float = .75) -> None:
383 | training_arguments = self.filter_train_params(**locals())
384 | self.init_experiment()
385 | if note:
386 | self.log(f"Note: {note}")
387 | if seed is None:
388 | seed = randint(0, 100)
389 | self.set_seed(seed)
390 | example_input = self._dataloader.dataset[0][0].to(self._device).unsqueeze(0)
391 | padding_module = self.get_padding_module()
392 | if padding_module:
393 | example_input = padding_module(example_input)
394 | example_shape = tuple(example_input.shape[1:])
395 | self.log(f"Example input shape: {example_shape}")
396 | toolbox = self.load_toolbox(num_epochs, example_shape) if self.recovery() else self.build_toolbox(
397 | num_epochs, example_shape)
398 | model_name = toolbox.model.__class__.__name__
399 | sanity_check_result = sanity_check(toolbox.model, example_shape, device=self._device)
400 | self.log(f"Model: {model_name}")
401 | self.log(str(sanity_check_result))
402 | self.log(f"Example output shape: {tuple(sanity_check_result.output.shape)}")
403 | if ema:
404 | toolbox.ema = optim.swa_utils.AveragedModel(toolbox.model)
405 | checkpoint_path = lambda v: f"{self.experiment_folder()}/checkpoint_{v}.pth"
406 | es_tolerance = early_stop_tolerance
407 | self._frontend.on_experiment_created(self._experiment_id, self._trainer_variant, model_name, note,
408 | sanity_check_result.num_macs, sanity_check_result.num_params, num_epochs,
409 | early_stop_tolerance)
410 | try:
411 | for epoch in range(self._tracker.epoch, self._tracker.epoch + num_epochs):
412 | if early_stop_tolerance == -1:
413 | epoch -= 1
414 | self.log(f"Early stopping triggered because the validation score has not improved for {
415 | es_tolerance} epochs")
416 | break
417 | self._tracker.epoch = epoch
418 | # Training
419 | t0 = time()
420 | self.train_epoch(epoch, toolbox)
421 | lr = toolbox.scheduler.get_last_lr()[0]
422 | self._record("learning rate", lr)
423 | self.show_metrics(epoch, skip=lambda m, _: m.startswith("val ") or m == "epoch duration")
424 | torch.save(toolbox.model.state_dict(), checkpoint_path("latest"))
425 | if epoch % (num_epochs / num_checkpoints) == 0:
426 | copy(checkpoint_path("latest"), checkpoint_path(epoch))
427 | self.log(f"Epoch {epoch} checkpoint saved")
428 | self.log(f"Epoch {epoch} training completed in {time() - t0:.1f} seconds")
429 | # Validation
430 | score, metrics = self.validate(toolbox)
431 | self._record("val score", score)
432 | msg = f"Validation score: {score:.4f}"
433 | if epoch > 1:
434 | msg += f" ({score - self._metrics["val score"][-2]:+.4f})"
435 | self.log(msg)
436 | if val_score_prediction and epoch > val_score_prediction_degree:
437 | target_epoch, max_score = self.predict_maximum_validation_score(
438 | num_epochs, degree=val_score_prediction_degree
439 | )
440 | self.log(f"Maximum validation score {max_score:.4f} predicted at epoch {target_epoch}")
441 | etc = self.etc(epoch, num_epochs, target_epoch=target_epoch)
442 | self.log(f"Estimated time of completion in {etc:.1f} seconds at {datetime.fromtimestamp(
443 | time() + etc):%m-%d %H:%M:%S}")
444 | self.show_metrics(epoch, metrics=metrics, prefix="validation", epochwise=False)
445 | if score > self._tracker.best_score:
446 | copy(checkpoint_path("latest"), checkpoint_path("best"))
447 | self.log(f"======== Best checkpoint updated ({self._tracker.best_score:.4f} -> {
448 | score:.4f}) ========")
449 | self._tracker.best_score = score
450 | early_stop_tolerance = es_tolerance
451 | if save_preview:
452 | self.save_preview(*self._tracker.worst_case, quality=preview_quality)
453 | else:
454 | early_stop_tolerance -= 1
455 | epoch_duration = time() - t0
456 | self._record("epoch duration", epoch_duration)
457 | self.log(f"Epoch {epoch} completed in {epoch_duration:.1f} seconds")
458 | self.log(f"=============== Best Validation Score {self._tracker.best_score:.4f} ===============")
459 | self.save_metrics()
460 | self.save_progress()
461 | self.save_metric_curves()
462 | self.save_everything_for_recovery(toolbox, self._tracker, **training_arguments)
463 | self._frontend.on_experiment_updated(self._experiment_id, epoch, self._metrics, early_stop_tolerance)
464 | except Exception as e:
465 | self.log("Training interrupted")
466 | self.log(repr(e))
467 | self._frontend.on_experiment_interrupted(self._experiment_id, e)
468 | raise e
469 | else:
470 | self.log("Training completed")
471 | self._frontend.on_experiment_completed(self._experiment_id)
472 |
473 | @staticmethod
474 | def filter_train_params(**kwargs) -> dict[str, Setting]:
475 | return {k: v for k, v in kwargs.items() if k in (
476 | "note", "num_checkpoints", "ema", "seed", "early_stop_tolerance", "val_score_prediction",
477 | "val_score_prediction_degree", "save_preview", "preview_quality"
478 | )}
479 |
480 | def train_with_settings(self, num_epochs: int, **kwargs) -> None:
481 | settings = self.filter_train_params(**load_settings())
482 | settings.update(kwargs)
483 | self.train(num_epochs, **settings)
484 |
485 | # Validation methods
486 |
487 | @abstractmethod
488 | def validate_case(self, image: torch.Tensor, label: torch.Tensor, toolbox: TrainerToolbox) -> tuple[float, dict[
489 | str, float], torch.Tensor]:
490 | raise NotImplementedError
491 |
492 | def validate(self, toolbox: TrainerToolbox) -> tuple[float, dict[str, list[float]]]:
493 | if self._validation_dataloader.batch_size != 1:
494 | raise RuntimeError("Validation dataloader should have batch size 1")
495 | toolbox.model.eval()
496 | if toolbox.ema:
497 | toolbox.ema.eval()
498 | score = 0
499 | worst_score = float("+inf")
500 | metrics = {}
501 | num_cases = len(self._validation_dataloader)
502 | with torch.no_grad(), Progress(
503 | *Progress.get_default_columns(), SpinnerColumn(), console=self._console
504 | ) as progress:
505 | val_prog = progress.add_task(f"Validating", total=num_cases)
506 | for image, label in self._validation_dataloader:
507 | image, label = image.to(self._device), label.to(self._device)
508 | padding_module = self.get_padding_module()
509 | if padding_module:
510 | image, label = padding_module(image), padding_module(label)
511 | image, label = image.squeeze(0), label.squeeze(0)
512 | progress.update(val_prog, description=f"Validating {tuple(image.shape)}")
513 | case_score, case_metrics, output = self.validate_case(image, label, toolbox)
514 | score += case_score
515 | if case_score < worst_score:
516 | self._tracker.worst_case = (image, label, output)
517 | worst_score = case_score
518 | try_append_all(case_metrics, metrics)
519 | progress.update(val_prog, advance=1, description=f"Validating ({case_score:.4f})")
520 | return score / num_cases, metrics
521 |
522 | def __call__(self, *args, **kwargs) -> None:
523 | self.train(*args, **kwargs)
524 |
525 | @override
526 | def __str__(self) -> str:
527 | return f"{self.__class__.__name__} {self._experiment_id}"
528 |
529 |
530 | class SlidingTrainer(Trainer, SlidingWindow, metaclass=ABCMeta):
531 | @override
532 | def build_padding_module(self) -> nn.Module | None:
533 | window_shape = self.get_window_shape()
534 | return (Pad2d if len(window_shape) == 2 else Pad3d)(window_shape)
535 |
536 | @abstractmethod
537 | def validate_case_windowed(self, images: torch.Tensor, label: torch.Tensor, toolbox: TrainerToolbox,
538 | metadata: SWMetadata) -> tuple[float, dict[str, float], torch.Tensor]:
539 | raise NotImplementedError
540 |
541 | @override
542 | def validate_case(self, image: torch.Tensor, label: torch.Tensor, toolbox: TrainerToolbox) -> tuple[float, dict[
543 | str, float], torch.Tensor]:
544 | images, metadata = self.do_sliding_window(image.unsqueeze(0))
545 | return self.validate_case_windowed(images, label, toolbox, metadata)
546 |
547 | @abstractmethod
548 | def backward_windowed(self, images: torch.Tensor, labels: torch.Tensor, toolbox: TrainerToolbox,
549 | metadata: SWMetadata) -> tuple[float, dict[str, float]]:
550 | raise NotImplementedError
551 |
552 | @override
553 | def backward(self, images: torch.Tensor, labels: torch.Tensor, toolbox: TrainerToolbox) -> tuple[float, dict[
554 | str, float]]:
555 | images, metadata = self.do_sliding_window(images)
556 | labels, _ = self.do_sliding_window(labels)
557 | return self.backward_windowed(images, labels, toolbox, metadata)
558 |
--------------------------------------------------------------------------------
/home/index.html:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 | MIP Candy
6 |
7 |
8 |
597 |
598 |
599 |
618 |
619 |
620 |
638 |
639 |
640 |
641 |
642 |
A candy for medical image processing
643 |
Next-generation infrastructure for fast prototyping in machine learning.
644 |
645 | MIP Candy brings ready-to-use training, inference, and evaluation pipelines together with
646 | aesthetics , so you can focus on your experiments, not boilerplate.
647 |
648 |
649 |
659 |
660 |
661 |
662 |
663 |
664 |
665 |
666 |
667 |
668 |
669 |
670 |
671 |
672 |
673 |
674 |
719 |
720 |
721 |
722 |
726 |
727 |
728 |
729 | Training
730 | Easy adaptation to your workflow
731 |
732 | Override a single method to plug in your own network architecture. Grab a tool from the box and
733 | customize your experiments.
734 |
735 |
736 | Sliding window
737 | ROI inspection & cropping
738 | Automatic padding & shape alignment
739 |
740 |
741 |
742 |
743 | Interface
744 | Thoughtful command-line UX
745 |
746 | A clean CLI layout makes it easy to configure experiments, track progress, and resume work without
747 | digging through scripts.
748 |
749 |
750 |
751 |
752 |
753 |
754 |
755 | Visualization
756 | Built-in 2D & 3D views
757 |
758 | Inspect slices or volumes directly from the training pipeline for intuitive understanding of your
759 | data and predictions.
760 |
761 |
762 |
763 |
764 |
765 |
766 |
767 | Reliability
768 | Interruption-tolerant runs
769 |
770 | Experiments can be safely paused and resumed with built-in recovery mechanisms, so cluster hiccups
771 | don’t cost you progress.
772 |
773 |
774 |
775 |
776 |
777 |
778 |
779 | Dashboards
780 | Remote monitoring ready
781 |
782 | Connect to Notion, Weights & Biases, and TensorBoard for rich experiment tracking and sharing
783 | with your team.
784 |
785 |
786 |
787 |
788 |
789 |
790 |
791 | Python 3.12+
792 | Modern Python, modern stack
793 |
794 | Built for Python 3.12 and above, MIP Candy takes advantage of modern typing and ecosystem
795 | improvements out of the box.
796 |
797 |
798 |
799 |
800 |
801 |
802 |
803 | Custom trainers
804 | Adapt a new network in one method.
805 |
806 | Start from SegmentationTrainer and only implement the network construction.
807 | MIP Candy handles data flow, loss computation, augmentation, checkpointing, and evaluation out of the box.
808 |
809 |
810 |
811 |
Example — Custom model integration
812 |
from typing import override
813 | from torch import nn
814 | from mipcandy import SegmentationTrainer
815 |
816 |
817 | class MyTrainer(SegmentationTrainer):
818 | @override
819 | def build_network(self, example_shape: tuple[int, ...]) -> nn.Module:
820 | ...
821 |
822 |
823 | Provide your architecture once. MIP Candy takes care of the entire training pipeline.
824 |
825 |
826 |
827 |
828 |
829 |
830 |
836 |
837 |
843 |
844 |
845 |
846 |
847 |
848 |
849 |
850 |
Quick start
851 |
Train like a Pro in a few lines
852 |
853 | Download a dataset, create a dataset wrapper, and hand it to a bundled trainer. Below is an example
854 | using the PH2 dataset with batch size 1 for varying shapes, although you can use a
855 | ROIDataset to align them.
856 |
857 |
858 |
859 |
Example — nnU-Net style training
860 |
from typing import override
861 |
862 | import torch
863 | from mipcandy_bundles.unet import UNetTrainer
864 | from torch.utils.data import DataLoader
865 |
866 | from mipcandy import download_dataset, NNUNetDataset
867 |
868 |
869 | class PH2(NNUNetDataset):
870 | @override
871 | def load(self, idx: int) -> tuple[torch.Tensor, torch.Tensor]:
872 | image, label = super().load(idx)
873 | return image.squeeze(0).permute(2, 0, 1), label
874 |
875 |
876 | download_dataset("nnunet_datasets/PH2", "tutorial/datasets/PH2")
877 | dataset, val_dataset = PH2("tutorial/datasets/PH2", device="cuda").fold()
878 | dataloader = DataLoader(dataset, 1, shuffle=True)
879 | val_dataloader = DataLoader(val_dataset, 1, shuffle=False)
880 | trainer = UNetTrainer("tutorial", dataloader, val_dataloader, device="cuda")
881 | trainer.train(1000, note="a nnU-Net style example")
882 |
883 |
884 |
885 |
886 |
887 |
888 |
889 |
890 |
Installation
891 |
Install MIP Candy
892 |
893 | MIP Candy requires Python ≥ 3.12 .
894 | Install the standard bundle from PyPI:
895 |
896 |
897 |
898 |
899 |
900 |
901 |
902 |
903 |
904 |
$
905 |
pip install "mipcandy[standard]"
906 |
907 |
908 |
909 |
910 |
Stand on the Giants
911 |
Install MIP Candy Bundles
912 |
913 | MIP Candy Bundles provide verified model architectures with corresponding trainers and predictors.
914 | You can install it along with MIP Candy.
915 |
916 |
917 |
918 |
919 |
920 |
921 |
922 |
923 |
924 |
$
925 |
pip install "mipcandy[all]"
926 |
927 |
928 |
929 |
930 |
931 |
932 |
933 |
934 |
938 |
939 |
953 |
954 |
957 |
958 |
959 |
960 |
961 |
--------------------------------------------------------------------------------