├── src └── umetrix │ ├── __init__.py │ ├── render.py │ ├── notebooks.py │ └── core.py ├── .gitignore ├── tests ├── data │ └── unet.tif ├── test_metrics.py └── conftest.py ├── .github └── workflows │ ├── test.yml │ └── linting.yml ├── pyproject.toml ├── .pre-commit-config.yaml ├── LICENSE.md ├── README.md └── notebooks └── unet_segmentation_metrics.ipynb /src/umetrix/__init__.py: -------------------------------------------------------------------------------- 1 | from .core import calculate, batch # NOQA: F401 2 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | *.ipynb_checkpoints 2 | *.egg-info/ 3 | .DS_Store 4 | __pycache__ 5 | notebooks/ 6 | _version.py 7 | -------------------------------------------------------------------------------- /tests/data/unet.tif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lowe-lab-ucl/unet_segmentation_metrics/HEAD/tests/data/unet.tif -------------------------------------------------------------------------------- /.github/workflows/test.yml: -------------------------------------------------------------------------------- 1 | name: Test 2 | 3 | on: 4 | push: 5 | branches: 6 | - main 7 | pull_request: 8 | paths-ignore: 9 | - "**.md" 10 | - "**.rst" 11 | 12 | jobs: 13 | tests: 14 | runs-on: ubuntu-latest 15 | steps: 16 | - name: Checkout source 17 | uses: actions/checkout@v3 18 | 19 | - name: Set up python 20 | uses: actions/setup-python@v4 21 | with: 22 | python-version: "3.10" 23 | cache: "pip" 24 | cache-dependency-path: "pyproject.toml" 25 | 26 | - name: Install dependencies 27 | run: | 28 | python -m pip install --upgrade pip 29 | python -m pip install pytest 30 | 31 | - name: Install umetrics 32 | run: | 33 | pip install -e . 34 | - name: Run tests 35 | run: | 36 | pytest 37 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["setuptools", "setuptools-scm"] 3 | build-backend = "setuptools.build_meta" 4 | 5 | [project] 6 | name = "umetrix" 7 | authors = [ 8 | {name = "Alan R. Lowe", email = "a.lowe@ucl.ac.uk"} 9 | ] 10 | description = "UNet Segmentation Metrics" 11 | readme = "README.md" 12 | requires-python = ">=3.8" 13 | keywords = ["image analysis"] 14 | license = {text = "BSD-3-Clause"} 15 | classifiers = [ 16 | "Programming Language :: Python :: 3" 17 | ] 18 | dependencies = [ 19 | "matplotlib", 20 | "numpy", 21 | "pandas", 22 | "scikit-learn", 23 | "scikit-image>=0.20.0" # to include the spacing argument in regionprops 24 | ] 25 | dynamic = ["version"] 26 | 27 | [tool.setuptools.packages.find] 28 | where = ["src"] 29 | include = ["umetrix*"] 30 | 31 | [tool.setuptools_scm] 32 | local_scheme = "no-local-version" 33 | write_to = "src/umetrix/_version.py" 34 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | - repo: https://github.com/charliermarsh/ruff-pre-commit 3 | rev: v0.0.262 4 | hooks: 5 | - id: ruff 6 | - repo: https://github.com/pre-commit/pre-commit-hooks 7 | rev: v4.4.0 8 | hooks: 9 | - id: check-case-conflict 10 | - id: check-docstring-first 11 | - id: check-executables-have-shebangs 12 | - id: check-merge-conflict 13 | - id: check-toml 14 | - id: end-of-file-fixer 15 | - id: mixed-line-ending 16 | args: [--fix=lf] 17 | - id: trailing-whitespace 18 | args: [--markdown-linebreak-ext=md] 19 | - repo: https://github.com/psf/black 20 | rev: 23.3.0 21 | hooks: 22 | - id: black 23 | - repo: https://github.com/pappasam/toml-sort 24 | rev: v0.23.0 25 | hooks: 26 | - id: toml-sort-fix 27 | -------------------------------------------------------------------------------- /.github/workflows/linting.yml: -------------------------------------------------------------------------------- 1 | name: Linting 2 | 3 | on: 4 | push: 5 | branches: 6 | - main 7 | pull_request: 8 | 9 | jobs: 10 | linting: 11 | runs-on: ubuntu-latest 12 | steps: 13 | - name: Checkout source 14 | uses: actions/checkout@v3 15 | 16 | - name: Cache pre-commit 17 | uses: actions/cache@v3 18 | with: 19 | path: ~/.cache/pre-commit 20 | key: pre-commit-${{ hashFiles('.pre-commit-config.yaml') }} 21 | 22 | - name: Set up python 23 | uses: actions/setup-python@v4 24 | with: 25 | python-version: "3.x" 26 | cache: "pip" 27 | cache-dependency-path: "pyproject.toml" 28 | 29 | - name: Install dependencies 30 | run: |- 31 | python -m pip install pre-commit 32 | pre-commit install 33 | 34 | - name: Run pre-commit 35 | run: pre-commit run --all-files --color always 36 | -------------------------------------------------------------------------------- /LICENSE.md: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 Alan R. Lowe (quantumjot) 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy of 6 | this software and associated documentation files (the "Software"), to deal in 7 | the Software without restriction, including without limitation the rights to 8 | use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies 9 | of the Software, and to permit persons to whom the Software is furnished to do 10 | so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | [![PyPI](https://img.shields.io/pypi/v/umetrix)](https://pypi.org/project/umetrix) 2 | [![Black](https://img.shields.io/badge/code%20style-black-000000.svg)](https://github.com/psf/black) 3 | [![umetrix](https://github.com/lowe-lab-ucl/unet_segmentation_metrics/actions/workflows/test.yml/badge.svg)]([https://github.com/quantumjot/vne](https://github.com/lowe-lab-ucl/unet_segmentation_metrics)/actions/workflows/test.yml) 4 | [![pre-commit](https://img.shields.io/badge/pre--commit-enabled-brightgreen?logo=pre-commit&logoColor=white)](https://github.com/pre-commit/pre-commit) 5 | 6 | # UNet segmentation metrics 7 | 8 | *WORK IN PROGRESS* 9 | 10 | Simple Python 3 tools to assess the performance of UNet segmentation networks 11 | (or any other segmentation method) by comparing the prediction to a ground truth 12 | image. 13 | 14 | Use it to calculate: 15 | + Jaccard metric for object detection 16 | + Intersection over Union (IoU) for object segmentation accuracy 17 | + Localization (positional) error for estimating MOTP during tracking 18 | + Pixel identity 19 | 20 | TODO: 21 | + [x] Add strict matching with IoU threshold 22 | + [ ] Add confusion matrix for multi-label/classification type tasks 23 | 24 | 25 | ### Single image usage 26 | 27 | ```python 28 | import umetrix 29 | from skimage.io import imread 30 | 31 | y_true = imread('true.tif') 32 | y_pred = imread('pred.tif') 33 | 34 | 35 | # can now make the calculation strict, by only considering objects that have 36 | # an IoU above a theshold as being true positives 37 | result = umetrix.calculate( 38 | y_true, 39 | y_pred, 40 | strict=True, 41 | iou_threshold=0.5 42 | ) 43 | 44 | print(result.results) 45 | ``` 46 | 47 | returns: 48 | 49 | ``` 50 | ============================ 51 | Segmentation Metrics (n=1) 52 | ============================ 53 | Strict: True (IoU > 0.5) 54 | n_true_labels: 354 55 | n_pred_labels: 362 56 | n_true_positives: 345 57 | n_false_positives: 10 58 | n_false_negatives: 0 59 | IoU: 0.999 60 | Jaccard: 0.972 61 | pixel_identity: 0.998 62 | localization_error: 0.010 63 | ``` 64 | 65 | 66 | ### Batch processing 67 | 68 | ```python 69 | import umetrix 70 | 71 | # provide a list of file pairs ('true', 'prediction') 72 | files = [ 73 | ('true0.tif', 'pred0.tif'), 74 | ('true1.tif', 'pred1.tif'), 75 | ('true2.tif', 'pred2.tif') 76 | ] 77 | 78 | batch_result = umetrix.batch(files) 79 | ``` 80 | 81 | Returns aggregate statistics over the batch. Jaccard index is calculated over 82 | all found objects, while other metrics are the average IoU etc. 83 | 84 | 85 | ### Installation 86 | 87 | 1. First clone the repo: 88 | ```sh 89 | $ git clone https://github.com/quantumjot/unet_segmentation_metrics.git 90 | ``` 91 | 92 | 2. (Optional, but advised) Create a conda environment: 93 | ```sh 94 | $ conda create -n umetrix python=3.9 95 | $ conda activate umetrix 96 | ``` 97 | 98 | 3. Install the package 99 | ```sh 100 | $ cd unet_segmentation_metrics 101 | $ pip install . 102 | ``` 103 | -------------------------------------------------------------------------------- /tests/test_metrics.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import numpy as np 3 | 4 | import umetrix 5 | 6 | 7 | STRICT_PARAMS = [(False, 0.0), (True, 0.1), (True, 0.2), (True, 0.5), (True, 0.7)] 8 | 9 | 10 | @pytest.mark.parametrize("strict,iou_threshold", STRICT_PARAMS) 11 | def test_calculate(image_pair, strict, iou_threshold): 12 | """Run the metrics on a pair of images.""" 13 | y_true, y_pred, stats = image_pair 14 | IoU = stats["IoU"] 15 | 16 | result = umetrix.calculate( 17 | y_true, y_pred, strict=strict, iou_threshold=iou_threshold 18 | ) 19 | 20 | # calculate the real number of true postives based on strict matching 21 | real_tp = int(IoU > result.iou_threshold) if strict else int(IoU > 0) 22 | 23 | assert result.n_true_labels == 1 24 | assert result.n_pred_labels == 1 25 | assert result.n_true_positives == real_tp 26 | assert result.n_false_positives == 1 - real_tp 27 | 28 | 29 | def test_calculate_no_true(image_pair): 30 | """Test a pair of images where there is no object in the GT.""" 31 | y_true, y_pred, _ = image_pair 32 | y_true = np.zeros_like(y_pred) 33 | 34 | result = umetrix.calculate(y_true, y_pred) 35 | assert result.n_true_labels == 0 36 | assert result.n_pred_labels == 1 37 | assert result.n_true_positives == 0 38 | assert result.n_false_negatives == 0 39 | assert result.n_false_positives == 1 40 | 41 | 42 | def test_calculate_no_pred(image_pair): 43 | """Test a pair of images where there is no object in the prediction.""" 44 | y_true, y_pred, _ = image_pair 45 | y_pred = np.zeros_like(y_true) 46 | 47 | result = umetrix.calculate(y_true, y_pred) 48 | assert result.n_true_labels == 1 49 | assert result.n_pred_labels == 0 50 | assert result.n_true_positives == 0 51 | assert result.n_false_negatives == 1 52 | assert result.n_false_positives == 0 53 | 54 | 55 | @pytest.mark.parametrize("strict,iou_threshold", STRICT_PARAMS) 56 | def test_calculate_grid(image_grid, strict, iou_threshold): 57 | """Test a multi-instance segmentation.""" 58 | y_true, y_pred, stats = image_grid 59 | result = umetrix.calculate( 60 | y_true, y_pred, strict=strict, iou_threshold=iou_threshold 61 | ) 62 | 63 | n_iou_over_threshold = sum([iou > iou_threshold for iou in stats["IoU"]]) 64 | n_iou_under_threshold = stats["n_pairs"] - n_iou_over_threshold if strict else 0 65 | n_tp = n_iou_over_threshold 66 | n_fp = stats["n_false_positive"] + n_iou_under_threshold 67 | n_fn = stats["n_false_negative"] + n_iou_under_threshold 68 | 69 | assert result.n_true_labels == stats["n_true"] 70 | assert result.n_pred_labels == stats["n_pred"] 71 | assert result.n_true_positives == n_tp 72 | assert result.n_false_positives == n_fp 73 | assert result.n_false_negatives == n_fn 74 | 75 | 76 | @pytest.mark.parametrize("strict,iou_threshold", STRICT_PARAMS) 77 | def test_real(real_image_pair, strict, iou_threshold): 78 | """Test a real image pair.""" 79 | y_true, y_pred = real_image_pair 80 | result = umetrix.calculate( 81 | y_true, y_pred, strict=strict, iou_threshold=iou_threshold 82 | ) 83 | assert result.n_true_labels == 13 84 | assert result.n_pred_labels == 14 85 | -------------------------------------------------------------------------------- /tests/conftest.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import numpy as np 3 | import numpy.typing as npt 4 | 5 | from pathlib import Path 6 | from skimage.io import imread 7 | from skimage.util import montage 8 | from typing import Tuple 9 | 10 | from umetrix.core import IoU 11 | 12 | SEED = 12343 13 | RNG = np.random.default_rng(seed=SEED) 14 | 15 | 16 | def _synthetic_image(sz: int = 32) -> npt.NDArray: 17 | image = np.zeros((sz, sz), dtype=np.uint8) 18 | boxsz = RNG.integers(low=sz // 4, high=sz - 1) 19 | xlo, ylo = RNG.integers(low=1, high=sz - boxsz, size=(2,)) 20 | image[xlo : xlo + boxsz, ylo : ylo + boxsz] = 1 21 | return image 22 | 23 | 24 | # def _IoU(y_true: npt.NDArray, y_pred: npt.NDArray) -> float: 25 | # union = np.sum(np.logical_or(y_true, y_pred)) 26 | # intersection = np.sum(np.logical_and(y_true, y_pred)) 27 | # return intersection / union 28 | 29 | 30 | @pytest.fixture 31 | def image_grid(N: int = 3, sz: int = 32) -> Tuple[npt.NDArray, npt.NDArray, dict]: 32 | image_types = RNG.choice( 33 | ["pair", "missing_true", "missing_pred"], size=(N * N,) 34 | ).tolist() 35 | true_stack = np.zeros((N * N, sz, sz), dtype=np.uint8) 36 | pred_stack = np.zeros((N * N, sz, sz), dtype=np.uint8) 37 | 38 | ious = [] 39 | 40 | n_true_positive = 0 41 | n_false_positive = 0 42 | n_false_negative = 0 43 | 44 | for idx, img_type in enumerate(image_types): 45 | if img_type == "pair": 46 | true_stack[idx, ...] = _synthetic_image() 47 | pred_stack[idx, ...] = _synthetic_image() 48 | iou = IoU(true_stack[idx, ...], pred_stack[idx, ...]) 49 | ious.append(iou) 50 | if iou > 0: 51 | n_true_positive += 1 52 | else: 53 | n_false_positive += 1 54 | n_false_negative += 1 55 | elif img_type == "missing_true": 56 | pred_stack[idx, ...] = _synthetic_image() 57 | ious.append(0.0) 58 | n_false_positive += 1 59 | else: 60 | true_stack[idx, ...] = _synthetic_image() 61 | ious.append(0.0) 62 | n_false_negative += 1 63 | 64 | # number of pairs where there is some overlap 65 | n_pairs = image_types.count("pair") 66 | n_missing_true = image_types.count("missing_true") 67 | n_missing_pred = image_types.count("missing_pred") 68 | 69 | stats = { 70 | "n_pairs": n_pairs, 71 | "n_true": n_pairs + n_missing_pred, 72 | "n_pred": n_pairs + n_missing_true, 73 | "n_true_positive": n_true_positive, 74 | "n_false_positive": n_false_positive, 75 | "n_false_negative": n_false_negative, 76 | "n_total": len(image_types), 77 | "IoU": ious, 78 | } 79 | 80 | return ( 81 | montage(true_stack, rescale_intensity=False, grid_shape=(N, N)), 82 | montage(pred_stack, rescale_intensity=False, grid_shape=(N, N)), 83 | stats, 84 | ) 85 | 86 | 87 | @pytest.fixture 88 | def image_pair() -> Tuple[npt.NDArray, npt.NDArray, dict]: 89 | y_true = _synthetic_image() 90 | y_pred = _synthetic_image() 91 | stats = {"IoU": IoU(y_true, y_pred)} 92 | return y_true, y_pred, stats 93 | 94 | 95 | @pytest.fixture 96 | def real_image_pair() -> Tuple[npt.NDArray, npt.NDArray]: 97 | filename = Path(__file__).parent.resolve() / "data" / "unet.tif" 98 | img = (imread(filename) > 0).astype(np.uint8) 99 | return img[0, ...], img[1, ...] 100 | -------------------------------------------------------------------------------- /src/umetrix/render.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import matplotlib.pyplot as plt 3 | import matplotlib.patches as patches 4 | 5 | 6 | def plot_metrics(seg_metrics): 7 | pred = seg_metrics._predicted 8 | ref = seg_metrics._reference 9 | 10 | iou = [None] * len(ref.labels) 11 | IoU = seg_metrics.per_object_IoU 12 | for i, tp in enumerate(seg_metrics.true_positives): 13 | iou[tp[0] - 1] = "{:.2f}".format(IoU[i]) 14 | 15 | fig, ax = plt.subplots(1, figsize=(16, 12)) 16 | ax.imshow(seg_metrics.image_overlay) 17 | 18 | for i, (sy, sx) in enumerate(ref.bboxes): 19 | r = patches.Rectangle( 20 | (sx.start, sy.start), 21 | sx.stop - sx.start, 22 | sy.stop - sy.start, 23 | edgecolor="g", 24 | facecolor="None", 25 | ) 26 | ax.add_patch(r) 27 | ax.text( 28 | sx.start, sy.start, "{}, IoU: {}".format(i, iou[i]), fontsize=6, color="w" 29 | ) 30 | for i, (sy, sx) in enumerate(pred.bboxes): 31 | r = patches.Rectangle( 32 | (sx.start, sy.start), 33 | sx.stop - sx.start, 34 | sy.stop - sy.start, 35 | edgecolor="m", 36 | facecolor="None", 37 | ) 38 | ax.add_patch(r) 39 | # ax.text(sx.start, sy.start, '{}'.format(i), fontsize=6, color='w') 40 | 41 | bboxes = pred.bboxes 42 | for fp in seg_metrics.false_positives: 43 | sy, sx = bboxes[fp - 1] 44 | w, h = sx.stop - sx.start, sy.stop - sy.start 45 | r = patches.Rectangle( 46 | (sx.start, sy.start), 47 | w, 48 | h, 49 | edgecolor="r", 50 | facecolor=(1.0, 0.0, 0.0, 0.0), 51 | linewidth=2, 52 | ) 53 | ax.add_patch(r) 54 | 55 | bboxes = ref.bboxes 56 | for fn in seg_metrics.false_negatives: 57 | sy, sx = bboxes[fn - 1] 58 | w, h = sx.stop - sx.start, sy.stop - sy.start 59 | r = patches.Rectangle( 60 | (sx.start, sy.start), 61 | w, 62 | h, 63 | edgecolor="c", 64 | facecolor=(0.0, 1.0, 1.0, 0.0), 65 | linewidth=2, 66 | ) 67 | ax.add_patch(r) 68 | plt.axis("off") 69 | plt.show() 70 | 71 | 72 | def make_bboxes(bbox_slices): 73 | """Calculate bboxes for napari""" 74 | minr = [sxy[0].start for sxy in bbox_slices] 75 | minc = [sxy[1].start for sxy in bbox_slices] 76 | maxr = [sxy[0].stop for sxy in bbox_slices] 77 | maxc = [sxy[1].stop for sxy in bbox_slices] 78 | 79 | bbox_rect = np.array([[minr, minc], [maxr, minc], [maxr, maxc], [minr, maxc]]) 80 | bbox_rect = np.moveaxis(bbox_rect, 2, 0) 81 | return bbox_rect 82 | 83 | 84 | def render_metrics_napari(seg_metrics): 85 | """Render the segmentation metrics for visualization in Napari.""" 86 | 87 | # pred = seg_metrics._predicted 88 | ref = seg_metrics._reference 89 | 90 | properties = {"iou": seg_metrics.per_object_IoU} 91 | 92 | bboxes = make_bboxes(ref.bboxes) 93 | tp_idx = np.asarray(seg_metrics.true_positives)[:, 0] - 1 94 | bboxes = bboxes[tp_idx, ...] 95 | 96 | # specify the display parameters for the text 97 | text_parameters = { 98 | "text": "IoU: {iou:.2f}\n", 99 | "size": 8, 100 | "color": "white", 101 | "anchor": "upper_left", 102 | "translation": [-2, 0], 103 | } 104 | 105 | return bboxes, properties, text_parameters 106 | -------------------------------------------------------------------------------- /src/umetrix/notebooks.py: -------------------------------------------------------------------------------- 1 | import base64 2 | import io 3 | import itertools 4 | import matplotlib.pyplot as plt 5 | import numpy as np 6 | 7 | from matplotlib import colors, colormaps 8 | 9 | from umetrix.core import METRICS 10 | 11 | 12 | DARK_CONTEXT = { 13 | "axes.edgecolor": "gray", 14 | "xtick.color": "gray", 15 | "ytick.color": "gray", 16 | "axes.labelcolor": "gray", 17 | "font.size": 18, 18 | } 19 | 20 | 21 | def _text_color_based_on_value(value: float, cmap: colors.Colormap) -> str: 22 | rgb = cmap(value) 23 | luminance = 0.299 * rgb[0] + 0.587 * rgb[1] + 0.114 * rgb[2] 24 | return "k" if luminance > 0.5 else "w" 25 | 26 | 27 | def _header() -> str: 28 | css = ( 29 | ".row {display: inline-flex; align-items: flex-start;}\n" 30 | ".metrics {float: left;}\n" 31 | ".confusion {width: 200px; float: left;}" 32 | ) 33 | return f"" 34 | 35 | 36 | def _footer() -> str: 37 | return "" 38 | 39 | 40 | def render_metrics_html(metrics) -> str: 41 | """Render the metrics to HTML""" 42 | 43 | html_table = _render_table(metrics) 44 | encoded_cm = _render_confusion(metrics) 45 | html_strict = ( 46 | f"

