├── tests
├── __init__.py
├── probly
│ ├── __init__.py
│ ├── plot
│ │ ├── __init__.py
│ │ └── test_credal.py
│ ├── train
│ │ ├── __init__.py
│ │ ├── bayesian
│ │ │ ├── __init__.py
│ │ │ └── test_torch.py
│ │ ├── calibration
│ │ │ ├── __init__.py
│ │ │ └── test_torch.py
│ │ └── evidential
│ │ │ ├── __init__.py
│ │ │ └── test_torch.py
│ ├── utils
│ │ ├── __init__.py
│ │ ├── test_errors.py
│ │ ├── test_sets.py
│ │ ├── test_torch.py
│ │ └── test_probabilities.py
│ ├── datasets
│ │ ├── __init__.py
│ │ └── test_torch.py
│ ├── evaluation
│ │ ├── __init__.py
│ │ ├── test_tasks.py
│ │ └── test_metrics.py
│ ├── traverse_nn
│ │ └── __init__.py
│ ├── quantification
│ │ ├── __init__.py
│ │ ├── test_regression.py
│ │ └── test_classification.py
│ ├── representation
│ │ └── __init__.py
│ ├── transformation
│ │ ├── __init__.py
│ │ ├── dropout
│ │ │ ├── __init__.py
│ │ │ └── test_common.py
│ │ ├── bayesian
│ │ │ ├── __init__.py
│ │ │ ├── test_common.py
│ │ │ └── test_torch.py
│ │ ├── dropconnect
│ │ │ ├── __init__.py
│ │ │ └── test_common.py
│ │ ├── evidential
│ │ │ ├── classification
│ │ │ │ ├── __init__.py
│ │ │ │ ├── test_common.py
│ │ │ │ └── test_torch.py
│ │ │ └── regression
│ │ │ │ ├── test_common.py
│ │ │ │ └── test_torch.py
│ │ └── ensemble
│ │ │ └── test_common.py
│ ├── fixtures
│ │ ├── common.py
│ │ ├── torch_data.py
│ │ ├── flax_models.py
│ │ └── torch_models.py
│ ├── general_utils.py
│ ├── tests_deprecations
│ │ ├── __init__.py
│ │ ├── deprecated_features.py
│ │ ├── utils.py
│ │ └── test_deprecations.py
│ ├── flax_utils.py
│ ├── torch_utils.py
│ └── markers.py
├── pytraverse
│ └── __init__.py
├── lazy_dispatch
│ ├── __init__.py
│ └── test_isinstance.py
└── conftest.py
├── src
├── probly
│ ├── train
│ │ ├── __init__.py
│ │ ├── bayesian
│ │ │ ├── __init__.py
│ │ │ └── torch.py
│ │ ├── evidential
│ │ │ └── __init__.py
│ │ └── calibration
│ │ │ ├── __init__.py
│ │ │ └── torch.py
│ ├── layers
│ │ ├── __init__.py
│ │ └── flax.py
│ ├── plot
│ │ ├── __init__.py
│ │ └── credal.py
│ ├── datasets
│ │ └── __init__.py
│ ├── evaluation
│ │ ├── __init__.py
│ │ └── tasks.py
│ ├── quantification
│ │ ├── __init__.py
│ │ └── regression.py
│ ├── representation
│ │ ├── distribution
│ │ │ └── __init__.py
│ │ ├── credal_set
│ │ │ ├── __init__.py
│ │ │ ├── jax.py
│ │ │ ├── torch.py
│ │ │ └── credal_set.py
│ │ ├── __init__.py
│ │ ├── sampling
│ │ │ ├── __init__.py
│ │ │ ├── flax_sampler.py
│ │ │ ├── jax_sample.py
│ │ │ ├── torch_sample.py
│ │ │ ├── torch_sampler.py
│ │ │ ├── sample.py
│ │ │ └── sampler.py
│ │ └── representer.py
│ ├── lazy_types.py
│ ├── transformation
│ │ ├── evidential
│ │ │ ├── __init__.py
│ │ │ ├── classification
│ │ │ │ ├── torch.py
│ │ │ │ ├── __init__.py
│ │ │ │ └── common.py
│ │ │ └── regression
│ │ │ │ ├── __init__.py
│ │ │ │ ├── torch.py
│ │ │ │ └── common.py
│ │ ├── dropout
│ │ │ ├── torch.py
│ │ │ ├── flax.py
│ │ │ ├── __init__.py
│ │ │ └── common.py
│ │ ├── bayesian
│ │ │ ├── __init__.py
│ │ │ ├── torch.py
│ │ │ └── common.py
│ │ ├── dropconnect
│ │ │ ├── torch.py
│ │ │ ├── flax.py
│ │ │ ├── __init__.py
│ │ │ └── common.py
│ │ ├── __init__.py
│ │ └── ensemble
│ │ │ ├── __init__.py
│ │ │ ├── torch.py
│ │ │ ├── common.py
│ │ │ └── flax.py
│ ├── __init__.py
│ ├── utils
│ │ ├── __init__.py
│ │ ├── errors.py
│ │ ├── sets.py
│ │ ├── torch.py
│ │ └── probabilities.py
│ ├── traverse_nn
│ │ ├── __init__.py
│ │ ├── common.py
│ │ ├── torch.py
│ │ └── flax.py
│ └── predictor.py
├── lazy_dispatch
│ ├── __init__.py
│ ├── load.py
│ └── isinstance.py
└── pytraverse
│ ├── __init__.py
│ └── generic.py
├── docs
├── source
│ ├── references.rst
│ ├── _static
│ │ ├── logo
│ │ │ ├── logo_dark.png
│ │ │ └── logo_light.png
│ │ └── github-mark.svg
│ ├── _templates
│ │ └── sidebar
│ │ │ └── footer.html
│ ├── api.rst
│ ├── methods.md
│ ├── index.rst
│ └── conf.py
├── Makefile
└── make.bat
├── .editorconfig
├── CHANGELOG.md
├── .pre-commit-config.yaml
├── .github
├── PULL_REQUEST_TEMPLATE.md
├── workflows
│ ├── deploy-docs.yml
│ ├── publish-release.yml
│ └── ci.yml
└── CONTRIBUTING.md
├── LICENSE
├── README.md
└── notebooks
└── examples
└── lazy_dispatch_test.ipynb
/tests/__init__.py:
--------------------------------------------------------------------------------
1 | """Tests for the packages."""
2 |
--------------------------------------------------------------------------------
/src/probly/train/__init__.py:
--------------------------------------------------------------------------------
1 | """Train module for probly."""
2 |
--------------------------------------------------------------------------------
/tests/probly/__init__.py:
--------------------------------------------------------------------------------
1 | """Tests for the probly package."""
2 |
--------------------------------------------------------------------------------
/tests/probly/plot/__init__.py:
--------------------------------------------------------------------------------
1 | """Tests for the plot module."""
2 |
--------------------------------------------------------------------------------
/tests/probly/train/__init__.py:
--------------------------------------------------------------------------------
1 | """Tests for the train module."""
2 |
--------------------------------------------------------------------------------
/tests/probly/utils/__init__.py:
--------------------------------------------------------------------------------
1 | """Tests for the utils module."""
2 |
--------------------------------------------------------------------------------
/tests/probly/utils/test_errors.py:
--------------------------------------------------------------------------------
1 | """Tests for errors functions."""
2 |
--------------------------------------------------------------------------------
/tests/pytraverse/__init__.py:
--------------------------------------------------------------------------------
1 | """Tests for the pytraverse package."""
2 |
--------------------------------------------------------------------------------
/src/probly/layers/__init__.py:
--------------------------------------------------------------------------------
1 | """Init module for layer implementations."""
2 |
--------------------------------------------------------------------------------
/src/probly/plot/__init__.py:
--------------------------------------------------------------------------------
1 | """Init module for plot implementations."""
2 |
--------------------------------------------------------------------------------
/tests/lazy_dispatch/__init__.py:
--------------------------------------------------------------------------------
1 | """Tests for the lazy_dispatch package."""
2 |
--------------------------------------------------------------------------------
/tests/probly/datasets/__init__.py:
--------------------------------------------------------------------------------
1 | """Tests for the datasets module."""
2 |
--------------------------------------------------------------------------------
/tests/probly/evaluation/__init__.py:
--------------------------------------------------------------------------------
1 | """Tests for the evaluation module."""
2 |
--------------------------------------------------------------------------------
/src/probly/datasets/__init__.py:
--------------------------------------------------------------------------------
1 | """Init module for dataset implementations."""
2 |
--------------------------------------------------------------------------------
/tests/probly/traverse_nn/__init__.py:
--------------------------------------------------------------------------------
1 | """Tests for the traverse_nn module."""
2 |
--------------------------------------------------------------------------------
/src/probly/evaluation/__init__.py:
--------------------------------------------------------------------------------
1 | """Init module for evaluation implementations."""
2 |
--------------------------------------------------------------------------------
/src/probly/train/bayesian/__init__.py:
--------------------------------------------------------------------------------
1 | """Train functionality for Bayesian models."""
2 |
--------------------------------------------------------------------------------
/tests/probly/quantification/__init__.py:
--------------------------------------------------------------------------------
1 | """Tests for the quantification module."""
2 |
--------------------------------------------------------------------------------
/tests/probly/representation/__init__.py:
--------------------------------------------------------------------------------
1 | """Tests for the representation module."""
2 |
--------------------------------------------------------------------------------
/tests/probly/train/bayesian/__init__.py:
--------------------------------------------------------------------------------
1 | """Tests for the Bayesian train module."""
2 |
--------------------------------------------------------------------------------
/tests/probly/transformation/__init__.py:
--------------------------------------------------------------------------------
1 | """Tests for the transformation module."""
2 |
--------------------------------------------------------------------------------
/tests/probly/transformation/dropout/__init__.py:
--------------------------------------------------------------------------------
1 | """Tests for the dropout module."""
2 |
--------------------------------------------------------------------------------
/src/probly/train/evidential/__init__.py:
--------------------------------------------------------------------------------
1 | """Train functionality for evidential models."""
2 |
--------------------------------------------------------------------------------
/tests/probly/train/calibration/__init__.py:
--------------------------------------------------------------------------------
1 | """Tests for the calibration train module."""
2 |
--------------------------------------------------------------------------------
/tests/probly/train/evidential/__init__.py:
--------------------------------------------------------------------------------
1 | """Tests for the evidential train module."""
2 |
--------------------------------------------------------------------------------
/tests/probly/transformation/bayesian/__init__.py:
--------------------------------------------------------------------------------
1 | """Tests for the bayesian module."""
2 |
--------------------------------------------------------------------------------
/docs/source/references.rst:
--------------------------------------------------------------------------------
1 | References
2 | ==========
3 |
4 | .. bibliography::
5 | :cited:
6 |
--------------------------------------------------------------------------------
/src/probly/train/calibration/__init__.py:
--------------------------------------------------------------------------------
1 | """Train functionality for calibration methods."""
2 |
--------------------------------------------------------------------------------
/tests/probly/transformation/dropconnect/__init__.py:
--------------------------------------------------------------------------------
1 | """Tests for the dropconnect module."""
2 |
--------------------------------------------------------------------------------
/src/probly/quantification/__init__.py:
--------------------------------------------------------------------------------
1 | """Init module for quantification measure implementations."""
2 |
--------------------------------------------------------------------------------
/src/probly/representation/distribution/__init__.py:
--------------------------------------------------------------------------------
1 | """Second-order distribution representations."""
2 |
--------------------------------------------------------------------------------
/tests/probly/transformation/evidential/classification/__init__.py:
--------------------------------------------------------------------------------
1 | """Tests for the classification module."""
2 |
--------------------------------------------------------------------------------
/docs/source/_static/logo/logo_dark.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pwhofman/probly/HEAD/docs/source/_static/logo/logo_dark.png
--------------------------------------------------------------------------------
/docs/source/_static/logo/logo_light.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pwhofman/probly/HEAD/docs/source/_static/logo/logo_light.png
--------------------------------------------------------------------------------
/.editorconfig:
--------------------------------------------------------------------------------
1 | root = true
2 |
3 | [*]
4 | end_of_line = lf
5 | charset = utf-8
6 | insert_final_newline = true
7 | indent_style = space
8 | indent_size = 4
9 | trim_trailing_whitespace = true
10 |
--------------------------------------------------------------------------------
/src/probly/representation/credal_set/__init__.py:
--------------------------------------------------------------------------------
1 | """Second-order distribution representations."""
2 |
3 | from __future__ import annotations
4 |
5 | from .credal_set import CredalSet, credal_set_from_sample
6 |
7 | __all__ = [
8 | "CredalSet",
9 | "credal_set_from_sample",
10 | ]
11 |
--------------------------------------------------------------------------------
/tests/conftest.py:
--------------------------------------------------------------------------------
1 | """Test fixtures for probly."""
2 |
3 | from __future__ import annotations
4 |
5 | pytest_plugins = [
6 | "tests.probly.fixtures.common",
7 | "tests.probly.fixtures.torch_models",
8 | "tests.probly.fixtures.torch_data",
9 | "tests.probly.fixtures.flax_models",
10 | ]
11 |
--------------------------------------------------------------------------------
/src/probly/lazy_types.py:
--------------------------------------------------------------------------------
1 | """Collection of fully-qualified type names for lazy type checking."""
2 |
3 | from __future__ import annotations
4 |
5 | TORCH_MODULE = "torch.nn.modules.module.Module"
6 | TORCH_TENSOR = "torch.Tensor"
7 |
8 | FLAX_MODULE = "flax.nnx.module.Module"
9 | JAX_ARRAY = "jax.Array"
10 |
--------------------------------------------------------------------------------
/CHANGELOG.md:
--------------------------------------------------------------------------------
1 | # Changelog
2 | This changelog is updated with every release of `probly`.
3 |
4 | ## Development
5 |
6 | - Added possiblity to create ensemble of torch models without resetting the weights of each model.
7 |
8 | ## 0.1.0 (2024-03-14)
9 | Initial pre-release of `probly` without functionalities.
10 |
--------------------------------------------------------------------------------
/src/probly/transformation/evidential/__init__.py:
--------------------------------------------------------------------------------
1 | """Evidential module for probly."""
2 |
3 | from probly.transformation.evidential.classification import evidential_classification
4 | from probly.transformation.evidential.regression import evidential_regression
5 |
6 | __all__ = ["evidential_classification", "evidential_regression"]
7 |
--------------------------------------------------------------------------------
/tests/probly/utils/test_sets.py:
--------------------------------------------------------------------------------
1 | """Tests for utils.sets functions."""
2 |
3 | from __future__ import annotations
4 |
5 | from probly.utils.sets import powerset
6 |
7 |
8 | def test_powerset() -> None:
9 | assert powerset([]) == [()]
10 | assert powerset([1]) == [(), (1,)]
11 | assert powerset([1, 2]) == [(), (1,), (2,), (1, 2)]
12 |
--------------------------------------------------------------------------------
/docs/source/_templates/sidebar/footer.html:
--------------------------------------------------------------------------------
1 |
6 |
--------------------------------------------------------------------------------
/src/probly/__init__.py:
--------------------------------------------------------------------------------
1 | """probly: Uncertainty Representation and Quantification for Machine Learning."""
2 |
3 | __version__ = "0.3.1"
4 |
5 | from probly import (
6 | datasets as datasets,
7 | evaluation as evaluation,
8 | plot as plot,
9 | quantification as quantification,
10 | representation as representation,
11 | train as train,
12 | transformation as transformation,
13 | utils as utils,
14 | )
15 |
--------------------------------------------------------------------------------
/src/probly/transformation/dropout/torch.py:
--------------------------------------------------------------------------------
1 | """Torch dropout implementation."""
2 |
3 | from __future__ import annotations
4 |
5 | from torch import nn
6 |
7 | from .common import register
8 |
9 |
10 | def prepend_torch_dropout(obj: nn.Module, p: float) -> nn.Sequential:
11 | """Prepend a Dropout layer before the given layer."""
12 | return nn.Sequential(nn.Dropout(p=p), obj)
13 |
14 |
15 | register(nn.Linear, prepend_torch_dropout)
16 |
--------------------------------------------------------------------------------
/tests/probly/fixtures/common.py:
--------------------------------------------------------------------------------
1 | """Common fixtures for tests."""
2 |
3 | from __future__ import annotations
4 |
5 | import pytest
6 |
7 | from probly.predictor import Predictor
8 |
9 |
10 | @pytest.fixture
11 | def dummy_predictor() -> Predictor:
12 | """Return a dummy predictor."""
13 |
14 | class DummyPredictor(Predictor):
15 | def __call__(self, x: float) -> float:
16 | return x
17 |
18 | return DummyPredictor()
19 |
--------------------------------------------------------------------------------
/src/probly/transformation/bayesian/__init__.py:
--------------------------------------------------------------------------------
1 | """Bayesian implementation for uncertainty quantification."""
2 |
3 | from __future__ import annotations
4 |
5 | from probly.lazy_types import TORCH_MODULE
6 |
7 | from . import common
8 |
9 | bayesian = common.bayesian
10 | register = common.register
11 |
12 |
13 | ## Torch
14 | @common.bayesian_traverser.delayed_register(TORCH_MODULE)
15 | def _(_: type) -> None:
16 | from . import torch as torch # noqa: PLC0415
17 |
--------------------------------------------------------------------------------
/src/probly/utils/__init__.py:
--------------------------------------------------------------------------------
1 | """Utils module for probly library."""
2 |
3 | from .probabilities import differential_entropy_gaussian, intersection_probability, kl_divergence_gaussian
4 | from .sets import capacity, moebius, powerset
5 |
6 | __all__ = [
7 | "capacity",
8 | "differential_entropy_gaussian",
9 | "differential_entropy_gaussian",
10 | "intersection_probability",
11 | "kl_divergence_gaussian",
12 | "moebius",
13 | "powerset",
14 | ]
15 |
--------------------------------------------------------------------------------
/src/probly/representation/__init__.py:
--------------------------------------------------------------------------------
1 | """Uncertainty representations for models."""
2 |
3 | from probly.representation import sampling
4 | from probly.representation.credal_set import CredalSet, credal_set_from_sample
5 | from probly.representation.representer import Representer
6 | from probly.representation.sampling import Sample, Sampler
7 |
8 | __all__ = [
9 | "CredalSet",
10 | "Representer",
11 | "Sample",
12 | "Sampler",
13 | "credal_set_from_sample",
14 | "sampling",
15 | ]
16 |
--------------------------------------------------------------------------------
/src/probly/transformation/dropconnect/torch.py:
--------------------------------------------------------------------------------
1 | """Torch dropout implementation."""
2 |
3 | from __future__ import annotations
4 |
5 | from torch import nn
6 |
7 | from probly.layers.torch import DropConnectLinear
8 |
9 | from .common import register
10 |
11 |
12 | def replace_torch_dropconnect(obj: nn.Linear, p: float) -> DropConnectLinear:
13 | """Replace a given layer by a DropConnectLinear layer."""
14 | return DropConnectLinear(obj, p=p)
15 |
16 |
17 | register(nn.Linear, replace_torch_dropconnect)
18 |
--------------------------------------------------------------------------------
/src/probly/transformation/dropconnect/flax.py:
--------------------------------------------------------------------------------
1 | """Flax dropconnect implementation."""
2 |
3 | from __future__ import annotations
4 |
5 | from flax import nnx
6 |
7 | from probly.layers.flax import DropConnectLinear
8 |
9 | from .common import register
10 |
11 |
12 | def replace_flax_dropconnect(obj: nnx.Linear, p: float) -> DropConnectLinear:
13 | """Replace a given layer by a DropConnectLinear layer."""
14 | return DropConnectLinear(obj, rate=p)
15 |
16 |
17 | register(nnx.Linear, replace_flax_dropconnect)
18 |
--------------------------------------------------------------------------------
/src/lazy_dispatch/__init__.py:
--------------------------------------------------------------------------------
1 | """Lazy alternatives to singledispatch and isinstance."""
2 |
3 | from lazy_dispatch.isinstance import LazyType, lazy_isinstance, lazy_issubclass
4 | from lazy_dispatch.load import lazy_callable, lazy_import
5 | from lazy_dispatch.singledispatch import is_valid_dispatch_type, lazydispatch
6 |
7 | __all__ = [
8 | "LazyType",
9 | "is_valid_dispatch_type",
10 | "lazy_callable",
11 | "lazy_import",
12 | "lazy_isinstance",
13 | "lazy_issubclass",
14 | "lazydispatch",
15 | ]
16 |
--------------------------------------------------------------------------------
/src/probly/transformation/evidential/classification/torch.py:
--------------------------------------------------------------------------------
1 | """Torch evidential classification implementation."""
2 |
3 | from __future__ import annotations
4 |
5 | from torch import nn
6 |
7 | from probly.transformation.evidential.classification.common import register
8 |
9 |
10 | def append_activation_torch(obj: nn.Module) -> nn.Sequential:
11 | """Register a base model that the activation function will be appended to."""
12 | return nn.Sequential(obj, nn.Softplus())
13 |
14 |
15 | register(nn.Module, append_activation_torch)
16 |
--------------------------------------------------------------------------------
/src/probly/utils/errors.py:
--------------------------------------------------------------------------------
1 | """All errors and warnings used in the probly package."""
2 |
3 | from __future__ import annotations
4 |
5 | from warnings import warn
6 |
7 |
8 | def raise_deprecation_warning(
9 | message: str,
10 | deprecated_in: str,
11 | removed_in: str,
12 | ) -> None:
13 | """Raise a deprecation warning with the given details."""
14 | message += f" This feature is deprecated in version {deprecated_in} and will be removed in version {removed_in}."
15 | warn(message, DeprecationWarning, stacklevel=2)
16 |
--------------------------------------------------------------------------------
/docs/source/api.rst:
--------------------------------------------------------------------------------
1 | API Reference
2 | =============
3 |
4 | .. autosummary::
5 | :toctree: api/
6 | :recursive:
7 |
8 | probly.datasets
9 | probly.evaluation
10 | probly.layers
11 | probly.plot
12 | probly.predictor
13 | probly.quantification
14 | probly.representation
15 | probly.train
16 | probly.transformation
17 | probly.traverse_nn
18 | probly.utils
19 |
20 |
21 | .. automodule:: probly
22 | :members:
23 | :undoc-members:
24 | :show-inheritance:
25 | :special-members: __init__, __call__
26 |
--------------------------------------------------------------------------------
/src/probly/transformation/dropout/flax.py:
--------------------------------------------------------------------------------
1 | """Torch dropout implementation."""
2 |
3 | from __future__ import annotations
4 |
5 | from typing import TYPE_CHECKING
6 |
7 | from flax.nnx import Dropout, Linear, Sequential
8 |
9 | from .common import register
10 |
11 | if TYPE_CHECKING:
12 | from collections.abc import Callable
13 |
14 |
15 | def prepend_flax_dropout(obj: Callable, p: float) -> Sequential:
16 | """Prepend a Dropout layer before the given layer."""
17 | return Sequential(Dropout(p), obj)
18 |
19 |
20 | register(Linear, prepend_flax_dropout)
21 |
--------------------------------------------------------------------------------
/src/probly/transformation/evidential/regression/__init__.py:
--------------------------------------------------------------------------------
1 | """Evidential regression implementation for uncertainty quantification."""
2 |
3 | from __future__ import annotations
4 |
5 | from probly.lazy_types import TORCH_MODULE
6 | from probly.transformation.evidential.regression import common
7 |
8 | evidential_regression = common.evidential_regression
9 | register = common.register
10 |
11 |
12 | ## Torch
13 | @common.evidential_regression_traverser.delayed_register(TORCH_MODULE)
14 | def _(_: type) -> None:
15 | from . import torch as torch # noqa: PLC0415
16 |
--------------------------------------------------------------------------------
/src/probly/transformation/evidential/classification/__init__.py:
--------------------------------------------------------------------------------
1 | """Evidential classification implementation for uncertainty quantification."""
2 |
3 | from __future__ import annotations
4 |
5 | from probly.lazy_types import TORCH_MODULE
6 | from probly.transformation.evidential.classification import common
7 |
8 | evidential_classification = common.evidential_classification
9 | register = common.register
10 |
11 |
12 | ## Torch
13 | @common.evidential_classification_appender.delayed_register(TORCH_MODULE)
14 | def _(_: type) -> None:
15 | from . import torch as torch # noqa: PLC0415
16 |
--------------------------------------------------------------------------------
/src/probly/transformation/__init__.py:
--------------------------------------------------------------------------------
1 | """Transformations for models."""
2 |
3 | from probly.transformation.bayesian import bayesian
4 | from probly.transformation.dropconnect import dropconnect
5 | from probly.transformation.dropout import dropout
6 | from probly.transformation.ensemble import ensemble
7 | from probly.transformation.evidential.classification import evidential_classification
8 | from probly.transformation.evidential.regression import evidential_regression
9 |
10 | __all__ = ["bayesian", "dropconnect", "dropout", "ensemble", "evidential_classification", "evidential_regression"]
11 |
--------------------------------------------------------------------------------
/src/probly/representation/sampling/__init__.py:
--------------------------------------------------------------------------------
1 | """Representation predictors that create representations from finite samples."""
2 |
3 | from __future__ import annotations
4 |
5 | from .sample import ArraySample, ListSample, Sample, create_sample
6 | from .sampler import CLEANUP_FUNCS, Sampler, SamplingStrategy, get_sampling_predictor, sampler_factory
7 |
8 | __all__ = [
9 | "CLEANUP_FUNCS",
10 | "ArraySample",
11 | "ListSample",
12 | "Sample",
13 | "Sampler",
14 | "SamplingStrategy",
15 | "create_sample",
16 | "get_sampling_predictor",
17 | "sampler_factory",
18 | ]
19 |
--------------------------------------------------------------------------------
/tests/probly/general_utils.py:
--------------------------------------------------------------------------------
1 | """Utils for testing."""
2 |
3 | from __future__ import annotations
4 |
5 | import numpy as np
6 |
7 | # Machine epsilon for np.float64. We assume (for now) that we don't use more precise floats.
8 | # For more information, see https://numpy.org/doc/stable/reference/generated/numpy.finfo.html.
9 | VALIDATE_EPS = np.finfo(np.float64).eps
10 |
11 |
12 | def validate_uncertainty(uncertainty: np.ndarray) -> None:
13 | assert isinstance(uncertainty, np.ndarray)
14 | assert not np.isnan(uncertainty).any()
15 | assert not np.isinf(uncertainty).any()
16 | assert (uncertainty >= -VALIDATE_EPS).all()
17 |
--------------------------------------------------------------------------------
/src/probly/transformation/ensemble/__init__.py:
--------------------------------------------------------------------------------
1 | """Ensemble implementation for uncertainty quantification."""
2 |
3 | from __future__ import annotations
4 |
5 | from probly.lazy_types import FLAX_MODULE, TORCH_MODULE
6 |
7 | from . import common
8 |
9 | ensemble = common.ensemble
10 | register = common.register
11 |
12 |
13 | ## Torch
14 | @common.ensemble_generator.delayed_register(TORCH_MODULE)
15 | def _(_: type) -> None:
16 | from . import torch as torch # noqa: PLC0415
17 |
18 |
19 | ## Flax
20 | @common.ensemble_generator.delayed_register(FLAX_MODULE)
21 | def _(_: type) -> None:
22 | from . import flax as flax # noqa: PLC0415
23 |
--------------------------------------------------------------------------------
/src/probly/transformation/dropout/__init__.py:
--------------------------------------------------------------------------------
1 | """Dropout ensemble implementation for uncertainty quantification."""
2 |
3 | from __future__ import annotations
4 |
5 | from probly.lazy_types import FLAX_MODULE, TORCH_MODULE
6 |
7 | from . import common
8 |
9 | dropout = common.dropout
10 | register = common.register
11 |
12 |
13 | ## Torch
14 | @common.dropout_traverser.delayed_register(TORCH_MODULE)
15 | def _(_: type) -> None:
16 | from . import torch as torch # noqa: PLC0415
17 |
18 |
19 | ## Flax
20 | @common.dropout_traverser.delayed_register(FLAX_MODULE)
21 | def _(_: type) -> None:
22 | from . import flax as flax # noqa: PLC0415
23 |
--------------------------------------------------------------------------------
/src/probly/transformation/dropconnect/__init__.py:
--------------------------------------------------------------------------------
1 | """DropConnect implementation for uncertainty quantification."""
2 |
3 | from __future__ import annotations
4 |
5 | from probly.lazy_types import FLAX_MODULE, TORCH_MODULE
6 |
7 | from . import common
8 |
9 | dropconnect = common.dropconnect
10 | register = common.register
11 |
12 |
13 | ## Torch
14 | @common.dropconnect_traverser.delayed_register(TORCH_MODULE)
15 | def _(_: type) -> None:
16 | from . import torch as torch # noqa: PLC0415
17 |
18 |
19 | ## Flax
20 | @common.dropconnect_traverser.delayed_register(FLAX_MODULE)
21 | def _(_: type) -> None:
22 | from . import flax as flax # noqa: PLC0415
23 |
--------------------------------------------------------------------------------
/tests/probly/fixtures/torch_data.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | import pytest
4 |
5 | torch = pytest.importorskip("torch")
6 |
7 | from torch import Tensor, nn # noqa: E402
8 |
9 |
10 | @pytest.fixture
11 | def sample_classification_data() -> tuple[Tensor, Tensor]:
12 | inputs = torch.randn(2, 3, 5, 5)
13 | targets = torch.randint(0, 2, (2,))
14 | return inputs, targets
15 |
16 |
17 | @pytest.fixture
18 | def sample_outputs(
19 | torch_conv_linear_model: nn.Module,
20 | ) -> tuple[Tensor, Tensor]:
21 | outputs = torch_conv_linear_model(torch.randn(2, 3, 5, 5))
22 | targets = torch.randint(0, 2, (2,))
23 | return outputs, targets
24 |
--------------------------------------------------------------------------------
/docs/Makefile:
--------------------------------------------------------------------------------
1 | # Minimal makefile for Sphinx documentation
2 | #
3 |
4 | # You can set these variables from the command line, and also
5 | # from the environment for the first two.
6 | SPHINXOPTS ?=
7 | SPHINXBUILD ?= sphinx-build
8 | SOURCEDIR = source
9 | BUILDDIR = build
10 |
11 | # Put it first so that "make" without argument is like "make help".
12 | help:
13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
14 |
15 | .PHONY: help Makefile
16 |
17 | # Catch-all target: route all unknown targets to Sphinx using the new
18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS).
19 | %: Makefile
20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
21 |
--------------------------------------------------------------------------------
/src/probly/traverse_nn/__init__.py:
--------------------------------------------------------------------------------
1 | """Traverser utilities for neural networks."""
2 |
3 | from probly.lazy_types import FLAX_MODULE, TORCH_MODULE
4 |
5 | from . import common
6 |
7 | ## NN
8 |
9 | LAYER_COUNT = common.LAYER_COUNT
10 | is_first_layer = common.is_first_layer
11 |
12 | layer_count_traverser = common.layer_count_traverser
13 | nn_traverser = common.nn_traverser
14 |
15 | nn_compose = common.compose
16 |
17 |
18 | ## Torch
19 | @nn_traverser.delayed_register(TORCH_MODULE)
20 | def _(_: type) -> None:
21 | from . import torch as torch # noqa: PLC0415
22 |
23 |
24 | ## Flax
25 | @nn_traverser.delayed_register(FLAX_MODULE)
26 | def _(_: type) -> None:
27 | from . import flax as flax # noqa: PLC0415
28 |
--------------------------------------------------------------------------------
/src/probly/predictor.py:
--------------------------------------------------------------------------------
1 | """Protocols and ABCs for representation wrappers."""
2 |
3 | from __future__ import annotations
4 |
5 | from typing import Protocol, Unpack, runtime_checkable
6 |
7 | from lazy_dispatch.singledispatch import lazydispatch
8 |
9 |
10 | @runtime_checkable
11 | class Predictor[In, KwIn, Out](Protocol):
12 | """Protocol for generic predictors."""
13 |
14 | def __call__(self, *args: In, **kwargs: Unpack[KwIn]) -> Out:
15 | """Call the wrapper with input data."""
16 | ...
17 |
18 |
19 | @lazydispatch
20 | def predict[In, KwIn, Out](predictor: Predictor[In, KwIn, Out], *args: In, **kwargs: Unpack[KwIn]) -> Out:
21 | """Generic predict function."""
22 | return predictor(*args, **kwargs)
23 |
--------------------------------------------------------------------------------
/src/probly/representation/representer.py:
--------------------------------------------------------------------------------
1 | """Generic representation builder."""
2 |
3 | from __future__ import annotations
4 |
5 | from abc import ABC
6 | from typing import TYPE_CHECKING
7 |
8 | if TYPE_CHECKING:
9 | from probly.predictor import Predictor
10 |
11 |
12 | class Representer[In, KwIn, Out](ABC):
13 | """Abstract base class for representation builders."""
14 |
15 | predictor: Predictor[In, KwIn, Out]
16 |
17 | def __init__(self, predictor: Predictor[In, KwIn, Out]) -> None:
18 | """Initialize the representer with a predictor.
19 |
20 | Args:
21 | predictor: Predictor[In, KwIn, Out], the predictor to be used for building representations.
22 | """
23 | self.predictor = predictor
24 |
--------------------------------------------------------------------------------
/tests/probly/transformation/dropout/test_common.py:
--------------------------------------------------------------------------------
1 | """Test for dropout models."""
2 |
3 | from __future__ import annotations
4 |
5 | import pytest
6 |
7 | from probly.predictor import Predictor
8 | from probly.transformation import dropout
9 |
10 |
11 | def test_invalid_p_value(dummy_predictor: Predictor) -> None:
12 | """Tests the behavior of the dropout function when provided with an invalid probability value.
13 |
14 | This function validates that the dropout function raises a ValueError when
15 | the probability parameter `p` is outside the valid range [0, 1].
16 |
17 | Raises:
18 | ValueError: If the probability `p` is not between 0 and 1.
19 | """
20 | p = 2
21 | with pytest.raises(ValueError, match=f"The probability p must be between 0 and 1, but got {p} instead."):
22 | dropout(dummy_predictor, p=p)
23 |
--------------------------------------------------------------------------------
/tests/probly/transformation/dropconnect/test_common.py:
--------------------------------------------------------------------------------
1 | """Test for dropconnect models."""
2 |
3 | from __future__ import annotations
4 |
5 | import pytest
6 |
7 | from probly.predictor import Predictor
8 | from probly.transformation import dropconnect
9 |
10 |
11 | def test_invalid_p_value(dummy_predictor: Predictor) -> None:
12 | """Tests the behavior of the dropconnect function when provided with an invalid probability value.
13 |
14 | This function validates that the dropconnect function raises a ValueError when
15 | the probability parameter `p` is outside the valid range [0, 1].
16 |
17 | Raises:
18 | ValueError: If the probability `p` is not between 0 and 1.
19 | """
20 | p = 2
21 | with pytest.raises(ValueError, match=f"The probability p must be between 0 and 1, but got {p} instead."):
22 | dropconnect(dummy_predictor, p=p)
23 |
--------------------------------------------------------------------------------
/tests/probly/tests_deprecations/__init__.py:
--------------------------------------------------------------------------------
1 | """A test module which contains tests for deprecations.
2 |
3 | This module is used to ensure that deprecated features are properly flagged and that after
4 | they should be removed, that they indeed are no longer present in the codebase.
5 |
6 | Usage:
7 | - Adding Deprecations: If you add a new deprecation, you should add a test here to ensure that
8 | it is raised properly. Ensure that in each deprecation message, you include the version
9 | of future removal, so that users can easily identify when the feature will be removed.
10 | The test should check that the deprecation warning includes a version.
11 | - Removing Deprecations: If you remove a deprecated feature, ensure that the test for that
12 | feature is also removed. This helps keep the test suite clean and focused on current
13 | features.
14 | """
15 |
--------------------------------------------------------------------------------
/src/probly/transformation/evidential/regression/torch.py:
--------------------------------------------------------------------------------
1 | """Torch evidential regression implementation."""
2 |
3 | from __future__ import annotations
4 |
5 | from typing import TYPE_CHECKING
6 |
7 | from torch import nn
8 |
9 | from probly.layers.torch import NormalInverseGammaLinear
10 | from probly.transformation.evidential.regression.common import REPLACED_LAST_LINEAR, register
11 |
12 | if TYPE_CHECKING:
13 | from pytraverse import State, TraverserResult
14 |
15 |
16 | def replace_last_torch_nig(obj: nn.Linear, state: State) -> TraverserResult:
17 | """Register a class to be replaced by the NormalInverseGammaLinear layer."""
18 | state[REPLACED_LAST_LINEAR] = True
19 | return NormalInverseGammaLinear(
20 | obj.in_features,
21 | obj.out_features,
22 | device=obj.weight.device,
23 | bias=obj.bias is not None,
24 | ), state
25 |
26 |
27 | register(nn.Linear, replace_last_torch_nig)
28 |
--------------------------------------------------------------------------------
/src/probly/representation/credal_set/jax.py:
--------------------------------------------------------------------------------
1 | """Torch credal set implementation."""
2 |
3 | from __future__ import annotations
4 |
5 | import jax
6 |
7 | from probly.representation.credal_set.credal_set import CredalSet, credal_set_from_sample
8 | from probly.representation.sampling.jax_sample import JaxArraySample
9 |
10 |
11 | @credal_set_from_sample.register(JaxArraySample)
12 | class JaxArrayCredalSet(CredalSet[jax.Array]):
13 | """A credal set implementation for torch tensors."""
14 |
15 | def __init__(self, sample: JaxArraySample) -> None:
16 | """Initialize the torch tensor credal set."""
17 | self.array = sample.array
18 |
19 | def lower(self) -> jax.Array:
20 | """Compute the lower envelope of the credal set."""
21 | return self.array.min(axis=1)
22 |
23 | def upper(self) -> jax.Array:
24 | """Compute the upper envelope of the credal set."""
25 | return self.array.max(axis=1)
26 |
--------------------------------------------------------------------------------
/tests/probly/train/bayesian/test_torch.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | import pytest
4 |
5 | from probly.train.bayesian.torch import ELBOLoss, collect_kl_divergence
6 | from probly.transformation import bayesian
7 | from tests.probly.torch_utils import validate_loss
8 |
9 | torch = pytest.importorskip("torch")
10 |
11 | from torch import Tensor, nn # noqa: E402
12 |
13 |
14 | def test_elbo_loss(
15 | sample_classification_data: tuple[Tensor, Tensor],
16 | torch_conv_linear_model: nn.Module,
17 | ) -> None:
18 | inputs, targets = sample_classification_data
19 | model: nn.Module = bayesian(torch_conv_linear_model)
20 | outputs = model(inputs)
21 |
22 | criterion = ELBOLoss()
23 | loss = criterion(outputs, targets, collect_kl_divergence(model))
24 | validate_loss(loss)
25 |
26 | criterion = ELBOLoss(0.0)
27 | loss = criterion(outputs, targets, collect_kl_divergence(model))
28 | validate_loss(loss)
29 |
--------------------------------------------------------------------------------
/docs/make.bat:
--------------------------------------------------------------------------------
1 | @ECHO OFF
2 |
3 | pushd %~dp0
4 |
5 | REM Command file for Sphinx documentation
6 |
7 | if "%SPHINXBUILD%" == "" (
8 | set SPHINXBUILD=sphinx-build
9 | )
10 | set SOURCEDIR=source
11 | set BUILDDIR=build
12 |
13 | %SPHINXBUILD% >NUL 2>NUL
14 | if errorlevel 9009 (
15 | echo.
16 | echo.The 'sphinx-build' command was not found. Make sure you have Sphinx
17 | echo.installed, then set the SPHINXBUILD environment variable to point
18 | echo.to the full path of the 'sphinx-build' executable. Alternatively you
19 | echo.may add the Sphinx directory to PATH.
20 | echo.
21 | echo.If you don't have Sphinx installed, grab it from
22 | echo.https://www.sphinx-doc.org/
23 | exit /b 1
24 | )
25 |
26 | if "%1" == "" goto help
27 |
28 | %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O%
29 | goto end
30 |
31 | :help
32 | %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O%
33 |
34 | :end
35 | popd
36 |
--------------------------------------------------------------------------------
/docs/source/_static/github-mark.svg:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/src/probly/representation/credal_set/torch.py:
--------------------------------------------------------------------------------
1 | """Torch credal set implementation."""
2 |
3 | from __future__ import annotations
4 |
5 | import torch
6 |
7 | from probly.representation.credal_set.credal_set import CredalSet, credal_set_from_sample
8 | from probly.representation.sampling.torch_sample import TorchTensorSample
9 |
10 |
11 | @credal_set_from_sample.register(TorchTensorSample)
12 | class TorchTensorCredalSet(CredalSet[torch.Tensor]):
13 | """A credal set implementation for torch tensors."""
14 |
15 | def __init__(self, sample: TorchTensorSample) -> None:
16 | """Initialize the torch tensor credal set."""
17 | self.tensor = sample.tensor
18 |
19 | def lower(self) -> torch.Tensor:
20 | """Compute the lower envelope of the credal set."""
21 | return self.tensor.min(dim=1).values
22 |
23 | def upper(self) -> torch.Tensor:
24 | """Compute the upper envelope of the credal set."""
25 | return self.tensor.max(dim=1).values
26 |
--------------------------------------------------------------------------------
/tests/probly/transformation/evidential/regression/test_common.py:
--------------------------------------------------------------------------------
1 | """Tests for evidential regression models."""
2 |
3 | from __future__ import annotations
4 |
5 | from typing import Any, cast
6 |
7 | from probly.predictor import Predictor
8 | from probly.transformation.evidential.regression import evidential_regression
9 |
10 |
11 | def test_unknown_base_returns_self(dummy_predictor: Predictor) -> None:
12 | """Tests that the base model is returned if no implementation is registered.
13 |
14 | This uses the provided dummy_predictor fixture instead of a locally defined class.
15 |
16 | Parameters:
17 | dummy_predictor (Predictor): The generic predictor fixture supplied by pytest.
18 | """
19 | # Use the official fixture which adheres to the Predictor type.
20 | base = dummy_predictor
21 |
22 | transformed = evidential_regression(cast(Any, base))
23 |
24 | # Assert that the function returned the exact same object instance.
25 | assert transformed is base
26 |
--------------------------------------------------------------------------------
/tests/probly/flax_utils.py:
--------------------------------------------------------------------------------
1 | """Util functions for flax tests."""
2 |
3 | from __future__ import annotations
4 |
5 | import pytest
6 |
7 | flax = pytest.importorskip("flax")
8 |
9 | from flax import nnx # noqa: E402
10 |
11 |
12 | def count_layers(model: nnx.Module, layer_type: type[nnx.Module]) -> int:
13 | """Counts the number of layers of a specific type in a neural network model.
14 |
15 | This function iterates through all the modules in the given model and counts
16 | how many of them match the specified layer type. It's particularly useful
17 | for analyzing the architecture of a neural network or verifying its
18 | composition.
19 |
20 | Parameters:
21 | model: The neural network model containing the layers to be counted.
22 | layer_type: The type of layer to count within the model.
23 |
24 | Returns:
25 | The number of layers of the specified type found in the model.
26 | """
27 | return sum(1 for _, m in model.iter_modules() if isinstance(m, layer_type))
28 |
--------------------------------------------------------------------------------
/.pre-commit-config.yaml:
--------------------------------------------------------------------------------
1 | repos:
2 | - repo: https://github.com/pre-commit/pre-commit-hooks
3 | rev: v5.0.0
4 | hooks:
5 | - id: check-json
6 | - id: check-yaml
7 | - id: check-toml
8 | - id: end-of-file-fixer
9 | - id: trailing-whitespace
10 | - id: mixed-line-ending
11 | - id: check-merge-conflict
12 | - repo: https://github.com/astral-sh/ruff-pre-commit
13 | rev: v0.12.0
14 | hooks:
15 | - id: ruff-check
16 | args: [--fix, --exit-non-zero-on-fix, --no-cache]
17 | - id: ruff-format
18 | args: [--verbose]
19 |
20 | - repo: https://github.com/pre-commit/mirrors-mypy
21 | rev: v1.18.2
22 | hooks:
23 | - id: mypy
24 | exclude: docs
25 | args:
26 | - "--ignore-missing-imports"
27 | - "--scripts-are-modules"
28 | # Since pre-commit env does not have all dependencies installed,
29 | # we ignore unused ignores to avoid false positives:
30 | - "--no-warn-unused-ignores"
31 |
--------------------------------------------------------------------------------
/src/probly/representation/sampling/flax_sampler.py:
--------------------------------------------------------------------------------
1 | """Sampling preparation for flax."""
2 |
3 | from __future__ import annotations
4 |
5 | from typing import TYPE_CHECKING
6 |
7 | from flax import nnx
8 |
9 | from . import sampler
10 |
11 | if TYPE_CHECKING:
12 | from lazy_dispatch.isinstance import LazyType
13 | from pytraverse import State
14 |
15 |
16 | def _enforce_train_mode(obj: nnx.Module, state: State) -> tuple[nnx.Module, State]:
17 | if getattr(obj, "deterministic", False):
18 | obj.deterministic = False # type: ignore[attr-defined]
19 | state[sampler.CLEANUP_FUNCS].add(lambda: setattr(obj, "deterministic", True))
20 | return obj, state
21 | return obj, state
22 |
23 |
24 | def register_forced_train_mode(cls: LazyType) -> None:
25 | """Register a class to be forced into train mode during sampling."""
26 | sampler.sampling_preparation_traverser.register(
27 | cls,
28 | _enforce_train_mode,
29 | )
30 |
31 |
32 | register_forced_train_mode(
33 | nnx.Dropout,
34 | )
35 |
--------------------------------------------------------------------------------
/.github/PULL_REQUEST_TEMPLATE.md:
--------------------------------------------------------------------------------
1 | ## Issue
2 | Please link the corresponding GitHub issue. If an issue does not already exist,
3 | please open one to describe the bug or feature request before creating a pull request.
4 |
5 | This allows us to discuss the proposal and helps avoid unnecessary work.
6 |
7 | ## Motivation and Context
8 |
9 | ---
10 |
11 | ## Public API Changes
12 |
13 | - [ ] No Public API changes
14 | - [ ] Yes, Public API changes (Details below)
15 |
16 | ---
17 |
18 | ## How Has This Been Tested?
19 |
20 | ---
21 |
22 | ## Checklist
23 |
24 | - [ ] The changes have been tested locally.
25 | - [ ] Documentation has been updated (if the public API or usage changes).
26 | - [ ] A entry has been added to [`CHANGELOG.md`](https://github.com/pwhofman/probly/blob/main/CHANGELOG.md) (if relevant for users).
27 | - [ ] The code follows the project's [style guidelines](https://github.com/pwhofman/probly/blob/main/.github/CONTRIBUTING.md) and passes minimal code style checks.
28 | - [ ] I have considered the impact of these changes on the public API.
29 |
30 | ---
31 |
--------------------------------------------------------------------------------
/src/probly/representation/sampling/jax_sample.py:
--------------------------------------------------------------------------------
1 | """JAX sample implementation."""
2 |
3 | from __future__ import annotations
4 |
5 | import jax
6 | import jax.numpy as jnp
7 |
8 | from .sample import Sample, create_sample
9 |
10 |
11 | @create_sample.register(jax.Array)
12 | class JaxArraySample(Sample[jax.Array]):
13 | """A sample implementation for JAX arrays."""
14 |
15 | def __init__(self, samples: list[jax.Array]) -> None:
16 | """Initialize the JAX array sample."""
17 | self.array = jnp.stack(samples).transpose(1, 0, 2) # we use the convention [instances, samples, classes]
18 |
19 | def mean(self) -> jax.Array:
20 | """Compute the mean of the sample."""
21 | return jnp.mean(self.array, axis=1)
22 |
23 | def std(self, ddof: int = 1) -> jax.Array:
24 | """Compute the standard deviation of the sample."""
25 | return jnp.std(self.array, axis=1, ddof=ddof)
26 |
27 | def var(self, ddof: int = 1) -> jax.Array:
28 | """Compute the variance of the sample."""
29 | return jnp.var(self.array, axis=1, ddof=ddof)
30 |
--------------------------------------------------------------------------------
/src/probly/representation/sampling/torch_sample.py:
--------------------------------------------------------------------------------
1 | """Torch sample implementation."""
2 |
3 | from __future__ import annotations
4 |
5 | import torch
6 |
7 | from .sample import Sample, create_sample
8 |
9 |
10 | @create_sample.register(torch.Tensor)
11 | class TorchTensorSample(Sample[torch.Tensor]):
12 | """A sample implementation for torch tensors."""
13 |
14 | def __init__(self, samples: list[torch.Tensor]) -> None:
15 | """Initialize the torch tensor sample."""
16 | self.tensor = torch.stack(samples).permute(1, 0, 2) # we use the convention [instances, samples, classes]
17 |
18 | def mean(self) -> torch.Tensor:
19 | """Compute the mean of the sample."""
20 | return self.tensor.mean(dim=1)
21 |
22 | def std(self, ddof: int = 1) -> torch.Tensor:
23 | """Compute the standard deviation of the sample."""
24 | return self.tensor.std(dim=1, correction=ddof)
25 |
26 | def var(self, ddof: int = 1) -> torch.Tensor:
27 | """Compute the variance of the sample."""
28 | return self.tensor.var(dim=1, correction=ddof)
29 |
--------------------------------------------------------------------------------
/src/probly/transformation/bayesian/torch.py:
--------------------------------------------------------------------------------
1 | """Torch Bayesian implementation."""
2 |
3 | from __future__ import annotations
4 |
5 | from torch import nn
6 |
7 | from probly.layers.torch import BayesConv2d, BayesLinear
8 |
9 | from .common import register
10 |
11 |
12 | def replace_torch_bayesian_linear(
13 | obj: nn.Linear,
14 | use_base_weights: bool,
15 | posterior_std: float,
16 | prior_mean: float,
17 | prior_std: float,
18 | ) -> BayesLinear:
19 | """Replace a given layer by a BayesLinear layer."""
20 | return BayesLinear(obj, use_base_weights, posterior_std, prior_mean, prior_std)
21 |
22 |
23 | def replace_torch_bayesian_conv2d(
24 | obj: nn.Conv2d,
25 | use_base_weights: bool,
26 | posterior_std: float,
27 | prior_mean: float,
28 | prior_std: float,
29 | ) -> BayesConv2d:
30 | """Replace a given layer by a BayesConv2d layer."""
31 | return BayesConv2d(obj, use_base_weights, posterior_std, prior_mean, prior_std)
32 |
33 |
34 | register(nn.Linear, replace_torch_bayesian_linear)
35 | register(nn.Conv2d, replace_torch_bayesian_conv2d)
36 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2025 the probly team
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/.github/workflows/deploy-docs.yml:
--------------------------------------------------------------------------------
1 | name: Deploy Sphinx Docs to GitHub Pages
2 |
3 | on:
4 | push:
5 | branches:
6 | - main
7 |
8 | jobs:
9 | build:
10 | runs-on: ubuntu-latest
11 |
12 | steps:
13 | - name: Checkout repository
14 | uses: actions/checkout@v4
15 |
16 | - name: Set up Python and uv
17 | id: setup-uv-python
18 | uses: astral-sh/setup-uv@v5
19 | with:
20 | python-version: "3.12"
21 |
22 | - name: Install dependencies
23 | run: uv sync --dev
24 |
25 | - name: Install pandoc
26 | run: sudo apt-get install -y pandoc
27 |
28 | - name: Copy content to docs/source folder
29 | run: |
30 | cp .github/CONTRIBUTING.md docs/source/contributing.md
31 | cp -r notebooks/examples/ docs/source/
32 |
33 | - name: Build Sphinx documentation
34 | working-directory: ./docs
35 | run: |
36 | make clean html
37 |
38 | - name: Deploy to GitHub Pages
39 | uses: peaceiris/actions-gh-pages@v3
40 | with:
41 | github_token: ${{ secrets.GITHUB_TOKEN }}
42 | publish_dir: ./docs/build/html
43 | enable_jekyll: false
44 |
--------------------------------------------------------------------------------
/src/probly/representation/sampling/torch_sampler.py:
--------------------------------------------------------------------------------
1 | """Sampling preparation for torch."""
2 |
3 | from __future__ import annotations
4 |
5 | from typing import TYPE_CHECKING
6 |
7 | import torch.nn
8 |
9 | from probly.layers.torch import DropConnectLinear
10 |
11 | from . import sampler
12 |
13 | if TYPE_CHECKING:
14 | from lazy_dispatch.isinstance import LazyType
15 | from pytraverse import State
16 |
17 |
18 | def _enforce_train_mode(obj: torch.nn.Module, state: State) -> tuple[torch.nn.Module, State]:
19 | if not obj.training:
20 | obj.train()
21 | state[sampler.CLEANUP_FUNCS].add(lambda: obj.train(False))
22 | return obj, state
23 | return obj, state
24 |
25 |
26 | def register_forced_train_mode(cls: LazyType) -> None:
27 | """Register a class to be forced into train mode during sampling."""
28 | sampler.sampling_preparation_traverser.register(
29 | cls,
30 | _enforce_train_mode,
31 | )
32 |
33 |
34 | register_forced_train_mode(
35 | torch.nn.Dropout
36 | | torch.nn.Dropout1d
37 | | torch.nn.Dropout2d
38 | | torch.nn.Dropout3d
39 | | torch.nn.AlphaDropout
40 | | torch.nn.FeatureAlphaDropout
41 | | DropConnectLinear,
42 | )
43 |
--------------------------------------------------------------------------------
/tests/probly/torch_utils.py:
--------------------------------------------------------------------------------
1 | """Util functions for torch tests."""
2 |
3 | from __future__ import annotations
4 |
5 | import pytest
6 |
7 | torch = pytest.importorskip("torch")
8 |
9 | from torch import Tensor, nn # noqa: E402
10 |
11 |
12 | def validate_loss(loss: Tensor) -> None:
13 | assert isinstance(loss, Tensor)
14 | assert loss.dim() == 0
15 | assert not torch.isnan(loss)
16 | assert not torch.isinf(loss)
17 | assert loss.item() >= 0
18 |
19 |
20 | def count_layers(model: nn.Module, layer_type: type[nn.Module]) -> int:
21 | """Counts the number of layers of a specific type in a neural network model.
22 |
23 | This function iterates through all the modules in the given model and counts
24 | how many of them match the specified layer type. It's particularly useful
25 | for analyzing the architecture of a neural network or verifying its
26 | composition.
27 |
28 | Parameters:
29 | model: The neural network model containing the layers to be counted.
30 | layer_type: The type of layer to count within the model.
31 |
32 | Returns:
33 | The number of layers of the specified type found in the model.
34 | """
35 | return sum(1 for m in model.modules() if isinstance(m, layer_type))
36 |
--------------------------------------------------------------------------------
/.github/workflows/publish-release.yml:
--------------------------------------------------------------------------------
1 | # This workflow will upload a Python Package using Twine when a release is created
2 | # For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-python#publishing-to-package-registries
3 |
4 | # This workflow uses actions that are not certified by GitHub.
5 | # They are provided by a third-party and are governed by
6 | # separate terms of service, privacy policy, and support
7 | # documentation.
8 |
9 | name: Publish Release
10 |
11 | on:
12 | release:
13 | types: [published]
14 | branches: [main]
15 |
16 | permissions:
17 | contents: read
18 |
19 | jobs:
20 | deploy:
21 |
22 | runs-on: ubuntu-latest
23 |
24 | steps:
25 | - uses: actions/checkout@v4
26 |
27 | - name: Set up Python
28 | uses: actions/setup-python@v5
29 | with:
30 | python-version: '3.12'
31 |
32 | - name: Install dependencies
33 | run: |
34 | python -m pip install --upgrade pip
35 | pip install build
36 |
37 | - name: Build package
38 | run: python -m build
39 |
40 | - name: Publish package
41 | uses: pypa/gh-action-pypi-publish@v1.12.4
42 | with:
43 | user: __token__
44 | password: ${{ secrets.PYPI_API_TOKEN }}
45 |
--------------------------------------------------------------------------------
/tests/probly/plot/test_credal.py:
--------------------------------------------------------------------------------
1 | """Tests for the plot module."""
2 |
3 | from __future__ import annotations
4 |
5 | import matplotlib.pyplot as plt
6 | import numpy as np
7 | import pytest
8 |
9 | from probly.plot.credal import credal_set_plot, simplex_plot
10 |
11 |
12 | def test_simplex_plot_outputs() -> None:
13 | probs = np.array([[1 / 3, 1 / 3, 1 / 3], [1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0]])
14 | fig, ax = simplex_plot(probs)
15 | assert isinstance(fig, plt.Figure)
16 | assert isinstance(ax, plt.Axes)
17 | assert ax.name == "ternary"
18 | assert ax.collections[0].get_offsets().shape[0] == len(probs)
19 |
20 |
21 | def test_credal_set_plot() -> None:
22 | probs = np.array([[1 / 3, 1 / 3, 1 / 3], [1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0]])
23 | fig, ax = credal_set_plot(probs)
24 | assert isinstance(fig, plt.Figure)
25 | assert isinstance(ax, plt.Axes)
26 | assert ax.name == "ternary"
27 | assert ax.collections[0].get_offsets().shape[0] == len(probs)
28 |
29 | probs = np.array([[1 / 3, 1 / 3, 1 / 3], [1 / 3, 1 / 3, 1 / 3]])
30 | with pytest.raises(
31 | ValueError,
32 | match="The set of vertices is empty. Please check the probabilities in the credal set.",
33 | ):
34 | credal_set_plot(probs)
35 |
--------------------------------------------------------------------------------
/tests/probly/transformation/evidential/classification/test_common.py:
--------------------------------------------------------------------------------
1 | """Tests for evidential classification registration and dispatch."""
2 |
3 | from __future__ import annotations
4 |
5 | from collections.abc import Callable
6 | from typing import Any, cast
7 |
8 | from probly.transformation.evidential.classification.common import (
9 | evidential_classification,
10 | register,
11 | )
12 |
13 |
14 | def test_multiple_types_are_handled_independently() -> None:
15 | class ModelA:
16 | def __init__(self, mid: str) -> None:
17 | self.id = mid
18 |
19 | class ModelB:
20 | def __init__(self, mid: str) -> None:
21 | self.id = mid
22 |
23 | def appender_a(base: ModelA) -> str:
24 | return f"A_Enhanced({base.id})"
25 |
26 | def appender_b(base: ModelB) -> str:
27 | return f"B_Enhanced({base.id})"
28 |
29 | register(ModelA, cast(Callable[..., object], appender_a))
30 | register(ModelB, cast(Callable[..., object], appender_b))
31 |
32 | a = ModelA("001")
33 | b = ModelB("002")
34 |
35 | result_a = evidential_classification(cast(Any, a))
36 | result_b = evidential_classification(cast(Any, b))
37 |
38 | assert cast(str, result_a) == "A_Enhanced(001)"
39 | assert cast(str, result_b) == "B_Enhanced(002)"
40 |
--------------------------------------------------------------------------------
/src/pytraverse/__init__.py:
--------------------------------------------------------------------------------
1 | """Generic functional datastructure traverser utilities."""
2 |
3 | from . import composition, core, decorators, generic
4 |
5 | ## Core
6 |
7 | Variable = core.Variable
8 | GlobalVariable = core.GlobalVariable
9 | StackVariable = core.StackVariable
10 | ComputedVariable = core.ComputedVariable
11 | computed = ComputedVariable # Alias for convenience (intended to be used as a decorator)
12 |
13 | type State[T] = core.State[T]
14 | type TraverserResult[T] = core.TraverserResult[T]
15 | type TraverserCallback[T] = core.TraverserCallback[T]
16 | type Traverser[T] = core.Traverser[T]
17 |
18 | traverse_with_state = core.traverse_with_state
19 | traverse = core.traverse
20 |
21 | ## Traverser Decorator
22 |
23 | traverser = decorators.traverser
24 |
25 | ## Composition
26 |
27 | sequential = composition.sequential
28 | top_sequential = composition.top_sequential
29 | SingledispatchTraverser = composition.SingledispatchTraverser
30 | LazySingledispatchTraverser = composition.LazydispatchTraverser
31 |
32 | singledispatch_traverser = composition.SingledispatchTraverser
33 | lazydispatch_traverser = composition.LazydispatchTraverser
34 |
35 | ## Generic traverser
36 |
37 | generic_traverser = generic.generic_traverser
38 | CLONE = generic.CLONE
39 | TRAVERSE_KEYS = generic.TRAVERSE_KEYS
40 | TRAVERSE_REVERSED = generic.TRAVERSE_REVERSED
41 |
--------------------------------------------------------------------------------
/src/probly/transformation/ensemble/torch.py:
--------------------------------------------------------------------------------
1 | """Torch dropout implementation."""
2 |
3 | from __future__ import annotations
4 |
5 | from torch import nn
6 |
7 | from probly.traverse_nn import nn_compose, nn_traverser
8 | from pytraverse import CLONE, singledispatch_traverser, traverse
9 |
10 | from .common import register
11 |
12 | reset_traverser = singledispatch_traverser[nn.Module](name="reset_traverser")
13 |
14 |
15 | @reset_traverser.register
16 | def _(obj: nn.Module) -> nn.Module:
17 | if hasattr(obj, "reset_parameters"):
18 | obj.reset_parameters() # type: ignore[operator]
19 | return obj
20 |
21 |
22 | def _reset_copy(module: nn.Module) -> nn.Module:
23 | return traverse(module, nn_compose(reset_traverser), init={CLONE: True})
24 |
25 |
26 | def _copy(module: nn.Module) -> nn.Module:
27 | return traverse(module, nn_traverser, init={CLONE: True})
28 |
29 |
30 | def generate_torch_ensemble(
31 | obj: nn.Module,
32 | num_members: int,
33 | reset_params: bool = True,
34 | ) -> nn.ModuleList:
35 | """Build a torch ensemble by copying the base model num_members times, resetting the parameters of each member."""
36 | if reset_params:
37 | return nn.ModuleList([_reset_copy(obj) for _ in range(num_members)])
38 | return nn.ModuleList([_copy(obj) for _ in range(num_members)])
39 |
40 |
41 | register(nn.Module, generate_torch_ensemble)
42 |
--------------------------------------------------------------------------------
/tests/probly/transformation/ensemble/test_common.py:
--------------------------------------------------------------------------------
1 | """Tests for commonpre ensemble generation."""
2 |
3 | from __future__ import annotations
4 |
5 | from unittest.mock import Mock
6 |
7 | import pytest
8 |
9 | from probly.predictor import Predictor
10 | from probly.transformation import ensemble
11 | from probly.transformation.ensemble.common import ensemble_generator, register
12 |
13 |
14 | def test_unregistered_type_raises(dummy_predictor: Predictor) -> None:
15 | """No ensemble generator is registered for type, NotImplementedError must occur."""
16 | base = dummy_predictor
17 | with pytest.raises(
18 | NotImplementedError,
19 | match=f"No ensemble generator is registered for type {type(base)}",
20 | ):
21 | ensemble_generator(dummy_predictor)
22 |
23 |
24 | def test_registered_generator_called(dummy_predictor: Predictor) -> None:
25 | """If the type is registered, the appropriate generator is called and its result is returned."""
26 | mock_generator = Mock()
27 | expected_result = object()
28 | mock_generator.return_value = expected_result
29 |
30 | register(type(dummy_predictor), mock_generator)
31 |
32 | result = ensemble(dummy_predictor, num_members=4)
33 |
34 | mock_generator.assert_called_once_with(
35 | dummy_predictor,
36 | num_members=4,
37 | reset_params=True,
38 | )
39 | assert result is expected_result
40 |
--------------------------------------------------------------------------------
/src/probly/transformation/evidential/classification/common.py:
--------------------------------------------------------------------------------
1 | """Shared evidential classification implementation."""
2 |
3 | from __future__ import annotations
4 |
5 | from typing import TYPE_CHECKING
6 |
7 | from lazy_dispatch import lazydispatch
8 |
9 | if TYPE_CHECKING:
10 | from collections.abc import Callable
11 |
12 | from lazy_dispatch.isinstance import LazyType
13 | from probly.predictor import Predictor
14 |
15 |
16 | @lazydispatch
17 | def evidential_classification_appender[In, KwIn, Out](base: Predictor[In, KwIn, Out]) -> Predictor[In, KwIn, Out]:
18 | """Append an evidential classification activation function to a base model."""
19 | msg = f"No evidential classification appender registered for type {type(base)}"
20 | raise NotImplementedError(msg)
21 |
22 |
23 | def register(cls: LazyType, appender: Callable) -> None:
24 | """Register a base model that the activation function will be appended to."""
25 | evidential_classification_appender.register(cls=cls, func=appender)
26 |
27 |
28 | def evidential_classification[T: Predictor](base: T) -> T:
29 | """Create an evidential classification predictor from a base predictor.
30 |
31 | Args:
32 | base: Predictor, The base model to be used for evidential classification.
33 |
34 | Returns:
35 | Predictor, The evidential classification predictor.
36 | """
37 | return evidential_classification_appender(base)
38 |
--------------------------------------------------------------------------------
/src/probly/transformation/ensemble/common.py:
--------------------------------------------------------------------------------
1 | """Shared ensemble implementation."""
2 |
3 | from __future__ import annotations
4 |
5 | from typing import TYPE_CHECKING
6 |
7 | from lazy_dispatch import lazydispatch
8 |
9 | if TYPE_CHECKING:
10 | from collections.abc import Callable
11 |
12 | from lazy_dispatch.isinstance import LazyType
13 | from probly.predictor import Predictor
14 |
15 |
16 | @lazydispatch
17 | def ensemble_generator[In, KwIn, Out](base: Predictor[In, KwIn, Out]) -> Predictor[In, KwIn, Out]:
18 | """Generate an ensemble from a base model."""
19 | msg = f"No ensemble generator is registered for type {type(base)}"
20 | raise NotImplementedError(msg)
21 |
22 |
23 | def register(cls: LazyType, generator: Callable) -> None:
24 | """Register a class which can be used as a base for an ensemble."""
25 | ensemble_generator.register(cls=cls, func=generator)
26 |
27 |
28 | def ensemble[T: Predictor](base: T, num_members: int, reset_params: bool = True) -> T:
29 | """Create an ensemble predictor from a base predictor.
30 |
31 | Args:
32 | base: Predictor, The base model to be used for the ensemble.
33 | num_members: The number of members in the ensemble.
34 | reset_params: Whether to reset the parameters of each member.
35 |
36 | Returns:
37 | Predictor, The ensemble predictor.
38 | """
39 | return ensemble_generator(base, num_members=num_members, reset_params=reset_params)
40 |
--------------------------------------------------------------------------------
/tests/probly/tests_deprecations/deprecated_features.py:
--------------------------------------------------------------------------------
1 | """Collects all deprecated behaviour tests."""
2 |
3 | from __future__ import annotations
4 |
5 | from typing import TYPE_CHECKING
6 |
7 | from probly.utils.errors import raise_deprecation_warning
8 |
9 | if TYPE_CHECKING:
10 | import pytest
11 |
12 | from .utils import register_deprecated
13 |
14 |
15 | @register_deprecated(
16 | name="ExampleFeature(NeverRemoveThis)",
17 | deprecated_in="1.0.0",
18 | removed_in="9.9.9",
19 | )
20 | def example_always_warn(requests: pytest.FixtureRequest) -> None:
21 | """An example feature that always raises a deprecation warning using a fixture.
22 |
23 | This "behavior" is just an example of how to handle deprecations in the codebase. This
24 | "feature" needs a fixture (in this case an example game) which is passed via the `request`
25 | fixture. The warning which is raised contains the same version information as the
26 | `DeprecatedFeature` instance in the `DEPRECATED_FEATURES` list.
27 |
28 | Note:
29 | Do not delete this "feature", as it is a good example of how to handle deprecations and
30 | write your own deprecation tests.
31 |
32 | """
33 | raise_deprecation_warning(
34 | "This is an example of a deprecated feature that always raises a warning.",
35 | deprecated_in="1.0.0",
36 | removed_in="9.9.9",
37 | )
38 | _ = requests.getfixturevalue("torch_conv_linear_model") # gets a fixture as an example of usage
39 |
--------------------------------------------------------------------------------
/tests/probly/quantification/test_regression.py:
--------------------------------------------------------------------------------
1 | """Tests for the regression module."""
2 |
3 | from __future__ import annotations
4 |
5 | from typing import TYPE_CHECKING
6 |
7 | import numpy as np
8 | import pytest
9 |
10 | from tests.probly.general_utils import validate_uncertainty
11 |
12 | if TYPE_CHECKING:
13 | from collections.abc import Callable
14 |
15 |
16 | from probly.quantification.regression import (
17 | conditional_differential_entropy,
18 | expected_conditional_variance,
19 | mutual_information,
20 | total_differential_entropy,
21 | total_variance,
22 | variance_conditional_expectation,
23 | )
24 |
25 |
26 | @pytest.fixture
27 | def sample_second_order_params() -> np.ndarray:
28 | rng = np.random.default_rng()
29 | mu = rng.uniform(0, 1000, (5, 10))
30 | sigma2 = rng.uniform(1e-20, 100, (5, 10))
31 | params = np.stack((mu, sigma2), axis=2)
32 | return params
33 |
34 |
35 | @pytest.mark.parametrize(
36 | "uncertainty_fn",
37 | [
38 | total_variance,
39 | expected_conditional_variance,
40 | variance_conditional_expectation,
41 | total_differential_entropy,
42 | conditional_differential_entropy,
43 | mutual_information,
44 | ],
45 | )
46 | def test_uncertainty_function(
47 | uncertainty_fn: Callable[[np.ndarray], np.ndarray],
48 | sample_second_order_params: np.ndarray,
49 | ) -> None:
50 | uncertainty = uncertainty_fn(sample_second_order_params)
51 | validate_uncertainty(uncertainty)
52 |
--------------------------------------------------------------------------------
/tests/probly/utils/test_torch.py:
--------------------------------------------------------------------------------
1 | """Tests for utils.torch functions."""
2 |
3 | from __future__ import annotations
4 |
5 | import torch
6 | from torch.utils.data import DataLoader, TensorDataset
7 |
8 | from probly.utils.torch import temperature_softmax, torch_collect_outputs, torch_reset_all_parameters
9 |
10 |
11 | def test_torch_reset_all_parameters(torch_conv_linear_model: torch.nn.Module) -> None:
12 | def flatten_params(model: torch.nn.Module) -> torch.Tensor:
13 | return torch.cat([param.flatten() for param in model.parameters()])
14 |
15 | before = flatten_params(torch_conv_linear_model)
16 | torch_reset_all_parameters(torch_conv_linear_model)
17 | after = flatten_params(torch_conv_linear_model)
18 | assert not torch.equal(before, after)
19 |
20 |
21 | def test_torch_collect_outputs(torch_conv_linear_model: torch.nn.Module) -> None:
22 | loader = DataLoader(
23 | TensorDataset(
24 | torch.randn(2, 3, 5, 5),
25 | torch.randn(
26 | 2,
27 | ),
28 | ),
29 | )
30 | outputs, targets = torch_collect_outputs(torch_conv_linear_model, loader, torch.device("cpu"))
31 | assert outputs.shape == (2, 2)
32 | assert targets.shape == (2,)
33 |
34 |
35 | def test_temperature_softmax() -> None:
36 | x = torch.tensor([[1.0, 2.0], [3.0, 4.0]])
37 | assert torch.equal(temperature_softmax(x, 2.0), torch.softmax(x / 2.0, dim=1))
38 | assert torch.equal(temperature_softmax(x, torch.tensor(1.0)), torch.softmax(x, dim=1))
39 |
--------------------------------------------------------------------------------
/src/probly/transformation/dropout/common.py:
--------------------------------------------------------------------------------
1 | """Shared dropout implementation."""
2 |
3 | from __future__ import annotations
4 |
5 | from typing import TYPE_CHECKING
6 |
7 | from probly.traverse_nn import is_first_layer, nn_compose
8 | from pytraverse import CLONE, GlobalVariable, lazydispatch_traverser, traverse
9 |
10 | if TYPE_CHECKING:
11 | from lazy_dispatch.isinstance import LazyType
12 | from probly.predictor import Predictor
13 | from pytraverse.composition import RegisteredLooseTraverser
14 |
15 | P = GlobalVariable[float]("P", "The probability of dropout.")
16 |
17 | dropout_traverser = lazydispatch_traverser[object](name="dropout_traverser")
18 |
19 |
20 | def register(cls: LazyType, traverser: RegisteredLooseTraverser) -> None:
21 | """Register a class to be prepended by Dropout layers."""
22 | dropout_traverser.register(cls=cls, traverser=traverser, skip_if=is_first_layer, vars={"p": P})
23 |
24 |
25 | def dropout[T: Predictor](base: T, p: float = 0.25) -> T:
26 | """Create a Dropout predictor from a base predictor.
27 |
28 | Args:
29 | base: Predictor, The base model to be used for dropout.
30 | p: float, The probability of dropping out a neuron. Default is 0.25.
31 |
32 | Returns:
33 | Predictor, The DropOut predictor.
34 | """
35 | if p < 0 or p > 1:
36 | msg = f"The probability p must be between 0 and 1, but got {p} instead."
37 | raise ValueError(msg)
38 | return traverse(base, nn_compose(dropout_traverser), init={P: p, CLONE: True})
39 |
--------------------------------------------------------------------------------
/src/probly/transformation/dropconnect/common.py:
--------------------------------------------------------------------------------
1 | """Shared DropConnect implementation."""
2 |
3 | from __future__ import annotations
4 |
5 | from typing import TYPE_CHECKING
6 |
7 | from probly.traverse_nn import is_first_layer, nn_compose
8 | from pytraverse import CLONE, GlobalVariable, lazydispatch_traverser, traverse
9 |
10 | if TYPE_CHECKING:
11 | from lazy_dispatch.isinstance import LazyType
12 | from probly.predictor import Predictor
13 | from pytraverse.composition import RegisteredLooseTraverser
14 |
15 | P = GlobalVariable[float]("P", "The probability of dropconnect.")
16 |
17 | dropconnect_traverser = lazydispatch_traverser[object](name="dropconnect_traverser")
18 |
19 |
20 | def register(cls: LazyType, traverser: RegisteredLooseTraverser) -> None:
21 | """Register a class to be replaced by DropConnect layers."""
22 | dropconnect_traverser.register(cls=cls, traverser=traverser, skip_if=is_first_layer, vars={"p": P})
23 |
24 |
25 | def dropconnect[T: Predictor](base: T, p: float = 0.25) -> T:
26 | """Create a DropConnect predictor from a base predictor.
27 |
28 | Args:
29 | base: The base model to be used for dropout.
30 | p: The probability of dropping out a neuron. Default is 0.25.
31 |
32 | Returns:
33 | The DropConnect predictor.
34 | """
35 | if p < 0 or p > 1:
36 | msg = f"The probability p must be between 0 and 1, but got {p} instead."
37 | raise ValueError(msg)
38 | return traverse(base, nn_compose(dropconnect_traverser), init={P: p, CLONE: True})
39 |
--------------------------------------------------------------------------------
/tests/probly/train/calibration/test_torch.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | import pytest
4 |
5 | from probly.train.calibration.torch import ExpectedCalibrationError, FocalLoss, LabelRelaxationLoss
6 | from tests.probly.torch_utils import validate_loss
7 |
8 | torch = pytest.importorskip("torch")
9 | from torch import Tensor # noqa: E402
10 |
11 |
12 | def test_focal_loss(sample_outputs: tuple[Tensor, Tensor]) -> None:
13 | outputs, targets = sample_outputs
14 | criterion = FocalLoss()
15 | loss = criterion(outputs, targets)
16 | validate_loss(loss)
17 | # TODO(pwhofman): Add tests for different values of alpha and gamma
18 | # https://github.com/pwhofman/probly/issues/92
19 |
20 |
21 | def test_expected_calibration_error(
22 | sample_outputs: tuple[Tensor, Tensor],
23 | ) -> None:
24 | outputs, targets = sample_outputs
25 | outputs = torch.softmax(outputs, dim=1)
26 | criterion = ExpectedCalibrationError()
27 | loss = criterion(outputs, targets)
28 | validate_loss(loss)
29 |
30 | criterion = ExpectedCalibrationError(num_bins=1)
31 | loss = criterion(outputs, targets)
32 | validate_loss(loss)
33 |
34 |
35 | def test_label_relaxation_loss(
36 | sample_outputs: tuple[Tensor, Tensor],
37 | ) -> None:
38 | outputs, targets = sample_outputs
39 | criterion = LabelRelaxationLoss()
40 | loss = criterion(outputs, targets)
41 | validate_loss(loss)
42 |
43 | criterion = LabelRelaxationLoss(alpha=1.0)
44 | loss = criterion(outputs, targets)
45 | validate_loss(loss)
46 |
--------------------------------------------------------------------------------
/src/probly/transformation/evidential/regression/common.py:
--------------------------------------------------------------------------------
1 | """Shared evidential regression implementation."""
2 |
3 | from __future__ import annotations
4 |
5 | from typing import TYPE_CHECKING
6 |
7 | from probly.traverse_nn import nn_compose
8 | from pytraverse import CLONE, TRAVERSE_REVERSED, GlobalVariable, lazydispatch_traverser, traverse
9 |
10 | if TYPE_CHECKING:
11 | from lazy_dispatch.isinstance import LazyType
12 | from probly.predictor import Predictor
13 | from pytraverse.composition import RegisteredLooseTraverser
14 |
15 | REPLACED_LAST_LINEAR = GlobalVariable[bool](
16 | "REPLACED_LAST_LINEAR",
17 | "Whether the last linear layer has been replaced with a NormalInverseGammaLinear layer.",
18 | default=False,
19 | )
20 |
21 | evidential_regression_traverser = lazydispatch_traverser[object](name="evidential_regression_traverser")
22 |
23 |
24 | def register(cls: LazyType, traverser: RegisteredLooseTraverser) -> None:
25 | """Register a class to be replaced by a normal inverse gamma layer."""
26 | evidential_regression_traverser.register(cls=cls, traverser=traverser, skip_if=lambda s: s[REPLACED_LAST_LINEAR])
27 |
28 |
29 | def evidential_regression[T: Predictor](base: T) -> T:
30 | """Create an evidential regression predictor from a base predictor.
31 |
32 | Args:
33 | base: Predictor, The base model to be used for evidential regression.
34 |
35 | Returns:
36 | Predictor, The evidential regression predictor.
37 | """
38 | return traverse(base, nn_compose(evidential_regression_traverser), init={TRAVERSE_REVERSED: True, CLONE: True})
39 |
--------------------------------------------------------------------------------
/tests/probly/evaluation/test_tasks.py:
--------------------------------------------------------------------------------
1 | """Tests for the tasks module."""
2 |
3 | from __future__ import annotations
4 |
5 | import numpy as np
6 | import pytest
7 |
8 | from probly.evaluation.tasks import out_of_distribution_detection, selective_prediction
9 |
10 |
11 | def test_selective_prediction_shapes() -> None:
12 | rng = np.random.default_rng()
13 | auroc, bin_losses = selective_prediction(rng.random(10), rng.random(10), n_bins=5)
14 | assert isinstance(auroc, float)
15 | assert isinstance(bin_losses, np.ndarray)
16 | assert bin_losses.shape == (5,)
17 |
18 |
19 | def test_selective_prediction_order() -> None:
20 | criterion = np.linspace(0, 1, 10)
21 | losses = np.linspace(0, 1, 10)
22 | _, bin_losses = selective_prediction(criterion, losses, n_bins=5)
23 | assert np.all(np.diff(bin_losses) <= 0)
24 |
25 |
26 | def test_selective_prediction_too_many_bins() -> None:
27 | rng = np.random.default_rng()
28 | with pytest.raises(ValueError, match="The number of bins can not be larger than the number of elements criterion"):
29 | selective_prediction(rng.random(5), rng.random(5), n_bins=10)
30 |
31 |
32 | def test_out_of_distribution_detection_shape() -> None:
33 | rng = np.random.default_rng()
34 | auroc = out_of_distribution_detection(rng.random(10), rng.random(10))
35 | assert isinstance(auroc, float)
36 |
37 |
38 | def test_out_of_distribution_detection_order() -> None:
39 | in_distribution = np.linspace(0, 1, 10)
40 | out_distribution = np.linspace(0, 1, 10) + 1
41 | auroc = out_of_distribution_detection(in_distribution, out_distribution)
42 | assert auroc == 0.995
43 |
--------------------------------------------------------------------------------
/docs/source/methods.md:
--------------------------------------------------------------------------------
1 | # Implemented methods
2 | The following methods are currently implemented in `probly`.
3 |
4 | ## Representation
5 | ### Second-order Distributions
6 | These methods represent (epistemic) uncertainty by a second-order distribution over distributions.
7 | #### Bayesian Neural Networks
8 | {cite:p}`blundellWeightUncertainty2015`
9 |
10 | #### Dropout
11 | {cite:p}`galDropoutBayesian2016`
12 |
13 | #### DropConnect
14 | {cite:p}`mobinyDropConnectEffective2019`
15 |
16 | #### Deep Ensembles
17 | {cite:p}`lakshminarayananSimpleScalable2017`
18 |
19 | #### Evidential Deep Learning
20 | {cite:p}`sensoyEvidentialDeep2018`
21 | {cite:p}`aminiDeepEvidential2020`
22 |
23 | ### Credal Sets
24 | These methods represent (epistemic) uncertainty by a convex set of distributions.
25 | #### Credal Ensembling
26 | {cite:p}`nguyenCredalEnsembling2025`
27 |
28 | ## Quantification
29 | #### Upper / lower entropy
30 | {cite:p}`abellanDisaggregatedTotal2006`
31 |
32 | #### Generalized Hartley
33 | {cite:p}`abellanNonSpecificity2000`
34 |
35 | #### Entropy-based
36 | {cite:p}`depewegDecompositionUncertainty2018`
37 |
38 | #### Distance-based
39 | {cite:p}`saleSecondOrder2024`
40 |
41 | ### Conformal Prediction
42 | These methods represent uncertainty by a set of predictions.
43 | #### Split Conformal Prediction
44 | {cite:p}`angelopoulosGentleIntroduction2021`
45 | ## Calibration
46 | These methods adjust the model's probabilities to better reflect the true probabilities.
47 | #### Focal Loss
48 | {cite:p}`linFocalLoss2017`
49 |
50 | #### Label Relaxation
51 | {cite:p}`lienenFromLabel2021`
52 |
53 | #### Temperature Scaling
54 | {cite:p}`guoOnCalibration2017`
55 |
--------------------------------------------------------------------------------
/tests/probly/utils/test_probabilities.py:
--------------------------------------------------------------------------------
1 | """Tests for utils.probabilities functions."""
2 |
3 | from __future__ import annotations
4 |
5 | import numpy as np
6 |
7 | from probly.utils.probabilities import differential_entropy_gaussian, intersection_probability, kl_divergence_gaussian
8 |
9 |
10 | def test_differential_entropy_gaussian() -> None:
11 | assert np.isclose(differential_entropy_gaussian(0.5), 1.54709559)
12 | assert np.allclose(differential_entropy_gaussian(np.array([1, 2]), base=np.e), np.array([1.41893853, 1.76551212]))
13 |
14 |
15 | def test_kl_divergence_gaussian() -> None:
16 | mu1 = np.array([0.0, 1.0])
17 | mu2 = np.array([1.0, 0.0])
18 | sigma21 = np.array([0.1, 0.1])
19 | sigma22 = np.array([0.1, 0.1])
20 | assert np.isclose(kl_divergence_gaussian(1.0, 1.0, 1.0, 1.0), 0.0)
21 | assert np.isclose(kl_divergence_gaussian(1.0, 1.0, 1.0, 1.0, base=np.e), 0.0)
22 | assert np.allclose(kl_divergence_gaussian(mu1, sigma21, mu2, sigma22, base=np.e), np.array([5.0, 5.0]))
23 |
24 |
25 | def test_intersection_probability() -> None:
26 | rng = np.random.default_rng()
27 |
28 | probs = rng.dirichlet(np.ones(2), size=(5, 5))
29 | int_prob = intersection_probability(probs)
30 | assert int_prob.shape == (5, 2)
31 | assert np.allclose(np.sum(int_prob, axis=1), 1.0)
32 |
33 | probs = rng.dirichlet(np.ones(10), size=(5, 5))
34 | int_prob = intersection_probability(probs)
35 | assert int_prob.shape == (5, 10)
36 | assert np.allclose(np.sum(int_prob, axis=1), 1.0)
37 |
38 | probs = np.array([1 / 3, 1 / 3, 1 / 3] * 5).reshape(1, 5, 3)
39 | int_prob = intersection_probability(probs)
40 | assert int_prob.shape == (1, 3)
41 | assert np.allclose(np.sum(int_prob, axis=1), 1.0)
42 |
--------------------------------------------------------------------------------
/tests/probly/transformation/bayesian/test_common.py:
--------------------------------------------------------------------------------
1 | """Test for dropout models."""
2 |
3 | from __future__ import annotations
4 |
5 | import pytest
6 |
7 | from probly.predictor import Predictor
8 | from probly.transformation import bayesian
9 |
10 |
11 | def test_invalid_prior_std_value(dummy_predictor: Predictor) -> None:
12 | """Tests the behavior of the bayesian function when provided with an invalid prior standard deviation value.
13 |
14 | This function validates that the bayesian function raises a ValueError when
15 | the prior standard deviation parameter is not positive.
16 |
17 | Raises:
18 | ValueError: If the prior standard deviation is not positive or equal to zero.
19 | """
20 | prior_std = -1.0
21 | msg = f"The prior standard deviation prior_std must be greater than 0, but got {prior_std} instead."
22 | with pytest.raises(ValueError, match=msg):
23 | bayesian(dummy_predictor, prior_std=prior_std)
24 |
25 |
26 | def test_invalid_posterior_std_value(dummy_predictor: Predictor) -> None:
27 | """Tests the behavior of the bayesian function when provided with an invalid posterior standard deviation value.
28 |
29 | This function validates that the bayesian function raises a ValueError when
30 | the posterior standard deviation parameter is not positive.
31 |
32 | Raises:
33 | ValueError: If the posterior standard deviation is not positive or equal to zero.
34 | """
35 | posterior_std = -1.0
36 | msg = (
37 | "The initial posterior standard deviation posterior_std must be greater than 0, "
38 | f"but got {posterior_std} instead."
39 | )
40 | with pytest.raises(ValueError, match=msg):
41 | bayesian(dummy_predictor, posterior_std=posterior_std)
42 |
--------------------------------------------------------------------------------
/src/probly/traverse_nn/common.py:
--------------------------------------------------------------------------------
1 | """Generic traverser helpers for neural networks."""
2 |
3 | from __future__ import annotations
4 |
5 | import pytraverse as t
6 |
7 | LAYER_COUNT = t.GlobalVariable[int](
8 | "LAYER_COUNT",
9 | "The DFS index of the current layer/module.",
10 | default=0,
11 | )
12 | FLATTEN_SEQUENTIAL = t.StackVariable[bool](
13 | "FLATTEN_SEQUENTIAL",
14 | "Whether to flatten sequential modules after making changes.",
15 | default=True,
16 | )
17 |
18 |
19 | @t.computed
20 | def is_first_layer(state: t.State) -> bool:
21 | """Whether the current layer is the first layer."""
22 | return state[LAYER_COUNT] == 0
23 |
24 |
25 | layer_count_traverser = t.singledispatch_traverser[object](name="layer_count_traverser")
26 |
27 | nn_traverser = t.lazydispatch_traverser[object](name="nn_traverser")
28 |
29 |
30 | def compose(
31 | traverser: t.Traverser,
32 | nn_traverser: t.Traverser = nn_traverser,
33 | name: str | None = None,
34 | ) -> t.Traverser:
35 | """Compose a custom traverser with neural network traversal functionality.
36 |
37 | This function creates a sequential traverser that combines neural network traversal,
38 | a custom traverser, and layer counting capabilities in a specific order.
39 |
40 | Args:
41 | traverser: A custom traverser function to be composed with the NN traverser.
42 | nn_traverser: The neural network traverser to use. Defaults to the module's
43 | nn_traverser.
44 | name: Optional name for the composed traverser.
45 |
46 | Returns:
47 | A composed sequential traverser that applies NN traversal, custom traversal,
48 | and layer counting in sequence.
49 | """
50 | return t.sequential(nn_traverser, traverser, layer_count_traverser, name=name)
51 |
--------------------------------------------------------------------------------
/src/probly/utils/sets.py:
--------------------------------------------------------------------------------
1 | """Utility functions regarding sets."""
2 |
3 | from __future__ import annotations
4 |
5 | import itertools
6 | from typing import TYPE_CHECKING
7 |
8 | import numpy as np
9 |
10 | if TYPE_CHECKING:
11 | from collections.abc import Iterable
12 |
13 |
14 | def powerset(iterable: Iterable[int]) -> list[tuple[()]]:
15 | """Generate the power set of a given iterable.
16 |
17 | Args:
18 | iterable: Iterable
19 | Returns:
20 | List[tuple], power set of the given iterable
21 |
22 | """
23 | s = list(iterable)
24 | return list(itertools.chain.from_iterable(itertools.combinations(s, r) for r in range(len(s) + 1)))
25 |
26 |
27 | def capacity(q: np.ndarray, a: Iterable[int]) -> np.ndarray:
28 | """Compute the capacity of set q given set a.
29 |
30 | Args:
31 | q: numpy.ndarray, shape (n_instances, n_samples, n_classes)
32 | a: Iterable, shape (n_classes,), indices indicating subset of classes
33 | Returns:
34 | min_capacity: numpy.ndarray, shape (n_instances,), capacity of q given a
35 |
36 | """
37 | selected_sum = np.sum(q[:, :, a], axis=2)
38 | min_capacity = np.min(selected_sum, axis=1)
39 | return min_capacity
40 |
41 |
42 | def moebius(q: np.ndarray, a: Iterable[int]) -> np.ndarray:
43 | """Compute the Moebius function of a set q given a set a.
44 |
45 | Args:
46 | q: numpy.ndarray of shape (num_samples, num_members, num_classes)
47 | a: numpy.ndarray, shape (n_classes,), indices indicating subset of classes
48 | Returns:
49 | m_a: numpy.ndarray, shape (n_instances,), moebius value of q given a
50 |
51 | """
52 | ps_a = powerset(a) # powerset of A
53 | ps_a.pop(0) # remove empty set
54 | m_a = np.zeros(q.shape[0])
55 | for b in ps_a:
56 | dl = len(set(a) - set(b))
57 | m_a += ((-1) ** dl) * capacity(q, b)
58 | return m_a
59 |
--------------------------------------------------------------------------------
/src/probly/representation/credal_set/credal_set.py:
--------------------------------------------------------------------------------
1 | """Classes representing credal sets."""
2 |
3 | from __future__ import annotations
4 |
5 | from typing import TYPE_CHECKING
6 |
7 | from lazy_dispatch.singledispatch import lazydispatch
8 | from probly.lazy_types import JAX_ARRAY, TORCH_TENSOR
9 | from probly.representation.sampling.sample import ArraySample
10 |
11 | if TYPE_CHECKING:
12 | import numpy as np
13 |
14 |
15 | class CredalSet[T]:
16 | """Base class for credal sets."""
17 |
18 | def lower(self) -> T:
19 | """Compute the lower envelope of the credal set."""
20 | msg = "lower method not implemented."
21 | raise NotImplementedError(msg)
22 |
23 | def upper(self) -> T:
24 | """Compute the upper envelope of the credal set."""
25 | msg = "upper method not implemented."
26 | raise NotImplementedError(msg)
27 |
28 |
29 | credal_set_from_sample = lazydispatch[type[CredalSet], CredalSet](CredalSet)
30 |
31 |
32 | @credal_set_from_sample.register(ArraySample)
33 | class ArrayCredalSet[T](CredalSet[T]):
34 | """A credal set of predictions stored in a numpy array."""
35 |
36 | def __init__(self, sample: ArraySample) -> None:
37 | """Initialize the array credal set."""
38 | self.array: np.ndarray = sample.array
39 |
40 | def lower(self) -> T:
41 | """Compute the lower envelope of the credal set."""
42 | return self.array.min(axis=1) # type: ignore[no-any-return]
43 |
44 | def upper(self) -> T:
45 | """Compute the upper envelope of the credal set."""
46 | return self.array.max(axis=1) # type: ignore[no-any-return]
47 |
48 |
49 | @credal_set_from_sample.delayed_register(TORCH_TENSOR)
50 | def _(_: type) -> None:
51 | from . import torch as torch # noqa: PLC0414, PLC0415
52 |
53 |
54 | @credal_set_from_sample.delayed_register(JAX_ARRAY)
55 | def _(_: type) -> None:
56 | from . import jax as jax # noqa: PLC0414, PLC0415
57 |
--------------------------------------------------------------------------------
/src/lazy_dispatch/load.py:
--------------------------------------------------------------------------------
1 | """Lazy import utilities."""
2 |
3 | from __future__ import annotations
4 |
5 | import importlib
6 | import importlib.util
7 | import sys
8 | from typing import TYPE_CHECKING, Any, overload
9 |
10 | if TYPE_CHECKING:
11 | from collections.abc import Callable, Iterable
12 | from types import ModuleType
13 |
14 |
15 | def lazy_import(name: str, package: str | None = None, register: bool = False) -> ModuleType:
16 | """Lazily import a module."""
17 | if name in sys.modules:
18 | return sys.modules[name]
19 |
20 | spec = importlib.util.find_spec(name, package=package)
21 |
22 | if spec is None or spec.loader is None:
23 | msg = f"Module {name} not found"
24 | raise ImportError(msg)
25 |
26 | loader: importlib.util.LazyLoader = importlib.util.LazyLoader(spec.loader)
27 | spec.loader = loader
28 | module = importlib.util.module_from_spec(spec)
29 | if register:
30 | sys.modules[name] = module
31 | loader.exec_module(module)
32 | return module
33 |
34 |
35 | @overload
36 | def lazy_callable(
37 | module: ModuleType | str,
38 | attrs: str,
39 | package: str | None = None,
40 | register: bool = False,
41 | ) -> Callable: ...
42 |
43 |
44 | @overload
45 | def lazy_callable(
46 | module: ModuleType | str,
47 | attrs: Iterable[str],
48 | package: str | None = None,
49 | register: bool = False,
50 | ) -> list[Callable]: ...
51 |
52 |
53 | def lazy_callable(
54 | module: ModuleType | str,
55 | attrs: str | Iterable[str],
56 | package: str | None = None,
57 | register: bool = False,
58 | ) -> Callable | list[Callable]:
59 | """Lazily get a callable attribute from a module or module name."""
60 | if isinstance(module, str):
61 | module = lazy_import(module, package=package, register=register)
62 |
63 | if isinstance(attrs, str):
64 |
65 | def fn(*args: Any, **kwargs: Any) -> Any: # noqa: ANN401
66 | return getattr(module, attrs)(*args, **kwargs)
67 |
68 | return fn
69 |
70 | return [lazy_callable(module, attr) for attr in attrs]
71 |
--------------------------------------------------------------------------------
/tests/probly/markers.py:
--------------------------------------------------------------------------------
1 | """Markers for optional dependencies.
2 |
3 | Markers can be imported in the test files to skip tests if certain packages are not installed.
4 | Note that we want to avoid marked tests as best as possible, but sometimes dependencies are lagging
5 | behind the latest version of the package or the python version.
6 |
7 | Markers should only be used on optional dependencies, i.e. packages that are not required to run
8 | `probly`.
9 |
10 | """
11 |
12 | from __future__ import annotations
13 |
14 | import importlib.util
15 |
16 | import pytest
17 |
18 | __all__ = [
19 | "skip_if_no_keras",
20 | "skip_if_no_lightgbm",
21 | "skip_if_no_sklearn",
22 | "skip_if_no_tabpfn",
23 | "skip_if_no_tensorflow",
24 | "skip_if_no_torchvision",
25 | "skip_if_no_xgboost",
26 | ]
27 |
28 |
29 | def is_installed(pkg_name: str) -> bool:
30 | """Check if a package is installed without importing it."""
31 | return importlib.util.find_spec(pkg_name) is not None
32 |
33 |
34 | # torch related: (torch itself is a main dependency) -----------------------------------------------
35 | skip_if_no_torchvision = pytest.mark.skipif(not is_installed("torchvision"), reason="torchvision is not installed")
36 |
37 | # sklearn-like: ------------------------------------------------------------------------------------
38 | skip_if_no_sklearn = pytest.mark.skipif(not is_installed("sklearn"), reason="sklearn is not installed")
39 |
40 | skip_if_no_xgboost = pytest.mark.skipif(not is_installed("xgboost"), reason="xgboost is not installed")
41 |
42 | skip_if_no_lightgbm = pytest.mark.skipif(not is_installed("lightgbm"), reason="lightgbm is not installed")
43 |
44 | # tensorflow related: ------------------------------------------------------------------------------
45 | skip_if_no_tensorflow = pytest.mark.skipif(not is_installed("tensorflow"), reason="tensorflow is not installed")
46 |
47 | skip_if_no_keras = pytest.mark.skipif(not is_installed("keras"), reason="keras is not installed")
48 |
49 | # misc: --------------------------------------------------------------------------------------------
50 | skip_if_no_tabpfn = pytest.mark.skipif(not is_installed("tabpfn"), reason="TabPFN is not available.")
51 |
--------------------------------------------------------------------------------
/src/probly/utils/torch.py:
--------------------------------------------------------------------------------
1 | """Utility functions for PyTorch models."""
2 |
3 | from __future__ import annotations
4 |
5 | import torch
6 | import torch.nn.functional as F
7 | from tqdm import tqdm
8 |
9 |
10 | @torch.no_grad()
11 | def torch_collect_outputs(
12 | model: torch.nn.Module,
13 | loader: torch.utils.data.DataLoader,
14 | device: torch.device,
15 | ) -> tuple[torch.Tensor, torch.Tensor]:
16 | """Collect outputs and targets from a model for a given data loader.
17 |
18 | Args:
19 | model: torch.nn.Module, model to collect outputs from
20 | loader: torch.utils.data.DataLoader, data loader to collect outputs from
21 | device: torch.device, device to move data to
22 | Returns:
23 | outputs: torch.Tensor, shape (n_instances, n_classes), model outputs
24 | targets: torch.Tensor, shape (n_instances,), target labels
25 | """
26 | outputs = torch.empty(0, device=device)
27 | targets = torch.empty(0, device=device)
28 | for inpt, target in tqdm(loader, desc="Batches"):
29 | outputs = torch.cat((outputs, model(inpt.to(device))), dim=0)
30 | targets = torch.cat((targets, target.to(device)), dim=0)
31 | return outputs, targets
32 |
33 |
34 | def torch_reset_all_parameters(module: torch.nn.Module) -> None:
35 | """Reset all parameters of a torch module.
36 |
37 | Args:
38 | module: torch.nn.Module, module to reset parameters
39 |
40 | """
41 | if hasattr(module, "reset_parameters"):
42 | module.reset_parameters()
43 | for child in module.children():
44 | if hasattr(child, "reset_parameters"):
45 | child.reset_parameters()
46 |
47 |
48 | def temperature_softmax(logits: torch.Tensor, temperature: float | torch.Tensor) -> torch.Tensor:
49 | """Compute the softmax of logits with temperature scaling applied.
50 |
51 | Computes the softmax based on the logits divided by the temperature. Assumes that the last dimension
52 | of logits is the class dimension.
53 |
54 | Args:
55 | logits: torch.Tensor, shape (n_instances, n_classes), logits to apply softmax on
56 | temperature: float, temperature scaling factor
57 | Returns:
58 | ts: torch.Tensor, shape (n_instances, n_classes), softmax of logits with temperature scaling applied
59 | """
60 | ts = F.softmax(logits / temperature, dim=-1)
61 | return ts
62 |
--------------------------------------------------------------------------------
/src/probly/train/bayesian/torch.py:
--------------------------------------------------------------------------------
1 | """Collection of torch Bayesian training functions."""
2 |
3 | from __future__ import annotations
4 |
5 | from typing import TYPE_CHECKING
6 |
7 | if TYPE_CHECKING:
8 | from probly.predictor import Predictor
9 |
10 | import torch
11 | from torch import nn
12 | import torch.nn.functional as F
13 |
14 | from probly.layers.torch import BayesConv2d, BayesLinear # noqa: TC001, required by traverser
15 | from probly.traverse_nn import nn_compose
16 | from pytraverse import GlobalVariable, State, TraverserResult, singledispatch_traverser, traverse_with_state
17 |
18 | KL_DIVERGENCE = GlobalVariable[torch.Tensor]("KL_DIVERGENCE", default=0.0)
19 |
20 |
21 | @singledispatch_traverser[object]
22 | def kl_divergence_traverser(
23 | obj: BayesLinear | BayesConv2d,
24 | state: State,
25 | ) -> TraverserResult[BayesLinear | BayesConv2d]:
26 | """Traverser to compute the KL divergence of a Bayesian layer."""
27 | state[KL_DIVERGENCE] += obj.kl_divergence
28 | return obj, state
29 |
30 |
31 | def collect_kl_divergence(model: Predictor) -> torch.Tensor:
32 | """Collect the KL divergence of the Bayesian model by summing the KL divergence of each Bayesian layer."""
33 | _, state = traverse_with_state(model, nn_compose(kl_divergence_traverser))
34 | return state[KL_DIVERGENCE]
35 |
36 |
37 | class ELBOLoss(nn.Module):
38 | """Evidence lower bound loss based on :cite:`blundellWeightUncertainty2015`.
39 |
40 | Attributes:
41 | kl_penalty: float, weight for KL divergence term
42 | """
43 |
44 | def __init__(self, kl_penalty: float = 1e-5) -> None:
45 | """Initializes an instance of the ELBOLoss class.
46 |
47 | Args:
48 | kl_penalty: float, weight for KL divergence term
49 | """
50 | super().__init__()
51 | self.kl_penalty = kl_penalty
52 |
53 | def forward(self, inputs: torch.Tensor, targets: torch.Tensor, kl: torch.Tensor) -> torch.Tensor:
54 | """Forward pass of the ELBO loss.
55 |
56 | Args:
57 | inputs: torch.Tensor of size (n_instances, n_classes)
58 | targets: torch.Tensor of size (n_instances,)
59 | kl: torch.Tensor, KL divergence of the model
60 | Returns:
61 | loss: torch.Tensor, mean loss value
62 | """
63 | loss = F.cross_entropy(inputs, targets) + self.kl_penalty * kl
64 | return loss
65 |
--------------------------------------------------------------------------------
/src/probly/transformation/ensemble/flax.py:
--------------------------------------------------------------------------------
1 | """Ensemble Flax implementation."""
2 |
3 | from __future__ import annotations
4 |
5 | from dataclasses import asdict, is_dataclass
6 | from typing import Any
7 |
8 | from flax import linen as nn
9 | import jax
10 | import jax.numpy as jnp
11 |
12 | from .common import register
13 |
14 |
15 | class FlaxEnsemble(nn.Module):
16 | """FlaxEnsemble class."""
17 |
18 | base_module: type
19 | base_kwargs: dict[str, Any] | None = None
20 | n_members: int = 10
21 |
22 | @nn.compact
23 | def __call__(
24 | self,
25 | x: jnp.ndarray,
26 | *,
27 | return_all: bool = False,
28 | **call_kwargs: object,
29 | ) -> jnp.ndarray | dict[str, jnp.ndarray]:
30 | """Apply each ensemble member and aggregate."""
31 | outputs: list[object] = []
32 | for i in range(self.n_members):
33 | ctor_kwargs = self.base_kwargs or {}
34 | member = self.base_module(**ctor_kwargs, name=f"member_{i}")
35 | y = member(x, **call_kwargs)
36 | outputs.append(y)
37 |
38 | stacked = jax.tree_util.tree_map(lambda *leaves: jnp.stack(leaves, axis=1), *outputs)
39 | if return_all:
40 | return stacked # type: ignore[return-value]
41 | averaged = jax.tree_util.tree_map(lambda arr: jnp.mean(arr, axis=1), stacked)
42 | return averaged # type: ignore[return-value]
43 |
44 |
45 | def generate_flax_ensemble(
46 | model: nn.Module | type[nn.Module] | str,
47 | n_members: int,
48 | ) -> FlaxEnsemble:
49 | """Create a class:FlaxEnsemble from a module instance or class."""
50 | base_module = model.__class__ if not isinstance(model, type) else model
51 | base_kwargs: dict[str, Any] | None = None
52 | if not isinstance(model, type) and is_dataclass(model):
53 | try:
54 | # mypy: after is_dataclass check, model is a dataclass instance
55 | all_fields = asdict(model) # type: ignore[arg-type]
56 | filtered = {k: v for k, v in all_fields.items() if not k.startswith("_") and k not in ("name", "parent")}
57 | if filtered:
58 | base_kwargs = filtered
59 | except (TypeError, ValueError):
60 | base_kwargs = None
61 | return FlaxEnsemble(base_module=base_module, base_kwargs=base_kwargs, n_members=n_members)
62 |
63 |
64 | register(nn.Module, generate_flax_ensemble)
65 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # probly: Uncertainty Representation and Quantification for Machine Learning
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 | [](https://badge.fury.io/py/probly)
10 | [](https://pypi.org/project/probly)
11 | [](https://pepy.tech/project/probly)
12 | [](https://codecov.io/gh/pwhofman/probly)
13 | [](.github/CONTRIBUTING.md)
14 | [](https://opensource.org/licenses/MIT)
15 |
16 |
17 | ## 🛠️ Install
18 | `probly` is intended to work with **Python 3.12 and above**. Installation can be done via `pip` and
19 | or `uv`:
20 |
21 | ```sh
22 | pip install probly
23 | ```
24 |
25 | ```sh
26 | uv add probly
27 | ```
28 |
29 | ## ⭐ Quickstart
30 |
31 | `probly` makes it very easy to make models uncertainty-aware and perform several downstream tasks:
32 |
33 | ```python
34 | import probly
35 | import torch.nn.functional as F
36 |
37 | net = ... # get neural network
38 | model = probly.transformation.dropout(net) # make neural network a Dropout model
39 | train(model) # train model as usual
40 |
41 | data = ... # get data
42 | data_ood = ... # get out of distribution data
43 | sampler = probly.representation.Sampler(model, num_samples=20)
44 | sample = sampler.predict(data) # predict an uncertainty representation
45 | sample_ood = sampler.predict(data_ood)
46 |
47 | eu = probly.quantification.classification.mutual_information(sample) # quantify model's epistemic uncertainty
48 | eu_ood = probly.quantification.classification.mutual_information(sample_ood)
49 |
50 | auroc = probly.evaluation.tasks.out_of_distribution_detection(eu, eu_ood) # evaluate model's uncertainty
51 | ```
52 |
53 | ## 📜 License
54 | This project is licensed under the [MIT License](https://github.com/pwhofman/probly/blob/main/LICENSE).
55 |
56 | ---
57 | Built with ❤️ by the probly team.
58 |
--------------------------------------------------------------------------------
/docs/source/index.rst:
--------------------------------------------------------------------------------
1 | .. probly documentation master file, created by
2 | sphinx-quickstart on Tue Apr 8 15:10:07 2025.
3 | You can adapt this file completely to your liking, but it should at least
4 | contain the root `toctree` directive.
5 |
6 | The `probly` Python package
7 | ===========================
8 |
9 | `probly` is a Python package for **uncertainty representation** and **quantification** for machine learning.
10 |
11 | Installation
12 | ~~~~~~~~~~~~
13 | `probly` is intended to work with **Python 3.10 and above**. Installation can be done via `pip` and
14 | or `uv`:
15 |
16 | .. code-block:: sh
17 |
18 | pip install probly
19 |
20 | or
21 |
22 | .. code-block:: sh
23 |
24 | uv add probly
25 |
26 | Quickstart
27 | ~~~~~~~~~~~~
28 | `probly` makes it very easy to make models uncertainty-aware and perform several downstream tasks:
29 |
30 | .. code-block:: python
31 |
32 | import probly
33 | import torch.nn.functional as F
34 |
35 | net = ... # get neural network
36 | model = probly.representation.Dropout(net) # make neural network a Dropout model
37 | train(model) # train model as usual
38 |
39 | data = ... # get data
40 | preds = model.predict_representation(data) # predict an uncertainty representation
41 | eu = probly.quantification.classification.mutual_information(preds) # compute model's epistemic uncertainty
42 |
43 | data_ood = ... # get out of distribution data
44 | preds_ood = model.predict_representation(data_ood)
45 | eu_ood = probly.quantification.classification.mutual_information(preds_ood)
46 | auroc = probly.tasks.out_of_distribution_detection(eu, eu_ood) # compute the AUROC score for out of distribution detection
47 |
48 |
49 |
50 | .. toctree::
51 | :maxdepth: 1
52 | :caption: Content
53 | :hidden:
54 |
55 | methods
56 | references
57 |
58 | .. toctree::
59 | :maxdepth: 0
60 | :caption: Tutorials
61 | :hidden:
62 |
63 | examples/fashionmnist_ood_ensemble
64 | examples/label_relaxation_calibration
65 | examples/sklearn_selective_prediction
66 | examples/synthetic_regression_dropout
67 | examples/temperature_scaling_calibration
68 | examples/train_bnn_classification
69 | examples/train_evidential_classification
70 | examples/train_evidential_regression
71 |
72 | .. toctree::
73 | :maxdepth: 2
74 | :caption: API
75 | :hidden:
76 |
77 | api
78 |
79 | .. toctree::
80 | :maxdepth: 1
81 | :caption: Development
82 | :hidden:
83 |
84 | contributing
85 |
--------------------------------------------------------------------------------
/.github/CONTRIBUTING.md:
--------------------------------------------------------------------------------
1 | # Contributing to probly 🏔️
2 | Probly is still in early development and we welcome contributions
3 | in many forms. If you have an idea for a new feature, a bug fix, or
4 | any other suggestion for improvement, please open an issue on GitHub.
5 | If you would like to contribute code, keep reading!
6 |
7 | ## What to work on ❓
8 | We want to offer support for PyTorch, HuggingFace, and sklearn models. We are interested in
9 | any contributions that translate existing features to these libraries. Furthermore, we are
10 | interested in any new features within the following scope:
11 | - Representation methods;
12 | - Quantification methods;
13 | - Calibration methods;
14 | - Downstream tasks;
15 | - Datasets and dataloaders.
16 |
17 | If you have other suggestions, please open an issue on GitHub such that we can discuss it.
18 |
19 | ## Development setup 🔬
20 | The recommended workflow for contributing to probly is:
21 | 1. Fork the `main` branch of repository on GitHub.
22 | 2. Clone the fork locally.
23 | 3. Commit the changes.
24 | 4. Push the changes to the fork.
25 | 5. Create a pull request to the `main` branch of the repository.
26 |
27 | ### Setting up the development environment
28 | Once you have cloned the fork, you can set up your development environment.
29 | You will need a Python 3.10+ environment. We recommend using [uv](https://docs.astral.sh/uv/) as a package manager.
30 | To set up the development environment, run the following command:
31 | ```sh
32 | uv sync --dev
33 | ```
34 | This will install all the required dependencies in your Python environment.
35 |
36 | ## Guidelines 🚓
37 | Here are some guidelines to follow when contributing to probly.
38 |
39 | ### General
40 | If you use code from other sources, make sure to carefully look at the license and give credit to the original author(s).
41 | If the feature you are implementing is based on a paper, make sure to include a reference
42 | in the docstring.
43 |
44 | ### Code style
45 | We use [Ruff](https://docs.astral.sh/ruff/) for linting and formatting, the rules of which can be found in the [pyproject.toml](https://github.com/pwhofman/probly/blob/main/pyproject.toml) file.
46 | However, if your development environment is set up correctly, the Ruff pre-commit hook should take care of this for you.
47 |
48 | ### Documentation
49 | If you are adding new features, make sure to document them in the docstring;
50 | our docstrings follow the [Google Style Guide](https://google.github.io/styleguide/pyguide.html#docstrings).
51 |
--------------------------------------------------------------------------------
/tests/probly/tests_deprecations/utils.py:
--------------------------------------------------------------------------------
1 | """A module containing all deprecated features / behaviors in probly. Adapted from the shapiq package.
2 |
3 | Usage:
4 | To add a new deprecated feature or behavior, create an instance of the `DeprecatedFeature` and
5 | append it to the `DEPRECATED_FEATURES` list. The `call` attribute should be a callable that
6 | triggers the deprecation warning when executed. The `deprecated_in` and `removed_in` attributes
7 | should specify the version in which the feature was deprecated and the version in which it will
8 | be removed, respectively.
9 |
10 | """
11 |
12 | from __future__ import annotations
13 |
14 | import importlib
15 | import pathlib
16 | import pkgutil
17 | from typing import TYPE_CHECKING, NamedTuple
18 |
19 | if TYPE_CHECKING:
20 | from collections.abc import Callable
21 |
22 | import pytest
23 |
24 | DeprecatedTestFunc = Callable[[pytest.FixtureRequest], None]
25 |
26 |
27 | class DeprecatedFeature(NamedTuple):
28 | """A named tuple to represent a deprecated feature."""
29 |
30 | name: str
31 | deprecated_in: str
32 | removed_in: str
33 | call: Callable[[pytest.FixtureRequest], None]
34 |
35 |
36 | DEPRECATED_FEATURES: list[DeprecatedFeature] = []
37 |
38 |
39 | def register_deprecated(
40 | name: str,
41 | deprecated_in: str,
42 | removed_in: str,
43 | ) -> Callable[[DeprecatedTestFunc], DeprecatedTestFunc]:
44 | """Decorator to register a deprecated feature.
45 |
46 | Args:
47 | name: The name of the deprecated feature.
48 | deprecated_in: The version in which the feature was deprecated.
49 | removed_in: The version in which the feature will be removed.
50 |
51 | Returns:
52 | A decorator that registers the deprecated feature.
53 | """
54 |
55 | def _decorator(func: Callable[[pytest.FixtureRequest], None]) -> Callable[[pytest.FixtureRequest], None]:
56 | DEPRECATED_FEATURES.append(DeprecatedFeature(name, deprecated_in, removed_in, func))
57 | return func
58 |
59 | return _decorator
60 |
61 |
62 | # auto-import all deprecated modules from this folder in the current package
63 | def _auto_import_deprecated_modules() -> None:
64 | current_path = pathlib.Path(__file__).parent
65 | for module_info in pkgutil.iter_modules([str(current_path)]):
66 | name = module_info.name
67 | if name not in {"__init__", "features"}:
68 | importlib.import_module(f"{__package__}.{name}")
69 |
70 |
71 | _auto_import_deprecated_modules()
72 |
--------------------------------------------------------------------------------
/src/probly/evaluation/tasks.py:
--------------------------------------------------------------------------------
1 | """Collection of downstream tasks to evaluate the performance of uncertainty pipelines."""
2 |
3 | from __future__ import annotations
4 |
5 | import numpy as np
6 |
7 | # TODO(mmshlk): remove sklearn dependency - https://github.com/pwhofman/probly/issues/132
8 | import sklearn.metrics as sm
9 |
10 |
11 | def selective_prediction(criterion: np.ndarray, losses: np.ndarray, n_bins: int = 50) -> tuple[float, np.ndarray]:
12 | """Selective prediction downstream task for evaluation.
13 |
14 | Perform selective prediction based on criterion and losses.
15 | The criterion is used the sort the losses. In line with uncertainty
16 | literature the sorting is done in descending order, i.e.
17 | the losses with the largest criterion are rejected first.
18 |
19 | Args:
20 | criterion: numpy.ndarray shape (n_instances,), criterion values
21 | losses: numpy.ndarray shape (n_instances,), loss values
22 | n_bins: int, number of bins
23 | Returns:
24 | auroc: float, area under the loss curve
25 | bin_losses: numpy.ndarray shape (n_bins,), loss per bin
26 |
27 | """
28 | if n_bins > len(losses):
29 | msg = "The number of bins can not be larger than the number of elements criterion"
30 | raise ValueError(msg)
31 | sort_idxs = np.argsort(criterion)[::-1]
32 | losses_sorted = losses[sort_idxs]
33 | bin_len = len(losses) // n_bins
34 | bin_losses = np.empty(n_bins)
35 | for i in range(n_bins):
36 | bin_losses[i] = np.mean(losses_sorted[(i * bin_len) :])
37 |
38 | # Also compute the area under the loss curve based on the bin losses.
39 | auroc = sm.auc(np.linspace(0, 1, n_bins), bin_losses)
40 | return auroc, bin_losses
41 |
42 |
43 | def out_of_distribution_detection(in_distribution: np.ndarray, out_distribution: np.ndarray) -> float:
44 | """Perform out-of-distribution detection using prediction functionals from id and ood data.
45 |
46 | This can be epistemic uncertainty, as is common, but also e.g. softmax confidence.
47 |
48 | Args:
49 | in_distribution: in-distribution prediction functionals
50 | out_distribution: out-of-distribution prediction functionals
51 | Returns:
52 | auroc: float, area under the roc curve
53 |
54 | """
55 | preds = np.concatenate((in_distribution, out_distribution))
56 | labels = np.concatenate((np.zeros(len(in_distribution)), np.ones(len(out_distribution))))
57 | auroc = sm.roc_auc_score(labels, preds)
58 | return float(auroc)
59 |
--------------------------------------------------------------------------------
/src/probly/plot/credal.py:
--------------------------------------------------------------------------------
1 | """Collection of credal plotting functions."""
2 |
3 | from __future__ import annotations
4 |
5 | import matplotlib.pyplot as plt
6 | import mpltern # noqa: F401, required for ternary projection, do not remove
7 | import numpy as np
8 |
9 |
10 | def simplex_plot(probs: np.ndarray) -> tuple[plt.Figure, plt.Axes]:
11 | """Plot probability distributions on the simplex.
12 |
13 | Args:
14 | probs: numpy.ndarray of shape (n_instances, n_classes)
15 |
16 | Returns:
17 | fig: matplotlib figure
18 | ax: matplotlib axes
19 | """
20 | fig = plt.figure()
21 | ax = fig.add_subplot(projection="ternary")
22 | ax.scatter(probs[:, 0], probs[:, 1], probs[:, 2])
23 | return fig, ax
24 |
25 |
26 | def credal_set_plot(probs: np.ndarray) -> tuple[plt.Figure, plt.Axes]:
27 | """Plot credal sets based on intervals of lower and upper probabilities.
28 |
29 | Args:
30 | probs: numpy.ndarray of shape (n_samples, n_classes)
31 |
32 | Returns:
33 | fig: matplotlib figure
34 | ax: matplotlib axes
35 | """
36 | fig = plt.figure()
37 | ax = fig.add_subplot(projection="ternary")
38 |
39 | lower_probs = np.min(probs, axis=0)
40 | upper_probs = np.max(probs, axis=0)
41 | lower_idxs = np.argmin(probs, axis=0)
42 | upper_idxs = np.argmax(probs, axis=0)
43 | edge_probs = np.vstack((probs[lower_idxs], probs[upper_idxs]))
44 |
45 | vertices_ = []
46 | for i, j, k in [(0, 1, 2), (1, 2, 0), (0, 2, 1)]:
47 | for x in [lower_probs[i], upper_probs[i]]:
48 | for y in [lower_probs[j], upper_probs[j]]:
49 | z = 1 - x - y
50 | if lower_probs[k] <= z <= upper_probs[k]:
51 | prob = [0, 0, 0]
52 | prob[i] = x
53 | prob[j] = y
54 | prob[k] = z
55 | vertices_.append(prob)
56 | vertices = np.array(vertices_)
57 |
58 | if len(vertices) > 0:
59 | center = np.mean(vertices, axis=0)
60 | angles = np.arctan2(vertices[:, 1] - center[1], vertices[:, 0] - center[0])
61 | vertices = vertices[np.argsort(angles)]
62 | ax.scatter(probs[:, 0], probs[:, 1], probs[:, 2])
63 | vertices_closed = np.vstack([vertices, vertices[0]])
64 | ax.fill(vertices_closed[:, 0], vertices_closed[:, 1], vertices_closed[:, 2])
65 | ax.plot(vertices_closed[:, 0], vertices_closed[:, 1], vertices_closed[:, 2])
66 | ax.scatter(edge_probs[:, 0], edge_probs[:, 1], edge_probs[:, 2])
67 | else:
68 | msg = "The set of vertices is empty. Please check the probabilities in the credal set."
69 | raise ValueError(msg)
70 |
71 | return fig, ax
72 |
--------------------------------------------------------------------------------
/src/probly/representation/sampling/sample.py:
--------------------------------------------------------------------------------
1 | """Classes representing prediction samples."""
2 |
3 | from __future__ import annotations
4 |
5 | from abc import ABC, abstractmethod
6 |
7 | import numpy as np
8 |
9 | from lazy_dispatch.singledispatch import lazydispatch
10 | from probly.lazy_types import JAX_ARRAY, TORCH_TENSOR
11 |
12 |
13 | class Sample[T](ABC):
14 | """Abstract base class for samples."""
15 |
16 | @abstractmethod
17 | def __init__(self, samples: list[T]) -> None:
18 | """Initialize the sample."""
19 | ...
20 |
21 | def mean(self) -> T:
22 | """Compute the mean of the sample."""
23 | msg = "mean method not implemented."
24 | raise NotImplementedError(msg)
25 |
26 | def std(self, ddof: int = 1) -> T:
27 | """Compute the standard deviation of the sample."""
28 | msg = "std method not implemented."
29 | raise NotImplementedError(msg)
30 |
31 | def var(self, ddof: int = 1) -> T:
32 | """Compute the variance of the sample."""
33 | msg = "var method not implemented."
34 | raise NotImplementedError(msg)
35 |
36 |
37 | class ListSample[T](list[T], Sample[T]):
38 | """A sample of predictions stored in a list."""
39 |
40 | def __add__[S](self, other: list[S]) -> ListSample[T | S]:
41 | """Add two samples together."""
42 | return type(self)(super().__add__(other)) # type: ignore[operator]
43 |
44 |
45 | create_sample = lazydispatch[type[Sample], Sample](ListSample, dispatch_on=lambda s: s[0])
46 |
47 |
48 | @create_sample.register(np.number | np.ndarray | float | int)
49 | class ArraySample[T](Sample[T]):
50 | """A sample of predictions stored in a numpy array."""
51 |
52 | def __init__(self, samples: list[T]) -> None:
53 | """Initialize the array sample."""
54 | self.array: np.ndarray = np.array(samples).transpose(1, 0, 2) # we use [instances, samples, classes]
55 |
56 | def mean(self) -> T:
57 | """Compute the mean of the sample."""
58 | return self.array.mean(axis=1) # type: ignore[no-any-return]
59 |
60 | def std(self, ddof: int = 1) -> T:
61 | """Compute the standard deviation of the sample."""
62 | return self.array.std(axis=1, ddof=ddof) # type: ignore[no-any-return]
63 |
64 | def var(self, ddof: int = 1) -> T:
65 | """Compute the variance of the sample."""
66 | return self.array.var(axis=1, ddof=ddof) # type: ignore[no-any-return]
67 |
68 |
69 | @create_sample.delayed_register(TORCH_TENSOR)
70 | def _(_: type) -> None:
71 | from . import torch_sample as torch_sample # noqa: PLC0414, PLC0415
72 |
73 |
74 | @create_sample.delayed_register(JAX_ARRAY)
75 | def _(_: type) -> None:
76 | from . import jax_sample as jax_sample # noqa: PLC0414, PLC0415
77 |
--------------------------------------------------------------------------------
/tests/probly/tests_deprecations/test_deprecations.py:
--------------------------------------------------------------------------------
1 | """This module contains tests for deprecations."""
2 |
3 | from __future__ import annotations
4 |
5 | from importlib.metadata import version
6 | import inspect
7 | import re
8 | import warnings
9 |
10 | from packaging.version import parse
11 | import pytest
12 |
13 | from .utils import DEPRECATED_FEATURES, DeprecatedFeature
14 |
15 |
16 | def feature_should_be_removed(feature: DeprecatedFeature) -> None:
17 | """Fails if the feature is still accessible after its scheduled removal date."""
18 | source_file = inspect.getsourcefile(feature.call)
19 | source_line = inspect.getsourcelines(feature.call)[1]
20 | if parse(version("probly")) >= parse(feature.removed_in):
21 | pytest.fail(
22 | f"{feature.name} was scheduled for removal in {feature.removed_in} "
23 | f"but is still accessible. Remove the deprecated behavior and this test.\n"
24 | f"Feature registered at: {source_file}:{source_line}",
25 | )
26 |
27 |
28 | def feature_raises_deprecation_warning(
29 | feature: DeprecatedFeature,
30 | request: pytest.FixtureRequest,
31 | ) -> None:
32 | """Fails if the feature does not raise a deprecation warning."""
33 | expected_msg = f"deprecated in version {feature.deprecated_in} and will be removed in version {feature.removed_in}."
34 |
35 | with warnings.catch_warnings(record=True) as caught:
36 | warnings.simplefilter("always", DeprecationWarning)
37 | feature.call(request)
38 |
39 | deprecation_warnings = [w for w in caught if issubclass(w.category, DeprecationWarning)]
40 |
41 | # check if any deprecation warning matches the expected regex
42 | if not any(re.search(expected_msg, str(w.message)) for w in deprecation_warnings):
43 | formatted = "\n".join(f"{w.category.__name__}: {w.message}" for w in deprecation_warnings)
44 | source_file = inspect.getsourcefile(feature.call)
45 | source_line = inspect.getsourcelines(feature.call)[1]
46 | pytest.fail(
47 | f"No matching DeprecationWarning for feature '{feature.name}'.\n"
48 | f"Expected regex: {expected_msg}\n"
49 | f"Warnings captured: {formatted or 'None'}\n"
50 | f"Feature registered at: {source_file}:{source_line}\n",
51 | )
52 |
53 |
54 | @pytest.mark.parametrize("feature", DEPRECATED_FEATURES, ids=lambda f: f.name)
55 | def test_deprecated_features(feature: DeprecatedFeature, request: pytest.FixtureRequest) -> None:
56 | """Tests the deprecated initialization with path_to_values."""
57 | # check if the feature should already be removed
58 | feature_should_be_removed(feature)
59 |
60 | # check if the feature raises a correct deprecation warning
61 | feature_raises_deprecation_warning(feature, request)
62 |
--------------------------------------------------------------------------------
/src/probly/transformation/bayesian/common.py:
--------------------------------------------------------------------------------
1 | """Shared Bayesian implementation."""
2 |
3 | from __future__ import annotations
4 |
5 | from typing import TYPE_CHECKING
6 |
7 | from probly.traverse_nn import nn_compose
8 | from pytraverse import CLONE, GlobalVariable, lazydispatch_traverser, traverse
9 |
10 | if TYPE_CHECKING:
11 | from lazy_dispatch.isinstance import LazyType
12 | from probly.predictor import Predictor
13 | from pytraverse.composition import RegisteredLooseTraverser
14 |
15 | USE_BASE_WEIGHTS = GlobalVariable[bool]("USE_BASE_WEIGHTS", default=False)
16 | POSTERIOR_STD = GlobalVariable[float]("POSTERIOR_STD", default=0.05)
17 | PRIOR_MEAN = GlobalVariable[float]("PRIOR_MEAN", default=0.0)
18 | PRIOR_STD = GlobalVariable[float]("PRIOR_STD", default=1.0)
19 |
20 | bayesian_traverser = lazydispatch_traverser[object](name="bayesian_traverser")
21 |
22 |
23 | def register(cls: LazyType, traverser: RegisteredLooseTraverser) -> None:
24 | """Register a class to be replaced by Bayesian layers."""
25 | bayesian_traverser.register(
26 | cls=cls,
27 | traverser=traverser,
28 | vars={
29 | "use_base_weights": USE_BASE_WEIGHTS,
30 | "posterior_std": POSTERIOR_STD,
31 | "prior_mean": PRIOR_MEAN,
32 | "prior_std": PRIOR_STD,
33 | },
34 | )
35 |
36 |
37 | def bayesian[T: Predictor](
38 | base: T,
39 | use_base_weights: bool = USE_BASE_WEIGHTS.default,
40 | posterior_std: float = POSTERIOR_STD.default,
41 | prior_mean: float = PRIOR_MEAN.default,
42 | prior_std: float = PRIOR_STD.default,
43 | ) -> T:
44 | """Create a Bayesian predictor from a base predictor.
45 |
46 | Args:
47 | base: The base model to be used for the Bayesian neural network.
48 | use_base_weights: bool, If True, the weights of the base model are used as the prior mean.
49 | posterior_std: float, The initial posterior standard deviation.
50 | prior_mean: float, The prior mean.
51 | prior_std: float, The prior standard deviation.
52 |
53 | Returns:
54 | The Bayesian predictor.
55 | """
56 | if posterior_std <= 0:
57 | msg = (
58 | "The initial posterior standard deviation posterior_std must be greater than 0, "
59 | f"but got {posterior_std} instead."
60 | )
61 | raise ValueError(msg)
62 | if prior_std <= 0:
63 | msg = f"The prior standard deviation prior_std must be greater than 0, but got {prior_std} instead."
64 | raise ValueError(msg)
65 | return traverse(
66 | base,
67 | nn_compose(bayesian_traverser),
68 | init={
69 | USE_BASE_WEIGHTS: use_base_weights,
70 | POSTERIOR_STD: posterior_std,
71 | PRIOR_MEAN: prior_mean,
72 | PRIOR_STD: prior_std,
73 | CLONE: True,
74 | },
75 | )
76 |
--------------------------------------------------------------------------------
/src/probly/utils/probabilities.py:
--------------------------------------------------------------------------------
1 | """General utility functions for all other modules."""
2 |
3 | from __future__ import annotations
4 |
5 | import numpy as np
6 |
7 |
8 | def differential_entropy_gaussian(sigma2: float | np.ndarray, base: float = 2) -> float | np.ndarray:
9 | """Compute the differential entropy of a Gaussian distribution given the variance.
10 |
11 | https://en.wikipedia.org/wiki/Differential_entropy
12 | Args:
13 | sigma2: float or numpy.ndarray shape (n_instances,), variance of the Gaussian distribution
14 | base: float, base of the logarithm
15 | Returns:
16 | diff_ent: float or numpy.ndarray shape (n_instances,), differential entropy of the Gaussian distribution
17 | """
18 | return 0.5 * np.log(2 * np.pi * np.e * sigma2) / np.log(base)
19 |
20 |
21 | def kl_divergence_gaussian(
22 | mu1: float | np.ndarray,
23 | sigma21: float | np.ndarray,
24 | mu2: float | np.ndarray,
25 | sigma22: float | np.ndarray,
26 | base: float = 2,
27 | ) -> float | np.ndarray:
28 | """Compute the KL-divergence between two Gaussian distributions.
29 |
30 | https://en.wikipedia.org/wiki/Kullback-Leibler_divergence#Examples
31 | Args:
32 | mu1: float or numpy.ndarray shape (n_instances,), mean of the first Gaussian distribution
33 | sigma21: float or numpy.ndarray shape (n_instances,), variance of the first Gaussian distribution
34 | mu2: float or numpy.ndarray shape (n_instances,), mean of the second Gaussian distribution
35 | sigma22: float or numpy.ndarray shape (n_instances,), variance of the second Gaussian distribution
36 | base: float, base of the logarithm
37 | Returns:
38 | kl_div: float or numpy.ndarray shape (n_instances,), KL-divergence between the two Gaussian distributions
39 | """
40 | kl_div = 0.5 * np.log(sigma22 / sigma21) / np.log(base) + (sigma21 + (mu1 - mu2) ** 2) / (2 * sigma22) - 0.5
41 | return kl_div
42 |
43 |
44 | def intersection_probability(probs: np.ndarray) -> np.ndarray:
45 | """Compute the intersection probability of a credal set based on intervals of lower and upper probabilities.
46 |
47 | Computes the intersection probability from :cite:`cuzzolinIntersectionProbability2022`.
48 |
49 | Args:
50 | probs: numpy.ndarray, shape (n_instances, n_samples, n_classes), credal sets
51 | Returns:
52 | int_probs: numpy.ndarray, shape (n_instances, n_classes), intersection probability of the credal sets
53 | """
54 | lower = np.min(probs, axis=1)
55 | upper = np.max(probs, axis=1)
56 | diff = upper - lower
57 | diff_sum = np.sum(diff, axis=1)
58 | lower_sum = np.sum(lower, axis=1)
59 | # Compute alpha for instances for which probability intervals are not empty, otherwise set alpha to 0.
60 | alpha = np.zeros(probs.shape[0])
61 | nonzero_idxs = diff_sum != 0
62 | alpha[nonzero_idxs] = (1 - lower_sum[nonzero_idxs]) / diff_sum[nonzero_idxs]
63 | int_probs = lower + alpha[:, None] * diff
64 | return int_probs
65 |
--------------------------------------------------------------------------------
/tests/lazy_dispatch/test_isinstance.py:
--------------------------------------------------------------------------------
1 | """Tests for the probly.lazy_dispatch.isinstance module."""
2 |
3 | from __future__ import annotations
4 |
5 | import pytest
6 |
7 | from lazy_dispatch.isinstance import LazyType, lazy_isinstance
8 |
9 |
10 | class TestIsInstance:
11 | """Tests for the lazy_isinstance function."""
12 |
13 | @pytest.mark.parametrize(
14 | ("obj", "classinfo", "expected"),
15 | [
16 | (5, int, True),
17 | (5, (int, str), True),
18 | (5, int | str, True),
19 | ("hello", str, True),
20 | ("hello", (int, str), True),
21 | ("hello", int | str, True),
22 | (5.0, int, False),
23 | (5.0, (int, str), False),
24 | (5.0, int | str, False),
25 | ([1, 2, 3], list, True),
26 | ([1, 2, 3], (list, dict), True),
27 | ([1, 2, 3], list | dict, True),
28 | ({"a": 1}, dict, True),
29 | ({"a": 1}, (list, dict), True),
30 | ({"a": 1}, list | (dict | list), True),
31 | (None, type(None), True),
32 | (None, (int, type(None)), True),
33 | (None, int | type(None), True),
34 | ],
35 | )
36 | def test_eager_builtin_types(self, obj: object, classinfo: LazyType, expected: bool) -> None:
37 | """Test lazy_isinstance with real types."""
38 | assert lazy_isinstance(obj, classinfo) == expected
39 |
40 | @pytest.mark.parametrize(
41 | ("obj", "classinfo", "expected"),
42 | [
43 | (5, "int", True),
44 | (5, ("builtins.int", "builtins.str"), True),
45 | ("hello", "str", True),
46 | ("hello", ("int", "str"), True),
47 | (5.0, "builtins.int", False),
48 | (5.0, ("int", "str"), False),
49 | ([1, 2, 3], "builtins.list", True),
50 | ([1, 2, 3], ("list", "dict"), True),
51 | ({"a": 1}, "dict", True),
52 | ({"a": 1}, ("list", "builtins.dict"), True),
53 | ],
54 | )
55 | def test_lazy_builtin_types(self, obj: object, classinfo: LazyType, expected: bool) -> None:
56 | """Test lazy_isinstance with stringified types."""
57 | assert lazy_isinstance(obj, classinfo) == expected
58 |
59 | @pytest.mark.parametrize(
60 | ("obj", "classinfo", "expected"),
61 | [
62 | (5, (str, "builtins.int"), True),
63 | (5, (int, "str"), True),
64 | ("hello", (int, "str"), True),
65 | ("hello", (str, "int"), True),
66 | (5.0, (str, "builtins.int"), False),
67 | (5.0, ("str", int), False),
68 | ([1, 2, 3], (dict, "builtins.list"), True),
69 | ([1, 2, 3], (list, "dict"), True),
70 | ({"a": 1}, (dict, "builtins.dict"), True),
71 | ({"a": 1}, ("dict", list), True),
72 | ],
73 | )
74 | def test_mixed_builtin_types(self, obj: object, classinfo: LazyType, expected: bool) -> None:
75 | """Test lazy_isinstance with mixtures of real and stringified types."""
76 | assert lazy_isinstance(obj, classinfo) == expected
77 |
--------------------------------------------------------------------------------
/src/lazy_dispatch/isinstance.py:
--------------------------------------------------------------------------------
1 | """A lazy version of isinstance."""
2 |
3 | from __future__ import annotations
4 |
5 | from types import UnionType
6 | from typing import Any, Union, get_args, get_origin
7 |
8 | type LazyType = type | str | UnionType | tuple[LazyType, ...]
9 |
10 |
11 | def _is_union_type(cls: Any) -> bool: # noqa: ANN401
12 | return get_origin(cls) in {Union, UnionType}
13 |
14 |
15 | def _split_lazy_type(lazy_type: LazyType) -> tuple[set[type], set[str]]:
16 | """Split classinfo into a set of types and a set of strings."""
17 | if isinstance(lazy_type, str):
18 | return set(), {t.strip() for t in lazy_type.split("|")}
19 | if isinstance(lazy_type, type):
20 | return {lazy_type}, set()
21 | if isinstance(lazy_type, tuple):
22 | types: set[type] = set()
23 | strings: set[str] = set()
24 | for item in lazy_type:
25 | t, s = _split_lazy_type(item)
26 | types.update(t)
27 | strings.update(s)
28 | return types, strings
29 | if _is_union_type(lazy_type):
30 | types = set()
31 | strings = set()
32 | for arg in get_args(lazy_type):
33 | t, s = _split_lazy_type(arg)
34 | types.update(t)
35 | strings.update(s)
36 | return types, strings
37 |
38 | msg = f"Invalid classinfo: {lazy_type!r}"
39 | raise TypeError(msg)
40 |
41 |
42 | def _find_matching_string_type(cls: type, string_types: set[str] | dict[str, Any]) -> str | None:
43 | """Check if the type's name matches any of the strings."""
44 | module = cls.__module__
45 | qualname = cls.__qualname__
46 |
47 | if module == "builtins":
48 | for s in string_types:
49 | if qualname == s:
50 | return s
51 |
52 | for s in string_types:
53 | if f"{module}.{qualname}" == s:
54 | return s
55 |
56 | return None
57 |
58 |
59 | def _find_closest_string_type(cls: type, string_types: set[str] | dict[str, Any]) -> tuple[type, str] | None:
60 | """Check if any type in the MRO matches any of the strings."""
61 | mro = cls.__mro__
62 | for super_cls in mro:
63 | matching_type = _find_matching_string_type(super_cls, string_types)
64 | if matching_type is not None:
65 | return super_cls, matching_type
66 | return None
67 |
68 |
69 | def lazy_isinstance(obj: object, class_or_tuple: LazyType, /) -> bool:
70 | """A lazy version of isinstance."""
71 | types, strings = _split_lazy_type(class_or_tuple)
72 | if len(types) > 0 and isinstance(obj, tuple(types)):
73 | return True
74 |
75 | if len(strings) > 0:
76 | return _find_closest_string_type(type(obj), strings) is not None
77 |
78 | return False
79 |
80 |
81 | def lazy_issubclass(cls: type, class_or_tuple: LazyType, /) -> bool:
82 | """A lazy version of issubclass."""
83 | types, strings = _split_lazy_type(class_or_tuple)
84 | if len(types) > 0 and issubclass(cls, tuple(types)):
85 | return True
86 |
87 | if len(strings) > 0:
88 | return _find_closest_string_type(cls, strings) is not None
89 |
90 | return False
91 |
--------------------------------------------------------------------------------
/tests/probly/datasets/test_torch.py:
--------------------------------------------------------------------------------
1 | """Tests for the data module."""
2 |
3 | from __future__ import annotations
4 |
5 | from typing import TYPE_CHECKING
6 |
7 | if TYPE_CHECKING:
8 | from collections.abc import Callable
9 | from typing import Any
10 |
11 | from torchvision.datasets import CIFAR10, ImageNet
12 |
13 | import json
14 | from pathlib import Path
15 | from unittest.mock import MagicMock, mock_open, patch
16 |
17 | import numpy as np
18 | import torch
19 |
20 | from probly.datasets.torch import CIFAR10H, Benthic, ImageNetReaL, Plankton, Treeversity1, Treeversity6
21 |
22 |
23 | def patch_cifar10_init(self: CIFAR10, root: str, train: bool, transform: Callable[..., Any], download: bool) -> None: # noqa: ARG001, the init requires these arguments
24 | self.root = root
25 |
26 |
27 | @patch("probly.datasets.torch.np.load")
28 | @patch("probly.datasets.torch.torchvision.datasets.CIFAR10.__init__", new=patch_cifar10_init)
29 | def test_cifar10h(mock_np_load: MagicMock, tmp_path: Path) -> None:
30 | counts = np.ones((5, 10))
31 | mock_np_load.return_value = counts
32 | dataset = CIFAR10H(root=str(tmp_path))
33 | dataset.data = [np.zeros((3, 32, 32))] * 2
34 | assert torch.allclose(torch.sum(dataset.targets, dim=1), torch.ones(5))
35 |
36 |
37 | def patch_imagenet_init(self: ImageNet, root: str, split: str, transform: Callable[..., Any]) -> None: # noqa: ARG001, the init requires these arguments
38 | self.samples = [
39 | ("some/path/ILSVRC2012_val_00000001.JPEG", 0),
40 | ("some/path/ILSVRC2012_val_00000002.JPEG", 1),
41 | ("some/path/ILSVRC2012_val_00000003.JPEG", 2),
42 | ]
43 | self.classes = [0, 1, 2]
44 |
45 |
46 | @patch("probly.datasets.torch.torchvision.datasets.ImageNet.__init__", new=patch_imagenet_init)
47 | @patch("pathlib.Path.open", new_callable=mock_open, read_data=json.dumps([[], [1], [1, 2]]))
48 | def test_imagenetreal(tmp_path: Path) -> None:
49 | dataset = ImageNetReaL(str(tmp_path))
50 | for dist in dataset.dists:
51 | assert torch.isclose(torch.sum(dist), torch.tensor(1.0))
52 |
53 |
54 | @patch("probly.datasets.torch.DCICDataset.__init__", return_value=None)
55 | def test_benthic(mock_dcic_init: MagicMock) -> None:
56 | root = "some/path"
57 | _ = Benthic(root, first_order=False)
58 | expected = Path(root) / "Benthic"
59 | mock_dcic_init.assert_called_once_with(expected, None, first_order=False)
60 |
61 |
62 | @patch("probly.datasets.torch.DCICDataset.__init__", return_value=None)
63 | def test_plankton(mock_dcic_init: MagicMock) -> None:
64 | root = "some/path"
65 | _ = Plankton(root, first_order=False)
66 | expected = Path(root) / "Plankton"
67 | mock_dcic_init.assert_called_once_with(expected, None, first_order=False)
68 |
69 |
70 | @patch("probly.datasets.torch.DCICDataset.__init__", return_value=None)
71 | def test_treeversity1(mock_dcic_init: MagicMock) -> None:
72 | root = "some/path"
73 | _ = Treeversity1(root, first_order=False)
74 | expected = Path(root) / "Treeversity#1"
75 | mock_dcic_init.assert_called_once_with(expected, None, first_order=False)
76 |
77 |
78 | @patch("probly.datasets.torch.DCICDataset.__init__", return_value=None)
79 | def test_treeversity6(mock_dcic_init: MagicMock) -> None:
80 | root = "some/path"
81 | _ = Treeversity6(root, first_order=False)
82 | expected = Path(root) / "Treeversity#6"
83 | mock_dcic_init.assert_called_once_with(expected, None, first_order=False)
84 |
--------------------------------------------------------------------------------
/tests/probly/transformation/evidential/classification/test_torch.py:
--------------------------------------------------------------------------------
1 | """Test for torch classification models."""
2 |
3 | from __future__ import annotations
4 |
5 | from torch import nn
6 |
7 | from probly.transformation.evidential.classification.common import (
8 | evidential_classification,
9 | )
10 | from tests.probly.torch_utils import count_layers
11 |
12 |
13 | def test_evidential_classification_appends_softplus_on_linear(torch_model_small_2d_2d: nn.Sequential) -> None:
14 | model = evidential_classification(torch_model_small_2d_2d)
15 |
16 | # count number of nn.Linear layers in original model
17 | count_linear_original = count_layers(torch_model_small_2d_2d, nn.Linear)
18 | # count number of softplus layers in original model
19 | count_softplus_original = count_layers(torch_model_small_2d_2d, nn.Softplus)
20 | # count number of nn.Sequential layers in original model
21 | count_sequential_original = count_layers(torch_model_small_2d_2d, nn.Sequential)
22 |
23 | # count number of nn.Linear layers in modified model
24 | count_linear_modified = count_layers(model, nn.Linear)
25 | # count number of softplus layers in modified model
26 | count_softplus_modified = count_layers(model, nn.Softplus)
27 | # count number of nn.Sequential layers in modified model
28 | count_sequential_modified = count_layers(model, nn.Sequential)
29 |
30 | # check that the model is not modified except for the softplus layer at the end of the new sequence layer
31 | assert model is not None
32 | assert isinstance(model, type(torch_model_small_2d_2d))
33 | assert count_linear_original == count_linear_modified
34 | assert count_softplus_original == (count_softplus_modified - 1)
35 | assert count_sequential_original == (count_sequential_modified - 1)
36 |
37 |
38 | def test_evidential_classification_appends_softplus_on_conv(torch_conv_linear_model: nn.Sequential) -> None:
39 | model = evidential_classification(torch_conv_linear_model)
40 |
41 | # count number of nn.Linear layers in original model
42 | count_linear_original = count_layers(torch_conv_linear_model, nn.Linear)
43 | # count number of softplus layers in original model
44 | count_softplus_original = count_layers(torch_conv_linear_model, nn.Softplus)
45 | # count number of nn.Sequential layers in original model
46 | count_sequential_original = count_layers(torch_conv_linear_model, nn.Sequential)
47 | # count number of nn.Conv2d layers in original model
48 | count_conv_original = count_layers(torch_conv_linear_model, nn.Conv2d)
49 |
50 | # count number of nn.Linear layers in modified model
51 | count_linear_modified = count_layers(model, nn.Linear)
52 | # count number of softplus layers in modified model
53 | count_softplus_modified = count_layers(model, nn.Softplus)
54 | # count number of nn.Sequential layers in modified model
55 | count_sequential_modified = count_layers(model, nn.Sequential)
56 | # count number of nn.Conv2d layers in modified model
57 | count_conv_modified = count_layers(model, nn.Conv2d)
58 |
59 | # check that the model is not modified except for the softplus layer at the end of the new sequence layer
60 | assert model is not None
61 | assert isinstance(model, type(torch_conv_linear_model))
62 | assert count_linear_original == count_linear_modified
63 | assert count_softplus_original == (count_softplus_modified - 1)
64 | assert count_sequential_original == (count_sequential_modified - 1)
65 | assert count_conv_original == count_conv_modified
66 |
--------------------------------------------------------------------------------
/tests/probly/fixtures/flax_models.py:
--------------------------------------------------------------------------------
1 | """Fixtures for models used in tests."""
2 |
3 | from __future__ import annotations
4 |
5 | import pytest
6 |
7 | torch = pytest.importorskip("flax")
8 | from flax import nnx # noqa: E402
9 | from flax.typing import Array # noqa: E402
10 |
11 |
12 | @pytest.fixture
13 | def flax_rngs() -> nnx.Rngs:
14 | """Return a random number generator for flax models."""
15 | return nnx.Rngs(0)
16 |
17 |
18 | @pytest.fixture
19 | def flax_model_small_2d_2d(flax_rngs: nnx.Rngs) -> nnx.Module:
20 | """Return a small linear model with 2 input and 2 output neurons."""
21 | model = nnx.Sequential(
22 | nnx.Linear(2, 2, rngs=flax_rngs),
23 | nnx.Linear(2, 2, rngs=flax_rngs),
24 | nnx.Linear(2, 2, rngs=flax_rngs),
25 | )
26 | return model
27 |
28 |
29 | @pytest.fixture
30 | def flax_conv_linear_model(flax_rngs: nnx.Rngs) -> nnx.Module:
31 | """Return a small convolutional model with 3 input channels and 2 output neurons."""
32 | model = nnx.Sequential(
33 | nnx.Conv(3, 5, (5, 5), rngs=flax_rngs),
34 | nnx.relu,
35 | nnx.flatten,
36 | nnx.Linear(5, 2, rngs=flax_rngs),
37 | )
38 | return model
39 |
40 |
41 | @pytest.fixture
42 | def flax_regression_model_1d(flax_rngs: nnx.Rngs) -> nnx.Module:
43 | """Return a small regression model with 2 input and 1 output neurons."""
44 | model = nnx.Sequential(
45 | nnx.Linear(2, 2, rngs=flax_rngs),
46 | nnx.relu,
47 | nnx.Linear(2, 1, rngs=flax_rngs),
48 | )
49 | return model
50 |
51 |
52 | @pytest.fixture
53 | def flax_regression_model_2d(flax_rngs: nnx.Rngs) -> nnx.Module:
54 | """Return a small regression model with 4 input and 2 output neurons."""
55 | model = nnx.Sequential(
56 | nnx.Linear(4, 4, rngs=flax_rngs),
57 | nnx.relu,
58 | nnx.Linear(4, 2, rngs=flax_rngs),
59 | )
60 | return model
61 |
62 |
63 | @pytest.fixture
64 | def flax_custom_model(flax_rngs: nnx.Rngs) -> nnx.Module:
65 | """Return a small custom model."""
66 |
67 | class TinyModel(nnx.Module):
68 | """A simple neural network model with two linear layers and activation functions.
69 |
70 | Attributes:
71 | linear1 : The first linear layer with input size 100 and output size 200.
72 | activation : The ReLU activation function applied after the first linear layer.
73 | linear2 : The second linear layer with input size 200 and output size 10.
74 | softmax : The softmax function for normalizing the output into probabilities.
75 | """
76 |
77 | def __init__(self, rngs: nnx.Rngs) -> None:
78 | """Initialize the TinyModel class."""
79 | super().__init__()
80 |
81 | self.linear1 = nnx.Linear(10, 20, rngs=rngs)
82 | self.activation = nnx.relu
83 | self.linear2 = nnx.Linear(20, 4, rngs=rngs)
84 | self.softmax = nnx.softmax
85 |
86 | def __call__(self, x: Array) -> Array:
87 | """Forward pass of the TinyModel model.
88 |
89 | Parameters:
90 | x: Input tensor to be processed by the forward method.
91 |
92 | Returns:
93 | Output tensor after being processed through the layers and activation functions.
94 | """
95 | x = self.linear1(x)
96 | x = self.activation(x)
97 | x = self.linear2(x)
98 | x = self.softmax(x)
99 | return x
100 |
101 | return TinyModel(rngs=flax_rngs)
102 |
--------------------------------------------------------------------------------
/tests/probly/transformation/evidential/regression/test_torch.py:
--------------------------------------------------------------------------------
1 | """Tests for torch evidential regression models."""
2 |
3 | from __future__ import annotations
4 |
5 | import pytest
6 | import torch as th
7 | from torch import nn
8 |
9 | from probly.layers.torch import NormalInverseGammaLinear
10 | from probly.transformation.evidential.regression import evidential_regression
11 | from tests.probly.torch_utils import count_layers
12 |
13 | torch = pytest.importorskip("torch")
14 |
15 |
16 | class TestEvidentialRegression:
17 | """Test class for torch evidential regression models."""
18 |
19 | def test_returns_a_clone(self, torch_model_small_2d_2d: nn.Sequential) -> None:
20 | """Tests if evidential_regression returns a clone of the input model."""
21 | original_model = torch_model_small_2d_2d
22 |
23 | new_model = evidential_regression(original_model)
24 |
25 | assert new_model is not original_model
26 |
27 | def test_replaces_only_last_linear_layer(self, torch_model_small_2d_2d: nn.Sequential) -> None:
28 | """Tests if evidential_regression *only* replaces the last linear layer.
29 |
30 | This function verifies that the new model has exactly one LESS nn.Linear layer
31 | than the original, and one NormalInverseGammaLinear (NIG) layer.
32 |
33 | Parameters:
34 | torch_model_small_2d_2d: The torch model to be tested.
35 | """
36 | original_model = torch_model_small_2d_2d
37 | new_model = evidential_regression(original_model)
38 |
39 | # Layer Count Checks
40 | count_linear_original = count_layers(original_model, nn.Linear)
41 | count_linear_modified = count_layers(new_model, nn.Linear)
42 |
43 | # Check the core logic: The new model should have one LESS nn.Linear layer
44 | assert count_linear_modified == (count_linear_original - 1)
45 |
46 | # The modified model should have exactly one NIG layer
47 | count_nig_modified = count_layers(new_model, NormalInverseGammaLinear)
48 | assert count_nig_modified == 1
49 |
50 | def test_last_layer_replacement_and_integrity(self, torch_model_small_2d_2d: nn.Sequential) -> None:
51 | """Tests replacement of the last layer and verifies model integrity.
52 |
53 | Parameters:
54 | torch_model_small_2d_2d: The torch model to be tested.
55 | """
56 | original_model = torch_model_small_2d_2d
57 |
58 | # 1. Forward Pass & Shape Check
59 | input_data = th.randn(1, 2)
60 |
61 | # The output of the original model (Tensor) is needed to check the shape
62 | original_output = original_model(input_data)
63 | expected_output_shape = original_output.shape
64 |
65 | # Transformation
66 | new_model = evidential_regression(original_model)
67 |
68 | # 1a. Shape Check
69 | new_output = new_model(input_data)
70 | # KORREKTUR: Der NIG-Layer gibt ein Dictionary zurück, wir prüfen die Shape des 'gamma'-Tensors (mean).
71 | assert new_output["gamma"].shape == expected_output_shape
72 |
73 | # 1b. Last Layer Replacement Check
74 | # Check if the last layer in the new model is the NIG layer.
75 | last_layer_index = len(original_model) - 1
76 | last_layer_modified = new_model[last_layer_index]
77 |
78 | assert isinstance(last_layer_modified, NormalInverseGammaLinear)
79 |
80 | # 1c. Rest Model Integrity Check
81 | # The layers BEFORE the last layer must retain their original type.
82 | for i in range(last_layer_index):
83 | original_layer = original_model[i]
84 | modified_layer = new_model[i]
85 | assert type(original_layer) is type(modified_layer)
86 |
--------------------------------------------------------------------------------
/tests/probly/train/evidential/test_torch.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | import pytest
4 |
5 | from probly.predictor import Predictor
6 | from probly.train.evidential.torch import (
7 | EvidentialCELoss,
8 | EvidentialKLDivergence,
9 | EvidentialLogLoss,
10 | EvidentialMSELoss,
11 | EvidentialNIGNLLLoss,
12 | EvidentialRegressionRegularization,
13 | )
14 | from probly.transformation import evidential_regression
15 | from tests.probly.torch_utils import validate_loss
16 |
17 | torch = pytest.importorskip("torch")
18 |
19 | from torch import Tensor, nn # noqa: E402
20 |
21 |
22 | def test_evidential_log_loss(
23 | sample_classification_data: tuple[Tensor, Tensor],
24 | evidential_classification_model: nn.Module,
25 | ) -> None:
26 | inputs, targets = sample_classification_data
27 | outputs = evidential_classification_model(inputs)
28 | criterion = EvidentialLogLoss()
29 | loss = criterion(outputs, targets)
30 | validate_loss(loss)
31 |
32 |
33 | def test_evidential_ce_loss(
34 | sample_classification_data: tuple[Tensor, Tensor],
35 | evidential_classification_model: nn.Module,
36 | ) -> None:
37 | inputs, targets = sample_classification_data
38 | outputs = evidential_classification_model(inputs)
39 | criterion = EvidentialCELoss()
40 | loss = criterion(outputs, targets)
41 | validate_loss(loss)
42 |
43 |
44 | def test_evidential_mse_loss(
45 | sample_classification_data: tuple[Tensor, Tensor],
46 | evidential_classification_model: nn.Module,
47 | ) -> None:
48 | inputs, targets = sample_classification_data
49 | outputs = evidential_classification_model(inputs)
50 | criterion = EvidentialMSELoss()
51 | loss = criterion(outputs, targets)
52 | validate_loss(loss)
53 |
54 |
55 | def test_evidential_kl_divergence(
56 | sample_classification_data: tuple[Tensor, Tensor],
57 | evidential_classification_model: nn.Module,
58 | ) -> None:
59 | inputs, targets = sample_classification_data
60 | outputs = evidential_classification_model(inputs)
61 | criterion = EvidentialKLDivergence()
62 | loss = criterion(outputs, targets)
63 | validate_loss(loss)
64 |
65 |
66 | def test_evidential_nig_nll_loss(
67 | torch_regression_model_1d: nn.Module,
68 | torch_regression_model_2d: nn.Module,
69 | ) -> None:
70 | inputs = torch.randn(2, 2)
71 | targets = torch.randn(2, 1)
72 | model: Predictor = evidential_regression(torch_regression_model_1d)
73 | outputs = model(inputs)
74 | criterion = EvidentialNIGNLLLoss()
75 | loss = criterion(outputs, targets)
76 | validate_loss(loss)
77 |
78 | inputs = torch.randn(2, 4)
79 | targets = torch.randn(2, 2)
80 | model = evidential_regression(torch_regression_model_2d)
81 | outputs = model(inputs)
82 | criterion = EvidentialNIGNLLLoss()
83 | loss = criterion(outputs, targets)
84 | validate_loss(loss)
85 |
86 |
87 | def test_evidential_regression_regularization(
88 | torch_regression_model_1d: nn.Module,
89 | torch_regression_model_2d: nn.Module,
90 | ) -> None:
91 | inputs = torch.randn(2, 2)
92 | targets = torch.randn(2, 1)
93 | model: Predictor = evidential_regression(torch_regression_model_1d)
94 | outputs = model(inputs)
95 | criterion = EvidentialRegressionRegularization()
96 | loss = criterion(outputs, targets)
97 | validate_loss(loss)
98 |
99 | inputs = torch.randn(2, 4)
100 | targets = torch.randn(2, 2)
101 | model = evidential_regression(torch_regression_model_2d)
102 | outputs = model(inputs)
103 | criterion = EvidentialRegressionRegularization()
104 | loss = criterion(outputs, targets)
105 | validate_loss(loss)
106 |
--------------------------------------------------------------------------------
/tests/probly/fixtures/torch_models.py:
--------------------------------------------------------------------------------
1 | """Fixtures for models used in tests."""
2 |
3 | from __future__ import annotations
4 |
5 | import pytest
6 |
7 | from probly.predictor import Predictor
8 | from probly.transformation import evidential_classification
9 |
10 | torch = pytest.importorskip("torch")
11 | from torch import Tensor, nn # noqa: E402
12 |
13 |
14 | @pytest.fixture
15 | def torch_model_small_2d_2d() -> nn.Module:
16 | """Return a small linear model with 2 input and 2 output neurons."""
17 | model = nn.Sequential(
18 | nn.Linear(2, 2),
19 | nn.Linear(2, 2),
20 | nn.Linear(2, 2),
21 | )
22 | return model
23 |
24 |
25 | @pytest.fixture
26 | def torch_conv_linear_model() -> nn.Module:
27 | """Return a small convolutional model with 3 input channels and 2 output neurons."""
28 | model = nn.Sequential(
29 | nn.Conv2d(3, 5, 5),
30 | nn.ReLU(),
31 | nn.Flatten(),
32 | nn.Linear(5, 2),
33 | )
34 | return model
35 |
36 |
37 | @pytest.fixture
38 | def torch_regression_model_1d() -> nn.Module:
39 | """Return a small regression model with 2 input and 1 output neurons."""
40 | model = nn.Sequential(
41 | nn.Linear(2, 2),
42 | nn.ReLU(),
43 | nn.Linear(2, 1),
44 | )
45 | return model
46 |
47 |
48 | @pytest.fixture
49 | def torch_regression_model_2d() -> nn.Module:
50 | """Return a small regression model with 4 input and 2 output neurons."""
51 | model = nn.Sequential(
52 | nn.Linear(4, 4),
53 | nn.ReLU(),
54 | nn.Linear(4, 2),
55 | )
56 | return model
57 |
58 |
59 | @pytest.fixture
60 | def torch_dropout_model() -> nn.Module:
61 | """Return a small dropout model with 2 input and 2 output neurons."""
62 | model = nn.Sequential(
63 | nn.Linear(2, 2),
64 | nn.ReLU(),
65 | nn.Dropout(0.5),
66 | nn.Linear(2, 2),
67 | )
68 | return model
69 |
70 |
71 | @pytest.fixture
72 | def torch_custom_model() -> nn.Module:
73 | """Return a small custom model."""
74 |
75 | class TinyModel(nn.Module):
76 | """A simple neural network model with two linear layers and activation functions.
77 |
78 | Attributes:
79 | linear1 : The first linear layer with input size 100 and output size 200.
80 | activation : The ReLU activation function applied after the first linear layer.
81 | linear2 : The second linear layer with input size 200 and output size 10.
82 | softmax : The softmax function for normalizing the output into probabilities.
83 | """
84 |
85 | def __init__(self) -> None:
86 | """Initialize the TinyModel class."""
87 | super().__init__()
88 |
89 | self.linear1 = nn.Linear(10, 20)
90 | self.activation = nn.ReLU()
91 | self.linear2 = nn.Linear(20, 4)
92 | self.softmax = nn.Softmax()
93 |
94 | def forward(self, x: Tensor) -> Tensor:
95 | """Forward pass of the TinyModel model.
96 |
97 | Parameters:
98 | x: Input tensor to be processed by the forward method.
99 |
100 | Returns:
101 | Output tensor after being processed through the layers and activation functions.
102 | """
103 | x = self.linear1(x)
104 | x = self.activation(x)
105 | x = self.linear2(x)
106 | x = self.softmax(x)
107 | return x
108 |
109 | return TinyModel()
110 |
111 |
112 | @pytest.fixture
113 | def evidential_classification_model(
114 | torch_conv_linear_model: nn.Module,
115 | ) -> Predictor:
116 | model: Predictor = evidential_classification(torch_conv_linear_model)
117 | return model
118 |
--------------------------------------------------------------------------------
/tests/probly/quantification/test_classification.py:
--------------------------------------------------------------------------------
1 | """Tests for the classification module."""
2 |
3 | from __future__ import annotations
4 |
5 | from typing import TYPE_CHECKING
6 |
7 | import numpy as np
8 | import pytest
9 |
10 | from tests.probly.general_utils import validate_uncertainty
11 |
12 | if TYPE_CHECKING:
13 | from collections.abc import Callable
14 |
15 | from probly.evaluation.metrics import brier_score, log_loss, spherical_score, zero_one_loss
16 | from probly.quantification.classification import (
17 | aleatoric_uncertainty_distance,
18 | conditional_entropy,
19 | epistemic_uncertainty_distance,
20 | evidential_uncertainty,
21 | expected_conditional_variance,
22 | expected_divergence,
23 | expected_entropy,
24 | expected_loss,
25 | generalized_hartley,
26 | lower_entropy,
27 | lower_entropy_convex_hull,
28 | mutual_information,
29 | total_entropy,
30 | total_uncertainty_distance,
31 | total_variance,
32 | upper_entropy,
33 | upper_entropy_convex_hull,
34 | variance_conditional_expectation,
35 | )
36 |
37 |
38 | @pytest.fixture
39 | def sample_second_order_data() -> tuple[np.ndarray, np.ndarray]:
40 | rng = np.random.default_rng()
41 | probs2d = rng.dirichlet(np.ones(2), (10, 5))
42 | probs3d = rng.dirichlet(np.ones(3), (10, 5))
43 | return probs2d, probs3d
44 |
45 |
46 | @pytest.fixture
47 | def simplex_uniform() -> np.ndarray:
48 | return np.array([[[1 / 3, 1 / 3, 1 / 3], [1 / 3, 1 / 3, 1 / 3], [1 / 3, 1 / 3, 1 / 3]]])
49 |
50 |
51 | @pytest.fixture
52 | def simplex_vertices() -> np.ndarray:
53 | return np.array([[[1, 0, 0], [0, 1, 0], [0, 0, 1]]])
54 |
55 |
56 | @pytest.mark.parametrize(
57 | "uncertainty_fn",
58 | [
59 | total_entropy,
60 | conditional_entropy,
61 | mutual_information,
62 | total_variance,
63 | expected_conditional_variance,
64 | variance_conditional_expectation,
65 | total_uncertainty_distance,
66 | aleatoric_uncertainty_distance,
67 | epistemic_uncertainty_distance,
68 | upper_entropy,
69 | lower_entropy,
70 | upper_entropy_convex_hull,
71 | lower_entropy_convex_hull,
72 | generalized_hartley,
73 | ],
74 | )
75 | def test_uncertainty_function(
76 | uncertainty_fn: Callable[[np.ndarray], np.ndarray],
77 | sample_second_order_data: tuple[np.ndarray, np.ndarray],
78 | ) -> None:
79 | probs2d, probs3d = sample_second_order_data
80 | uncertainty = uncertainty_fn(probs2d)
81 | validate_uncertainty(uncertainty)
82 |
83 | uncertainty = uncertainty_fn(probs3d)
84 | validate_uncertainty(uncertainty)
85 |
86 |
87 | @pytest.mark.parametrize("uncertainty_fn", [expected_loss, expected_entropy, expected_divergence])
88 | def test_loss_uncertainty_function(
89 | uncertainty_fn: Callable[[np.ndarray, Callable[[np.ndarray, np.ndarray | None], np.ndarray]], np.ndarray],
90 | sample_second_order_data: tuple[np.ndarray, np.ndarray],
91 | ) -> None:
92 | probs2d, probs3d = sample_second_order_data
93 | for loss_fn in [log_loss, brier_score, zero_one_loss, spherical_score]:
94 | uncertainty = uncertainty_fn(probs2d, loss_fn)
95 | validate_uncertainty(uncertainty)
96 |
97 | uncertainty = uncertainty_fn(probs3d, loss_fn)
98 | validate_uncertainty(uncertainty)
99 |
100 |
101 | def test_lower_entropy(simplex_vertices: np.ndarray, simplex_uniform: np.ndarray) -> None:
102 | le = lower_entropy(simplex_vertices)
103 | assert le == pytest.approx(0.0)
104 |
105 | le = lower_entropy(simplex_uniform)
106 | assert le == pytest.approx(1.5849625007)
107 |
108 |
109 | def test_upper_entropy(simplex_vertices: np.ndarray, simplex_uniform: np.ndarray) -> None:
110 | ue = upper_entropy(simplex_vertices)
111 | assert ue == pytest.approx(1.5849625007)
112 |
113 | ue = upper_entropy(simplex_uniform)
114 | assert ue == pytest.approx(1.5849625007)
115 |
116 |
117 | def test_evidential_uncertainty() -> None:
118 | rng = np.random.default_rng()
119 | evidence = rng.uniform(0, 100, (10, 3))
120 | uncertainty = evidential_uncertainty(evidence)
121 | validate_uncertainty(uncertainty)
122 |
--------------------------------------------------------------------------------
/src/probly/traverse_nn/torch.py:
--------------------------------------------------------------------------------
1 | """Traversal implementation for PyTorch modules."""
2 |
3 | from __future__ import annotations
4 |
5 | from collections import OrderedDict
6 | import copy
7 |
8 | from torch.nn import Module, Sequential
9 |
10 | import pytraverse as t
11 | from pytraverse import generic
12 | from pytraverse.decorators import traverser
13 |
14 | from . import common as tnn
15 |
16 | # Torch traversal variables
17 |
18 | ROOT = t.StackVariable[Module | None]("ROOT", "A reference to the outermost module.")
19 | CLONE = t.StackVariable[bool](
20 | "CLONE",
21 | "Whether to clone torch modules before making changes.",
22 | default=generic.CLONE,
23 | )
24 | TRAVERSE_REVERSED = t.StackVariable[bool](
25 | "TRAVERSE_REVERSED",
26 | "Whether to traverse elements in reverse order.",
27 | default=generic.TRAVERSE_REVERSED,
28 | )
29 | FLATTEN_SEQUENTIAL = t.StackVariable[bool](
30 | "FLATTEN_SEQUENTIAL",
31 | "Whether to flatten sequential torch modules after making changes.",
32 | default=tnn.FLATTEN_SEQUENTIAL,
33 | )
34 |
35 | # Torch model cloning
36 |
37 |
38 | @traverser(type=Module)
39 | def _clone_traverser(
40 | obj: Module,
41 | state: t.State[Module],
42 | ) -> t.TraverserResult[Module]:
43 | if state[CLONE]:
44 | obj = copy.deepcopy(obj)
45 | # Do not clone the module twice:
46 | state[CLONE] = False
47 | # After deepcopy, generic datastructures will have been cloned as well:
48 | state[generic.CLONE] = False
49 |
50 | return obj, state
51 |
52 |
53 | # Torch model root tracking
54 |
55 |
56 | @traverser(type=Module)
57 | def _root_traverser(
58 | obj: Module,
59 | state: t.State[Module],
60 | ) -> t.TraverserResult[Module]:
61 | if state[ROOT] is None:
62 | state[ROOT] = obj
63 | state[tnn.LAYER_COUNT] = 0
64 | return obj, state
65 |
66 |
67 | # Torch model layer counting
68 |
69 |
70 | @tnn.layer_count_traverser.register(vars={"count": tnn.LAYER_COUNT}, update_vars=True)
71 | def _module_counter(obj: Module, count: int) -> tuple[Module, dict[str, int]]:
72 | return obj, {
73 | "count": count + 1, # Increment LAYER_COUNT for each traversed module.
74 | }
75 |
76 |
77 | @tnn.layer_count_traverser.register
78 | def _sequential_counter(obj: Sequential) -> Sequential:
79 | return obj # Don't count sequential modules as layers.
80 |
81 |
82 | # Torch model traverser
83 |
84 | _torch_traverser = t.singledispatch_traverser[Module](name="_torch_traverser")
85 |
86 |
87 | @_torch_traverser.register
88 | def _module_traverser(
89 | obj: Module,
90 | state: t.State[Module],
91 | traverse: t.TraverserCallback[Module],
92 | ) -> t.TraverserResult[Module]:
93 | children = obj.named_children()
94 | if state[TRAVERSE_REVERSED]:
95 | children = reversed(list(children))
96 | for name, module in children:
97 | new_module, state = traverse(module, state, name)
98 | setattr(obj, name, new_module)
99 |
100 | return obj, state
101 |
102 |
103 | @_torch_traverser.register
104 | def _sequential_traverser(
105 | obj: Sequential,
106 | state: t.State[Module],
107 | traverse: t.TraverserCallback[Module],
108 | ) -> t.TraverserResult[Module]:
109 | if not state[FLATTEN_SEQUENTIAL]:
110 | return _module_traverser(obj, state, traverse)
111 |
112 | seq = []
113 | children = obj.named_children()
114 | traverse_reversed = state[TRAVERSE_REVERSED]
115 | if traverse_reversed:
116 | children = reversed(list(children))
117 |
118 | for name, module in children:
119 | new_module, state = traverse(module, state, name)
120 | if isinstance(new_module, Sequential):
121 | sub_children = new_module.named_children()
122 | if traverse_reversed:
123 | sub_children = reversed(list(sub_children))
124 | for sub_name, sub_module in sub_children:
125 | seq.append((f"{name}_{sub_name}", sub_module))
126 | else:
127 | seq.append((name, new_module))
128 |
129 | new_obj = Sequential(OrderedDict(reversed(seq) if traverse_reversed else seq))
130 |
131 | return new_obj, state
132 |
133 |
134 | # Public API combining cloning, root tracking, and module traversing
135 |
136 | torch_traverser: t.Traverser[Module] = t.sequential(
137 | _clone_traverser,
138 | _root_traverser,
139 | _torch_traverser,
140 | name="torch_traverser",
141 | )
142 | torch_traverser.register = _torch_traverser.register # type: ignore[attr-defined]
143 |
144 | tnn.nn_traverser.register(Module, torch_traverser)
145 |
--------------------------------------------------------------------------------
/src/probly/traverse_nn/flax.py:
--------------------------------------------------------------------------------
1 | """Traversal implementation for flax modules."""
2 |
3 | from __future__ import annotations
4 |
5 | import copy
6 | from typing import TYPE_CHECKING
7 |
8 | from flax.nnx.helpers import Sequential
9 | from flax.nnx.module import Module
10 |
11 | import pytraverse as t
12 | from pytraverse import generic
13 | from pytraverse.decorators import traverser
14 |
15 | from . import common as tnn
16 |
17 | if TYPE_CHECKING:
18 | from collections.abc import Iterable, Iterator
19 |
20 | # Torch traversal variables
21 |
22 | ROOT = t.StackVariable[Module | None]("ROOT", "A reference to the outermost module.")
23 | CLONE = t.StackVariable[bool](
24 | "CLONE",
25 | "Whether to clone torch modules before making changes.",
26 | default=generic.CLONE,
27 | )
28 | TRAVERSE_REVERSED = t.StackVariable[bool](
29 | "TRAVERSE_REVERSED",
30 | "Whether to traverse elements in reverse order.",
31 | default=generic.TRAVERSE_REVERSED,
32 | )
33 | FLATTEN_SEQUENTIAL = t.StackVariable[bool](
34 | "FLATTEN_SEQUENTIAL",
35 | "Whether to flatten sequential flax modules after making changes.",
36 | default=tnn.FLATTEN_SEQUENTIAL,
37 | )
38 |
39 | # Torch model cloning
40 |
41 |
42 | @traverser(type=Module)
43 | def _clone_traverser(
44 | obj: Module,
45 | state: t.State[Module],
46 | ) -> t.TraverserResult[Module]:
47 | if state[CLONE]:
48 | obj = copy.deepcopy(obj)
49 | # Do not clone the module twice:
50 | state[CLONE] = False
51 | # After deepcopy, generic datastructures will have been cloned as well:
52 | state[generic.CLONE] = False
53 |
54 | return obj, state
55 |
56 |
57 | # Flax model root tracking
58 |
59 |
60 | @traverser(type=Module)
61 | def _root_traverser(
62 | obj: Module,
63 | state: t.State[Module],
64 | ) -> t.TraverserResult[Module]:
65 | if state[ROOT] is None:
66 | state[ROOT] = obj
67 | state[tnn.LAYER_COUNT] = 0
68 | return obj, state
69 |
70 |
71 | # Flax model layer counting
72 |
73 |
74 | @tnn.layer_count_traverser.register(vars={"count": tnn.LAYER_COUNT}, update_vars=True)
75 | def _module_counter(obj: Module, count: int) -> tuple[Module, dict[str, int]]:
76 | return obj, {
77 | "count": count + 1, # Increment LAYER_COUNT for each traversed module.
78 | }
79 |
80 |
81 | @tnn.layer_count_traverser.register
82 | def _sequential_counter(obj: Sequential) -> Sequential:
83 | return obj # Don't count sequential modules as layers.
84 |
85 |
86 | # Flax model traverser
87 |
88 | _torch_traverser = t.singledispatch_traverser[Module](name="_torch_traverser")
89 |
90 |
91 | @_torch_traverser.register
92 | def _module_traverser(
93 | obj: Module,
94 | state: t.State[Module],
95 | traverse: t.TraverserCallback[Module],
96 | ) -> t.TraverserResult[Module]:
97 | children: Iterator[tuple[str, Module]] = obj.iter_children() # type: ignore[assignment]
98 | if state[TRAVERSE_REVERSED]:
99 | children = reversed(list(children))
100 | for name, module in children:
101 | new_module, state = traverse(module, state, name)
102 | setattr(obj, name, new_module)
103 |
104 | return obj, state
105 |
106 |
107 | @_torch_traverser.register
108 | def _sequential_traverser(
109 | obj: Sequential,
110 | state: t.State[Module],
111 | traverse: t.TraverserCallback[Module],
112 | ) -> t.TraverserResult[Module]:
113 | if not state[FLATTEN_SEQUENTIAL]:
114 | return _module_traverser(obj, state, traverse)
115 |
116 | seq: list[Module] = []
117 | children: Iterable[Module] = obj.layers # type: ignore[assignment]
118 | traverse_reversed = state[TRAVERSE_REVERSED]
119 | if traverse_reversed:
120 | children = reversed(list(children))
121 |
122 | for module in children:
123 | new_module, state = traverse(module, state)
124 | if isinstance(new_module, Sequential):
125 | sub_children: Iterable[Module] = new_module.layers # type: ignore[assignment]
126 | if traverse_reversed:
127 | sub_children = reversed(list(sub_children))
128 | seq += sub_children
129 | else:
130 | seq.append(new_module)
131 |
132 | new_obj = Sequential(*(reversed(seq) if traverse_reversed else seq))
133 |
134 | return new_obj, state
135 |
136 |
137 | # Public API combining cloning, root tracking, and module traversing
138 |
139 | torch_traverser: t.Traverser[Module] = t.sequential(
140 | _clone_traverser,
141 | _root_traverser,
142 | _torch_traverser,
143 | name="torch_traverser",
144 | )
145 | torch_traverser.register = _torch_traverser.register # type: ignore[attr-defined]
146 |
147 | tnn.nn_traverser.register(Module, torch_traverser)
148 |
--------------------------------------------------------------------------------
/tests/probly/transformation/bayesian/test_torch.py:
--------------------------------------------------------------------------------
1 | """Test for torch bayesian models."""
2 |
3 | from __future__ import annotations
4 |
5 | import pytest
6 |
7 | from probly.layers.torch import BayesConv2d, BayesLinear
8 | from probly.transformation import bayesian
9 | from tests.probly.torch_utils import count_layers
10 |
11 | torch = pytest.importorskip("torch")
12 |
13 | from torch import nn # noqa: E402
14 |
15 |
16 | class TestNetworkArchitectures:
17 | """Test class for different network architectures."""
18 |
19 | def test_linear_network_replacement(
20 | self,
21 | torch_model_small_2d_2d: nn.Sequential,
22 | ) -> None:
23 | """Tests if a model incorporates a bayesian layer correctly when a linear layer is present.
24 |
25 | This function verifies that:
26 | - A standard linear layer is replaced with a bayesian linear layer.
27 | """
28 | model = bayesian(torch_model_small_2d_2d)
29 |
30 | # count number of nn.Linear layers in original model
31 | count_linear_original = count_layers(torch_model_small_2d_2d, nn.Linear)
32 | # count number of BayesLinear layers in original model
33 | count_bayesian_original = count_layers(torch_model_small_2d_2d, BayesLinear)
34 | # count number of nn.Sequential layers in original model
35 | count_sequential_original = count_layers(torch_model_small_2d_2d, nn.Sequential)
36 |
37 | # count number of nn.Linear layers in modified model
38 | count_linear_modified = count_layers(model, nn.Linear)
39 | # count number of BayesLinear layers in modified model
40 | count_bayesian_modified = count_layers(model, BayesLinear)
41 | # count number of nn.Sequential layers in modified model
42 | count_sequential_modified = count_layers(model, nn.Sequential)
43 |
44 | # check that the model is not modified except for the bayesian layer
45 | assert model is not None
46 | assert isinstance(model, type(torch_model_small_2d_2d))
47 | assert count_bayesian_modified == count_bayesian_original + count_linear_original
48 | assert count_linear_modified == 0
49 | assert count_sequential_original == count_sequential_modified
50 |
51 | def test_convolutional_network(self, torch_conv_linear_model: nn.Sequential) -> None:
52 | """Tests the convolutional neural network modification with added bayesian layers.
53 |
54 | This function evaluates whether the given convolutional neural network model
55 | has been correctly modified to include bayesian layers without altering the
56 | number of other components such as linear, sequential, or convolutional layers.
57 | """
58 | model = bayesian(torch_conv_linear_model)
59 |
60 | # count number of nn.Linear layers in original model
61 | count_linear_original = count_layers(torch_conv_linear_model, nn.Linear)
62 | # count number of BayesConv2d layers in original model
63 | count_bayesian_conv_original = count_layers(torch_conv_linear_model, BayesConv2d)
64 | # count number of nn.Sequential layers in original model
65 | count_sequential_original = count_layers(torch_conv_linear_model, nn.Sequential)
66 | # count number of nn.Conv2d layers in original model
67 | count_conv_original = count_layers(torch_conv_linear_model, nn.Conv2d)
68 | # count number of BayesLinear layers in original model
69 | count_bayesian_linear_original = count_layers(torch_conv_linear_model, BayesLinear)
70 |
71 | # count number of nn.Linear layers in modified model
72 | count_linear_modified = count_layers(model, nn.Linear)
73 | # count number of BayesConv2d layers in modified model
74 | count_bayesian_conv_modified = count_layers(model, BayesConv2d)
75 | # count number of nn.Sequential layers in modified model
76 | count_sequential_modified = count_layers(model, nn.Sequential)
77 | # count number of nn.Conv2d layers in modified model
78 | count_conv_modified = count_layers(model, nn.Conv2d)
79 | # count number of BayesLinear layers in modified model
80 | count_bayesian_linear_modified = count_layers(model, BayesLinear)
81 |
82 | # check that the model is not modified except for the bayesian layer
83 | assert model is not None
84 | assert isinstance(model, type(torch_conv_linear_model))
85 | assert count_linear_modified == 0
86 | assert count_conv_modified == 0
87 | assert count_bayesian_conv_modified == count_bayesian_conv_original + count_conv_original
88 | assert count_bayesian_linear_modified == count_bayesian_linear_original + count_linear_original
89 | assert count_sequential_original == count_sequential_modified
90 |
91 | def test_custom_network(self, torch_custom_model: nn.Module) -> None:
92 | """Tests the custom model modification with added bayesian layers."""
93 | model = bayesian(torch_custom_model)
94 |
95 | # check if model type is correct
96 | assert isinstance(model, type(torch_custom_model))
97 | assert not isinstance(model, nn.Sequential)
98 |
--------------------------------------------------------------------------------
/src/probly/representation/sampling/sampler.py:
--------------------------------------------------------------------------------
1 | """Model sampling and sample representer implementation."""
2 |
3 | from __future__ import annotations
4 |
5 | from collections.abc import Callable, Iterable
6 | from typing import Any, Literal, Unpack
7 |
8 | from probly.lazy_types import FLAX_MODULE, TORCH_MODULE
9 | from probly.predictor import Predictor, predict
10 | from probly.representation.representer import Representer
11 | from probly.traverse_nn import nn_compose
12 | from pytraverse import CLONE, GlobalVariable, lazydispatch_traverser, traverse_with_state
13 |
14 | from .sample import Sample, create_sample
15 |
16 | type SamplingStrategy = Literal["sequential"]
17 |
18 |
19 | sampling_preparation_traverser = lazydispatch_traverser[object](name="sampling_preparation_traverser")
20 |
21 | CLEANUP_FUNCS = GlobalVariable[set[Callable[[], Any]]](name="CLEANUP_FUNCS")
22 |
23 |
24 | @sampling_preparation_traverser.delayed_register(TORCH_MODULE)
25 | def _(_: type) -> None:
26 | from . import torch_sampler as torch_sampler # noqa: PLC0414, PLC0415
27 |
28 |
29 | @sampling_preparation_traverser.delayed_register(FLAX_MODULE)
30 | def _(_: type) -> None:
31 | from . import flax_sampler as flax_sampler # noqa: PLC0414, PLC0415
32 |
33 |
34 | def get_sampling_predictor[In, KwIn, Out](
35 | predictor: Predictor[In, KwIn, Out],
36 | ) -> tuple[Predictor[In, KwIn, Out], Callable[[], None]]:
37 | """Get the predictor to be used for sampling."""
38 | predictor, state = traverse_with_state(
39 | predictor,
40 | nn_compose(sampling_preparation_traverser),
41 | init={CLONE: False, CLEANUP_FUNCS: set()},
42 | )
43 | cleanup_funcs = state[CLEANUP_FUNCS]
44 |
45 | def cleanup() -> None:
46 | for func in cleanup_funcs:
47 | func()
48 |
49 | return predictor, cleanup
50 |
51 |
52 | def sampler_factory[In, KwIn, Out](
53 | predictor: Predictor[In, KwIn, Out],
54 | num_samples: int = 1,
55 | strategy: SamplingStrategy = "sequential",
56 | ) -> Predictor[In, KwIn, list[Out]]:
57 | """Sample multiple predictions from the predictor."""
58 |
59 | def sampler(*args: In, **kwargs: Unpack[KwIn]) -> list[Out]:
60 | sampling_predictor, cleanup = get_sampling_predictor(predictor)
61 | try:
62 | if strategy == "sequential":
63 | return [predict(sampling_predictor, *args, **kwargs) for _ in range(num_samples)]
64 | finally:
65 | cleanup()
66 |
67 | msg = f"Unknown sampling strategy: {strategy}"
68 | raise ValueError(msg)
69 |
70 | return sampler
71 |
72 |
73 | class Sampler[In, KwIn, Out](Representer[In, KwIn, Out]):
74 | """A representation predictor that creates representations from finite samples."""
75 |
76 | sampling_strategy: SamplingStrategy
77 | sample_factory: Callable[[Iterable[Out]], Sample[Out]]
78 |
79 | def __init__(
80 | self,
81 | predictor: Predictor[In, KwIn, Out],
82 | sampling_strategy: SamplingStrategy = "sequential",
83 | sample_factory: Callable[[Iterable[Out]], Sample[Out]] = create_sample,
84 | ) -> None:
85 | """Initialize the sampler.
86 |
87 | Args:
88 | predictor (Predictor[In, KwIn, Out]): The predictor to be used for sampling.
89 | sampling_strategy (SamplingStrategy, optional): How the samples should be computed.
90 | sample_factory (Callable[[Iterable[Out]], Sample[Out]], optional): Factory to create the sample.
91 | """
92 | super().__init__(predictor)
93 | self.sampling_strategy = sampling_strategy
94 | self.sample_factory = sample_factory
95 |
96 | def predict(self, *args: In, num_samples: int, **kwargs: Unpack[KwIn]) -> Sample[Out]:
97 | """Sample from the predictor for a given input."""
98 | return self.sample_factory(
99 | sampler_factory(
100 | self.predictor,
101 | num_samples=num_samples,
102 | strategy=self.sampling_strategy,
103 | )(*args, **kwargs),
104 | )
105 |
106 |
107 | class EnsembleSampler[In, KwIn, Out](Representer[In, KwIn, Iterable[Out]]):
108 | """A sampler that creates representations from ensemble predictions."""
109 |
110 | sample_factory: Callable[[Iterable[Out]], Sample[Out]]
111 |
112 | def __init__(
113 | self,
114 | predictor: Predictor[In, KwIn, Iterable[Out]],
115 | sample_factory: Callable[[Iterable[Out]], Sample[Out]] = create_sample,
116 | ) -> None:
117 | """Initialize the ensemble sampler.
118 |
119 | Args:
120 | predictor (Predictor[In, KwIn, Out]): The ensemble predictor.
121 | sample_factory (Callable[[Iterable[Out]], Sample[Out]], optional): Factory to create the sample.
122 | """
123 | super().__init__(predictor)
124 | self.sample_factory = sample_factory
125 |
126 | def sample(self, *args: In, **kwargs: Unpack[KwIn]) -> Sample[Out]:
127 | """Sample from the ensemble predictor for a given input."""
128 | return self.sample_factory(
129 | self.predictor(*args, **kwargs),
130 | )
131 |
--------------------------------------------------------------------------------
/src/probly/layers/flax.py:
--------------------------------------------------------------------------------
1 | """flax layer implementation."""
2 |
3 | from __future__ import annotations
4 |
5 | from flax import nnx
6 | from flax.nnx import rnglib
7 | from flax.nnx.module import first_from
8 | import jax
9 | from jax import lax, random
10 | import jax.numpy as jnp
11 |
12 |
13 | class DropConnectLinear(nnx.Module):
14 | """Custom Linear layer with DropConnect applied to weights during training.
15 |
16 | Attributes:
17 | weight: nnx.Param, weight matrix of shape.
18 | bias: nnx.Param, bias vector of shape.
19 | rate: float, the dropconnect probability.
20 | deterministic: bool, if false the inputs are masked, whereas if true, no mask
21 | is applied and the inputs are returned as is.
22 | rng_collection: str, the rng collection name to use when requesting a rng key.
23 | rngs: nnx.Rngs or nnx.RngStream or None, rng key.
24 |
25 | """
26 |
27 | def __init__(
28 | self,
29 | base_layer: nnx.Linear,
30 | rate: float = 0.25,
31 | *,
32 | rng_collection: str = "dropconnect",
33 | rngs: rnglib.Rngs | rnglib.RngStream | None = None,
34 | ) -> None:
35 | """Initialize a DropConnectLinear layer based on a given linear base layer.
36 |
37 | Args:
38 | base_layer: nnx.Linear, The original linear layer to be wrapped.
39 | rate: float, the dropconnect probability.
40 | rng_collection: str, rng collection name to use when requesting a rng key.
41 | rngs: nnx.Rngs or nn.RngStream or None, rng key.
42 | """
43 | self.weight = base_layer.kernel
44 | self.bias = base_layer.bias if base_layer.bias is not None else None
45 | self.rate = rate
46 | self.rng_collection = rng_collection
47 |
48 | if isinstance(rngs, rnglib.Rngs):
49 | self.rngs = rngs[self.rng_collection].fork()
50 | elif isinstance(rngs, rnglib.RngStream):
51 | self.rngs = rngs.fork()
52 | elif rngs is None:
53 | self.rngs = nnx.data(None)
54 | else:
55 | msg = f"rngs must be a RNGS, RngStream or None, but got {type(rngs)}."
56 | raise TypeError(msg)
57 |
58 | def __call__(
59 | self,
60 | inputs: jax.Array,
61 | *,
62 | deterministic: bool = False,
63 | rngs: rnglib.Rngs | rnglib.RngStream | jax.Array | None = None,
64 | ) -> jax.Array:
65 | """Forward pass of the DropConnectLinear layer.
66 |
67 | Args:
68 | inputs: jax.Array, input data that should be randomly masked.
69 | deterministic: bool, if false the inputs are masked, whereas if true, no mask
70 | is applied and the inputs are returned as is.
71 | rngs: nnx.Rngs, nnx.RngStream or jax.Array, optional key used to generate the dropconnect mask.
72 |
73 | Returns:
74 | jax.Array, layer output.
75 | """
76 | self.deterministic = deterministic
77 |
78 | deterministic = first_from(
79 | deterministic,
80 | self.deterministic,
81 | error_msg="""No `deterministic` argument was provided to DropConnect
82 | as either a __call__ argument or class attribute.""",
83 | )
84 |
85 | if (self.rate == 0.0) or deterministic:
86 | return inputs
87 |
88 | # Prevent gradient NaNs in 1.0 edge-case.
89 | if self.rate == 1.0:
90 | out = inputs @ jnp.zeros_like(self.weight.value)
91 | return out if self.bias is None else out + self.bias
92 |
93 | rngs = first_from(
94 | rngs,
95 | self.rngs,
96 | error_msg="""`deterministic` is False, but no `rngs` argument was provided
97 | to DropConnect as either a __call__ argument or class atribute.""",
98 | )
99 |
100 | if isinstance(rngs, rnglib.Rngs):
101 | key = rngs[self.rng_collection]()
102 | elif isinstance(rngs, rnglib.RngStream):
103 | key = rngs()
104 | elif isinstance(rngs, jax.Array):
105 | key = rngs
106 | else:
107 | msg = f"rngs must be Rngs, RngStream or jax.Array, but got {type(rngs)}."
108 | raise TypeError(msg)
109 |
110 | keep_prob = 1.0 - self.rate
111 | mask = random.bernoulli(key, p=keep_prob, shape=self.weight.value.shape)
112 | masked_weight = lax.select(mask, self.weight.value / keep_prob, jnp.zeros_like(self.weight.value))
113 |
114 | out = inputs @ masked_weight
115 | if self.bias is not None:
116 | out = out + self.bias.value
117 | return out
118 |
119 | def __repr__(self) -> str:
120 | """Return a string representation of the layer including its class name and key attributes."""
121 | return f"{self.__class__.__name__}({self.extra_repr()})"
122 |
123 | def extra_repr(self) -> str:
124 | """Expose description of in- and out-features of this layer."""
125 | in_features = self.weight.value.shape[0]
126 | out_features = self.weight.value.shape[1]
127 | return f"in_features={in_features}, out_features={out_features}, bias={self.bias is not None}"
128 |
--------------------------------------------------------------------------------
/src/probly/quantification/regression.py:
--------------------------------------------------------------------------------
1 | """Collection of uncertainty quantification measures for regression settings."""
2 |
3 | from __future__ import annotations
4 |
5 | from typing import TYPE_CHECKING
6 |
7 | import numpy as np
8 |
9 | from probly.utils import differential_entropy_gaussian, kl_divergence_gaussian
10 |
11 | if TYPE_CHECKING:
12 | import numpy.typing as npt
13 |
14 |
15 | def total_variance(probs: npt.NDArray[np.floating]) -> npt.NDArray[np.floating]:
16 | """Compute total variance as the total uncertainty using variance-based measures.
17 |
18 | Assumes that the input is from a distribution over parameters of
19 | a normal distribution. The first element of the parameter vector is the mean
20 | and the second element is the variance.
21 | The total uncertainty is the variance of the mixture of normal distributions.
22 |
23 | Args:
24 | probs: numpy.ndarray, shape (n_instances, n_samples, (mu, sigma^2))
25 |
26 | Returns:
27 | tv: numpy.ndarray, shape (n_instances,)
28 |
29 | """
30 | tv = np.mean(probs[:, :, 1], axis=1) + np.var(probs[:, :, 0], axis=1)
31 | return tv
32 |
33 |
34 | def expected_conditional_variance(probs: np.ndarray) -> np.ndarray:
35 | """Compute expected conditional variance as the aleatoric uncertainty using variance-based measures.
36 |
37 | Assume that the input is from a distribution over parameters of
38 | a normal distribution. The first element of the parameter vector is the mean
39 | and the second element is the variance.
40 | The aleatoric uncertainty is the mean of the variance of the samples.
41 |
42 | Args:
43 | probs: numpy.ndarray, shape (n_instances, n_samples, (mu, sigma^2))
44 |
45 | Returns:
46 | ecv: numpy.ndarray, shape (n_instances,)
47 |
48 | """
49 | ecv = np.mean(probs[:, :, 1], axis=1)
50 | return ecv
51 |
52 |
53 | def variance_conditional_expectation(probs: np.ndarray) -> np.ndarray:
54 | """Compute variance of conditional expectation as the epistemic uncertainty using variance-based measures.
55 |
56 | Assume that the input is from a distribution over parameters of
57 | a normal distribution. The first element of the parameter vector is the mean
58 | and the second element is the variance.
59 | The epistemic uncertainty is the variance of the mean of the samples.
60 |
61 | Args:
62 | probs: numpy.ndarray, shape (n_instances, n_samples, (mu, sigma^2))
63 |
64 | Returns:
65 | vce: numpy.ndarray, shape (n_instances,)
66 |
67 | """
68 | vce = np.var(probs[:, :, 0], axis=1)
69 | return vce
70 |
71 |
72 | def total_differential_entropy(probs: np.ndarray) -> np.ndarray:
73 | """Compute total differential entropy as the epistemic uncertainty using entropy-based measures.
74 |
75 | Assume that the input is from a distribution over parameters of
76 | a normal distribution. The first element of the parameter vector is the mean
77 | and the second element is the variance.
78 | The total uncertainty is the differential entropy of the mixture of normal distributions.
79 |
80 | Args:
81 | probs: numpy.ndarray, shape (n_instances, n_samples, (mu, sigma^2))
82 |
83 | Returns:
84 | tde: numpy.ndarray, shape (n_instances,)
85 |
86 | """
87 | sigma2_mean = np.mean(probs[:, :, 1], axis=1) + np.var(probs[:, :, 0], axis=1)
88 | tde = differential_entropy_gaussian(sigma2_mean)
89 | return tde
90 |
91 |
92 | def conditional_differential_entropy(probs: np.ndarray) -> np.ndarray:
93 | """Compute conditional differential entropy as the aleatoric uncertainty using entropy-based measures.
94 |
95 | Assume that the input is from a distribution over parameters of
96 | a normal distribution. The first element of the parameter vector is the mean
97 | and the second element is the variance.
98 | The aleatoric uncertainty is the mean of the differential entropy of the samples.
99 |
100 | Args:
101 | probs: numpy.ndarray, shape (n_instances, n_samples, (mu, sigma^2))
102 |
103 | Returns:
104 | cde: numpy.ndarray, shape (n_instances,)
105 |
106 | """
107 | cde = np.mean(differential_entropy_gaussian(probs[:, :, 1]), axis=1)
108 | return cde
109 |
110 |
111 | def mutual_information(probs: np.ndarray) -> np.ndarray:
112 | """Compute mutual information as the epistemic uncertainty using entropy-based measures.
113 |
114 | Assume that the input is from a distribution over parameters of
115 | a normal distribution. The first element of the parameter vector is the mean
116 | and the second element is the variance.
117 | The epistemic uncertainty is the expected KL-divergence of the samples
118 | to the mean distribution.
119 |
120 | Args:
121 | probs: numpy.ndarray, shape (n_instances, n_samples, (mu, sigma^2))
122 |
123 | Returns:
124 | mi: numpy.ndarray, shape (n_instances,)
125 |
126 | """
127 | mu_mean = np.mean(probs[:, :, 0], axis=1)
128 | sigma2_mean = np.mean(probs[:, :, 1], axis=1) + np.var(probs[:, :, 0], axis=1)
129 | mu_mean = np.repeat(np.expand_dims(mu_mean, 1), repeats=probs.shape[1], axis=1)
130 | sigma2_mean = np.repeat(np.expand_dims(sigma2_mean, 1), repeats=probs.shape[1], axis=1)
131 | mi = np.mean(kl_divergence_gaussian(probs[:, :, 0], probs[:, :, 1], mu_mean, sigma2_mean), axis=1)
132 | return mi
133 |
--------------------------------------------------------------------------------
/tests/probly/evaluation/test_metrics.py:
--------------------------------------------------------------------------------
1 | """Tests for the metrics module."""
2 |
3 | from __future__ import annotations
4 |
5 | import math
6 |
7 | import numpy as np
8 | import pytest
9 |
10 | from probly.evaluation.metrics import (
11 | ROUND_DECIMALS,
12 | brier_score,
13 | coverage,
14 | coverage_convex_hull,
15 | covered_efficiency,
16 | efficiency,
17 | expected_calibration_error,
18 | log_loss,
19 | spherical_score,
20 | zero_one_loss,
21 | )
22 | from tests.probly.general_utils import validate_uncertainty
23 |
24 |
25 | @pytest.fixture
26 | def sample_zero_order_data() -> tuple[np.ndarray, np.ndarray]:
27 | rng = np.random.default_rng()
28 | probs = rng.dirichlet(np.ones(3), 10)
29 | targets = rng.integers(0, 3, 10)
30 | return probs, targets
31 |
32 |
33 | @pytest.fixture
34 | def sample_first_order_data() -> tuple[np.ndarray, np.ndarray]:
35 | rng = np.random.default_rng()
36 | probs = rng.dirichlet(np.ones(3), (10, 5))
37 | targets = rng.dirichlet(np.ones(3), 10)
38 | return probs, targets
39 |
40 |
41 | @pytest.fixture
42 | def sample_conformal_data() -> tuple[np.ndarray, np.ndarray]:
43 | rng = np.random.default_rng()
44 | probs = rng.choice([True, False], (10, 3))
45 | targets = rng.integers(0, 3, 10)
46 | return probs, targets
47 |
48 |
49 | def validate_metric(metric: float) -> None:
50 | assert isinstance(metric, float)
51 | assert not math.isnan(metric)
52 | assert not math.isinf(metric)
53 | assert metric >= 0
54 |
55 |
56 | def test_expected_calibration_error() -> None:
57 | probs = np.array([[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0], [1.0, 0.0, 0.0]])
58 | targets = np.array([0, 1, 2, 0])
59 | ece = expected_calibration_error(probs, targets)
60 | validate_metric(ece)
61 | assert ece == 0.0
62 |
63 | targets = np.array([1, 2, 0, 1])
64 | ece = expected_calibration_error(probs, targets)
65 | validate_metric(ece)
66 | assert ece == 1.0
67 |
68 |
69 | def test_coverage(
70 | sample_conformal_data: tuple[np.array, np.array],
71 | sample_first_order_data: tuple[np.array, np.array],
72 | ) -> None:
73 | preds, targets = sample_conformal_data
74 | cov = coverage(preds, targets)
75 | validate_metric(cov)
76 |
77 | probs, targets = sample_first_order_data
78 | cov = coverage(probs, targets)
79 | validate_metric(cov)
80 |
81 |
82 | def test_efficiency(
83 | sample_conformal_data: tuple[np.array, np.array],
84 | sample_first_order_data: tuple[np.array, np.array],
85 | ) -> None:
86 | preds, _ = sample_conformal_data
87 | eff = efficiency(preds)
88 | validate_metric(eff)
89 |
90 | probs, _ = sample_first_order_data
91 | eff = efficiency(probs)
92 | validate_metric(eff)
93 |
94 |
95 | def test_coverage_convex_hull(sample_first_order_data: tuple[np.array, np.array]) -> None:
96 | probs, targets = sample_first_order_data
97 | cov = coverage_convex_hull(probs, targets)
98 | validate_metric(cov)
99 |
100 |
101 | def test_covered_efficiency(
102 | sample_conformal_data: tuple[np.array, np.array],
103 | sample_first_order_data: tuple[np.array, np.array],
104 | ) -> None:
105 | preds, targets = sample_conformal_data
106 | eff = covered_efficiency(preds, targets)
107 | covered = preds[np.arange(preds.shape[0]), targets]
108 | # if none of the instances cover the target, the efficiency should be np.nan
109 | if not np.any(covered):
110 | assert math.isnan(eff)
111 | else:
112 | validate_metric(eff)
113 |
114 | probs, targets = sample_first_order_data
115 | eff = covered_efficiency(probs, targets)
116 | probs_lower = np.round(np.nanmin(probs, axis=1), decimals=ROUND_DECIMALS)
117 | probs_upper = np.round(np.nanmax(probs, axis=1), decimals=ROUND_DECIMALS)
118 | covered = np.all((probs_lower <= targets) & (targets <= probs_upper), axis=1)
119 | # if none of the instances cover the target, the efficiency should be np.nan
120 | if not np.any(covered):
121 | assert math.isnan(eff)
122 | else:
123 | validate_metric(eff)
124 |
125 |
126 | def test_log_loss(
127 | sample_zero_order_data: tuple[np.array, np.array],
128 | sample_first_order_data: tuple[np.array, np.array],
129 | ) -> None:
130 | probs, targets = sample_zero_order_data
131 | loss = log_loss(probs, targets)
132 | validate_metric(loss)
133 |
134 | probs, _ = sample_first_order_data
135 | loss = log_loss(probs, None)
136 | validate_uncertainty(loss)
137 |
138 |
139 | def test_brier_score(
140 | sample_zero_order_data: tuple[np.array, np.array],
141 | sample_first_order_data: tuple[np.array, np.array],
142 | ) -> None:
143 | probs, targets = sample_zero_order_data
144 | loss = brier_score(probs, targets)
145 | validate_metric(loss)
146 |
147 | probs, _ = sample_first_order_data
148 | loss = brier_score(probs, None)
149 | validate_uncertainty(loss)
150 |
151 |
152 | def test_zero_one_loss(
153 | sample_zero_order_data: tuple[np.array, np.array],
154 | sample_first_order_data: tuple[np.array, np.array],
155 | ) -> None:
156 | probs, targets = sample_zero_order_data
157 | loss = zero_one_loss(probs, targets)
158 | validate_metric(loss)
159 |
160 | probs, _ = sample_first_order_data
161 | loss = zero_one_loss(probs, None)
162 | validate_uncertainty(loss)
163 |
164 |
165 | def test_spherical_score(
166 | sample_zero_order_data: tuple[np.array, np.array],
167 | sample_first_order_data: tuple[np.array, np.array],
168 | ) -> None:
169 | probs, targets = sample_zero_order_data
170 | loss = spherical_score(probs, targets)
171 | validate_metric(loss)
172 |
173 | probs, _ = sample_first_order_data
174 | loss = spherical_score(probs, None)
175 | validate_uncertainty(loss)
176 |
--------------------------------------------------------------------------------
/src/probly/train/calibration/torch.py:
--------------------------------------------------------------------------------
1 | """Collection of torch calibration training functions."""
2 |
3 | from __future__ import annotations
4 |
5 | import torch
6 | from torch import nn
7 | import torch.nn.functional as F
8 |
9 |
10 | class ExpectedCalibrationError(nn.Module):
11 | """Expected Calibration Error (ECE) :cite:`guoOnCalibration2017`.
12 |
13 | Attributes:
14 | num_bins: int, number of bins to use for calibration
15 | self.bins: torch.Tensor, the actual bins for calibration
16 | """
17 |
18 | def __init__(self, num_bins: int = 10) -> None:
19 | """Initializes an instance of the ExpectedCalibrationError class.
20 |
21 | Args:
22 | num_bins: int, number of bins to use for calibration
23 | """
24 | super().__init__()
25 | self.num_bins = num_bins
26 | self.bins = torch.linspace(0, 1, num_bins + 1)
27 |
28 | def forward(self, inputs: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
29 | """Forward pass of the expected calibration error.
30 |
31 | Assumes that inputs are probability distributions over classes.
32 |
33 | Args:
34 | inputs: torch.Tensor of size (n_instances, n_classes).
35 | targets: torch.Tensor of size (n_instances,)
36 |
37 | Returns:
38 | loss: torch.Tensor, mean loss value
39 | """
40 | confs, preds = torch.max(inputs, dim=1)
41 | bin_indices = torch.bucketize(confs, self.bins.to(inputs.device), right=True) - 1
42 | num_instances = inputs.shape[0]
43 | loss: torch.Tensor = torch.tensor(0, dtype=torch.float32, device=inputs.device)
44 | for i in range(self.num_bins):
45 | _bin = torch.where(bin_indices == i)[0]
46 | # check if bin is empty
47 | if _bin.shape[0] == 0:
48 | continue
49 | acc_bin = torch.mean((preds[_bin] == targets[_bin]).float())
50 | conf_bin = torch.mean(confs[_bin])
51 | weight = _bin.shape[0] / num_instances
52 | loss += weight * torch.abs(acc_bin - conf_bin)
53 | return loss
54 |
55 |
56 | class LabelRelaxationLoss(nn.Module):
57 | """Label Relaxation Loss from :cite:`lienenFromLabel2021`.
58 |
59 | This loss is used to improve the calibration of a neural network. It works by minimizing
60 | the Kullback-Leibler divergence between the predicted probabilities and the target distribution in the credal set
61 | defined by the alpha parameter. The target distribution is the distribution in the credal set that minimizes the
62 | Kullback-Leibler divergence from the predicted probabilities. If the predicted probability distribution
63 | is in the credal set, the loss is zero.
64 |
65 | Attributes:
66 | alpha: float, the parameter that controls the amount of label relaxation. Increasing alpha, increases the size
67 | of the credal set and thus the amount of label relaxation.
68 | """
69 |
70 | def __init__(self, alpha: float = 0.1) -> None:
71 | """Initializes an instance of the LabelRelaxationLoss class.
72 |
73 | Args:
74 | alpha: float, the parameter that controls the amount of label relaxation.
75 | Increasing alpha, increases the size of the credal set and thus the amount of label relaxation.
76 | """
77 | super().__init__()
78 | self.alpha = alpha
79 |
80 | def forward(self, inputs: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
81 | """Forward pass of the label relaxation loss.
82 |
83 | Args:
84 | inputs: torch.Tensor of size (n_instances, n_classes)
85 | targets: torch.Tensor of size (n_instances,)
86 |
87 | Returns:
88 | loss: torch.Tensor, mean loss value
89 | """
90 | inputs_probs = F.softmax(inputs, dim=1)
91 |
92 | with torch.no_grad():
93 | inv_one_hot = 1 - F.one_hot(targets, inputs.shape[1])
94 | targets_real = self.alpha * inputs_probs / torch.sum(inv_one_hot * inputs_probs, dim=1, keepdim=True)
95 | targets_real[torch.arange(targets.shape[0]), targets] = 1 - self.alpha
96 |
97 | kl_div = torch.sum(F.kl_div(inputs_probs.log(), targets_real, log_target=False, reduction="none"), dim=1)
98 | loss = torch.where(torch.sum(inv_one_hot * inputs_probs, dim=1) <= self.alpha, 0, kl_div)
99 | return loss.mean()
100 |
101 |
102 | class FocalLoss(nn.Module):
103 | """Focal Loss based on :cite:`linFocalLoss2017`.
104 |
105 | Attributes:
106 | alpha: float, control importance of minority class
107 | gamma: float, control loss for hard instances
108 | """
109 |
110 | def __init__(self, alpha: float = 1, gamma: float = 2) -> None:
111 | """Initializes an instance of the FocalLoss class.
112 |
113 | Args:
114 | alpha: float, control importance of minority class
115 | gamma: float, control loss for hard instances
116 | """
117 | super().__init__()
118 | self.alpha = alpha
119 | self.gamma = gamma
120 |
121 | def forward(self, inputs: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
122 | """Forward pass of the focal loss.
123 |
124 | Args:
125 | inputs: torch.Tensor of size (n_instances, n_classes)
126 | targets: torch.Tensor of size (n_instances,)
127 |
128 | Returns:
129 | loss: torch.Tensor, mean loss value
130 |
131 | """
132 | targets_one_hot = F.one_hot(targets, num_classes=inputs.shape[-1])
133 | prob = F.softmax(inputs, dim=-1)
134 | p_t = torch.sum(prob * targets_one_hot, dim=-1)
135 |
136 | log_prob = torch.log(prob)
137 | loss = -self.alpha * (1 - p_t) ** self.gamma * torch.sum(log_prob * targets_one_hot, dim=-1)
138 |
139 | return torch.mean(loss)
140 |
--------------------------------------------------------------------------------
/notebooks/examples/lazy_dispatch_test.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": 1,
6 | "id": "af9ba3c1",
7 | "metadata": {},
8 | "outputs": [],
9 | "source": [
10 | "%load_ext autoreload\n",
11 | "%autoreload 2"
12 | ]
13 | },
14 | {
15 | "cell_type": "code",
16 | "execution_count": 2,
17 | "id": "5a090940",
18 | "metadata": {},
19 | "outputs": [],
20 | "source": [
21 | "class A:\n",
22 | " class A:\n",
23 | " \"\"\"A test class.\"\"\"\n",
24 | "\n",
25 | "\n",
26 | "class B:\n",
27 | " class B:\n",
28 | " \"\"\"A test class.\"\"\"\n",
29 | "\n",
30 | "\n",
31 | "class C(A):\n",
32 | " class C(A.A):\n",
33 | " \"\"\"A test class.\"\"\"\n",
34 | "\n",
35 | "\n",
36 | "class D(A, B):\n",
37 | " pass\n",
38 | "\n",
39 | "\n",
40 | "class E(C, D):\n",
41 | " pass"
42 | ]
43 | },
44 | {
45 | "cell_type": "code",
46 | "execution_count": null,
47 | "id": "458266bd",
48 | "metadata": {},
49 | "outputs": [
50 | {
51 | "data": {
52 | "text/plain": [
53 | "torch.Tensor"
54 | ]
55 | },
56 | "execution_count": 24,
57 | "metadata": {},
58 | "output_type": "execute_result"
59 | }
60 | ],
61 | "source": [
62 | "import torch\n",
63 | "\n",
64 | "torch.nn.Module"
65 | ]
66 | },
67 | {
68 | "cell_type": "code",
69 | "execution_count": 29,
70 | "id": "99f8fb2c",
71 | "metadata": {},
72 | "outputs": [
73 | {
74 | "data": {
75 | "text/plain": [
76 | "True"
77 | ]
78 | },
79 | "execution_count": 29,
80 | "metadata": {},
81 | "output_type": "execute_result"
82 | }
83 | ],
84 | "source": [
85 | "from lazy_dispatch.isinstance import lazy_isinstance\n",
86 | "\n",
87 | "t = torch.tensor([1, 2, 3])\n",
88 | "lazy_isinstance(t, \"torch.Tensor\")"
89 | ]
90 | },
91 | {
92 | "cell_type": "code",
93 | "execution_count": null,
94 | "id": "149b8309",
95 | "metadata": {},
96 | "outputs": [],
97 | "source": [
98 | "from typing import TYPE_CHECKING\n",
99 | "\n",
100 | "import lazy_dispatch as ld\n",
101 | "\n",
102 | "if TYPE_CHECKING:\n",
103 | " import keras # type: ignore # noqa: PGH003\n",
104 | " import torch\n",
105 | "\n",
106 | "\n",
107 | "@ld.lazy_singledispatch\n",
108 | "def f(_: object) -> str:\n",
109 | " return \"unknown\"\n",
110 | "\n",
111 | "\n",
112 | "@f.register\n",
113 | "def _(_: int) -> str:\n",
114 | " return \"int\"\n",
115 | "\n",
116 | "\n",
117 | "@f.register\n",
118 | "def _(_: \"torch.nn.modules.module.Module\") -> str:\n",
119 | " return \"torch module\"\n",
120 | "\n",
121 | "\n",
122 | "@f.register\n",
123 | "def _(_: \"torch.nn.modules.linear.Linear\") -> str:\n",
124 | " return \"torch linear\"\n",
125 | "\n",
126 | "\n",
127 | "@f.register\n",
128 | "def _(_: \"keras.engine.training.Model\") -> str:\n",
129 | " return \"keras model\"\n",
130 | "\n",
131 | "\n",
132 | "@f.register\n",
133 | "def _(_: A) -> str:\n",
134 | " return \"A\"\n",
135 | "\n",
136 | "\n",
137 | "@f.register\n",
138 | "def _(_: \"__main__.B\") -> str: # type: ignore # noqa: F821, PGH003\n",
139 | " return \"B\"\n",
140 | "\n",
141 | "\n",
142 | "@f.register\n",
143 | "def _(_: E) -> str:\n",
144 | " return \"E\""
145 | ]
146 | },
147 | {
148 | "cell_type": "code",
149 | "execution_count": 4,
150 | "id": "d6dfb9cd",
151 | "metadata": {},
152 | "outputs": [
153 | {
154 | "data": {
155 | "text/plain": [
156 | "{'_': object, 'return': str}"
157 | ]
158 | },
159 | "execution_count": 4,
160 | "metadata": {},
161 | "output_type": "execute_result"
162 | }
163 | ],
164 | "source": [
165 | "f.__annotations__"
166 | ]
167 | },
168 | {
169 | "cell_type": "code",
170 | "execution_count": null,
171 | "id": "7af81c92",
172 | "metadata": {},
173 | "outputs": [
174 | {
175 | "name": "stdout",
176 | "output_type": "stream",
177 | "text": [
178 | "42 int\n",
179 | "1.5 unknown\n",
180 | "linear: torch linear\n",
181 | "seq: torch module\n",
182 | "A: A\n",
183 | "B: B\n",
184 | "C: A\n",
185 | "D: A\n",
186 | "E: E\n"
187 | ]
188 | }
189 | ],
190 | "source": [
191 | "import torch.nn\n",
192 | "\n",
193 | "layer = torch.nn.Linear(10, 10)\n",
194 | "seq = torch.nn.Sequential(layer, layer)\n",
195 | "\n",
196 | "print(42, f(42))\n",
197 | "print(1.5, f(1.5))\n",
198 | "print(\"linear:\", f(layer))\n",
199 | "print(\"seq:\", f(seq))\n",
200 | "print(\"A:\", f(A()))\n",
201 | "print(\"B:\", f(B()))\n",
202 | "print(\"C:\", f(C()))\n",
203 | "print(\"D:\", f(D()))\n",
204 | "print(\"E:\", f(E()))"
205 | ]
206 | },
207 | {
208 | "cell_type": "code",
209 | "execution_count": null,
210 | "id": "f35cfff2",
211 | "metadata": {},
212 | "outputs": [
213 | {
214 | "data": {
215 | "text/plain": [
216 | "mappingproxy({'keras.engine.training.Model': str>})"
217 | ]
218 | },
219 | "execution_count": 20,
220 | "metadata": {},
221 | "output_type": "execute_result"
222 | }
223 | ],
224 | "source": [
225 | "f.string_registry"
226 | ]
227 | }
228 | ],
229 | "metadata": {
230 | "kernelspec": {
231 | "display_name": "uncertainty-toolkit",
232 | "language": "python",
233 | "name": "python3"
234 | },
235 | "language_info": {
236 | "codemirror_mode": {
237 | "name": "ipython",
238 | "version": 3
239 | },
240 | "file_extension": ".py",
241 | "mimetype": "text/x-python",
242 | "name": "python",
243 | "nbconvert_exporter": "python",
244 | "pygments_lexer": "ipython3",
245 | "version": "3.13.1"
246 | }
247 | },
248 | "nbformat": 4,
249 | "nbformat_minor": 5
250 | }
251 |
--------------------------------------------------------------------------------
/docs/source/conf.py:
--------------------------------------------------------------------------------
1 | """Configuration file for the Sphinx documentation builder."""
2 |
3 | # Configuration file for the Sphinx documentation builder.
4 | #
5 | # For the full list of built-in configuration values, see the documentation:
6 | # https://www.sphinx-doc.org/en/master/usage/configuration.html
7 | from __future__ import annotations
8 |
9 | import importlib
10 | import inspect
11 | import os
12 | import sys
13 |
14 | import probly
15 |
16 | # -- Path setup --------------------------------------------------------------
17 | sys.path.insert(0, os.path.abspath("../../src"))
18 | sys.path.insert(0, os.path.abspath("../../examples"))
19 |
20 | # -- Project information -----------------------------------------------------
21 | # https://www.sphinx-doc.org/en/master/usage/configuration.html#project-information
22 |
23 | project = "probly"
24 | copyright = "2025, probly team" # noqa: A001
25 | author = "probly team"
26 | release = probly.__version__
27 | version = probly.__version__
28 |
29 | # -- General configuration ---------------------------------------------------
30 | # https://www.sphinx-doc.org/en/master/usage/configuration.html#general-configuration
31 |
32 | extensions = [
33 | "sphinx.ext.autodoc", # generates API documentation from docstrings
34 | "sphinx.ext.autosummary", # generates .rst files for each module
35 | "sphinx_autodoc_typehints", # optional, nice for type hints in docs
36 | # "sphinx.ext.linkcode", # adds [source] links to code that link to GitHub. Use when repo is public. # noqa: E501, ERA001
37 | "sphinx.ext.viewcode", # adds [source] links to code that link to the source code in the docs.
38 | "sphinx.ext.napoleon", # for Google-style docstrings
39 | "sphinx.ext.duration", # optional, show the duration of the build
40 | "myst_nb", # for jupyter notebook support, also includes myst_parser
41 | "sphinx.ext.intersphinx", # for linking to other projects' docs
42 | "sphinx.ext.mathjax", # for math support
43 | "sphinx.ext.doctest", # for testing code snippets in the docs
44 | "sphinx_copybutton", # adds a copy button to code blocks
45 | "sphinx.ext.autosectionlabel", # for auto-generating section labels,
46 | "sphinxcontrib.bibtex", # for bibliography support
47 | ]
48 |
49 | templates_path = ["_templates"]
50 | exclude_patterns = ["_build", "Thumbs.db", ".DS_Store"]
51 | bibtex_bibfiles = ["references.bib"]
52 | bibtex_default_style = "alpha"
53 | nb_execution_mode = "off" # don't run notebooks when building the docs
54 |
55 | intersphinx_mapping = {
56 | "python3": ("https://docs.python.org/3", None),
57 | "numpy": ("https://numpy.org/doc/stable/", None),
58 | "scipy": ("https://docs.scipy.org/doc/scipy", None),
59 | "matplotlib": ("https://matplotlib.org/stable", None),
60 | "PIL": ("https://pillow.readthedocs.io/en/stable/", None),
61 | "torch": ("https://pytorch.org/docs/stable/", None),
62 | }
63 |
64 |
65 | def linkcode_resolve(domain: str, info: dict[str, str]) -> str | None:
66 | """Resolve the link to the source code in GitHub.
67 |
68 | This function is required by sphinx.ext.linkcode and is used to generate links to the source code on GitHub.
69 |
70 | Args:
71 | domain (str): The domain of the object.
72 | info (dict[str, str]): The information about the object.
73 |
74 | Returns:
75 | str | None: The URL to the source code or None if not found.
76 | """
77 | if domain != "py" or not info["module"]:
78 | return None
79 |
80 | try:
81 | module = importlib.import_module(info["module"])
82 | obj = module
83 | for part in info["fullname"].split("."):
84 | obj = getattr(obj, part)
85 | fn = inspect.getsourcefile(obj)
86 | source, lineno = inspect.getsourcelines(obj)
87 | root = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", ".."))
88 | relpath = os.path.relpath(fn, start=root)
89 | except (ModuleNotFoundError, AttributeError, TypeError, OSError):
90 | return None
91 |
92 | base = "https://github.com/pwhofman/probly"
93 | tag = "v0.2.0-pre-alpha" if version == "0.2.0" else f"v{version}"
94 |
95 | return f"{base}/blob/{tag}/{relpath}#L{lineno}"
96 |
97 |
98 | # -- Options for HTML output -------------------------------------------------
99 | # https://www.sphinx-doc.org/en/master/usage/configuration.html#options-for-html-output
100 | html_theme = "furo"
101 | html_static_path = ["_static"]
102 | html_css_files = [
103 | "css/custom.css",
104 | ]
105 | # TODO(pwhofman): add favicon Issue: https://github.com/pwhofman/probly/issues/95
106 | # html_favicon = "_static/logo/" # noqa: ERA001
107 | pygments_dark_style = "monokai"
108 | html_theme_options = {
109 | "sidebar_hide_name": True,
110 | "light_logo": "logo/logo_light.png",
111 | "dark_logo": "logo/logo_dark.png",
112 | }
113 |
114 | html_sidebars = {
115 | "**": [
116 | "sidebar/scroll-start.html",
117 | "sidebar/brand.html",
118 | "sidebar/search.html",
119 | "sidebar/navigation.html",
120 | "sidebar/ethical-ads.html",
121 | "sidebar/scroll-end.html",
122 | "sidebar/footer.html", # to get the github link in the footer of the sidebar
123 | ],
124 | }
125 |
126 | html_show_sourcelink = False # to remove button next to dark mode showing source in txt format
127 |
128 | # -- Autodoc ---------------------------------------------------------------------------------------
129 | autosummary_generate = True
130 | autodoc_default_options = {
131 | "show-inheritance": True,
132 | "members": True,
133 | "member-order": "groupwise",
134 | "special-members": "__call__",
135 | "undoc-members": True,
136 | "exclude-members": "__weakref__",
137 | }
138 | autoclass_content = "class"
139 | # TODO(pwhofman): maybe set this to True, Issue https://github.com/pwhofman/probly/issues/94
140 | autodoc_inherit_docstrings = False
141 |
142 | autodoc_typehints = "both" # to show type hints in the docstring
143 |
144 | # -- Copy Paste Button -----------------------------------------------------------------------------
145 | # Ignore >>> when copying code
146 | copybutton_prompt_text = r">>> |\.\.\. "
147 | copybutton_prompt_is_regexp = True
148 |
--------------------------------------------------------------------------------
/.github/workflows/ci.yml:
--------------------------------------------------------------------------------
1 | # A GitHub Actions workflow file to run on pull request events.
2 | name: Pull Request Checks
3 | on:
4 | pull_request:
5 | push:
6 | branches: [main]
7 | workflow_dispatch:
8 |
9 | jobs:
10 | # ----------------------------------------------------------------------------------------------
11 | # Code Quality Checks and Linting
12 | # ----------------------------------------------------------------------------------------------
13 | check_code_quality:
14 | name: Code Quality Checks
15 | runs-on: ubuntu-latest
16 | steps:
17 | - uses: actions/checkout@v5
18 | - name: Set up Python and uv
19 | uses: astral-sh/setup-uv@v7
20 | with:
21 | python-version: "3.12"
22 | enable-cache: true
23 | - name: Install pre-commit
24 | run: |
25 | uv sync --only-group lint
26 | uv run --no-sync pre-commit install
27 | - name: Run code-quality checks
28 | run: uv run --no-sync pre-commit run --all-files --show-diff-on-failure
29 | # ----------------------------------------------------------------------------------------------
30 | # Install and Import Check
31 | # ----------------------------------------------------------------------------------------------
32 | install_and_import:
33 | name: Install and Import Check
34 | runs-on: ubuntu-latest
35 | steps:
36 | - uses: actions/checkout@v5
37 | - name: Set up Python and uv
38 | uses: astral-sh/setup-uv@v7
39 | with:
40 | python-version: "3.12"
41 | enable-cache: true
42 | - name: Create uv virtual environment
43 | run: uv venv
44 | - name: Install probly package
45 | run: uv run --no-sync uv pip install .
46 | - name: Test import
47 | run: uv run --no-sync python -c "import probly; print('✅ probly imported successfully')"
48 | # ----------------------------------------------------------------------------------------------
49 | # Unit Tests with Matrix
50 | # ----------------------------------------------------------------------------------------------
51 | run_unit_tests:
52 | name: Run Unit Tests
53 | strategy:
54 | fail-fast: false
55 | matrix:
56 | include:
57 | - os: ubuntu-latest
58 | python-version: "3.12"
59 | dependency-group: "all_ml"
60 | - os: ubuntu-latest
61 | python-version: "3.12"
62 | dependency-group: "torch"
63 | - os: windows-latest
64 | python-version: "3.12"
65 | dependency-group: "all_ml"
66 | - os: macos-latest
67 | python-version: "3.12"
68 | dependency-group: "all_ml"
69 | runs-on: ${{ matrix.os }}
70 | steps:
71 | - uses: actions/checkout@v5
72 | - name: Set up Python and uv
73 | uses: astral-sh/setup-uv@v7
74 | with:
75 | python-version: ${{ matrix.python-version }}
76 | enable-cache: true
77 | - name: Install test dependencies
78 | run: uv sync --no-dev --group test --group ${{ matrix.dependency-group }}
79 | - name: Run unit tests
80 | run: uv run --no-sync pytest "tests/probly" -m "not integration"
81 | # ----------------------------------------------------------------------------------------------
82 | # Test Documentation Build
83 | # ----------------------------------------------------------------------------------------------
84 | doc_build:
85 | name: Test Documentation Build
86 | runs-on: ubuntu-latest
87 | steps:
88 | - uses: actions/checkout@v5
89 | - name: Set up Python and uv
90 | uses: astral-sh/setup-uv@v7
91 | with:
92 | python-version: "3.12"
93 | enable-cache: true
94 | - name: Install pandoc
95 | run: |
96 | sudo apt-get update
97 | sudo apt-get install -y pandoc
98 | - name: Install dependencies
99 | run: uv sync
100 | # TODO(mmshlk) Turn this step to fail on warnings in the future. https://github.com/pwhofman/probly/issues/131
101 | - name: Build documentation
102 | run: |
103 | uv run sphinx-build -b html docs/source docs/build/html
104 |
105 | # ----------------------------------------------------------------------------------------------
106 | # Code Coverage
107 | # ----------------------------------------------------------------------------------------------
108 | run_coverage:
109 | name: Run Test Coverage
110 | runs-on: ubuntu-latest
111 | needs:
112 | - run_unit_tests
113 | - check_code_quality
114 | - install_and_import
115 | steps:
116 | - uses: actions/checkout@v5
117 | - name: Set up Python and uv
118 | uses: astral-sh/setup-uv@v7
119 | with:
120 | python-version: "3.12"
121 | enable-cache: true
122 | - name: Install test dependencies
123 | run: uv sync --no-dev --group test --group all_ml
124 | - name: Measure coverage
125 | run: uv run --no-sync pytest "tests/probly" --cov=probly --cov-report=xml
126 | - name: Upload coverage to codecov
127 | if: ${{ github.repository == 'pwhofman/probly' }} # Only upload reports for main repo and not forks
128 | uses: codecov/codecov-action@v5
129 | with:
130 | token: ${{ secrets.CODECOV_TOKEN }}
131 | fail_ci_if_error: true
132 |
--------------------------------------------------------------------------------
/src/pytraverse/generic.py:
--------------------------------------------------------------------------------
1 | """Traverser for standard Python datatypes (tuples, lists, dicts, sets).
2 |
3 | This module provides a generic traverser that can handle common Python data structures
4 | using single dispatch. It includes configurable behavior for cloning and traversing
5 | dictionary keys.
6 | """
7 |
8 | from __future__ import annotations
9 |
10 | from typing import TYPE_CHECKING, Any
11 |
12 | from pytraverse.composition import SingledispatchTraverser
13 | from pytraverse.core import (
14 | StackVariable,
15 | State,
16 | TraverserCallback,
17 | TraverserResult,
18 | )
19 |
20 | if TYPE_CHECKING:
21 | from collections.abc import Iterable
22 |
23 | CLONE = StackVariable[bool](
24 | "CLONE",
25 | "Whether to clone datastructures before making changes.",
26 | default=True,
27 | )
28 | TRAVERSE_KEYS = StackVariable[bool](
29 | "TRAVERSE_KEYS",
30 | "Whether to traverse the keys of dictionaries.",
31 | default=False,
32 | )
33 | TRAVERSE_REVERSED = StackVariable[bool](
34 | "TRAVERSE_REVERSED",
35 | "Whether to traverse elements in reverse order.",
36 | default=False,
37 | )
38 |
39 |
40 | generic_traverser = SingledispatchTraverser[object](name="generic_traverser")
41 |
42 |
43 | @generic_traverser.register
44 | def _tuple_traverser(
45 | obj: tuple,
46 | state: State[tuple[Any]],
47 | traverse: TraverserCallback[Any],
48 | ) -> TraverserResult[tuple[Any]]:
49 | """Traverse tuple elements and reconstruct the tuple.
50 |
51 | Always creates a new tuple since tuples are immutable.
52 |
53 | Args:
54 | obj: The tuple to traverse.
55 | state: Current traversal state.
56 | traverse: Callback for traversing child elements.
57 |
58 | Returns:
59 | A new tuple with traversed elements and updated state.
60 | """
61 | new_obj = []
62 | items: Iterable[tuple[int, Any]] = enumerate(obj)
63 | traverse_reversed = state[TRAVERSE_REVERSED]
64 | if traverse_reversed:
65 | items = reversed(list(items))
66 | for i, element in items:
67 | new_element, state = traverse(element, state, i)
68 | new_obj.append(new_element)
69 | if traverse_reversed:
70 | return tuple(reversed(new_obj)), state
71 | return tuple(new_obj), state
72 |
73 |
74 | @generic_traverser.register
75 | def _list_traverser(
76 | obj: list,
77 | state: State[list[Any]],
78 | traverse: TraverserCallback[Any],
79 | ) -> TraverserResult[list[Any]]:
80 | """Traverse list elements, optionally cloning the list.
81 |
82 | Behavior depends on the CLONE variable:
83 | - If True: Creates a new list with traversed elements
84 | - If False: Modifies the original list in-place
85 |
86 | Args:
87 | obj: The list to traverse.
88 | state: Current traversal state.
89 | traverse: Callback for traversing child elements.
90 |
91 | Returns:
92 | The modified or new list and updated state.
93 | """
94 | items: Iterable[tuple[int, Any]] = enumerate(obj)
95 | traverse_reversed = state[TRAVERSE_REVERSED]
96 | if traverse_reversed:
97 | items = reversed(list(items))
98 |
99 | if state[CLONE]:
100 | new_obj = obj.__class__()
101 | for i, element in items:
102 | new_element, state = traverse(element, state, i)
103 | new_obj.append(new_element)
104 | if traverse_reversed:
105 | new_obj.reverse()
106 | return new_obj, state
107 |
108 | for i, element in items:
109 | new_element, state = traverse(element, state, i)
110 | obj[i] = new_element
111 | return obj, state
112 |
113 |
114 | @generic_traverser.register
115 | def _dict_traverser(
116 | obj: dict,
117 | state: State[dict[Any, Any]],
118 | traverse: TraverserCallback[Any],
119 | ) -> TraverserResult[dict[Any, Any]]:
120 | """Traverse dictionary values and optionally keys.
121 |
122 | Behavior depends on CLONE and TRAVERSE_KEYS variables:
123 | - CLONE=True or TRAVERSE_KEYS=True: Creates a new dictionary
124 | - CLONE=False and TRAVERSE_KEYS=False: Modifies original dictionary
125 | - TRAVERSE_KEYS=True: Also traverses dictionary keys
126 |
127 | Args:
128 | obj: The dictionary to traverse.
129 | state: Current traversal state.
130 | traverse: Callback for traversing child elements.
131 |
132 | Returns:
133 | The modified or new dictionary and updated state.
134 | """
135 | traverse_keys = state[TRAVERSE_KEYS]
136 | items: Iterable[tuple[Any, Any]] = obj.items()
137 | traverse_reversed = state[TRAVERSE_REVERSED]
138 | if traverse_reversed:
139 | items = reversed(list(items))
140 |
141 | if state[CLONE] or traverse_keys:
142 | new_obj = obj.__class__()
143 | if traverse_reversed:
144 | additions = []
145 | for key, value in items:
146 | if traverse_keys:
147 | new_key, state = traverse(key, state)
148 | else:
149 | new_key = key
150 | new_value, state = traverse(value, state, new_key)
151 | if traverse_reversed:
152 | additions.append((new_key, new_value))
153 | else:
154 | new_obj[new_key] = new_value
155 | if traverse_reversed:
156 | new_obj.update(reversed(additions))
157 | return new_obj, state
158 |
159 | for key, value in items:
160 | new_value, state = traverse(value, state, key)
161 | obj[key] = new_value
162 |
163 | return obj, state
164 |
165 |
166 | @generic_traverser.register
167 | def _set_traverser(
168 | obj: set,
169 | state: State[set[Any]],
170 | traverse: TraverserCallback[Any],
171 | ) -> TraverserResult[set[Any]]:
172 | """Traverse set elements and reconstruct the set.
173 |
174 | Always creates a new set since sets are unordered and elements
175 | may change during traversal.
176 |
177 | Args:
178 | obj: The set to traverse.
179 | state: Current traversal state.
180 | traverse: Callback for traversing child elements.
181 |
182 | Returns:
183 | A new set with traversed elements and updated state.
184 | """
185 | new_obj = obj.__class__()
186 | for element in obj:
187 | new_element, state = traverse(element, state)
188 | new_obj.add(new_element)
189 | return new_obj, state
190 |
--------------------------------------------------------------------------------