Strict matching (IoU threshold: {metrics.iou_threshold})

" 47 | if metrics.strict 48 | else "" 49 | ) 50 | 51 | html = ( 52 | _header() 53 | + "

Segmentation Metrics

" 54 | + html_strict 55 | + "
" 56 | + html_table 57 | + "
" 58 | + f"" 59 | + "
" 60 | + _footer() 61 | ) 62 | 63 | return html 64 | 65 | 66 | def _render_table(metrics) -> str: 67 | """Render the table of results""" 68 | 69 | def _get_f_string(m): 70 | val = getattr(metrics, m) 71 | return f"{val:.3f}" if isinstance(val, float) else f"{val:d}" 72 | 73 | return ( 74 | "" 75 | + "".join( 76 | [f"" for m in METRICS] 77 | ) 78 | + "
Metric
{m}" + _get_f_string(m) + "
" 79 | ) 80 | 81 | 82 | def _render_confusion(metrics, *, cmap: str = "Blues") -> str: 83 | """Render a confusion matrix as an image""" 84 | grid = np.zeros((2, 2), dtype=float) 85 | grid[1, 1] = metrics.n_true_positives 86 | grid[0, 1] = metrics.n_false_positives 87 | grid[1, 0] = metrics.n_false_negatives 88 | cmap = colormaps[cmap] 89 | 90 | with plt.rc_context(DARK_CONTEXT): 91 | _, ax = plt.subplots(figsize=(4, 4)) 92 | ax.pcolor(grid, cmap=cmap) 93 | ax.set_xticks([0.5, 1.5], labels=["Negative", "Positive"]) 94 | ax.set_yticks( 95 | [0.5, 1.5], labels=["Negative", "Positive"], rotation=90, va="center" 96 | ) 97 | 98 | for i, j in itertools.product(range(2), range(2)): 99 | ax.text( 100 | i + 0.5, 101 | j + 0.5, 102 | grid[i, j].astype(int), 103 | ha="center", 104 | va="center", 105 | color=_text_color_based_on_value(grid[i, j] / np.max(grid), cmap), 106 | ) 107 | ax.set_ylabel("Predicted") 108 | ax.set_xlabel("Ground truth") 109 | stream = io.BytesIO() 110 | plt.savefig(stream, format="png", bbox_inches="tight", transparent=True) 111 | plt.close() 112 | return base64.b64encode(stream.getvalue()).decode("utf-8") 113 | -------------------------------------------------------------------------------- /src/umetrix/core.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import enum 4 | import numpy as np 5 | import numpy.typing as npt 6 | 7 | from skimage.io import imread 8 | from scipy.ndimage import label 9 | from scipy.ndimage import center_of_mass 10 | from scipy.ndimage import find_objects 11 | from scipy.optimize import linear_sum_assignment 12 | 13 | from typing import Dict, Tuple 14 | 15 | from umetrix import render 16 | 17 | 18 | DEFAULT_MAXIMUM_COST = 1e8 19 | 20 | 21 | class Metrics(str, enum.Enum): 22 | N_TRUE_LABELS = "n_true_labels" 23 | N_PRED_LABELS = "n_pred_labels" 24 | N_TRUE_POSITIVES = "n_true_positives" 25 | N_FALSE_POSITIVES = "n_false_positives" 26 | N_FALSE_NEGATIVES = "n_false_negatives" 27 | IOU = "IoU" 28 | JACCARD = "Jaccard" 29 | PIXEL_IDENTITY = "pixel_identity" 30 | LOCALIZATION_ERROR = "localization_error" 31 | 32 | 33 | METRICS = ( 34 | "n_true_labels", 35 | "n_pred_labels", 36 | "n_true_positives", 37 | "n_false_positives", 38 | "n_false_negatives", 39 | "IoU", 40 | "Jaccard", 41 | "pixel_identity", 42 | "localization_error", 43 | ) 44 | 45 | 46 | def IoU(ref: npt.NDArray, pred: npt.NDArray) -> float: 47 | """Calculate the IoU between two binary masks.""" 48 | intersection = np.sum(np.logical_and(ref, pred)) 49 | union = np.sum(np.logical_or(ref, pred)) 50 | iou = 0.0 if union == 0 else intersection / union 51 | return iou 52 | 53 | 54 | def find_matches( 55 | ref: LabeledSegmentation, 56 | pred: LabeledSegmentation, 57 | *, 58 | strict: bool = False, 59 | iou_threshold: float = 0.5, 60 | ) -> Dict: 61 | """Perform matching between the reference and the predicted image. 62 | 63 | Parameters 64 | ---------- 65 | ref : 66 | The reference (ground truth) segmentation. 67 | pred : 68 | The predicted segmentation. 69 | strict : bool 70 | Whether to use strict matching, i.e. only allowing matches above a 71 | threshold IoU value. 72 | iou_threshold : 73 | A threshold value to use when strict matching. 74 | 75 | Return 76 | ------ 77 | matches : dict 78 | A dictionary of matches between the two images. 79 | """ 80 | 81 | # make an infinite cost matrix, so that we only consider matches where 82 | # there is some overlap in the masks 83 | cost_matrix = np.full((len(ref.labels), len(pred.labels)), DEFAULT_MAXIMUM_COST) 84 | 85 | for r_id, ref_label in enumerate(ref.labels): 86 | mask = ref.labeled == ref_label 87 | _matches = [m for m in np.unique(pred.labeled[mask]) if m > 0] 88 | for pred_label in _matches: 89 | p_id = pred.labels.index(pred_label) 90 | reward = IoU(mask, pred.labeled == pred_label) 91 | if (reward < iou_threshold) and strict: 92 | continue 93 | cost_matrix[r_id, p_id] = 1.0 - reward 94 | 95 | # if it's strict, make sure every element is above the threshold 96 | if strict: 97 | cost_threshold = 1.0 - iou_threshold 98 | cost_mask = cost_matrix == DEFAULT_MAXIMUM_COST 99 | assert np.all(cost_matrix[~cost_mask] <= cost_threshold) 100 | 101 | # solve it using JV 102 | sol_row, sol_col = linear_sum_assignment(cost_matrix) 103 | 104 | # remove infeasible solutions 105 | edges = [ 106 | (ref.labels[r], pred.labels[c], 1.0 - cost_matrix[r, c]) 107 | for r, c in zip(sol_row, sol_col) 108 | if cost_matrix[r, c] <= 1 109 | ] 110 | 111 | # return a default dictionary if there are no matches 112 | if not edges: 113 | matches = { 114 | "true_matches": [], 115 | "true_matches_IoU": [], 116 | "in_ref_only": set(ref.labels), 117 | "in_pred_only": set(pred.labels), 118 | } 119 | return matches 120 | 121 | # find the labels that haven't been used 122 | used_ref, used_pred, IoUs = zip(*edges) 123 | in_ref_only = set(ref.labels).difference(used_ref) 124 | in_pred_only = set(pred.labels).difference(used_pred) 125 | 126 | # return a dictionary of found matches 127 | matches = { 128 | "true_matches": list(set(zip(used_ref, used_pred))), 129 | "true_matches_IoU": IoUs, 130 | "in_ref_only": in_ref_only, 131 | "in_pred_only": in_pred_only, 132 | } 133 | 134 | return matches 135 | 136 | 137 | class MetricResults(object): 138 | def __init__(self, metrics): 139 | assert isinstance(metrics, SegmentationMetrics) 140 | self._images = 1 141 | self._metrics = metrics 142 | 143 | # list of metrics that are aggregated 144 | self._agg = ( 145 | "n_true_labels", 146 | "n_pred_labels", 147 | "n_true_positives", 148 | "n_false_positives", 149 | "n_false_negatives", 150 | "per_object_IoU", 151 | "per_object_localization_error", 152 | "per_image_pixel_identity", 153 | ) 154 | 155 | def __getattr__(self, key): 156 | return getattr(self._metrics, key) 157 | 158 | @property 159 | def n_images(self) -> int: 160 | if any([getattr(self, m) is None for m in self._agg]): 161 | return 0 162 | else: 163 | return self._images 164 | 165 | def __add__(self, result: MetricResults) -> MetricResults: 166 | assert isinstance(result, MetricResults) 167 | for m in self._agg: 168 | setattr(self, m, getattr(result, m) + getattr(self, m)) 169 | self._images += 1 170 | return self 171 | 172 | def __repr__(self) -> str: 173 | title = f" Segmentation Metrics (n={self.n_images})\n" 174 | hbar = "=" * len(title) + "\n" 175 | r = hbar + title + hbar 176 | if self.strict: 177 | r += f"Strict: {self.strict} (IoU > {self.iou_threshold})\n" 178 | for m in METRICS: 179 | mval = getattr(self, m) 180 | if isinstance(mval, float): 181 | r += f"{m}: {mval:.3f}\n" 182 | else: 183 | r += f"{m}: {mval}\n" 184 | return r 185 | 186 | def _repr_html_(self): 187 | from umetrix.notebooks import render_metrics_html 188 | 189 | return render_metrics_html(self) 190 | 191 | @property 192 | def localization_error(self) -> float: 193 | return np.mean(self.per_object_localization_error) 194 | 195 | @property 196 | def IoU(self) -> float: 197 | return np.mean(self.per_object_IoU) 198 | 199 | @property 200 | def Jaccard(self) -> float: 201 | """Jaccard metric""" 202 | tp = self.n_true_positives 203 | fn = self.n_false_negatives 204 | fp = self.n_false_positives 205 | return tp / (tp + fn + fp) 206 | 207 | @property 208 | def pixel_identity(self) -> float: 209 | return np.mean(self.per_image_pixel_identity) 210 | 211 | @staticmethod 212 | def merge(results: list) -> MetricResults: 213 | """Merge n results together and return a single object.""" 214 | merged = results.pop(0) 215 | for result in results: 216 | assert isinstance(result, MetricResults) 217 | assert result.n_images == 1 218 | merged = merged + result 219 | return merged 220 | 221 | 222 | class SegmentationMetrics: 223 | """A class for calculating various segmentation metrics to assess the 224 | accuracy of a trained model. 225 | 226 | Parameters 227 | ---------- 228 | reference : array 229 | An array containing labeled objects from the ground truth. 230 | predicted : array 231 | An array containing labeled objects from the segmentation algorithm. 232 | strict : bool 233 | Whether to disregard matches with a low IoU score. 234 | iou_threshold : float 235 | Threshold IoU for strict matching. 236 | 237 | Properties 238 | ---------- 239 | Jaccard : float 240 | The Jaccard index calculated according to the notes below. 241 | IoU : float 242 | The Intersection over Union metric. 243 | localisation_precision : float 244 | The localisation precision. 245 | true_positives : int 246 | Number of TP predictions. 247 | false_positives : int 248 | Number of FP predictions. 249 | false_negatives : int 250 | Number of FN predicitons. 251 | 252 | 253 | Notes 254 | ----- 255 | The Jaccard metric is calculated accordingly: 256 | 257 | FP = number of objects in predicted but not in reference 258 | TP = number of objects in both 259 | TN = background correctly segmented (not used) 260 | FN = number of objects in true but not in predicted 261 | 262 | J = TP / (TP+FP+FN) 263 | 264 | The IoU is calculated as the intersection of the binary segmentation 265 | divided by the union. 266 | """ 267 | 268 | def __init__( 269 | self, reference: LabeledSegmentation, predicted: LabeledSegmentation, **kwargs 270 | ): 271 | assert isinstance(predicted, LabeledSegmentation) 272 | assert isinstance(reference, LabeledSegmentation) 273 | 274 | self._reference = reference 275 | self._predicted = predicted 276 | self._strict = kwargs.get("strict", False) 277 | self._iou_threshold = kwargs.get("iou_threshold", 0.5) 278 | 279 | if self.iou_threshold < 0.0 or self.iou_threshold > 1.0: 280 | raise ValueError( 281 | f"IoU Threshold shoud be in (0, 1), found: {self.iou_threshold:.2f}" 282 | ) 283 | assert isinstance(self.strict, bool) 284 | 285 | # find the matches 286 | self._matches = find_matches( 287 | self._reference, 288 | self._predicted, 289 | strict=self.strict, 290 | iou_threshold=self.iou_threshold, 291 | ) 292 | 293 | @property 294 | def strict(self) -> bool: 295 | return self._strict 296 | 297 | @property 298 | def iou_threshold(self) -> float: 299 | return self._iou_threshold 300 | 301 | @property 302 | def results(self): 303 | return MetricResults(self) 304 | 305 | @property 306 | def image_overlay(self): 307 | # n_labels = max([self._predicted.n_labels, self._reference.n_labels]) 308 | # scale = int(255 / n_labels) 309 | return ( 310 | np.stack( 311 | [self._predicted.image, self._reference.image, self._predicted.image], 312 | axis=-1, 313 | ) 314 | * 127 315 | ) 316 | 317 | @property 318 | def n_true_labels(self): 319 | return self._reference.n_labels 320 | 321 | @property 322 | def n_pred_labels(self): 323 | return self._predicted.n_labels 324 | 325 | @property 326 | def true_positives(self): 327 | """Only one match between reference and predicted.""" 328 | return self._matches["true_matches"] 329 | 330 | @property 331 | def false_negatives(self): 332 | """No match in predicted for reference object.""" 333 | return self._matches["in_ref_only"] 334 | 335 | @property 336 | def false_positives(self): 337 | """Combination of non unique matches and unmatched objects.""" 338 | return self._matches["in_pred_only"] 339 | 340 | @property 341 | def n_true_positives(self): 342 | return len(self.true_positives) 343 | 344 | @property 345 | def n_false_negatives(self): 346 | return len(self.false_negatives) 347 | 348 | @property 349 | def n_false_positives(self): 350 | return len(self.false_positives) 351 | 352 | @property 353 | def per_object_IoU(self): 354 | """Intersection over Union (IoU) metric""" 355 | return self._matches["true_matches_IoU"] 356 | 357 | @property 358 | def per_image_pixel_identity(self): 359 | """Calculate the per-image pixel identity.""" 360 | n_tot = np.prod(self._reference.image.shape) 361 | return [np.sum(self._reference.image == self._predicted.image) / n_tot] 362 | 363 | @property 364 | def per_object_localization_error(self): 365 | """Calculate the per-object localization error.""" 366 | ref_centroids = self._reference.centroids 367 | tgt_centroids = self._predicted.centroids 368 | positional_error = [] 369 | for m in self.true_positives: 370 | true_centroid = np.array(ref_centroids[m[0] - 1]) 371 | pred_centroid = np.array(tgt_centroids[m[1] - 1]) 372 | err = np.sum((true_centroid - pred_centroid) ** 2) 373 | positional_error.append(err) 374 | return positional_error 375 | 376 | def plot(self): 377 | render.plot_metrics(self) 378 | 379 | def to_napari(self): 380 | return render.render_metrics_napari(self) 381 | 382 | def __repr__(self): 383 | return self.results.__repr__() 384 | 385 | def _repr_html_(self): 386 | return self.results._repr_html_() 387 | 388 | 389 | class LabeledSegmentation: 390 | """A helper class to enable simple calculation of accuracy statistics for 391 | image segmentation output. 392 | """ 393 | 394 | def __init__(self, image: npt.NDArray): 395 | self.image = image 396 | self.labeled, self.n_labels = label(image.astype(bool)) 397 | 398 | @property 399 | def shape(self) -> Tuple[int]: 400 | return self.image.shape 401 | 402 | @property 403 | def bboxes(self): 404 | return [find_objects(self.labeled == label)[0] for label in self.labels] 405 | 406 | @property 407 | def labels(self): 408 | return range(1, self.n_labels + 1) 409 | 410 | @property 411 | def centroids(self): 412 | return [center_of_mass(self.labeled == label) for label in self.labels] 413 | 414 | @property 415 | def areas(self): 416 | return [np.sum(self.labeled == label) for label in self.labels] 417 | 418 | 419 | def calculate(reference, predicted, **kwargs): 420 | """Take a predicted image and compare with the reference image. 421 | 422 | Compute various metrics. 423 | """ 424 | 425 | ref = LabeledSegmentation(reference) 426 | tgt = LabeledSegmentation(predicted) 427 | 428 | # make sure they are the same size 429 | assert ref.shape == tgt.shape 430 | 431 | return SegmentationMetrics(ref, tgt, **kwargs) 432 | 433 | 434 | def batch(files, **kwargs): 435 | """batch process a list of files""" 436 | metrix = [] 437 | for f_ref, f_pred in files: 438 | true = imread(f_ref) 439 | pred = imread(f_pred) 440 | result = calculate(true, pred, **kwargs).results 441 | metrix.append(result) 442 | return MetricResults.merge(metrix) 443 | 444 | 445 | if __name__ == "__main__": 446 | pass 447 | -------------------------------------------------------------------------------- /notebooks/unet_segmentation_metrics.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "## Simple demonstration of calculating segmentation metrics" 8 | ] 9 | }, 10 | { 11 | "cell_type": "markdown", 12 | "metadata": {}, 13 | "source": [ 14 | "*NOTE(arl)*: the metrics here are calculated as follows in batch mode:\n", 15 | "+ `n_true_labels` is the sum of all true labels, etc\n", 16 | "+ `IoU` is the mean IoU of all found objects\n", 17 | "+ `Jaccard` is the Jaccard index over all found objects\n", 18 | "+ `localization_error` is the mean error for all found objects\n", 19 | "+ `pixel_identity` is the per image pixel identity" 20 | ] 21 | }, 22 | { 23 | "cell_type": "code", 24 | "execution_count": 1, 25 | "metadata": {}, 26 | "outputs": [], 27 | "source": [ 28 | "import umetrix\n", 29 | "\n", 30 | "import numpy as np\n", 31 | "from skimage.io import imread" 32 | ] 33 | }, 34 | { 35 | "cell_type": "code", 36 | "execution_count": 2, 37 | "metadata": {}, 38 | "outputs": [], 39 | "source": [ 40 | "# load a ground truth - prediction image pair\n", 41 | "p = \"../tests/data/unet.tif\"\n", 42 | "s = imread(p)\n", 43 | "y_true = s[-2, ...]\n", 44 | "y_pred = s[-1, ...]" 45 | ] 46 | }, 47 | { 48 | "cell_type": "code", 49 | "execution_count": 3, 50 | "metadata": {}, 51 | "outputs": [], 52 | "source": [ 53 | "result = umetrix.calculate(y_true, y_pred)" 54 | ] 55 | }, 56 | { 57 | "cell_type": "code", 58 | "execution_count": 4, 59 | "metadata": {}, 60 | "outputs": [ 61 | { 62 | "data": { 63 | "text/html": [ 64 | "

Segmentation Metrics

Metric
n_true_labels13
n_pred_labels14
n_true_positives12
n_false_positives2
n_false_negatives1
IoU0.807
Jaccard0.800
pixel_identity0.959
localization_error15.524
" 67 | ], 68 | "text/plain": [ 69 | "============================\n", 70 | " Segmentation Metrics (n=1)\n", 71 | "============================\n", 72 | "n_true_labels: 13\n", 73 | "n_pred_labels: 14\n", 74 | "n_true_positives: 12\n", 75 | "n_false_positives: 2\n", 76 | "n_false_negatives: 1\n", 77 | "IoU: 0.807\n", 78 | "Jaccard: 0.800\n", 79 | "pixel_identity: 0.959\n", 80 | "localization_error: 15.524" 81 | ] 82 | }, 83 | "execution_count": 4, 84 | "metadata": {}, 85 | "output_type": "execute_result" 86 | } 87 | ], 88 | "source": [ 89 | "result" 90 | ] 91 | }, 92 | { 93 | "cell_type": "markdown", 94 | "metadata": {}, 95 | "source": [ 96 | "## visualize the metrics" 97 | ] 98 | }, 99 | { 100 | "cell_type": "code", 101 | "execution_count": 5, 102 | "metadata": {}, 103 | "outputs": [ 104 | { 105 | "data": { 106 | "image/png": "", 107 | "text/plain": [ 108 | "
" 109 | ] 110 | }, 111 | "metadata": {}, 112 | "output_type": "display_data" 113 | } 114 | ], 115 | "source": [ 116 | "result.plot()" 117 | ] 118 | }, 119 | { 120 | "cell_type": "markdown", 121 | "metadata": {}, 122 | "source": [ 123 | "### now perform the calculation with strict matching only" 124 | ] 125 | }, 126 | { 127 | "cell_type": "code", 128 | "execution_count": 6, 129 | "metadata": {}, 130 | "outputs": [ 131 | { 132 | "data": { 133 | "text/html": [ 134 | "

Segmentation Metrics

Strict matching (IoU threshold: 0.5)

Metric
n_true_labels13
n_pred_labels14
n_true_positives11
n_false_positives3
n_false_negatives2
IoU0.841
Jaccard0.688
pixel_identity0.959
localization_error5.305
" 137 | ], 138 | "text/plain": [ 139 | "============================\n", 140 | " Segmentation Metrics (n=1)\n", 141 | "============================\n", 142 | "Strict: True (IoU > 0.5)\n", 143 | "n_true_labels: 13\n", 144 | "n_pred_labels: 14\n", 145 | "n_true_positives: 11\n", 146 | "n_false_positives: 3\n", 147 | "n_false_negatives: 2\n", 148 | "IoU: 0.841\n", 149 | "Jaccard: 0.688\n", 150 | "pixel_identity: 0.959\n", 151 | "localization_error: 5.305" 152 | ] 153 | }, 154 | "execution_count": 6, 155 | "metadata": {}, 156 | "output_type": "execute_result" 157 | } 158 | ], 159 | "source": [ 160 | "result = umetrix.calculate(y_true, y_pred, strict=True, iou_threshold=0.5)\n", 161 | "result" 162 | ] 163 | }, 164 | { 165 | "cell_type": "code", 166 | "execution_count": 7, 167 | "metadata": {}, 168 | "outputs": [ 169 | { 170 | "data": { 171 | "image/png": "", 172 | "text/plain": [ 173 | "
" 174 | ] 175 | }, 176 | "metadata": {}, 177 | "output_type": "display_data" 178 | } 179 | ], 180 | "source": [ 181 | "result.plot()" 182 | ] 183 | }, 184 | { 185 | "cell_type": "code", 186 | "execution_count": null, 187 | "metadata": {}, 188 | "outputs": [], 189 | "source": [] 190 | } 191 | ], 192 | "metadata": { 193 | "kernelspec": { 194 | "display_name": "Python [conda env:btrack]", 195 | "language": "python", 196 | "name": "conda-env-btrack-py" 197 | }, 198 | "language_info": { 199 | "codemirror_mode": { 200 | "name": "ipython", 201 | "version": 3 202 | }, 203 | "file_extension": ".py", 204 | "mimetype": "text/x-python", 205 | "name": "python", 206 | "nbconvert_exporter": "python", 207 | "pygments_lexer": "ipython3", 208 | "version": "3.10.9" 209 | } 210 | }, 211 | "nbformat": 4, 212 | "nbformat_minor": 4 213 | } 214 | --------------------------------------------------------------------------------