├── tests ├── __init__.py ├── probly │ ├── __init__.py │ ├── plot │ │ ├── __init__.py │ │ └── test_credal.py │ ├── train │ │ ├── __init__.py │ │ ├── bayesian │ │ │ ├── __init__.py │ │ │ └── test_torch.py │ │ ├── calibration │ │ │ ├── __init__.py │ │ │ └── test_torch.py │ │ └── evidential │ │ │ ├── __init__.py │ │ │ └── test_torch.py │ ├── utils │ │ ├── __init__.py │ │ ├── test_errors.py │ │ ├── test_sets.py │ │ ├── test_torch.py │ │ └── test_probabilities.py │ ├── datasets │ │ ├── __init__.py │ │ └── test_torch.py │ ├── evaluation │ │ ├── __init__.py │ │ ├── test_tasks.py │ │ └── test_metrics.py │ ├── traverse_nn │ │ └── __init__.py │ ├── quantification │ │ ├── __init__.py │ │ ├── test_regression.py │ │ └── test_classification.py │ ├── representation │ │ └── __init__.py │ ├── transformation │ │ ├── __init__.py │ │ ├── dropout │ │ │ ├── __init__.py │ │ │ └── test_common.py │ │ ├── bayesian │ │ │ ├── __init__.py │ │ │ ├── test_common.py │ │ │ └── test_torch.py │ │ ├── dropconnect │ │ │ ├── __init__.py │ │ │ └── test_common.py │ │ ├── evidential │ │ │ ├── classification │ │ │ │ ├── __init__.py │ │ │ │ ├── test_common.py │ │ │ │ └── test_torch.py │ │ │ └── regression │ │ │ │ ├── test_common.py │ │ │ │ └── test_torch.py │ │ └── ensemble │ │ │ └── test_common.py │ ├── fixtures │ │ ├── common.py │ │ ├── torch_data.py │ │ ├── flax_models.py │ │ └── torch_models.py │ ├── general_utils.py │ ├── tests_deprecations │ │ ├── __init__.py │ │ ├── deprecated_features.py │ │ ├── utils.py │ │ └── test_deprecations.py │ ├── flax_utils.py │ ├── torch_utils.py │ └── markers.py ├── pytraverse │ └── __init__.py ├── lazy_dispatch │ ├── __init__.py │ └── test_isinstance.py └── conftest.py ├── src ├── probly │ ├── train │ │ ├── __init__.py │ │ ├── bayesian │ │ │ ├── __init__.py │ │ │ └── torch.py │ │ ├── evidential │ │ │ └── __init__.py │ │ └── calibration │ │ │ ├── __init__.py │ │ │ └── torch.py │ ├── layers │ │ ├── __init__.py │ │ └── flax.py │ ├── plot │ │ ├── __init__.py │ │ └── credal.py │ ├── datasets │ │ └── __init__.py │ ├── evaluation │ │ ├── __init__.py │ │ └── tasks.py │ ├── quantification │ │ ├── __init__.py │ │ └── regression.py │ ├── representation │ │ ├── distribution │ │ │ └── __init__.py │ │ ├── credal_set │ │ │ ├── __init__.py │ │ │ ├── jax.py │ │ │ ├── torch.py │ │ │ └── credal_set.py │ │ ├── __init__.py │ │ ├── sampling │ │ │ ├── __init__.py │ │ │ ├── flax_sampler.py │ │ │ ├── jax_sample.py │ │ │ ├── torch_sample.py │ │ │ ├── torch_sampler.py │ │ │ ├── sample.py │ │ │ └── sampler.py │ │ └── representer.py │ ├── lazy_types.py │ ├── transformation │ │ ├── evidential │ │ │ ├── __init__.py │ │ │ ├── classification │ │ │ │ ├── torch.py │ │ │ │ ├── __init__.py │ │ │ │ └── common.py │ │ │ └── regression │ │ │ │ ├── __init__.py │ │ │ │ ├── torch.py │ │ │ │ └── common.py │ │ ├── dropout │ │ │ ├── torch.py │ │ │ ├── flax.py │ │ │ ├── __init__.py │ │ │ └── common.py │ │ ├── bayesian │ │ │ ├── __init__.py │ │ │ ├── torch.py │ │ │ └── common.py │ │ ├── dropconnect │ │ │ ├── torch.py │ │ │ ├── flax.py │ │ │ ├── __init__.py │ │ │ └── common.py │ │ ├── __init__.py │ │ └── ensemble │ │ │ ├── __init__.py │ │ │ ├── torch.py │ │ │ ├── common.py │ │ │ └── flax.py │ ├── __init__.py │ ├── utils │ │ ├── __init__.py │ │ ├── errors.py │ │ ├── sets.py │ │ ├── torch.py │ │ └── probabilities.py │ ├── traverse_nn │ │ ├── __init__.py │ │ ├── common.py │ │ ├── torch.py │ │ └── flax.py │ └── predictor.py ├── lazy_dispatch │ ├── __init__.py │ ├── load.py │ └── isinstance.py └── pytraverse │ ├── __init__.py │ └── generic.py ├── docs ├── source │ ├── references.rst │ ├── _static │ │ ├── logo │ │ │ ├── logo_dark.png │ │ │ └── logo_light.png │ │ └── github-mark.svg │ ├── _templates │ │ └── sidebar │ │ │ └── footer.html │ ├── api.rst │ ├── methods.md │ ├── index.rst │ └── conf.py ├── Makefile └── make.bat ├── .editorconfig ├── CHANGELOG.md ├── .pre-commit-config.yaml ├── .github ├── PULL_REQUEST_TEMPLATE.md ├── workflows │ ├── deploy-docs.yml │ ├── publish-release.yml │ └── ci.yml └── CONTRIBUTING.md ├── LICENSE ├── README.md └── notebooks └── examples └── lazy_dispatch_test.ipynb /tests/__init__.py: -------------------------------------------------------------------------------- 1 | """Tests for the packages.""" 2 | -------------------------------------------------------------------------------- /src/probly/train/__init__.py: -------------------------------------------------------------------------------- 1 | """Train module for probly.""" 2 | -------------------------------------------------------------------------------- /tests/probly/__init__.py: -------------------------------------------------------------------------------- 1 | """Tests for the probly package.""" 2 | -------------------------------------------------------------------------------- /tests/probly/plot/__init__.py: -------------------------------------------------------------------------------- 1 | """Tests for the plot module.""" 2 | -------------------------------------------------------------------------------- /tests/probly/train/__init__.py: -------------------------------------------------------------------------------- 1 | """Tests for the train module.""" 2 | -------------------------------------------------------------------------------- /tests/probly/utils/__init__.py: -------------------------------------------------------------------------------- 1 | """Tests for the utils module.""" 2 | -------------------------------------------------------------------------------- /tests/probly/utils/test_errors.py: -------------------------------------------------------------------------------- 1 | """Tests for errors functions.""" 2 | -------------------------------------------------------------------------------- /tests/pytraverse/__init__.py: -------------------------------------------------------------------------------- 1 | """Tests for the pytraverse package.""" 2 | -------------------------------------------------------------------------------- /src/probly/layers/__init__.py: -------------------------------------------------------------------------------- 1 | """Init module for layer implementations.""" 2 | -------------------------------------------------------------------------------- /src/probly/plot/__init__.py: -------------------------------------------------------------------------------- 1 | """Init module for plot implementations.""" 2 | -------------------------------------------------------------------------------- /tests/lazy_dispatch/__init__.py: -------------------------------------------------------------------------------- 1 | """Tests for the lazy_dispatch package.""" 2 | -------------------------------------------------------------------------------- /tests/probly/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | """Tests for the datasets module.""" 2 | -------------------------------------------------------------------------------- /tests/probly/evaluation/__init__.py: -------------------------------------------------------------------------------- 1 | """Tests for the evaluation module.""" 2 | -------------------------------------------------------------------------------- /src/probly/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | """Init module for dataset implementations.""" 2 | -------------------------------------------------------------------------------- /tests/probly/traverse_nn/__init__.py: -------------------------------------------------------------------------------- 1 | """Tests for the traverse_nn module.""" 2 | -------------------------------------------------------------------------------- /src/probly/evaluation/__init__.py: -------------------------------------------------------------------------------- 1 | """Init module for evaluation implementations.""" 2 | -------------------------------------------------------------------------------- /src/probly/train/bayesian/__init__.py: -------------------------------------------------------------------------------- 1 | """Train functionality for Bayesian models.""" 2 | -------------------------------------------------------------------------------- /tests/probly/quantification/__init__.py: -------------------------------------------------------------------------------- 1 | """Tests for the quantification module.""" 2 | -------------------------------------------------------------------------------- /tests/probly/representation/__init__.py: -------------------------------------------------------------------------------- 1 | """Tests for the representation module.""" 2 | -------------------------------------------------------------------------------- /tests/probly/train/bayesian/__init__.py: -------------------------------------------------------------------------------- 1 | """Tests for the Bayesian train module.""" 2 | -------------------------------------------------------------------------------- /tests/probly/transformation/__init__.py: -------------------------------------------------------------------------------- 1 | """Tests for the transformation module.""" 2 | -------------------------------------------------------------------------------- /tests/probly/transformation/dropout/__init__.py: -------------------------------------------------------------------------------- 1 | """Tests for the dropout module.""" 2 | -------------------------------------------------------------------------------- /src/probly/train/evidential/__init__.py: -------------------------------------------------------------------------------- 1 | """Train functionality for evidential models.""" 2 | -------------------------------------------------------------------------------- /tests/probly/train/calibration/__init__.py: -------------------------------------------------------------------------------- 1 | """Tests for the calibration train module.""" 2 | -------------------------------------------------------------------------------- /tests/probly/train/evidential/__init__.py: -------------------------------------------------------------------------------- 1 | """Tests for the evidential train module.""" 2 | -------------------------------------------------------------------------------- /tests/probly/transformation/bayesian/__init__.py: -------------------------------------------------------------------------------- 1 | """Tests for the bayesian module.""" 2 | -------------------------------------------------------------------------------- /docs/source/references.rst: -------------------------------------------------------------------------------- 1 | References 2 | ========== 3 | 4 | .. bibliography:: 5 | :cited: 6 | -------------------------------------------------------------------------------- /src/probly/train/calibration/__init__.py: -------------------------------------------------------------------------------- 1 | """Train functionality for calibration methods.""" 2 | -------------------------------------------------------------------------------- /tests/probly/transformation/dropconnect/__init__.py: -------------------------------------------------------------------------------- 1 | """Tests for the dropconnect module.""" 2 | -------------------------------------------------------------------------------- /src/probly/quantification/__init__.py: -------------------------------------------------------------------------------- 1 | """Init module for quantification measure implementations.""" 2 | -------------------------------------------------------------------------------- /src/probly/representation/distribution/__init__.py: -------------------------------------------------------------------------------- 1 | """Second-order distribution representations.""" 2 | -------------------------------------------------------------------------------- /tests/probly/transformation/evidential/classification/__init__.py: -------------------------------------------------------------------------------- 1 | """Tests for the classification module.""" 2 | -------------------------------------------------------------------------------- /docs/source/_static/logo/logo_dark.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pwhofman/probly/HEAD/docs/source/_static/logo/logo_dark.png -------------------------------------------------------------------------------- /docs/source/_static/logo/logo_light.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pwhofman/probly/HEAD/docs/source/_static/logo/logo_light.png -------------------------------------------------------------------------------- /.editorconfig: -------------------------------------------------------------------------------- 1 | root = true 2 | 3 | [*] 4 | end_of_line = lf 5 | charset = utf-8 6 | insert_final_newline = true 7 | indent_style = space 8 | indent_size = 4 9 | trim_trailing_whitespace = true 10 | -------------------------------------------------------------------------------- /src/probly/representation/credal_set/__init__.py: -------------------------------------------------------------------------------- 1 | """Second-order distribution representations.""" 2 | 3 | from __future__ import annotations 4 | 5 | from .credal_set import CredalSet, credal_set_from_sample 6 | 7 | __all__ = [ 8 | "CredalSet", 9 | "credal_set_from_sample", 10 | ] 11 | -------------------------------------------------------------------------------- /tests/conftest.py: -------------------------------------------------------------------------------- 1 | """Test fixtures for probly.""" 2 | 3 | from __future__ import annotations 4 | 5 | pytest_plugins = [ 6 | "tests.probly.fixtures.common", 7 | "tests.probly.fixtures.torch_models", 8 | "tests.probly.fixtures.torch_data", 9 | "tests.probly.fixtures.flax_models", 10 | ] 11 | -------------------------------------------------------------------------------- /src/probly/lazy_types.py: -------------------------------------------------------------------------------- 1 | """Collection of fully-qualified type names for lazy type checking.""" 2 | 3 | from __future__ import annotations 4 | 5 | TORCH_MODULE = "torch.nn.modules.module.Module" 6 | TORCH_TENSOR = "torch.Tensor" 7 | 8 | FLAX_MODULE = "flax.nnx.module.Module" 9 | JAX_ARRAY = "jax.Array" 10 | -------------------------------------------------------------------------------- /CHANGELOG.md: -------------------------------------------------------------------------------- 1 | # Changelog 2 | This changelog is updated with every release of `probly`. 3 | 4 | ## Development 5 | 6 | - Added possiblity to create ensemble of torch models without resetting the weights of each model. 7 | 8 | ## 0.1.0 (2024-03-14) 9 | Initial pre-release of `probly` without functionalities. 10 | -------------------------------------------------------------------------------- /src/probly/transformation/evidential/__init__.py: -------------------------------------------------------------------------------- 1 | """Evidential module for probly.""" 2 | 3 | from probly.transformation.evidential.classification import evidential_classification 4 | from probly.transformation.evidential.regression import evidential_regression 5 | 6 | __all__ = ["evidential_classification", "evidential_regression"] 7 | -------------------------------------------------------------------------------- /tests/probly/utils/test_sets.py: -------------------------------------------------------------------------------- 1 | """Tests for utils.sets functions.""" 2 | 3 | from __future__ import annotations 4 | 5 | from probly.utils.sets import powerset 6 | 7 | 8 | def test_powerset() -> None: 9 | assert powerset([]) == [()] 10 | assert powerset([1]) == [(), (1,)] 11 | assert powerset([1, 2]) == [(), (1,), (2,), (1, 2)] 12 | -------------------------------------------------------------------------------- /docs/source/_templates/sidebar/footer.html: -------------------------------------------------------------------------------- 1 |
2 | 3 | GitHub Logo 4 | 5 |
6 | -------------------------------------------------------------------------------- /src/probly/__init__.py: -------------------------------------------------------------------------------- 1 | """probly: Uncertainty Representation and Quantification for Machine Learning.""" 2 | 3 | __version__ = "0.3.1" 4 | 5 | from probly import ( 6 | datasets as datasets, 7 | evaluation as evaluation, 8 | plot as plot, 9 | quantification as quantification, 10 | representation as representation, 11 | train as train, 12 | transformation as transformation, 13 | utils as utils, 14 | ) 15 | -------------------------------------------------------------------------------- /src/probly/transformation/dropout/torch.py: -------------------------------------------------------------------------------- 1 | """Torch dropout implementation.""" 2 | 3 | from __future__ import annotations 4 | 5 | from torch import nn 6 | 7 | from .common import register 8 | 9 | 10 | def prepend_torch_dropout(obj: nn.Module, p: float) -> nn.Sequential: 11 | """Prepend a Dropout layer before the given layer.""" 12 | return nn.Sequential(nn.Dropout(p=p), obj) 13 | 14 | 15 | register(nn.Linear, prepend_torch_dropout) 16 | -------------------------------------------------------------------------------- /tests/probly/fixtures/common.py: -------------------------------------------------------------------------------- 1 | """Common fixtures for tests.""" 2 | 3 | from __future__ import annotations 4 | 5 | import pytest 6 | 7 | from probly.predictor import Predictor 8 | 9 | 10 | @pytest.fixture 11 | def dummy_predictor() -> Predictor: 12 | """Return a dummy predictor.""" 13 | 14 | class DummyPredictor(Predictor): 15 | def __call__(self, x: float) -> float: 16 | return x 17 | 18 | return DummyPredictor() 19 | -------------------------------------------------------------------------------- /src/probly/transformation/bayesian/__init__.py: -------------------------------------------------------------------------------- 1 | """Bayesian implementation for uncertainty quantification.""" 2 | 3 | from __future__ import annotations 4 | 5 | from probly.lazy_types import TORCH_MODULE 6 | 7 | from . import common 8 | 9 | bayesian = common.bayesian 10 | register = common.register 11 | 12 | 13 | ## Torch 14 | @common.bayesian_traverser.delayed_register(TORCH_MODULE) 15 | def _(_: type) -> None: 16 | from . import torch as torch # noqa: PLC0415 17 | -------------------------------------------------------------------------------- /src/probly/utils/__init__.py: -------------------------------------------------------------------------------- 1 | """Utils module for probly library.""" 2 | 3 | from .probabilities import differential_entropy_gaussian, intersection_probability, kl_divergence_gaussian 4 | from .sets import capacity, moebius, powerset 5 | 6 | __all__ = [ 7 | "capacity", 8 | "differential_entropy_gaussian", 9 | "differential_entropy_gaussian", 10 | "intersection_probability", 11 | "kl_divergence_gaussian", 12 | "moebius", 13 | "powerset", 14 | ] 15 | -------------------------------------------------------------------------------- /src/probly/representation/__init__.py: -------------------------------------------------------------------------------- 1 | """Uncertainty representations for models.""" 2 | 3 | from probly.representation import sampling 4 | from probly.representation.credal_set import CredalSet, credal_set_from_sample 5 | from probly.representation.representer import Representer 6 | from probly.representation.sampling import Sample, Sampler 7 | 8 | __all__ = [ 9 | "CredalSet", 10 | "Representer", 11 | "Sample", 12 | "Sampler", 13 | "credal_set_from_sample", 14 | "sampling", 15 | ] 16 | -------------------------------------------------------------------------------- /src/probly/transformation/dropconnect/torch.py: -------------------------------------------------------------------------------- 1 | """Torch dropout implementation.""" 2 | 3 | from __future__ import annotations 4 | 5 | from torch import nn 6 | 7 | from probly.layers.torch import DropConnectLinear 8 | 9 | from .common import register 10 | 11 | 12 | def replace_torch_dropconnect(obj: nn.Linear, p: float) -> DropConnectLinear: 13 | """Replace a given layer by a DropConnectLinear layer.""" 14 | return DropConnectLinear(obj, p=p) 15 | 16 | 17 | register(nn.Linear, replace_torch_dropconnect) 18 | -------------------------------------------------------------------------------- /src/probly/transformation/dropconnect/flax.py: -------------------------------------------------------------------------------- 1 | """Flax dropconnect implementation.""" 2 | 3 | from __future__ import annotations 4 | 5 | from flax import nnx 6 | 7 | from probly.layers.flax import DropConnectLinear 8 | 9 | from .common import register 10 | 11 | 12 | def replace_flax_dropconnect(obj: nnx.Linear, p: float) -> DropConnectLinear: 13 | """Replace a given layer by a DropConnectLinear layer.""" 14 | return DropConnectLinear(obj, rate=p) 15 | 16 | 17 | register(nnx.Linear, replace_flax_dropconnect) 18 | -------------------------------------------------------------------------------- /src/lazy_dispatch/__init__.py: -------------------------------------------------------------------------------- 1 | """Lazy alternatives to singledispatch and isinstance.""" 2 | 3 | from lazy_dispatch.isinstance import LazyType, lazy_isinstance, lazy_issubclass 4 | from lazy_dispatch.load import lazy_callable, lazy_import 5 | from lazy_dispatch.singledispatch import is_valid_dispatch_type, lazydispatch 6 | 7 | __all__ = [ 8 | "LazyType", 9 | "is_valid_dispatch_type", 10 | "lazy_callable", 11 | "lazy_import", 12 | "lazy_isinstance", 13 | "lazy_issubclass", 14 | "lazydispatch", 15 | ] 16 | -------------------------------------------------------------------------------- /src/probly/transformation/evidential/classification/torch.py: -------------------------------------------------------------------------------- 1 | """Torch evidential classification implementation.""" 2 | 3 | from __future__ import annotations 4 | 5 | from torch import nn 6 | 7 | from probly.transformation.evidential.classification.common import register 8 | 9 | 10 | def append_activation_torch(obj: nn.Module) -> nn.Sequential: 11 | """Register a base model that the activation function will be appended to.""" 12 | return nn.Sequential(obj, nn.Softplus()) 13 | 14 | 15 | register(nn.Module, append_activation_torch) 16 | -------------------------------------------------------------------------------- /src/probly/utils/errors.py: -------------------------------------------------------------------------------- 1 | """All errors and warnings used in the probly package.""" 2 | 3 | from __future__ import annotations 4 | 5 | from warnings import warn 6 | 7 | 8 | def raise_deprecation_warning( 9 | message: str, 10 | deprecated_in: str, 11 | removed_in: str, 12 | ) -> None: 13 | """Raise a deprecation warning with the given details.""" 14 | message += f" This feature is deprecated in version {deprecated_in} and will be removed in version {removed_in}." 15 | warn(message, DeprecationWarning, stacklevel=2) 16 | -------------------------------------------------------------------------------- /docs/source/api.rst: -------------------------------------------------------------------------------- 1 | API Reference 2 | ============= 3 | 4 | .. autosummary:: 5 | :toctree: api/ 6 | :recursive: 7 | 8 | probly.datasets 9 | probly.evaluation 10 | probly.layers 11 | probly.plot 12 | probly.predictor 13 | probly.quantification 14 | probly.representation 15 | probly.train 16 | probly.transformation 17 | probly.traverse_nn 18 | probly.utils 19 | 20 | 21 | .. automodule:: probly 22 | :members: 23 | :undoc-members: 24 | :show-inheritance: 25 | :special-members: __init__, __call__ 26 | -------------------------------------------------------------------------------- /src/probly/transformation/dropout/flax.py: -------------------------------------------------------------------------------- 1 | """Torch dropout implementation.""" 2 | 3 | from __future__ import annotations 4 | 5 | from typing import TYPE_CHECKING 6 | 7 | from flax.nnx import Dropout, Linear, Sequential 8 | 9 | from .common import register 10 | 11 | if TYPE_CHECKING: 12 | from collections.abc import Callable 13 | 14 | 15 | def prepend_flax_dropout(obj: Callable, p: float) -> Sequential: 16 | """Prepend a Dropout layer before the given layer.""" 17 | return Sequential(Dropout(p), obj) 18 | 19 | 20 | register(Linear, prepend_flax_dropout) 21 | -------------------------------------------------------------------------------- /src/probly/transformation/evidential/regression/__init__.py: -------------------------------------------------------------------------------- 1 | """Evidential regression implementation for uncertainty quantification.""" 2 | 3 | from __future__ import annotations 4 | 5 | from probly.lazy_types import TORCH_MODULE 6 | from probly.transformation.evidential.regression import common 7 | 8 | evidential_regression = common.evidential_regression 9 | register = common.register 10 | 11 | 12 | ## Torch 13 | @common.evidential_regression_traverser.delayed_register(TORCH_MODULE) 14 | def _(_: type) -> None: 15 | from . import torch as torch # noqa: PLC0415 16 | -------------------------------------------------------------------------------- /src/probly/transformation/evidential/classification/__init__.py: -------------------------------------------------------------------------------- 1 | """Evidential classification implementation for uncertainty quantification.""" 2 | 3 | from __future__ import annotations 4 | 5 | from probly.lazy_types import TORCH_MODULE 6 | from probly.transformation.evidential.classification import common 7 | 8 | evidential_classification = common.evidential_classification 9 | register = common.register 10 | 11 | 12 | ## Torch 13 | @common.evidential_classification_appender.delayed_register(TORCH_MODULE) 14 | def _(_: type) -> None: 15 | from . import torch as torch # noqa: PLC0415 16 | -------------------------------------------------------------------------------- /src/probly/transformation/__init__.py: -------------------------------------------------------------------------------- 1 | """Transformations for models.""" 2 | 3 | from probly.transformation.bayesian import bayesian 4 | from probly.transformation.dropconnect import dropconnect 5 | from probly.transformation.dropout import dropout 6 | from probly.transformation.ensemble import ensemble 7 | from probly.transformation.evidential.classification import evidential_classification 8 | from probly.transformation.evidential.regression import evidential_regression 9 | 10 | __all__ = ["bayesian", "dropconnect", "dropout", "ensemble", "evidential_classification", "evidential_regression"] 11 | -------------------------------------------------------------------------------- /src/probly/representation/sampling/__init__.py: -------------------------------------------------------------------------------- 1 | """Representation predictors that create representations from finite samples.""" 2 | 3 | from __future__ import annotations 4 | 5 | from .sample import ArraySample, ListSample, Sample, create_sample 6 | from .sampler import CLEANUP_FUNCS, Sampler, SamplingStrategy, get_sampling_predictor, sampler_factory 7 | 8 | __all__ = [ 9 | "CLEANUP_FUNCS", 10 | "ArraySample", 11 | "ListSample", 12 | "Sample", 13 | "Sampler", 14 | "SamplingStrategy", 15 | "create_sample", 16 | "get_sampling_predictor", 17 | "sampler_factory", 18 | ] 19 | -------------------------------------------------------------------------------- /tests/probly/general_utils.py: -------------------------------------------------------------------------------- 1 | """Utils for testing.""" 2 | 3 | from __future__ import annotations 4 | 5 | import numpy as np 6 | 7 | # Machine epsilon for np.float64. We assume (for now) that we don't use more precise floats. 8 | # For more information, see https://numpy.org/doc/stable/reference/generated/numpy.finfo.html. 9 | VALIDATE_EPS = np.finfo(np.float64).eps 10 | 11 | 12 | def validate_uncertainty(uncertainty: np.ndarray) -> None: 13 | assert isinstance(uncertainty, np.ndarray) 14 | assert not np.isnan(uncertainty).any() 15 | assert not np.isinf(uncertainty).any() 16 | assert (uncertainty >= -VALIDATE_EPS).all() 17 | -------------------------------------------------------------------------------- /src/probly/transformation/ensemble/__init__.py: -------------------------------------------------------------------------------- 1 | """Ensemble implementation for uncertainty quantification.""" 2 | 3 | from __future__ import annotations 4 | 5 | from probly.lazy_types import FLAX_MODULE, TORCH_MODULE 6 | 7 | from . import common 8 | 9 | ensemble = common.ensemble 10 | register = common.register 11 | 12 | 13 | ## Torch 14 | @common.ensemble_generator.delayed_register(TORCH_MODULE) 15 | def _(_: type) -> None: 16 | from . import torch as torch # noqa: PLC0415 17 | 18 | 19 | ## Flax 20 | @common.ensemble_generator.delayed_register(FLAX_MODULE) 21 | def _(_: type) -> None: 22 | from . import flax as flax # noqa: PLC0415 23 | -------------------------------------------------------------------------------- /src/probly/transformation/dropout/__init__.py: -------------------------------------------------------------------------------- 1 | """Dropout ensemble implementation for uncertainty quantification.""" 2 | 3 | from __future__ import annotations 4 | 5 | from probly.lazy_types import FLAX_MODULE, TORCH_MODULE 6 | 7 | from . import common 8 | 9 | dropout = common.dropout 10 | register = common.register 11 | 12 | 13 | ## Torch 14 | @common.dropout_traverser.delayed_register(TORCH_MODULE) 15 | def _(_: type) -> None: 16 | from . import torch as torch # noqa: PLC0415 17 | 18 | 19 | ## Flax 20 | @common.dropout_traverser.delayed_register(FLAX_MODULE) 21 | def _(_: type) -> None: 22 | from . import flax as flax # noqa: PLC0415 23 | -------------------------------------------------------------------------------- /src/probly/transformation/dropconnect/__init__.py: -------------------------------------------------------------------------------- 1 | """DropConnect implementation for uncertainty quantification.""" 2 | 3 | from __future__ import annotations 4 | 5 | from probly.lazy_types import FLAX_MODULE, TORCH_MODULE 6 | 7 | from . import common 8 | 9 | dropconnect = common.dropconnect 10 | register = common.register 11 | 12 | 13 | ## Torch 14 | @common.dropconnect_traverser.delayed_register(TORCH_MODULE) 15 | def _(_: type) -> None: 16 | from . import torch as torch # noqa: PLC0415 17 | 18 | 19 | ## Flax 20 | @common.dropconnect_traverser.delayed_register(FLAX_MODULE) 21 | def _(_: type) -> None: 22 | from . import flax as flax # noqa: PLC0415 23 | -------------------------------------------------------------------------------- /tests/probly/fixtures/torch_data.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import pytest 4 | 5 | torch = pytest.importorskip("torch") 6 | 7 | from torch import Tensor, nn # noqa: E402 8 | 9 | 10 | @pytest.fixture 11 | def sample_classification_data() -> tuple[Tensor, Tensor]: 12 | inputs = torch.randn(2, 3, 5, 5) 13 | targets = torch.randint(0, 2, (2,)) 14 | return inputs, targets 15 | 16 | 17 | @pytest.fixture 18 | def sample_outputs( 19 | torch_conv_linear_model: nn.Module, 20 | ) -> tuple[Tensor, Tensor]: 21 | outputs = torch_conv_linear_model(torch.randn(2, 3, 5, 5)) 22 | targets = torch.randint(0, 2, (2,)) 23 | return outputs, targets 24 | -------------------------------------------------------------------------------- /docs/Makefile: -------------------------------------------------------------------------------- 1 | # Minimal makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line, and also 5 | # from the environment for the first two. 6 | SPHINXOPTS ?= 7 | SPHINXBUILD ?= sphinx-build 8 | SOURCEDIR = source 9 | BUILDDIR = build 10 | 11 | # Put it first so that "make" without argument is like "make help". 12 | help: 13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 14 | 15 | .PHONY: help Makefile 16 | 17 | # Catch-all target: route all unknown targets to Sphinx using the new 18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 19 | %: Makefile 20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 21 | -------------------------------------------------------------------------------- /src/probly/traverse_nn/__init__.py: -------------------------------------------------------------------------------- 1 | """Traverser utilities for neural networks.""" 2 | 3 | from probly.lazy_types import FLAX_MODULE, TORCH_MODULE 4 | 5 | from . import common 6 | 7 | ## NN 8 | 9 | LAYER_COUNT = common.LAYER_COUNT 10 | is_first_layer = common.is_first_layer 11 | 12 | layer_count_traverser = common.layer_count_traverser 13 | nn_traverser = common.nn_traverser 14 | 15 | nn_compose = common.compose 16 | 17 | 18 | ## Torch 19 | @nn_traverser.delayed_register(TORCH_MODULE) 20 | def _(_: type) -> None: 21 | from . import torch as torch # noqa: PLC0415 22 | 23 | 24 | ## Flax 25 | @nn_traverser.delayed_register(FLAX_MODULE) 26 | def _(_: type) -> None: 27 | from . import flax as flax # noqa: PLC0415 28 | -------------------------------------------------------------------------------- /src/probly/predictor.py: -------------------------------------------------------------------------------- 1 | """Protocols and ABCs for representation wrappers.""" 2 | 3 | from __future__ import annotations 4 | 5 | from typing import Protocol, Unpack, runtime_checkable 6 | 7 | from lazy_dispatch.singledispatch import lazydispatch 8 | 9 | 10 | @runtime_checkable 11 | class Predictor[In, KwIn, Out](Protocol): 12 | """Protocol for generic predictors.""" 13 | 14 | def __call__(self, *args: In, **kwargs: Unpack[KwIn]) -> Out: 15 | """Call the wrapper with input data.""" 16 | ... 17 | 18 | 19 | @lazydispatch 20 | def predict[In, KwIn, Out](predictor: Predictor[In, KwIn, Out], *args: In, **kwargs: Unpack[KwIn]) -> Out: 21 | """Generic predict function.""" 22 | return predictor(*args, **kwargs) 23 | -------------------------------------------------------------------------------- /src/probly/representation/representer.py: -------------------------------------------------------------------------------- 1 | """Generic representation builder.""" 2 | 3 | from __future__ import annotations 4 | 5 | from abc import ABC 6 | from typing import TYPE_CHECKING 7 | 8 | if TYPE_CHECKING: 9 | from probly.predictor import Predictor 10 | 11 | 12 | class Representer[In, KwIn, Out](ABC): 13 | """Abstract base class for representation builders.""" 14 | 15 | predictor: Predictor[In, KwIn, Out] 16 | 17 | def __init__(self, predictor: Predictor[In, KwIn, Out]) -> None: 18 | """Initialize the representer with a predictor. 19 | 20 | Args: 21 | predictor: Predictor[In, KwIn, Out], the predictor to be used for building representations. 22 | """ 23 | self.predictor = predictor 24 | -------------------------------------------------------------------------------- /tests/probly/transformation/dropout/test_common.py: -------------------------------------------------------------------------------- 1 | """Test for dropout models.""" 2 | 3 | from __future__ import annotations 4 | 5 | import pytest 6 | 7 | from probly.predictor import Predictor 8 | from probly.transformation import dropout 9 | 10 | 11 | def test_invalid_p_value(dummy_predictor: Predictor) -> None: 12 | """Tests the behavior of the dropout function when provided with an invalid probability value. 13 | 14 | This function validates that the dropout function raises a ValueError when 15 | the probability parameter `p` is outside the valid range [0, 1]. 16 | 17 | Raises: 18 | ValueError: If the probability `p` is not between 0 and 1. 19 | """ 20 | p = 2 21 | with pytest.raises(ValueError, match=f"The probability p must be between 0 and 1, but got {p} instead."): 22 | dropout(dummy_predictor, p=p) 23 | -------------------------------------------------------------------------------- /tests/probly/transformation/dropconnect/test_common.py: -------------------------------------------------------------------------------- 1 | """Test for dropconnect models.""" 2 | 3 | from __future__ import annotations 4 | 5 | import pytest 6 | 7 | from probly.predictor import Predictor 8 | from probly.transformation import dropconnect 9 | 10 | 11 | def test_invalid_p_value(dummy_predictor: Predictor) -> None: 12 | """Tests the behavior of the dropconnect function when provided with an invalid probability value. 13 | 14 | This function validates that the dropconnect function raises a ValueError when 15 | the probability parameter `p` is outside the valid range [0, 1]. 16 | 17 | Raises: 18 | ValueError: If the probability `p` is not between 0 and 1. 19 | """ 20 | p = 2 21 | with pytest.raises(ValueError, match=f"The probability p must be between 0 and 1, but got {p} instead."): 22 | dropconnect(dummy_predictor, p=p) 23 | -------------------------------------------------------------------------------- /tests/probly/tests_deprecations/__init__.py: -------------------------------------------------------------------------------- 1 | """A test module which contains tests for deprecations. 2 | 3 | This module is used to ensure that deprecated features are properly flagged and that after 4 | they should be removed, that they indeed are no longer present in the codebase. 5 | 6 | Usage: 7 | - Adding Deprecations: If you add a new deprecation, you should add a test here to ensure that 8 | it is raised properly. Ensure that in each deprecation message, you include the version 9 | of future removal, so that users can easily identify when the feature will be removed. 10 | The test should check that the deprecation warning includes a version. 11 | - Removing Deprecations: If you remove a deprecated feature, ensure that the test for that 12 | feature is also removed. This helps keep the test suite clean and focused on current 13 | features. 14 | """ 15 | -------------------------------------------------------------------------------- /src/probly/transformation/evidential/regression/torch.py: -------------------------------------------------------------------------------- 1 | """Torch evidential regression implementation.""" 2 | 3 | from __future__ import annotations 4 | 5 | from typing import TYPE_CHECKING 6 | 7 | from torch import nn 8 | 9 | from probly.layers.torch import NormalInverseGammaLinear 10 | from probly.transformation.evidential.regression.common import REPLACED_LAST_LINEAR, register 11 | 12 | if TYPE_CHECKING: 13 | from pytraverse import State, TraverserResult 14 | 15 | 16 | def replace_last_torch_nig(obj: nn.Linear, state: State) -> TraverserResult: 17 | """Register a class to be replaced by the NormalInverseGammaLinear layer.""" 18 | state[REPLACED_LAST_LINEAR] = True 19 | return NormalInverseGammaLinear( 20 | obj.in_features, 21 | obj.out_features, 22 | device=obj.weight.device, 23 | bias=obj.bias is not None, 24 | ), state 25 | 26 | 27 | register(nn.Linear, replace_last_torch_nig) 28 | -------------------------------------------------------------------------------- /src/probly/representation/credal_set/jax.py: -------------------------------------------------------------------------------- 1 | """Torch credal set implementation.""" 2 | 3 | from __future__ import annotations 4 | 5 | import jax 6 | 7 | from probly.representation.credal_set.credal_set import CredalSet, credal_set_from_sample 8 | from probly.representation.sampling.jax_sample import JaxArraySample 9 | 10 | 11 | @credal_set_from_sample.register(JaxArraySample) 12 | class JaxArrayCredalSet(CredalSet[jax.Array]): 13 | """A credal set implementation for torch tensors.""" 14 | 15 | def __init__(self, sample: JaxArraySample) -> None: 16 | """Initialize the torch tensor credal set.""" 17 | self.array = sample.array 18 | 19 | def lower(self) -> jax.Array: 20 | """Compute the lower envelope of the credal set.""" 21 | return self.array.min(axis=1) 22 | 23 | def upper(self) -> jax.Array: 24 | """Compute the upper envelope of the credal set.""" 25 | return self.array.max(axis=1) 26 | -------------------------------------------------------------------------------- /tests/probly/train/bayesian/test_torch.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import pytest 4 | 5 | from probly.train.bayesian.torch import ELBOLoss, collect_kl_divergence 6 | from probly.transformation import bayesian 7 | from tests.probly.torch_utils import validate_loss 8 | 9 | torch = pytest.importorskip("torch") 10 | 11 | from torch import Tensor, nn # noqa: E402 12 | 13 | 14 | def test_elbo_loss( 15 | sample_classification_data: tuple[Tensor, Tensor], 16 | torch_conv_linear_model: nn.Module, 17 | ) -> None: 18 | inputs, targets = sample_classification_data 19 | model: nn.Module = bayesian(torch_conv_linear_model) 20 | outputs = model(inputs) 21 | 22 | criterion = ELBOLoss() 23 | loss = criterion(outputs, targets, collect_kl_divergence(model)) 24 | validate_loss(loss) 25 | 26 | criterion = ELBOLoss(0.0) 27 | loss = criterion(outputs, targets, collect_kl_divergence(model)) 28 | validate_loss(loss) 29 | -------------------------------------------------------------------------------- /docs/make.bat: -------------------------------------------------------------------------------- 1 | @ECHO OFF 2 | 3 | pushd %~dp0 4 | 5 | REM Command file for Sphinx documentation 6 | 7 | if "%SPHINXBUILD%" == "" ( 8 | set SPHINXBUILD=sphinx-build 9 | ) 10 | set SOURCEDIR=source 11 | set BUILDDIR=build 12 | 13 | %SPHINXBUILD% >NUL 2>NUL 14 | if errorlevel 9009 ( 15 | echo. 16 | echo.The 'sphinx-build' command was not found. Make sure you have Sphinx 17 | echo.installed, then set the SPHINXBUILD environment variable to point 18 | echo.to the full path of the 'sphinx-build' executable. Alternatively you 19 | echo.may add the Sphinx directory to PATH. 20 | echo. 21 | echo.If you don't have Sphinx installed, grab it from 22 | echo.https://www.sphinx-doc.org/ 23 | exit /b 1 24 | ) 25 | 26 | if "%1" == "" goto help 27 | 28 | %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 29 | goto end 30 | 31 | :help 32 | %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 33 | 34 | :end 35 | popd 36 | -------------------------------------------------------------------------------- /docs/source/_static/github-mark.svg: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /src/probly/representation/credal_set/torch.py: -------------------------------------------------------------------------------- 1 | """Torch credal set implementation.""" 2 | 3 | from __future__ import annotations 4 | 5 | import torch 6 | 7 | from probly.representation.credal_set.credal_set import CredalSet, credal_set_from_sample 8 | from probly.representation.sampling.torch_sample import TorchTensorSample 9 | 10 | 11 | @credal_set_from_sample.register(TorchTensorSample) 12 | class TorchTensorCredalSet(CredalSet[torch.Tensor]): 13 | """A credal set implementation for torch tensors.""" 14 | 15 | def __init__(self, sample: TorchTensorSample) -> None: 16 | """Initialize the torch tensor credal set.""" 17 | self.tensor = sample.tensor 18 | 19 | def lower(self) -> torch.Tensor: 20 | """Compute the lower envelope of the credal set.""" 21 | return self.tensor.min(dim=1).values 22 | 23 | def upper(self) -> torch.Tensor: 24 | """Compute the upper envelope of the credal set.""" 25 | return self.tensor.max(dim=1).values 26 | -------------------------------------------------------------------------------- /tests/probly/transformation/evidential/regression/test_common.py: -------------------------------------------------------------------------------- 1 | """Tests for evidential regression models.""" 2 | 3 | from __future__ import annotations 4 | 5 | from typing import Any, cast 6 | 7 | from probly.predictor import Predictor 8 | from probly.transformation.evidential.regression import evidential_regression 9 | 10 | 11 | def test_unknown_base_returns_self(dummy_predictor: Predictor) -> None: 12 | """Tests that the base model is returned if no implementation is registered. 13 | 14 | This uses the provided dummy_predictor fixture instead of a locally defined class. 15 | 16 | Parameters: 17 | dummy_predictor (Predictor): The generic predictor fixture supplied by pytest. 18 | """ 19 | # Use the official fixture which adheres to the Predictor type. 20 | base = dummy_predictor 21 | 22 | transformed = evidential_regression(cast(Any, base)) 23 | 24 | # Assert that the function returned the exact same object instance. 25 | assert transformed is base 26 | -------------------------------------------------------------------------------- /tests/probly/flax_utils.py: -------------------------------------------------------------------------------- 1 | """Util functions for flax tests.""" 2 | 3 | from __future__ import annotations 4 | 5 | import pytest 6 | 7 | flax = pytest.importorskip("flax") 8 | 9 | from flax import nnx # noqa: E402 10 | 11 | 12 | def count_layers(model: nnx.Module, layer_type: type[nnx.Module]) -> int: 13 | """Counts the number of layers of a specific type in a neural network model. 14 | 15 | This function iterates through all the modules in the given model and counts 16 | how many of them match the specified layer type. It's particularly useful 17 | for analyzing the architecture of a neural network or verifying its 18 | composition. 19 | 20 | Parameters: 21 | model: The neural network model containing the layers to be counted. 22 | layer_type: The type of layer to count within the model. 23 | 24 | Returns: 25 | The number of layers of the specified type found in the model. 26 | """ 27 | return sum(1 for _, m in model.iter_modules() if isinstance(m, layer_type)) 28 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | - repo: https://github.com/pre-commit/pre-commit-hooks 3 | rev: v5.0.0 4 | hooks: 5 | - id: check-json 6 | - id: check-yaml 7 | - id: check-toml 8 | - id: end-of-file-fixer 9 | - id: trailing-whitespace 10 | - id: mixed-line-ending 11 | - id: check-merge-conflict 12 | - repo: https://github.com/astral-sh/ruff-pre-commit 13 | rev: v0.12.0 14 | hooks: 15 | - id: ruff-check 16 | args: [--fix, --exit-non-zero-on-fix, --no-cache] 17 | - id: ruff-format 18 | args: [--verbose] 19 | 20 | - repo: https://github.com/pre-commit/mirrors-mypy 21 | rev: v1.18.2 22 | hooks: 23 | - id: mypy 24 | exclude: docs 25 | args: 26 | - "--ignore-missing-imports" 27 | - "--scripts-are-modules" 28 | # Since pre-commit env does not have all dependencies installed, 29 | # we ignore unused ignores to avoid false positives: 30 | - "--no-warn-unused-ignores" 31 | -------------------------------------------------------------------------------- /src/probly/representation/sampling/flax_sampler.py: -------------------------------------------------------------------------------- 1 | """Sampling preparation for flax.""" 2 | 3 | from __future__ import annotations 4 | 5 | from typing import TYPE_CHECKING 6 | 7 | from flax import nnx 8 | 9 | from . import sampler 10 | 11 | if TYPE_CHECKING: 12 | from lazy_dispatch.isinstance import LazyType 13 | from pytraverse import State 14 | 15 | 16 | def _enforce_train_mode(obj: nnx.Module, state: State) -> tuple[nnx.Module, State]: 17 | if getattr(obj, "deterministic", False): 18 | obj.deterministic = False # type: ignore[attr-defined] 19 | state[sampler.CLEANUP_FUNCS].add(lambda: setattr(obj, "deterministic", True)) 20 | return obj, state 21 | return obj, state 22 | 23 | 24 | def register_forced_train_mode(cls: LazyType) -> None: 25 | """Register a class to be forced into train mode during sampling.""" 26 | sampler.sampling_preparation_traverser.register( 27 | cls, 28 | _enforce_train_mode, 29 | ) 30 | 31 | 32 | register_forced_train_mode( 33 | nnx.Dropout, 34 | ) 35 | -------------------------------------------------------------------------------- /.github/PULL_REQUEST_TEMPLATE.md: -------------------------------------------------------------------------------- 1 | ## Issue 2 | Please link the corresponding GitHub issue. If an issue does not already exist, 3 | please open one to describe the bug or feature request before creating a pull request. 4 | 5 | This allows us to discuss the proposal and helps avoid unnecessary work. 6 | 7 | ## Motivation and Context 8 | 9 | --- 10 | 11 | ## Public API Changes 12 | 13 | - [ ] No Public API changes 14 | - [ ] Yes, Public API changes (Details below) 15 | 16 | --- 17 | 18 | ## How Has This Been Tested? 19 | 20 | --- 21 | 22 | ## Checklist 23 | 24 | - [ ] The changes have been tested locally. 25 | - [ ] Documentation has been updated (if the public API or usage changes). 26 | - [ ] A entry has been added to [`CHANGELOG.md`](https://github.com/pwhofman/probly/blob/main/CHANGELOG.md) (if relevant for users). 27 | - [ ] The code follows the project's [style guidelines](https://github.com/pwhofman/probly/blob/main/.github/CONTRIBUTING.md) and passes minimal code style checks. 28 | - [ ] I have considered the impact of these changes on the public API. 29 | 30 | --- 31 | -------------------------------------------------------------------------------- /src/probly/representation/sampling/jax_sample.py: -------------------------------------------------------------------------------- 1 | """JAX sample implementation.""" 2 | 3 | from __future__ import annotations 4 | 5 | import jax 6 | import jax.numpy as jnp 7 | 8 | from .sample import Sample, create_sample 9 | 10 | 11 | @create_sample.register(jax.Array) 12 | class JaxArraySample(Sample[jax.Array]): 13 | """A sample implementation for JAX arrays.""" 14 | 15 | def __init__(self, samples: list[jax.Array]) -> None: 16 | """Initialize the JAX array sample.""" 17 | self.array = jnp.stack(samples).transpose(1, 0, 2) # we use the convention [instances, samples, classes] 18 | 19 | def mean(self) -> jax.Array: 20 | """Compute the mean of the sample.""" 21 | return jnp.mean(self.array, axis=1) 22 | 23 | def std(self, ddof: int = 1) -> jax.Array: 24 | """Compute the standard deviation of the sample.""" 25 | return jnp.std(self.array, axis=1, ddof=ddof) 26 | 27 | def var(self, ddof: int = 1) -> jax.Array: 28 | """Compute the variance of the sample.""" 29 | return jnp.var(self.array, axis=1, ddof=ddof) 30 | -------------------------------------------------------------------------------- /src/probly/representation/sampling/torch_sample.py: -------------------------------------------------------------------------------- 1 | """Torch sample implementation.""" 2 | 3 | from __future__ import annotations 4 | 5 | import torch 6 | 7 | from .sample import Sample, create_sample 8 | 9 | 10 | @create_sample.register(torch.Tensor) 11 | class TorchTensorSample(Sample[torch.Tensor]): 12 | """A sample implementation for torch tensors.""" 13 | 14 | def __init__(self, samples: list[torch.Tensor]) -> None: 15 | """Initialize the torch tensor sample.""" 16 | self.tensor = torch.stack(samples).permute(1, 0, 2) # we use the convention [instances, samples, classes] 17 | 18 | def mean(self) -> torch.Tensor: 19 | """Compute the mean of the sample.""" 20 | return self.tensor.mean(dim=1) 21 | 22 | def std(self, ddof: int = 1) -> torch.Tensor: 23 | """Compute the standard deviation of the sample.""" 24 | return self.tensor.std(dim=1, correction=ddof) 25 | 26 | def var(self, ddof: int = 1) -> torch.Tensor: 27 | """Compute the variance of the sample.""" 28 | return self.tensor.var(dim=1, correction=ddof) 29 | -------------------------------------------------------------------------------- /src/probly/transformation/bayesian/torch.py: -------------------------------------------------------------------------------- 1 | """Torch Bayesian implementation.""" 2 | 3 | from __future__ import annotations 4 | 5 | from torch import nn 6 | 7 | from probly.layers.torch import BayesConv2d, BayesLinear 8 | 9 | from .common import register 10 | 11 | 12 | def replace_torch_bayesian_linear( 13 | obj: nn.Linear, 14 | use_base_weights: bool, 15 | posterior_std: float, 16 | prior_mean: float, 17 | prior_std: float, 18 | ) -> BayesLinear: 19 | """Replace a given layer by a BayesLinear layer.""" 20 | return BayesLinear(obj, use_base_weights, posterior_std, prior_mean, prior_std) 21 | 22 | 23 | def replace_torch_bayesian_conv2d( 24 | obj: nn.Conv2d, 25 | use_base_weights: bool, 26 | posterior_std: float, 27 | prior_mean: float, 28 | prior_std: float, 29 | ) -> BayesConv2d: 30 | """Replace a given layer by a BayesConv2d layer.""" 31 | return BayesConv2d(obj, use_base_weights, posterior_std, prior_mean, prior_std) 32 | 33 | 34 | register(nn.Linear, replace_torch_bayesian_linear) 35 | register(nn.Conv2d, replace_torch_bayesian_conv2d) 36 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2025 the probly team 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /.github/workflows/deploy-docs.yml: -------------------------------------------------------------------------------- 1 | name: Deploy Sphinx Docs to GitHub Pages 2 | 3 | on: 4 | push: 5 | branches: 6 | - main 7 | 8 | jobs: 9 | build: 10 | runs-on: ubuntu-latest 11 | 12 | steps: 13 | - name: Checkout repository 14 | uses: actions/checkout@v4 15 | 16 | - name: Set up Python and uv 17 | id: setup-uv-python 18 | uses: astral-sh/setup-uv@v5 19 | with: 20 | python-version: "3.12" 21 | 22 | - name: Install dependencies 23 | run: uv sync --dev 24 | 25 | - name: Install pandoc 26 | run: sudo apt-get install -y pandoc 27 | 28 | - name: Copy content to docs/source folder 29 | run: | 30 | cp .github/CONTRIBUTING.md docs/source/contributing.md 31 | cp -r notebooks/examples/ docs/source/ 32 | 33 | - name: Build Sphinx documentation 34 | working-directory: ./docs 35 | run: | 36 | make clean html 37 | 38 | - name: Deploy to GitHub Pages 39 | uses: peaceiris/actions-gh-pages@v3 40 | with: 41 | github_token: ${{ secrets.GITHUB_TOKEN }} 42 | publish_dir: ./docs/build/html 43 | enable_jekyll: false 44 | -------------------------------------------------------------------------------- /src/probly/representation/sampling/torch_sampler.py: -------------------------------------------------------------------------------- 1 | """Sampling preparation for torch.""" 2 | 3 | from __future__ import annotations 4 | 5 | from typing import TYPE_CHECKING 6 | 7 | import torch.nn 8 | 9 | from probly.layers.torch import DropConnectLinear 10 | 11 | from . import sampler 12 | 13 | if TYPE_CHECKING: 14 | from lazy_dispatch.isinstance import LazyType 15 | from pytraverse import State 16 | 17 | 18 | def _enforce_train_mode(obj: torch.nn.Module, state: State) -> tuple[torch.nn.Module, State]: 19 | if not obj.training: 20 | obj.train() 21 | state[sampler.CLEANUP_FUNCS].add(lambda: obj.train(False)) 22 | return obj, state 23 | return obj, state 24 | 25 | 26 | def register_forced_train_mode(cls: LazyType) -> None: 27 | """Register a class to be forced into train mode during sampling.""" 28 | sampler.sampling_preparation_traverser.register( 29 | cls, 30 | _enforce_train_mode, 31 | ) 32 | 33 | 34 | register_forced_train_mode( 35 | torch.nn.Dropout 36 | | torch.nn.Dropout1d 37 | | torch.nn.Dropout2d 38 | | torch.nn.Dropout3d 39 | | torch.nn.AlphaDropout 40 | | torch.nn.FeatureAlphaDropout 41 | | DropConnectLinear, 42 | ) 43 | -------------------------------------------------------------------------------- /tests/probly/torch_utils.py: -------------------------------------------------------------------------------- 1 | """Util functions for torch tests.""" 2 | 3 | from __future__ import annotations 4 | 5 | import pytest 6 | 7 | torch = pytest.importorskip("torch") 8 | 9 | from torch import Tensor, nn # noqa: E402 10 | 11 | 12 | def validate_loss(loss: Tensor) -> None: 13 | assert isinstance(loss, Tensor) 14 | assert loss.dim() == 0 15 | assert not torch.isnan(loss) 16 | assert not torch.isinf(loss) 17 | assert loss.item() >= 0 18 | 19 | 20 | def count_layers(model: nn.Module, layer_type: type[nn.Module]) -> int: 21 | """Counts the number of layers of a specific type in a neural network model. 22 | 23 | This function iterates through all the modules in the given model and counts 24 | how many of them match the specified layer type. It's particularly useful 25 | for analyzing the architecture of a neural network or verifying its 26 | composition. 27 | 28 | Parameters: 29 | model: The neural network model containing the layers to be counted. 30 | layer_type: The type of layer to count within the model. 31 | 32 | Returns: 33 | The number of layers of the specified type found in the model. 34 | """ 35 | return sum(1 for m in model.modules() if isinstance(m, layer_type)) 36 | -------------------------------------------------------------------------------- /.github/workflows/publish-release.yml: -------------------------------------------------------------------------------- 1 | # This workflow will upload a Python Package using Twine when a release is created 2 | # For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-python#publishing-to-package-registries 3 | 4 | # This workflow uses actions that are not certified by GitHub. 5 | # They are provided by a third-party and are governed by 6 | # separate terms of service, privacy policy, and support 7 | # documentation. 8 | 9 | name: Publish Release 10 | 11 | on: 12 | release: 13 | types: [published] 14 | branches: [main] 15 | 16 | permissions: 17 | contents: read 18 | 19 | jobs: 20 | deploy: 21 | 22 | runs-on: ubuntu-latest 23 | 24 | steps: 25 | - uses: actions/checkout@v4 26 | 27 | - name: Set up Python 28 | uses: actions/setup-python@v5 29 | with: 30 | python-version: '3.12' 31 | 32 | - name: Install dependencies 33 | run: | 34 | python -m pip install --upgrade pip 35 | pip install build 36 | 37 | - name: Build package 38 | run: python -m build 39 | 40 | - name: Publish package 41 | uses: pypa/gh-action-pypi-publish@v1.12.4 42 | with: 43 | user: __token__ 44 | password: ${{ secrets.PYPI_API_TOKEN }} 45 | -------------------------------------------------------------------------------- /tests/probly/plot/test_credal.py: -------------------------------------------------------------------------------- 1 | """Tests for the plot module.""" 2 | 3 | from __future__ import annotations 4 | 5 | import matplotlib.pyplot as plt 6 | import numpy as np 7 | import pytest 8 | 9 | from probly.plot.credal import credal_set_plot, simplex_plot 10 | 11 | 12 | def test_simplex_plot_outputs() -> None: 13 | probs = np.array([[1 / 3, 1 / 3, 1 / 3], [1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0]]) 14 | fig, ax = simplex_plot(probs) 15 | assert isinstance(fig, plt.Figure) 16 | assert isinstance(ax, plt.Axes) 17 | assert ax.name == "ternary" 18 | assert ax.collections[0].get_offsets().shape[0] == len(probs) 19 | 20 | 21 | def test_credal_set_plot() -> None: 22 | probs = np.array([[1 / 3, 1 / 3, 1 / 3], [1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0]]) 23 | fig, ax = credal_set_plot(probs) 24 | assert isinstance(fig, plt.Figure) 25 | assert isinstance(ax, plt.Axes) 26 | assert ax.name == "ternary" 27 | assert ax.collections[0].get_offsets().shape[0] == len(probs) 28 | 29 | probs = np.array([[1 / 3, 1 / 3, 1 / 3], [1 / 3, 1 / 3, 1 / 3]]) 30 | with pytest.raises( 31 | ValueError, 32 | match="The set of vertices is empty. Please check the probabilities in the credal set.", 33 | ): 34 | credal_set_plot(probs) 35 | -------------------------------------------------------------------------------- /tests/probly/transformation/evidential/classification/test_common.py: -------------------------------------------------------------------------------- 1 | """Tests for evidential classification registration and dispatch.""" 2 | 3 | from __future__ import annotations 4 | 5 | from collections.abc import Callable 6 | from typing import Any, cast 7 | 8 | from probly.transformation.evidential.classification.common import ( 9 | evidential_classification, 10 | register, 11 | ) 12 | 13 | 14 | def test_multiple_types_are_handled_independently() -> None: 15 | class ModelA: 16 | def __init__(self, mid: str) -> None: 17 | self.id = mid 18 | 19 | class ModelB: 20 | def __init__(self, mid: str) -> None: 21 | self.id = mid 22 | 23 | def appender_a(base: ModelA) -> str: 24 | return f"A_Enhanced({base.id})" 25 | 26 | def appender_b(base: ModelB) -> str: 27 | return f"B_Enhanced({base.id})" 28 | 29 | register(ModelA, cast(Callable[..., object], appender_a)) 30 | register(ModelB, cast(Callable[..., object], appender_b)) 31 | 32 | a = ModelA("001") 33 | b = ModelB("002") 34 | 35 | result_a = evidential_classification(cast(Any, a)) 36 | result_b = evidential_classification(cast(Any, b)) 37 | 38 | assert cast(str, result_a) == "A_Enhanced(001)" 39 | assert cast(str, result_b) == "B_Enhanced(002)" 40 | -------------------------------------------------------------------------------- /src/pytraverse/__init__.py: -------------------------------------------------------------------------------- 1 | """Generic functional datastructure traverser utilities.""" 2 | 3 | from . import composition, core, decorators, generic 4 | 5 | ## Core 6 | 7 | Variable = core.Variable 8 | GlobalVariable = core.GlobalVariable 9 | StackVariable = core.StackVariable 10 | ComputedVariable = core.ComputedVariable 11 | computed = ComputedVariable # Alias for convenience (intended to be used as a decorator) 12 | 13 | type State[T] = core.State[T] 14 | type TraverserResult[T] = core.TraverserResult[T] 15 | type TraverserCallback[T] = core.TraverserCallback[T] 16 | type Traverser[T] = core.Traverser[T] 17 | 18 | traverse_with_state = core.traverse_with_state 19 | traverse = core.traverse 20 | 21 | ## Traverser Decorator 22 | 23 | traverser = decorators.traverser 24 | 25 | ## Composition 26 | 27 | sequential = composition.sequential 28 | top_sequential = composition.top_sequential 29 | SingledispatchTraverser = composition.SingledispatchTraverser 30 | LazySingledispatchTraverser = composition.LazydispatchTraverser 31 | 32 | singledispatch_traverser = composition.SingledispatchTraverser 33 | lazydispatch_traverser = composition.LazydispatchTraverser 34 | 35 | ## Generic traverser 36 | 37 | generic_traverser = generic.generic_traverser 38 | CLONE = generic.CLONE 39 | TRAVERSE_KEYS = generic.TRAVERSE_KEYS 40 | TRAVERSE_REVERSED = generic.TRAVERSE_REVERSED 41 | -------------------------------------------------------------------------------- /src/probly/transformation/ensemble/torch.py: -------------------------------------------------------------------------------- 1 | """Torch dropout implementation.""" 2 | 3 | from __future__ import annotations 4 | 5 | from torch import nn 6 | 7 | from probly.traverse_nn import nn_compose, nn_traverser 8 | from pytraverse import CLONE, singledispatch_traverser, traverse 9 | 10 | from .common import register 11 | 12 | reset_traverser = singledispatch_traverser[nn.Module](name="reset_traverser") 13 | 14 | 15 | @reset_traverser.register 16 | def _(obj: nn.Module) -> nn.Module: 17 | if hasattr(obj, "reset_parameters"): 18 | obj.reset_parameters() # type: ignore[operator] 19 | return obj 20 | 21 | 22 | def _reset_copy(module: nn.Module) -> nn.Module: 23 | return traverse(module, nn_compose(reset_traverser), init={CLONE: True}) 24 | 25 | 26 | def _copy(module: nn.Module) -> nn.Module: 27 | return traverse(module, nn_traverser, init={CLONE: True}) 28 | 29 | 30 | def generate_torch_ensemble( 31 | obj: nn.Module, 32 | num_members: int, 33 | reset_params: bool = True, 34 | ) -> nn.ModuleList: 35 | """Build a torch ensemble by copying the base model num_members times, resetting the parameters of each member.""" 36 | if reset_params: 37 | return nn.ModuleList([_reset_copy(obj) for _ in range(num_members)]) 38 | return nn.ModuleList([_copy(obj) for _ in range(num_members)]) 39 | 40 | 41 | register(nn.Module, generate_torch_ensemble) 42 | -------------------------------------------------------------------------------- /tests/probly/transformation/ensemble/test_common.py: -------------------------------------------------------------------------------- 1 | """Tests for commonpre ensemble generation.""" 2 | 3 | from __future__ import annotations 4 | 5 | from unittest.mock import Mock 6 | 7 | import pytest 8 | 9 | from probly.predictor import Predictor 10 | from probly.transformation import ensemble 11 | from probly.transformation.ensemble.common import ensemble_generator, register 12 | 13 | 14 | def test_unregistered_type_raises(dummy_predictor: Predictor) -> None: 15 | """No ensemble generator is registered for type, NotImplementedError must occur.""" 16 | base = dummy_predictor 17 | with pytest.raises( 18 | NotImplementedError, 19 | match=f"No ensemble generator is registered for type {type(base)}", 20 | ): 21 | ensemble_generator(dummy_predictor) 22 | 23 | 24 | def test_registered_generator_called(dummy_predictor: Predictor) -> None: 25 | """If the type is registered, the appropriate generator is called and its result is returned.""" 26 | mock_generator = Mock() 27 | expected_result = object() 28 | mock_generator.return_value = expected_result 29 | 30 | register(type(dummy_predictor), mock_generator) 31 | 32 | result = ensemble(dummy_predictor, num_members=4) 33 | 34 | mock_generator.assert_called_once_with( 35 | dummy_predictor, 36 | num_members=4, 37 | reset_params=True, 38 | ) 39 | assert result is expected_result 40 | -------------------------------------------------------------------------------- /src/probly/transformation/evidential/classification/common.py: -------------------------------------------------------------------------------- 1 | """Shared evidential classification implementation.""" 2 | 3 | from __future__ import annotations 4 | 5 | from typing import TYPE_CHECKING 6 | 7 | from lazy_dispatch import lazydispatch 8 | 9 | if TYPE_CHECKING: 10 | from collections.abc import Callable 11 | 12 | from lazy_dispatch.isinstance import LazyType 13 | from probly.predictor import Predictor 14 | 15 | 16 | @lazydispatch 17 | def evidential_classification_appender[In, KwIn, Out](base: Predictor[In, KwIn, Out]) -> Predictor[In, KwIn, Out]: 18 | """Append an evidential classification activation function to a base model.""" 19 | msg = f"No evidential classification appender registered for type {type(base)}" 20 | raise NotImplementedError(msg) 21 | 22 | 23 | def register(cls: LazyType, appender: Callable) -> None: 24 | """Register a base model that the activation function will be appended to.""" 25 | evidential_classification_appender.register(cls=cls, func=appender) 26 | 27 | 28 | def evidential_classification[T: Predictor](base: T) -> T: 29 | """Create an evidential classification predictor from a base predictor. 30 | 31 | Args: 32 | base: Predictor, The base model to be used for evidential classification. 33 | 34 | Returns: 35 | Predictor, The evidential classification predictor. 36 | """ 37 | return evidential_classification_appender(base) 38 | -------------------------------------------------------------------------------- /src/probly/transformation/ensemble/common.py: -------------------------------------------------------------------------------- 1 | """Shared ensemble implementation.""" 2 | 3 | from __future__ import annotations 4 | 5 | from typing import TYPE_CHECKING 6 | 7 | from lazy_dispatch import lazydispatch 8 | 9 | if TYPE_CHECKING: 10 | from collections.abc import Callable 11 | 12 | from lazy_dispatch.isinstance import LazyType 13 | from probly.predictor import Predictor 14 | 15 | 16 | @lazydispatch 17 | def ensemble_generator[In, KwIn, Out](base: Predictor[In, KwIn, Out]) -> Predictor[In, KwIn, Out]: 18 | """Generate an ensemble from a base model.""" 19 | msg = f"No ensemble generator is registered for type {type(base)}" 20 | raise NotImplementedError(msg) 21 | 22 | 23 | def register(cls: LazyType, generator: Callable) -> None: 24 | """Register a class which can be used as a base for an ensemble.""" 25 | ensemble_generator.register(cls=cls, func=generator) 26 | 27 | 28 | def ensemble[T: Predictor](base: T, num_members: int, reset_params: bool = True) -> T: 29 | """Create an ensemble predictor from a base predictor. 30 | 31 | Args: 32 | base: Predictor, The base model to be used for the ensemble. 33 | num_members: The number of members in the ensemble. 34 | reset_params: Whether to reset the parameters of each member. 35 | 36 | Returns: 37 | Predictor, The ensemble predictor. 38 | """ 39 | return ensemble_generator(base, num_members=num_members, reset_params=reset_params) 40 | -------------------------------------------------------------------------------- /tests/probly/tests_deprecations/deprecated_features.py: -------------------------------------------------------------------------------- 1 | """Collects all deprecated behaviour tests.""" 2 | 3 | from __future__ import annotations 4 | 5 | from typing import TYPE_CHECKING 6 | 7 | from probly.utils.errors import raise_deprecation_warning 8 | 9 | if TYPE_CHECKING: 10 | import pytest 11 | 12 | from .utils import register_deprecated 13 | 14 | 15 | @register_deprecated( 16 | name="ExampleFeature(NeverRemoveThis)", 17 | deprecated_in="1.0.0", 18 | removed_in="9.9.9", 19 | ) 20 | def example_always_warn(requests: pytest.FixtureRequest) -> None: 21 | """An example feature that always raises a deprecation warning using a fixture. 22 | 23 | This "behavior" is just an example of how to handle deprecations in the codebase. This 24 | "feature" needs a fixture (in this case an example game) which is passed via the `request` 25 | fixture. The warning which is raised contains the same version information as the 26 | `DeprecatedFeature` instance in the `DEPRECATED_FEATURES` list. 27 | 28 | Note: 29 | Do not delete this "feature", as it is a good example of how to handle deprecations and 30 | write your own deprecation tests. 31 | 32 | """ 33 | raise_deprecation_warning( 34 | "This is an example of a deprecated feature that always raises a warning.", 35 | deprecated_in="1.0.0", 36 | removed_in="9.9.9", 37 | ) 38 | _ = requests.getfixturevalue("torch_conv_linear_model") # gets a fixture as an example of usage 39 | -------------------------------------------------------------------------------- /tests/probly/quantification/test_regression.py: -------------------------------------------------------------------------------- 1 | """Tests for the regression module.""" 2 | 3 | from __future__ import annotations 4 | 5 | from typing import TYPE_CHECKING 6 | 7 | import numpy as np 8 | import pytest 9 | 10 | from tests.probly.general_utils import validate_uncertainty 11 | 12 | if TYPE_CHECKING: 13 | from collections.abc import Callable 14 | 15 | 16 | from probly.quantification.regression import ( 17 | conditional_differential_entropy, 18 | expected_conditional_variance, 19 | mutual_information, 20 | total_differential_entropy, 21 | total_variance, 22 | variance_conditional_expectation, 23 | ) 24 | 25 | 26 | @pytest.fixture 27 | def sample_second_order_params() -> np.ndarray: 28 | rng = np.random.default_rng() 29 | mu = rng.uniform(0, 1000, (5, 10)) 30 | sigma2 = rng.uniform(1e-20, 100, (5, 10)) 31 | params = np.stack((mu, sigma2), axis=2) 32 | return params 33 | 34 | 35 | @pytest.mark.parametrize( 36 | "uncertainty_fn", 37 | [ 38 | total_variance, 39 | expected_conditional_variance, 40 | variance_conditional_expectation, 41 | total_differential_entropy, 42 | conditional_differential_entropy, 43 | mutual_information, 44 | ], 45 | ) 46 | def test_uncertainty_function( 47 | uncertainty_fn: Callable[[np.ndarray], np.ndarray], 48 | sample_second_order_params: np.ndarray, 49 | ) -> None: 50 | uncertainty = uncertainty_fn(sample_second_order_params) 51 | validate_uncertainty(uncertainty) 52 | -------------------------------------------------------------------------------- /tests/probly/utils/test_torch.py: -------------------------------------------------------------------------------- 1 | """Tests for utils.torch functions.""" 2 | 3 | from __future__ import annotations 4 | 5 | import torch 6 | from torch.utils.data import DataLoader, TensorDataset 7 | 8 | from probly.utils.torch import temperature_softmax, torch_collect_outputs, torch_reset_all_parameters 9 | 10 | 11 | def test_torch_reset_all_parameters(torch_conv_linear_model: torch.nn.Module) -> None: 12 | def flatten_params(model: torch.nn.Module) -> torch.Tensor: 13 | return torch.cat([param.flatten() for param in model.parameters()]) 14 | 15 | before = flatten_params(torch_conv_linear_model) 16 | torch_reset_all_parameters(torch_conv_linear_model) 17 | after = flatten_params(torch_conv_linear_model) 18 | assert not torch.equal(before, after) 19 | 20 | 21 | def test_torch_collect_outputs(torch_conv_linear_model: torch.nn.Module) -> None: 22 | loader = DataLoader( 23 | TensorDataset( 24 | torch.randn(2, 3, 5, 5), 25 | torch.randn( 26 | 2, 27 | ), 28 | ), 29 | ) 30 | outputs, targets = torch_collect_outputs(torch_conv_linear_model, loader, torch.device("cpu")) 31 | assert outputs.shape == (2, 2) 32 | assert targets.shape == (2,) 33 | 34 | 35 | def test_temperature_softmax() -> None: 36 | x = torch.tensor([[1.0, 2.0], [3.0, 4.0]]) 37 | assert torch.equal(temperature_softmax(x, 2.0), torch.softmax(x / 2.0, dim=1)) 38 | assert torch.equal(temperature_softmax(x, torch.tensor(1.0)), torch.softmax(x, dim=1)) 39 | -------------------------------------------------------------------------------- /src/probly/transformation/dropout/common.py: -------------------------------------------------------------------------------- 1 | """Shared dropout implementation.""" 2 | 3 | from __future__ import annotations 4 | 5 | from typing import TYPE_CHECKING 6 | 7 | from probly.traverse_nn import is_first_layer, nn_compose 8 | from pytraverse import CLONE, GlobalVariable, lazydispatch_traverser, traverse 9 | 10 | if TYPE_CHECKING: 11 | from lazy_dispatch.isinstance import LazyType 12 | from probly.predictor import Predictor 13 | from pytraverse.composition import RegisteredLooseTraverser 14 | 15 | P = GlobalVariable[float]("P", "The probability of dropout.") 16 | 17 | dropout_traverser = lazydispatch_traverser[object](name="dropout_traverser") 18 | 19 | 20 | def register(cls: LazyType, traverser: RegisteredLooseTraverser) -> None: 21 | """Register a class to be prepended by Dropout layers.""" 22 | dropout_traverser.register(cls=cls, traverser=traverser, skip_if=is_first_layer, vars={"p": P}) 23 | 24 | 25 | def dropout[T: Predictor](base: T, p: float = 0.25) -> T: 26 | """Create a Dropout predictor from a base predictor. 27 | 28 | Args: 29 | base: Predictor, The base model to be used for dropout. 30 | p: float, The probability of dropping out a neuron. Default is 0.25. 31 | 32 | Returns: 33 | Predictor, The DropOut predictor. 34 | """ 35 | if p < 0 or p > 1: 36 | msg = f"The probability p must be between 0 and 1, but got {p} instead." 37 | raise ValueError(msg) 38 | return traverse(base, nn_compose(dropout_traverser), init={P: p, CLONE: True}) 39 | -------------------------------------------------------------------------------- /src/probly/transformation/dropconnect/common.py: -------------------------------------------------------------------------------- 1 | """Shared DropConnect implementation.""" 2 | 3 | from __future__ import annotations 4 | 5 | from typing import TYPE_CHECKING 6 | 7 | from probly.traverse_nn import is_first_layer, nn_compose 8 | from pytraverse import CLONE, GlobalVariable, lazydispatch_traverser, traverse 9 | 10 | if TYPE_CHECKING: 11 | from lazy_dispatch.isinstance import LazyType 12 | from probly.predictor import Predictor 13 | from pytraverse.composition import RegisteredLooseTraverser 14 | 15 | P = GlobalVariable[float]("P", "The probability of dropconnect.") 16 | 17 | dropconnect_traverser = lazydispatch_traverser[object](name="dropconnect_traverser") 18 | 19 | 20 | def register(cls: LazyType, traverser: RegisteredLooseTraverser) -> None: 21 | """Register a class to be replaced by DropConnect layers.""" 22 | dropconnect_traverser.register(cls=cls, traverser=traverser, skip_if=is_first_layer, vars={"p": P}) 23 | 24 | 25 | def dropconnect[T: Predictor](base: T, p: float = 0.25) -> T: 26 | """Create a DropConnect predictor from a base predictor. 27 | 28 | Args: 29 | base: The base model to be used for dropout. 30 | p: The probability of dropping out a neuron. Default is 0.25. 31 | 32 | Returns: 33 | The DropConnect predictor. 34 | """ 35 | if p < 0 or p > 1: 36 | msg = f"The probability p must be between 0 and 1, but got {p} instead." 37 | raise ValueError(msg) 38 | return traverse(base, nn_compose(dropconnect_traverser), init={P: p, CLONE: True}) 39 | -------------------------------------------------------------------------------- /tests/probly/train/calibration/test_torch.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import pytest 4 | 5 | from probly.train.calibration.torch import ExpectedCalibrationError, FocalLoss, LabelRelaxationLoss 6 | from tests.probly.torch_utils import validate_loss 7 | 8 | torch = pytest.importorskip("torch") 9 | from torch import Tensor # noqa: E402 10 | 11 | 12 | def test_focal_loss(sample_outputs: tuple[Tensor, Tensor]) -> None: 13 | outputs, targets = sample_outputs 14 | criterion = FocalLoss() 15 | loss = criterion(outputs, targets) 16 | validate_loss(loss) 17 | # TODO(pwhofman): Add tests for different values of alpha and gamma 18 | # https://github.com/pwhofman/probly/issues/92 19 | 20 | 21 | def test_expected_calibration_error( 22 | sample_outputs: tuple[Tensor, Tensor], 23 | ) -> None: 24 | outputs, targets = sample_outputs 25 | outputs = torch.softmax(outputs, dim=1) 26 | criterion = ExpectedCalibrationError() 27 | loss = criterion(outputs, targets) 28 | validate_loss(loss) 29 | 30 | criterion = ExpectedCalibrationError(num_bins=1) 31 | loss = criterion(outputs, targets) 32 | validate_loss(loss) 33 | 34 | 35 | def test_label_relaxation_loss( 36 | sample_outputs: tuple[Tensor, Tensor], 37 | ) -> None: 38 | outputs, targets = sample_outputs 39 | criterion = LabelRelaxationLoss() 40 | loss = criterion(outputs, targets) 41 | validate_loss(loss) 42 | 43 | criterion = LabelRelaxationLoss(alpha=1.0) 44 | loss = criterion(outputs, targets) 45 | validate_loss(loss) 46 | -------------------------------------------------------------------------------- /src/probly/transformation/evidential/regression/common.py: -------------------------------------------------------------------------------- 1 | """Shared evidential regression implementation.""" 2 | 3 | from __future__ import annotations 4 | 5 | from typing import TYPE_CHECKING 6 | 7 | from probly.traverse_nn import nn_compose 8 | from pytraverse import CLONE, TRAVERSE_REVERSED, GlobalVariable, lazydispatch_traverser, traverse 9 | 10 | if TYPE_CHECKING: 11 | from lazy_dispatch.isinstance import LazyType 12 | from probly.predictor import Predictor 13 | from pytraverse.composition import RegisteredLooseTraverser 14 | 15 | REPLACED_LAST_LINEAR = GlobalVariable[bool]( 16 | "REPLACED_LAST_LINEAR", 17 | "Whether the last linear layer has been replaced with a NormalInverseGammaLinear layer.", 18 | default=False, 19 | ) 20 | 21 | evidential_regression_traverser = lazydispatch_traverser[object](name="evidential_regression_traverser") 22 | 23 | 24 | def register(cls: LazyType, traverser: RegisteredLooseTraverser) -> None: 25 | """Register a class to be replaced by a normal inverse gamma layer.""" 26 | evidential_regression_traverser.register(cls=cls, traverser=traverser, skip_if=lambda s: s[REPLACED_LAST_LINEAR]) 27 | 28 | 29 | def evidential_regression[T: Predictor](base: T) -> T: 30 | """Create an evidential regression predictor from a base predictor. 31 | 32 | Args: 33 | base: Predictor, The base model to be used for evidential regression. 34 | 35 | Returns: 36 | Predictor, The evidential regression predictor. 37 | """ 38 | return traverse(base, nn_compose(evidential_regression_traverser), init={TRAVERSE_REVERSED: True, CLONE: True}) 39 | -------------------------------------------------------------------------------- /tests/probly/evaluation/test_tasks.py: -------------------------------------------------------------------------------- 1 | """Tests for the tasks module.""" 2 | 3 | from __future__ import annotations 4 | 5 | import numpy as np 6 | import pytest 7 | 8 | from probly.evaluation.tasks import out_of_distribution_detection, selective_prediction 9 | 10 | 11 | def test_selective_prediction_shapes() -> None: 12 | rng = np.random.default_rng() 13 | auroc, bin_losses = selective_prediction(rng.random(10), rng.random(10), n_bins=5) 14 | assert isinstance(auroc, float) 15 | assert isinstance(bin_losses, np.ndarray) 16 | assert bin_losses.shape == (5,) 17 | 18 | 19 | def test_selective_prediction_order() -> None: 20 | criterion = np.linspace(0, 1, 10) 21 | losses = np.linspace(0, 1, 10) 22 | _, bin_losses = selective_prediction(criterion, losses, n_bins=5) 23 | assert np.all(np.diff(bin_losses) <= 0) 24 | 25 | 26 | def test_selective_prediction_too_many_bins() -> None: 27 | rng = np.random.default_rng() 28 | with pytest.raises(ValueError, match="The number of bins can not be larger than the number of elements criterion"): 29 | selective_prediction(rng.random(5), rng.random(5), n_bins=10) 30 | 31 | 32 | def test_out_of_distribution_detection_shape() -> None: 33 | rng = np.random.default_rng() 34 | auroc = out_of_distribution_detection(rng.random(10), rng.random(10)) 35 | assert isinstance(auroc, float) 36 | 37 | 38 | def test_out_of_distribution_detection_order() -> None: 39 | in_distribution = np.linspace(0, 1, 10) 40 | out_distribution = np.linspace(0, 1, 10) + 1 41 | auroc = out_of_distribution_detection(in_distribution, out_distribution) 42 | assert auroc == 0.995 43 | -------------------------------------------------------------------------------- /docs/source/methods.md: -------------------------------------------------------------------------------- 1 | # Implemented methods 2 | The following methods are currently implemented in `probly`. 3 | 4 | ## Representation 5 | ### Second-order Distributions 6 | These methods represent (epistemic) uncertainty by a second-order distribution over distributions. 7 | #### Bayesian Neural Networks 8 | {cite:p}`blundellWeightUncertainty2015` 9 | 10 | #### Dropout 11 | {cite:p}`galDropoutBayesian2016` 12 | 13 | #### DropConnect 14 | {cite:p}`mobinyDropConnectEffective2019` 15 | 16 | #### Deep Ensembles 17 | {cite:p}`lakshminarayananSimpleScalable2017` 18 | 19 | #### Evidential Deep Learning 20 | {cite:p}`sensoyEvidentialDeep2018` 21 | {cite:p}`aminiDeepEvidential2020` 22 | 23 | ### Credal Sets 24 | These methods represent (epistemic) uncertainty by a convex set of distributions. 25 | #### Credal Ensembling 26 | {cite:p}`nguyenCredalEnsembling2025` 27 | 28 | ## Quantification 29 | #### Upper / lower entropy 30 | {cite:p}`abellanDisaggregatedTotal2006` 31 | 32 | #### Generalized Hartley 33 | {cite:p}`abellanNonSpecificity2000` 34 | 35 | #### Entropy-based 36 | {cite:p}`depewegDecompositionUncertainty2018` 37 | 38 | #### Distance-based 39 | {cite:p}`saleSecondOrder2024` 40 | 41 | ### Conformal Prediction 42 | These methods represent uncertainty by a set of predictions. 43 | #### Split Conformal Prediction 44 | {cite:p}`angelopoulosGentleIntroduction2021` 45 | ## Calibration 46 | These methods adjust the model's probabilities to better reflect the true probabilities. 47 | #### Focal Loss 48 | {cite:p}`linFocalLoss2017` 49 | 50 | #### Label Relaxation 51 | {cite:p}`lienenFromLabel2021` 52 | 53 | #### Temperature Scaling 54 | {cite:p}`guoOnCalibration2017` 55 | -------------------------------------------------------------------------------- /tests/probly/utils/test_probabilities.py: -------------------------------------------------------------------------------- 1 | """Tests for utils.probabilities functions.""" 2 | 3 | from __future__ import annotations 4 | 5 | import numpy as np 6 | 7 | from probly.utils.probabilities import differential_entropy_gaussian, intersection_probability, kl_divergence_gaussian 8 | 9 | 10 | def test_differential_entropy_gaussian() -> None: 11 | assert np.isclose(differential_entropy_gaussian(0.5), 1.54709559) 12 | assert np.allclose(differential_entropy_gaussian(np.array([1, 2]), base=np.e), np.array([1.41893853, 1.76551212])) 13 | 14 | 15 | def test_kl_divergence_gaussian() -> None: 16 | mu1 = np.array([0.0, 1.0]) 17 | mu2 = np.array([1.0, 0.0]) 18 | sigma21 = np.array([0.1, 0.1]) 19 | sigma22 = np.array([0.1, 0.1]) 20 | assert np.isclose(kl_divergence_gaussian(1.0, 1.0, 1.0, 1.0), 0.0) 21 | assert np.isclose(kl_divergence_gaussian(1.0, 1.0, 1.0, 1.0, base=np.e), 0.0) 22 | assert np.allclose(kl_divergence_gaussian(mu1, sigma21, mu2, sigma22, base=np.e), np.array([5.0, 5.0])) 23 | 24 | 25 | def test_intersection_probability() -> None: 26 | rng = np.random.default_rng() 27 | 28 | probs = rng.dirichlet(np.ones(2), size=(5, 5)) 29 | int_prob = intersection_probability(probs) 30 | assert int_prob.shape == (5, 2) 31 | assert np.allclose(np.sum(int_prob, axis=1), 1.0) 32 | 33 | probs = rng.dirichlet(np.ones(10), size=(5, 5)) 34 | int_prob = intersection_probability(probs) 35 | assert int_prob.shape == (5, 10) 36 | assert np.allclose(np.sum(int_prob, axis=1), 1.0) 37 | 38 | probs = np.array([1 / 3, 1 / 3, 1 / 3] * 5).reshape(1, 5, 3) 39 | int_prob = intersection_probability(probs) 40 | assert int_prob.shape == (1, 3) 41 | assert np.allclose(np.sum(int_prob, axis=1), 1.0) 42 | -------------------------------------------------------------------------------- /tests/probly/transformation/bayesian/test_common.py: -------------------------------------------------------------------------------- 1 | """Test for dropout models.""" 2 | 3 | from __future__ import annotations 4 | 5 | import pytest 6 | 7 | from probly.predictor import Predictor 8 | from probly.transformation import bayesian 9 | 10 | 11 | def test_invalid_prior_std_value(dummy_predictor: Predictor) -> None: 12 | """Tests the behavior of the bayesian function when provided with an invalid prior standard deviation value. 13 | 14 | This function validates that the bayesian function raises a ValueError when 15 | the prior standard deviation parameter is not positive. 16 | 17 | Raises: 18 | ValueError: If the prior standard deviation is not positive or equal to zero. 19 | """ 20 | prior_std = -1.0 21 | msg = f"The prior standard deviation prior_std must be greater than 0, but got {prior_std} instead." 22 | with pytest.raises(ValueError, match=msg): 23 | bayesian(dummy_predictor, prior_std=prior_std) 24 | 25 | 26 | def test_invalid_posterior_std_value(dummy_predictor: Predictor) -> None: 27 | """Tests the behavior of the bayesian function when provided with an invalid posterior standard deviation value. 28 | 29 | This function validates that the bayesian function raises a ValueError when 30 | the posterior standard deviation parameter is not positive. 31 | 32 | Raises: 33 | ValueError: If the posterior standard deviation is not positive or equal to zero. 34 | """ 35 | posterior_std = -1.0 36 | msg = ( 37 | "The initial posterior standard deviation posterior_std must be greater than 0, " 38 | f"but got {posterior_std} instead." 39 | ) 40 | with pytest.raises(ValueError, match=msg): 41 | bayesian(dummy_predictor, posterior_std=posterior_std) 42 | -------------------------------------------------------------------------------- /src/probly/traverse_nn/common.py: -------------------------------------------------------------------------------- 1 | """Generic traverser helpers for neural networks.""" 2 | 3 | from __future__ import annotations 4 | 5 | import pytraverse as t 6 | 7 | LAYER_COUNT = t.GlobalVariable[int]( 8 | "LAYER_COUNT", 9 | "The DFS index of the current layer/module.", 10 | default=0, 11 | ) 12 | FLATTEN_SEQUENTIAL = t.StackVariable[bool]( 13 | "FLATTEN_SEQUENTIAL", 14 | "Whether to flatten sequential modules after making changes.", 15 | default=True, 16 | ) 17 | 18 | 19 | @t.computed 20 | def is_first_layer(state: t.State) -> bool: 21 | """Whether the current layer is the first layer.""" 22 | return state[LAYER_COUNT] == 0 23 | 24 | 25 | layer_count_traverser = t.singledispatch_traverser[object](name="layer_count_traverser") 26 | 27 | nn_traverser = t.lazydispatch_traverser[object](name="nn_traverser") 28 | 29 | 30 | def compose( 31 | traverser: t.Traverser, 32 | nn_traverser: t.Traverser = nn_traverser, 33 | name: str | None = None, 34 | ) -> t.Traverser: 35 | """Compose a custom traverser with neural network traversal functionality. 36 | 37 | This function creates a sequential traverser that combines neural network traversal, 38 | a custom traverser, and layer counting capabilities in a specific order. 39 | 40 | Args: 41 | traverser: A custom traverser function to be composed with the NN traverser. 42 | nn_traverser: The neural network traverser to use. Defaults to the module's 43 | nn_traverser. 44 | name: Optional name for the composed traverser. 45 | 46 | Returns: 47 | A composed sequential traverser that applies NN traversal, custom traversal, 48 | and layer counting in sequence. 49 | """ 50 | return t.sequential(nn_traverser, traverser, layer_count_traverser, name=name) 51 | -------------------------------------------------------------------------------- /src/probly/utils/sets.py: -------------------------------------------------------------------------------- 1 | """Utility functions regarding sets.""" 2 | 3 | from __future__ import annotations 4 | 5 | import itertools 6 | from typing import TYPE_CHECKING 7 | 8 | import numpy as np 9 | 10 | if TYPE_CHECKING: 11 | from collections.abc import Iterable 12 | 13 | 14 | def powerset(iterable: Iterable[int]) -> list[tuple[()]]: 15 | """Generate the power set of a given iterable. 16 | 17 | Args: 18 | iterable: Iterable 19 | Returns: 20 | List[tuple], power set of the given iterable 21 | 22 | """ 23 | s = list(iterable) 24 | return list(itertools.chain.from_iterable(itertools.combinations(s, r) for r in range(len(s) + 1))) 25 | 26 | 27 | def capacity(q: np.ndarray, a: Iterable[int]) -> np.ndarray: 28 | """Compute the capacity of set q given set a. 29 | 30 | Args: 31 | q: numpy.ndarray, shape (n_instances, n_samples, n_classes) 32 | a: Iterable, shape (n_classes,), indices indicating subset of classes 33 | Returns: 34 | min_capacity: numpy.ndarray, shape (n_instances,), capacity of q given a 35 | 36 | """ 37 | selected_sum = np.sum(q[:, :, a], axis=2) 38 | min_capacity = np.min(selected_sum, axis=1) 39 | return min_capacity 40 | 41 | 42 | def moebius(q: np.ndarray, a: Iterable[int]) -> np.ndarray: 43 | """Compute the Moebius function of a set q given a set a. 44 | 45 | Args: 46 | q: numpy.ndarray of shape (num_samples, num_members, num_classes) 47 | a: numpy.ndarray, shape (n_classes,), indices indicating subset of classes 48 | Returns: 49 | m_a: numpy.ndarray, shape (n_instances,), moebius value of q given a 50 | 51 | """ 52 | ps_a = powerset(a) # powerset of A 53 | ps_a.pop(0) # remove empty set 54 | m_a = np.zeros(q.shape[0]) 55 | for b in ps_a: 56 | dl = len(set(a) - set(b)) 57 | m_a += ((-1) ** dl) * capacity(q, b) 58 | return m_a 59 | -------------------------------------------------------------------------------- /src/probly/representation/credal_set/credal_set.py: -------------------------------------------------------------------------------- 1 | """Classes representing credal sets.""" 2 | 3 | from __future__ import annotations 4 | 5 | from typing import TYPE_CHECKING 6 | 7 | from lazy_dispatch.singledispatch import lazydispatch 8 | from probly.lazy_types import JAX_ARRAY, TORCH_TENSOR 9 | from probly.representation.sampling.sample import ArraySample 10 | 11 | if TYPE_CHECKING: 12 | import numpy as np 13 | 14 | 15 | class CredalSet[T]: 16 | """Base class for credal sets.""" 17 | 18 | def lower(self) -> T: 19 | """Compute the lower envelope of the credal set.""" 20 | msg = "lower method not implemented." 21 | raise NotImplementedError(msg) 22 | 23 | def upper(self) -> T: 24 | """Compute the upper envelope of the credal set.""" 25 | msg = "upper method not implemented." 26 | raise NotImplementedError(msg) 27 | 28 | 29 | credal_set_from_sample = lazydispatch[type[CredalSet], CredalSet](CredalSet) 30 | 31 | 32 | @credal_set_from_sample.register(ArraySample) 33 | class ArrayCredalSet[T](CredalSet[T]): 34 | """A credal set of predictions stored in a numpy array.""" 35 | 36 | def __init__(self, sample: ArraySample) -> None: 37 | """Initialize the array credal set.""" 38 | self.array: np.ndarray = sample.array 39 | 40 | def lower(self) -> T: 41 | """Compute the lower envelope of the credal set.""" 42 | return self.array.min(axis=1) # type: ignore[no-any-return] 43 | 44 | def upper(self) -> T: 45 | """Compute the upper envelope of the credal set.""" 46 | return self.array.max(axis=1) # type: ignore[no-any-return] 47 | 48 | 49 | @credal_set_from_sample.delayed_register(TORCH_TENSOR) 50 | def _(_: type) -> None: 51 | from . import torch as torch # noqa: PLC0414, PLC0415 52 | 53 | 54 | @credal_set_from_sample.delayed_register(JAX_ARRAY) 55 | def _(_: type) -> None: 56 | from . import jax as jax # noqa: PLC0414, PLC0415 57 | -------------------------------------------------------------------------------- /src/lazy_dispatch/load.py: -------------------------------------------------------------------------------- 1 | """Lazy import utilities.""" 2 | 3 | from __future__ import annotations 4 | 5 | import importlib 6 | import importlib.util 7 | import sys 8 | from typing import TYPE_CHECKING, Any, overload 9 | 10 | if TYPE_CHECKING: 11 | from collections.abc import Callable, Iterable 12 | from types import ModuleType 13 | 14 | 15 | def lazy_import(name: str, package: str | None = None, register: bool = False) -> ModuleType: 16 | """Lazily import a module.""" 17 | if name in sys.modules: 18 | return sys.modules[name] 19 | 20 | spec = importlib.util.find_spec(name, package=package) 21 | 22 | if spec is None or spec.loader is None: 23 | msg = f"Module {name} not found" 24 | raise ImportError(msg) 25 | 26 | loader: importlib.util.LazyLoader = importlib.util.LazyLoader(spec.loader) 27 | spec.loader = loader 28 | module = importlib.util.module_from_spec(spec) 29 | if register: 30 | sys.modules[name] = module 31 | loader.exec_module(module) 32 | return module 33 | 34 | 35 | @overload 36 | def lazy_callable( 37 | module: ModuleType | str, 38 | attrs: str, 39 | package: str | None = None, 40 | register: bool = False, 41 | ) -> Callable: ... 42 | 43 | 44 | @overload 45 | def lazy_callable( 46 | module: ModuleType | str, 47 | attrs: Iterable[str], 48 | package: str | None = None, 49 | register: bool = False, 50 | ) -> list[Callable]: ... 51 | 52 | 53 | def lazy_callable( 54 | module: ModuleType | str, 55 | attrs: str | Iterable[str], 56 | package: str | None = None, 57 | register: bool = False, 58 | ) -> Callable | list[Callable]: 59 | """Lazily get a callable attribute from a module or module name.""" 60 | if isinstance(module, str): 61 | module = lazy_import(module, package=package, register=register) 62 | 63 | if isinstance(attrs, str): 64 | 65 | def fn(*args: Any, **kwargs: Any) -> Any: # noqa: ANN401 66 | return getattr(module, attrs)(*args, **kwargs) 67 | 68 | return fn 69 | 70 | return [lazy_callable(module, attr) for attr in attrs] 71 | -------------------------------------------------------------------------------- /tests/probly/markers.py: -------------------------------------------------------------------------------- 1 | """Markers for optional dependencies. 2 | 3 | Markers can be imported in the test files to skip tests if certain packages are not installed. 4 | Note that we want to avoid marked tests as best as possible, but sometimes dependencies are lagging 5 | behind the latest version of the package or the python version. 6 | 7 | Markers should only be used on optional dependencies, i.e. packages that are not required to run 8 | `probly`. 9 | 10 | """ 11 | 12 | from __future__ import annotations 13 | 14 | import importlib.util 15 | 16 | import pytest 17 | 18 | __all__ = [ 19 | "skip_if_no_keras", 20 | "skip_if_no_lightgbm", 21 | "skip_if_no_sklearn", 22 | "skip_if_no_tabpfn", 23 | "skip_if_no_tensorflow", 24 | "skip_if_no_torchvision", 25 | "skip_if_no_xgboost", 26 | ] 27 | 28 | 29 | def is_installed(pkg_name: str) -> bool: 30 | """Check if a package is installed without importing it.""" 31 | return importlib.util.find_spec(pkg_name) is not None 32 | 33 | 34 | # torch related: (torch itself is a main dependency) ----------------------------------------------- 35 | skip_if_no_torchvision = pytest.mark.skipif(not is_installed("torchvision"), reason="torchvision is not installed") 36 | 37 | # sklearn-like: ------------------------------------------------------------------------------------ 38 | skip_if_no_sklearn = pytest.mark.skipif(not is_installed("sklearn"), reason="sklearn is not installed") 39 | 40 | skip_if_no_xgboost = pytest.mark.skipif(not is_installed("xgboost"), reason="xgboost is not installed") 41 | 42 | skip_if_no_lightgbm = pytest.mark.skipif(not is_installed("lightgbm"), reason="lightgbm is not installed") 43 | 44 | # tensorflow related: ------------------------------------------------------------------------------ 45 | skip_if_no_tensorflow = pytest.mark.skipif(not is_installed("tensorflow"), reason="tensorflow is not installed") 46 | 47 | skip_if_no_keras = pytest.mark.skipif(not is_installed("keras"), reason="keras is not installed") 48 | 49 | # misc: -------------------------------------------------------------------------------------------- 50 | skip_if_no_tabpfn = pytest.mark.skipif(not is_installed("tabpfn"), reason="TabPFN is not available.") 51 | -------------------------------------------------------------------------------- /src/probly/utils/torch.py: -------------------------------------------------------------------------------- 1 | """Utility functions for PyTorch models.""" 2 | 3 | from __future__ import annotations 4 | 5 | import torch 6 | import torch.nn.functional as F 7 | from tqdm import tqdm 8 | 9 | 10 | @torch.no_grad() 11 | def torch_collect_outputs( 12 | model: torch.nn.Module, 13 | loader: torch.utils.data.DataLoader, 14 | device: torch.device, 15 | ) -> tuple[torch.Tensor, torch.Tensor]: 16 | """Collect outputs and targets from a model for a given data loader. 17 | 18 | Args: 19 | model: torch.nn.Module, model to collect outputs from 20 | loader: torch.utils.data.DataLoader, data loader to collect outputs from 21 | device: torch.device, device to move data to 22 | Returns: 23 | outputs: torch.Tensor, shape (n_instances, n_classes), model outputs 24 | targets: torch.Tensor, shape (n_instances,), target labels 25 | """ 26 | outputs = torch.empty(0, device=device) 27 | targets = torch.empty(0, device=device) 28 | for inpt, target in tqdm(loader, desc="Batches"): 29 | outputs = torch.cat((outputs, model(inpt.to(device))), dim=0) 30 | targets = torch.cat((targets, target.to(device)), dim=0) 31 | return outputs, targets 32 | 33 | 34 | def torch_reset_all_parameters(module: torch.nn.Module) -> None: 35 | """Reset all parameters of a torch module. 36 | 37 | Args: 38 | module: torch.nn.Module, module to reset parameters 39 | 40 | """ 41 | if hasattr(module, "reset_parameters"): 42 | module.reset_parameters() 43 | for child in module.children(): 44 | if hasattr(child, "reset_parameters"): 45 | child.reset_parameters() 46 | 47 | 48 | def temperature_softmax(logits: torch.Tensor, temperature: float | torch.Tensor) -> torch.Tensor: 49 | """Compute the softmax of logits with temperature scaling applied. 50 | 51 | Computes the softmax based on the logits divided by the temperature. Assumes that the last dimension 52 | of logits is the class dimension. 53 | 54 | Args: 55 | logits: torch.Tensor, shape (n_instances, n_classes), logits to apply softmax on 56 | temperature: float, temperature scaling factor 57 | Returns: 58 | ts: torch.Tensor, shape (n_instances, n_classes), softmax of logits with temperature scaling applied 59 | """ 60 | ts = F.softmax(logits / temperature, dim=-1) 61 | return ts 62 | -------------------------------------------------------------------------------- /src/probly/train/bayesian/torch.py: -------------------------------------------------------------------------------- 1 | """Collection of torch Bayesian training functions.""" 2 | 3 | from __future__ import annotations 4 | 5 | from typing import TYPE_CHECKING 6 | 7 | if TYPE_CHECKING: 8 | from probly.predictor import Predictor 9 | 10 | import torch 11 | from torch import nn 12 | import torch.nn.functional as F 13 | 14 | from probly.layers.torch import BayesConv2d, BayesLinear # noqa: TC001, required by traverser 15 | from probly.traverse_nn import nn_compose 16 | from pytraverse import GlobalVariable, State, TraverserResult, singledispatch_traverser, traverse_with_state 17 | 18 | KL_DIVERGENCE = GlobalVariable[torch.Tensor]("KL_DIVERGENCE", default=0.0) 19 | 20 | 21 | @singledispatch_traverser[object] 22 | def kl_divergence_traverser( 23 | obj: BayesLinear | BayesConv2d, 24 | state: State, 25 | ) -> TraverserResult[BayesLinear | BayesConv2d]: 26 | """Traverser to compute the KL divergence of a Bayesian layer.""" 27 | state[KL_DIVERGENCE] += obj.kl_divergence 28 | return obj, state 29 | 30 | 31 | def collect_kl_divergence(model: Predictor) -> torch.Tensor: 32 | """Collect the KL divergence of the Bayesian model by summing the KL divergence of each Bayesian layer.""" 33 | _, state = traverse_with_state(model, nn_compose(kl_divergence_traverser)) 34 | return state[KL_DIVERGENCE] 35 | 36 | 37 | class ELBOLoss(nn.Module): 38 | """Evidence lower bound loss based on :cite:`blundellWeightUncertainty2015`. 39 | 40 | Attributes: 41 | kl_penalty: float, weight for KL divergence term 42 | """ 43 | 44 | def __init__(self, kl_penalty: float = 1e-5) -> None: 45 | """Initializes an instance of the ELBOLoss class. 46 | 47 | Args: 48 | kl_penalty: float, weight for KL divergence term 49 | """ 50 | super().__init__() 51 | self.kl_penalty = kl_penalty 52 | 53 | def forward(self, inputs: torch.Tensor, targets: torch.Tensor, kl: torch.Tensor) -> torch.Tensor: 54 | """Forward pass of the ELBO loss. 55 | 56 | Args: 57 | inputs: torch.Tensor of size (n_instances, n_classes) 58 | targets: torch.Tensor of size (n_instances,) 59 | kl: torch.Tensor, KL divergence of the model 60 | Returns: 61 | loss: torch.Tensor, mean loss value 62 | """ 63 | loss = F.cross_entropy(inputs, targets) + self.kl_penalty * kl 64 | return loss 65 | -------------------------------------------------------------------------------- /src/probly/transformation/ensemble/flax.py: -------------------------------------------------------------------------------- 1 | """Ensemble Flax implementation.""" 2 | 3 | from __future__ import annotations 4 | 5 | from dataclasses import asdict, is_dataclass 6 | from typing import Any 7 | 8 | from flax import linen as nn 9 | import jax 10 | import jax.numpy as jnp 11 | 12 | from .common import register 13 | 14 | 15 | class FlaxEnsemble(nn.Module): 16 | """FlaxEnsemble class.""" 17 | 18 | base_module: type 19 | base_kwargs: dict[str, Any] | None = None 20 | n_members: int = 10 21 | 22 | @nn.compact 23 | def __call__( 24 | self, 25 | x: jnp.ndarray, 26 | *, 27 | return_all: bool = False, 28 | **call_kwargs: object, 29 | ) -> jnp.ndarray | dict[str, jnp.ndarray]: 30 | """Apply each ensemble member and aggregate.""" 31 | outputs: list[object] = [] 32 | for i in range(self.n_members): 33 | ctor_kwargs = self.base_kwargs or {} 34 | member = self.base_module(**ctor_kwargs, name=f"member_{i}") 35 | y = member(x, **call_kwargs) 36 | outputs.append(y) 37 | 38 | stacked = jax.tree_util.tree_map(lambda *leaves: jnp.stack(leaves, axis=1), *outputs) 39 | if return_all: 40 | return stacked # type: ignore[return-value] 41 | averaged = jax.tree_util.tree_map(lambda arr: jnp.mean(arr, axis=1), stacked) 42 | return averaged # type: ignore[return-value] 43 | 44 | 45 | def generate_flax_ensemble( 46 | model: nn.Module | type[nn.Module] | str, 47 | n_members: int, 48 | ) -> FlaxEnsemble: 49 | """Create a class:FlaxEnsemble from a module instance or class.""" 50 | base_module = model.__class__ if not isinstance(model, type) else model 51 | base_kwargs: dict[str, Any] | None = None 52 | if not isinstance(model, type) and is_dataclass(model): 53 | try: 54 | # mypy: after is_dataclass check, model is a dataclass instance 55 | all_fields = asdict(model) # type: ignore[arg-type] 56 | filtered = {k: v for k, v in all_fields.items() if not k.startswith("_") and k not in ("name", "parent")} 57 | if filtered: 58 | base_kwargs = filtered 59 | except (TypeError, ValueError): 60 | base_kwargs = None 61 | return FlaxEnsemble(base_module=base_module, base_kwargs=base_kwargs, n_members=n_members) 62 | 63 | 64 | register(nn.Module, generate_flax_ensemble) 65 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # probly: Uncertainty Representation and Quantification for Machine Learning 2 |
3 | 4 | 5 | 6 | probly logo 7 | 8 | 9 | [![PyPI version](https://badge.fury.io/py/probly.svg)](https://badge.fury.io/py/probly) 10 | [![PyPI status](https://img.shields.io/pypi/status/probly.svg?color=blue)](https://pypi.org/project/probly) 11 | [![PePy](https://static.pepy.tech/badge/probly?style=flat-square)](https://pepy.tech/project/probly) 12 | [![codecov](https://codecov.io/gh/pwhofman/probly/branch/main/graph/badge.svg)](https://codecov.io/gh/pwhofman/probly) 13 | [![Contributions Welcome](https://img.shields.io/badge/contributions-welcome-brightgreen)](.github/CONTRIBUTING.md) 14 | [![License](https://img.shields.io/badge/License-MIT-brightgreen.svg)](https://opensource.org/licenses/MIT) 15 |
16 | 17 | ## 🛠️ Install 18 | `probly` is intended to work with **Python 3.12 and above**. Installation can be done via `pip` and 19 | or `uv`: 20 | 21 | ```sh 22 | pip install probly 23 | ``` 24 | 25 | ```sh 26 | uv add probly 27 | ``` 28 | 29 | ## ⭐ Quickstart 30 | 31 | `probly` makes it very easy to make models uncertainty-aware and perform several downstream tasks: 32 | 33 | ```python 34 | import probly 35 | import torch.nn.functional as F 36 | 37 | net = ... # get neural network 38 | model = probly.transformation.dropout(net) # make neural network a Dropout model 39 | train(model) # train model as usual 40 | 41 | data = ... # get data 42 | data_ood = ... # get out of distribution data 43 | sampler = probly.representation.Sampler(model, num_samples=20) 44 | sample = sampler.predict(data) # predict an uncertainty representation 45 | sample_ood = sampler.predict(data_ood) 46 | 47 | eu = probly.quantification.classification.mutual_information(sample) # quantify model's epistemic uncertainty 48 | eu_ood = probly.quantification.classification.mutual_information(sample_ood) 49 | 50 | auroc = probly.evaluation.tasks.out_of_distribution_detection(eu, eu_ood) # evaluate model's uncertainty 51 | ``` 52 | 53 | ## 📜 License 54 | This project is licensed under the [MIT License](https://github.com/pwhofman/probly/blob/main/LICENSE). 55 | 56 | --- 57 | Built with ❤️ by the probly team. 58 | -------------------------------------------------------------------------------- /docs/source/index.rst: -------------------------------------------------------------------------------- 1 | .. probly documentation master file, created by 2 | sphinx-quickstart on Tue Apr 8 15:10:07 2025. 3 | You can adapt this file completely to your liking, but it should at least 4 | contain the root `toctree` directive. 5 | 6 | The `probly` Python package 7 | =========================== 8 | 9 | `probly` is a Python package for **uncertainty representation** and **quantification** for machine learning. 10 | 11 | Installation 12 | ~~~~~~~~~~~~ 13 | `probly` is intended to work with **Python 3.10 and above**. Installation can be done via `pip` and 14 | or `uv`: 15 | 16 | .. code-block:: sh 17 | 18 | pip install probly 19 | 20 | or 21 | 22 | .. code-block:: sh 23 | 24 | uv add probly 25 | 26 | Quickstart 27 | ~~~~~~~~~~~~ 28 | `probly` makes it very easy to make models uncertainty-aware and perform several downstream tasks: 29 | 30 | .. code-block:: python 31 | 32 | import probly 33 | import torch.nn.functional as F 34 | 35 | net = ... # get neural network 36 | model = probly.representation.Dropout(net) # make neural network a Dropout model 37 | train(model) # train model as usual 38 | 39 | data = ... # get data 40 | preds = model.predict_representation(data) # predict an uncertainty representation 41 | eu = probly.quantification.classification.mutual_information(preds) # compute model's epistemic uncertainty 42 | 43 | data_ood = ... # get out of distribution data 44 | preds_ood = model.predict_representation(data_ood) 45 | eu_ood = probly.quantification.classification.mutual_information(preds_ood) 46 | auroc = probly.tasks.out_of_distribution_detection(eu, eu_ood) # compute the AUROC score for out of distribution detection 47 | 48 | 49 | 50 | .. toctree:: 51 | :maxdepth: 1 52 | :caption: Content 53 | :hidden: 54 | 55 | methods 56 | references 57 | 58 | .. toctree:: 59 | :maxdepth: 0 60 | :caption: Tutorials 61 | :hidden: 62 | 63 | examples/fashionmnist_ood_ensemble 64 | examples/label_relaxation_calibration 65 | examples/sklearn_selective_prediction 66 | examples/synthetic_regression_dropout 67 | examples/temperature_scaling_calibration 68 | examples/train_bnn_classification 69 | examples/train_evidential_classification 70 | examples/train_evidential_regression 71 | 72 | .. toctree:: 73 | :maxdepth: 2 74 | :caption: API 75 | :hidden: 76 | 77 | api 78 | 79 | .. toctree:: 80 | :maxdepth: 1 81 | :caption: Development 82 | :hidden: 83 | 84 | contributing 85 | -------------------------------------------------------------------------------- /.github/CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing to probly 🏔️ 2 | Probly is still in early development and we welcome contributions 3 | in many forms. If you have an idea for a new feature, a bug fix, or 4 | any other suggestion for improvement, please open an issue on GitHub. 5 | If you would like to contribute code, keep reading! 6 | 7 | ## What to work on ❓ 8 | We want to offer support for PyTorch, HuggingFace, and sklearn models. We are interested in 9 | any contributions that translate existing features to these libraries. Furthermore, we are 10 | interested in any new features within the following scope: 11 | - Representation methods; 12 | - Quantification methods; 13 | - Calibration methods; 14 | - Downstream tasks; 15 | - Datasets and dataloaders. 16 | 17 | If you have other suggestions, please open an issue on GitHub such that we can discuss it. 18 | 19 | ## Development setup 🔬 20 | The recommended workflow for contributing to probly is: 21 | 1. Fork the `main` branch of repository on GitHub. 22 | 2. Clone the fork locally. 23 | 3. Commit the changes. 24 | 4. Push the changes to the fork. 25 | 5. Create a pull request to the `main` branch of the repository. 26 | 27 | ### Setting up the development environment 28 | Once you have cloned the fork, you can set up your development environment. 29 | You will need a Python 3.10+ environment. We recommend using [uv](https://docs.astral.sh/uv/) as a package manager. 30 | To set up the development environment, run the following command: 31 | ```sh 32 | uv sync --dev 33 | ``` 34 | This will install all the required dependencies in your Python environment. 35 | 36 | ## Guidelines 🚓 37 | Here are some guidelines to follow when contributing to probly. 38 | 39 | ### General 40 | If you use code from other sources, make sure to carefully look at the license and give credit to the original author(s). 41 | If the feature you are implementing is based on a paper, make sure to include a reference 42 | in the docstring. 43 | 44 | ### Code style 45 | We use [Ruff](https://docs.astral.sh/ruff/) for linting and formatting, the rules of which can be found in the [pyproject.toml](https://github.com/pwhofman/probly/blob/main/pyproject.toml) file. 46 | However, if your development environment is set up correctly, the Ruff pre-commit hook should take care of this for you. 47 | 48 | ### Documentation 49 | If you are adding new features, make sure to document them in the docstring; 50 | our docstrings follow the [Google Style Guide](https://google.github.io/styleguide/pyguide.html#docstrings). 51 | -------------------------------------------------------------------------------- /tests/probly/tests_deprecations/utils.py: -------------------------------------------------------------------------------- 1 | """A module containing all deprecated features / behaviors in probly. Adapted from the shapiq package. 2 | 3 | Usage: 4 | To add a new deprecated feature or behavior, create an instance of the `DeprecatedFeature` and 5 | append it to the `DEPRECATED_FEATURES` list. The `call` attribute should be a callable that 6 | triggers the deprecation warning when executed. The `deprecated_in` and `removed_in` attributes 7 | should specify the version in which the feature was deprecated and the version in which it will 8 | be removed, respectively. 9 | 10 | """ 11 | 12 | from __future__ import annotations 13 | 14 | import importlib 15 | import pathlib 16 | import pkgutil 17 | from typing import TYPE_CHECKING, NamedTuple 18 | 19 | if TYPE_CHECKING: 20 | from collections.abc import Callable 21 | 22 | import pytest 23 | 24 | DeprecatedTestFunc = Callable[[pytest.FixtureRequest], None] 25 | 26 | 27 | class DeprecatedFeature(NamedTuple): 28 | """A named tuple to represent a deprecated feature.""" 29 | 30 | name: str 31 | deprecated_in: str 32 | removed_in: str 33 | call: Callable[[pytest.FixtureRequest], None] 34 | 35 | 36 | DEPRECATED_FEATURES: list[DeprecatedFeature] = [] 37 | 38 | 39 | def register_deprecated( 40 | name: str, 41 | deprecated_in: str, 42 | removed_in: str, 43 | ) -> Callable[[DeprecatedTestFunc], DeprecatedTestFunc]: 44 | """Decorator to register a deprecated feature. 45 | 46 | Args: 47 | name: The name of the deprecated feature. 48 | deprecated_in: The version in which the feature was deprecated. 49 | removed_in: The version in which the feature will be removed. 50 | 51 | Returns: 52 | A decorator that registers the deprecated feature. 53 | """ 54 | 55 | def _decorator(func: Callable[[pytest.FixtureRequest], None]) -> Callable[[pytest.FixtureRequest], None]: 56 | DEPRECATED_FEATURES.append(DeprecatedFeature(name, deprecated_in, removed_in, func)) 57 | return func 58 | 59 | return _decorator 60 | 61 | 62 | # auto-import all deprecated modules from this folder in the current package 63 | def _auto_import_deprecated_modules() -> None: 64 | current_path = pathlib.Path(__file__).parent 65 | for module_info in pkgutil.iter_modules([str(current_path)]): 66 | name = module_info.name 67 | if name not in {"__init__", "features"}: 68 | importlib.import_module(f"{__package__}.{name}") 69 | 70 | 71 | _auto_import_deprecated_modules() 72 | -------------------------------------------------------------------------------- /src/probly/evaluation/tasks.py: -------------------------------------------------------------------------------- 1 | """Collection of downstream tasks to evaluate the performance of uncertainty pipelines.""" 2 | 3 | from __future__ import annotations 4 | 5 | import numpy as np 6 | 7 | # TODO(mmshlk): remove sklearn dependency - https://github.com/pwhofman/probly/issues/132 8 | import sklearn.metrics as sm 9 | 10 | 11 | def selective_prediction(criterion: np.ndarray, losses: np.ndarray, n_bins: int = 50) -> tuple[float, np.ndarray]: 12 | """Selective prediction downstream task for evaluation. 13 | 14 | Perform selective prediction based on criterion and losses. 15 | The criterion is used the sort the losses. In line with uncertainty 16 | literature the sorting is done in descending order, i.e. 17 | the losses with the largest criterion are rejected first. 18 | 19 | Args: 20 | criterion: numpy.ndarray shape (n_instances,), criterion values 21 | losses: numpy.ndarray shape (n_instances,), loss values 22 | n_bins: int, number of bins 23 | Returns: 24 | auroc: float, area under the loss curve 25 | bin_losses: numpy.ndarray shape (n_bins,), loss per bin 26 | 27 | """ 28 | if n_bins > len(losses): 29 | msg = "The number of bins can not be larger than the number of elements criterion" 30 | raise ValueError(msg) 31 | sort_idxs = np.argsort(criterion)[::-1] 32 | losses_sorted = losses[sort_idxs] 33 | bin_len = len(losses) // n_bins 34 | bin_losses = np.empty(n_bins) 35 | for i in range(n_bins): 36 | bin_losses[i] = np.mean(losses_sorted[(i * bin_len) :]) 37 | 38 | # Also compute the area under the loss curve based on the bin losses. 39 | auroc = sm.auc(np.linspace(0, 1, n_bins), bin_losses) 40 | return auroc, bin_losses 41 | 42 | 43 | def out_of_distribution_detection(in_distribution: np.ndarray, out_distribution: np.ndarray) -> float: 44 | """Perform out-of-distribution detection using prediction functionals from id and ood data. 45 | 46 | This can be epistemic uncertainty, as is common, but also e.g. softmax confidence. 47 | 48 | Args: 49 | in_distribution: in-distribution prediction functionals 50 | out_distribution: out-of-distribution prediction functionals 51 | Returns: 52 | auroc: float, area under the roc curve 53 | 54 | """ 55 | preds = np.concatenate((in_distribution, out_distribution)) 56 | labels = np.concatenate((np.zeros(len(in_distribution)), np.ones(len(out_distribution)))) 57 | auroc = sm.roc_auc_score(labels, preds) 58 | return float(auroc) 59 | -------------------------------------------------------------------------------- /src/probly/plot/credal.py: -------------------------------------------------------------------------------- 1 | """Collection of credal plotting functions.""" 2 | 3 | from __future__ import annotations 4 | 5 | import matplotlib.pyplot as plt 6 | import mpltern # noqa: F401, required for ternary projection, do not remove 7 | import numpy as np 8 | 9 | 10 | def simplex_plot(probs: np.ndarray) -> tuple[plt.Figure, plt.Axes]: 11 | """Plot probability distributions on the simplex. 12 | 13 | Args: 14 | probs: numpy.ndarray of shape (n_instances, n_classes) 15 | 16 | Returns: 17 | fig: matplotlib figure 18 | ax: matplotlib axes 19 | """ 20 | fig = plt.figure() 21 | ax = fig.add_subplot(projection="ternary") 22 | ax.scatter(probs[:, 0], probs[:, 1], probs[:, 2]) 23 | return fig, ax 24 | 25 | 26 | def credal_set_plot(probs: np.ndarray) -> tuple[plt.Figure, plt.Axes]: 27 | """Plot credal sets based on intervals of lower and upper probabilities. 28 | 29 | Args: 30 | probs: numpy.ndarray of shape (n_samples, n_classes) 31 | 32 | Returns: 33 | fig: matplotlib figure 34 | ax: matplotlib axes 35 | """ 36 | fig = plt.figure() 37 | ax = fig.add_subplot(projection="ternary") 38 | 39 | lower_probs = np.min(probs, axis=0) 40 | upper_probs = np.max(probs, axis=0) 41 | lower_idxs = np.argmin(probs, axis=0) 42 | upper_idxs = np.argmax(probs, axis=0) 43 | edge_probs = np.vstack((probs[lower_idxs], probs[upper_idxs])) 44 | 45 | vertices_ = [] 46 | for i, j, k in [(0, 1, 2), (1, 2, 0), (0, 2, 1)]: 47 | for x in [lower_probs[i], upper_probs[i]]: 48 | for y in [lower_probs[j], upper_probs[j]]: 49 | z = 1 - x - y 50 | if lower_probs[k] <= z <= upper_probs[k]: 51 | prob = [0, 0, 0] 52 | prob[i] = x 53 | prob[j] = y 54 | prob[k] = z 55 | vertices_.append(prob) 56 | vertices = np.array(vertices_) 57 | 58 | if len(vertices) > 0: 59 | center = np.mean(vertices, axis=0) 60 | angles = np.arctan2(vertices[:, 1] - center[1], vertices[:, 0] - center[0]) 61 | vertices = vertices[np.argsort(angles)] 62 | ax.scatter(probs[:, 0], probs[:, 1], probs[:, 2]) 63 | vertices_closed = np.vstack([vertices, vertices[0]]) 64 | ax.fill(vertices_closed[:, 0], vertices_closed[:, 1], vertices_closed[:, 2]) 65 | ax.plot(vertices_closed[:, 0], vertices_closed[:, 1], vertices_closed[:, 2]) 66 | ax.scatter(edge_probs[:, 0], edge_probs[:, 1], edge_probs[:, 2]) 67 | else: 68 | msg = "The set of vertices is empty. Please check the probabilities in the credal set." 69 | raise ValueError(msg) 70 | 71 | return fig, ax 72 | -------------------------------------------------------------------------------- /src/probly/representation/sampling/sample.py: -------------------------------------------------------------------------------- 1 | """Classes representing prediction samples.""" 2 | 3 | from __future__ import annotations 4 | 5 | from abc import ABC, abstractmethod 6 | 7 | import numpy as np 8 | 9 | from lazy_dispatch.singledispatch import lazydispatch 10 | from probly.lazy_types import JAX_ARRAY, TORCH_TENSOR 11 | 12 | 13 | class Sample[T](ABC): 14 | """Abstract base class for samples.""" 15 | 16 | @abstractmethod 17 | def __init__(self, samples: list[T]) -> None: 18 | """Initialize the sample.""" 19 | ... 20 | 21 | def mean(self) -> T: 22 | """Compute the mean of the sample.""" 23 | msg = "mean method not implemented." 24 | raise NotImplementedError(msg) 25 | 26 | def std(self, ddof: int = 1) -> T: 27 | """Compute the standard deviation of the sample.""" 28 | msg = "std method not implemented." 29 | raise NotImplementedError(msg) 30 | 31 | def var(self, ddof: int = 1) -> T: 32 | """Compute the variance of the sample.""" 33 | msg = "var method not implemented." 34 | raise NotImplementedError(msg) 35 | 36 | 37 | class ListSample[T](list[T], Sample[T]): 38 | """A sample of predictions stored in a list.""" 39 | 40 | def __add__[S](self, other: list[S]) -> ListSample[T | S]: 41 | """Add two samples together.""" 42 | return type(self)(super().__add__(other)) # type: ignore[operator] 43 | 44 | 45 | create_sample = lazydispatch[type[Sample], Sample](ListSample, dispatch_on=lambda s: s[0]) 46 | 47 | 48 | @create_sample.register(np.number | np.ndarray | float | int) 49 | class ArraySample[T](Sample[T]): 50 | """A sample of predictions stored in a numpy array.""" 51 | 52 | def __init__(self, samples: list[T]) -> None: 53 | """Initialize the array sample.""" 54 | self.array: np.ndarray = np.array(samples).transpose(1, 0, 2) # we use [instances, samples, classes] 55 | 56 | def mean(self) -> T: 57 | """Compute the mean of the sample.""" 58 | return self.array.mean(axis=1) # type: ignore[no-any-return] 59 | 60 | def std(self, ddof: int = 1) -> T: 61 | """Compute the standard deviation of the sample.""" 62 | return self.array.std(axis=1, ddof=ddof) # type: ignore[no-any-return] 63 | 64 | def var(self, ddof: int = 1) -> T: 65 | """Compute the variance of the sample.""" 66 | return self.array.var(axis=1, ddof=ddof) # type: ignore[no-any-return] 67 | 68 | 69 | @create_sample.delayed_register(TORCH_TENSOR) 70 | def _(_: type) -> None: 71 | from . import torch_sample as torch_sample # noqa: PLC0414, PLC0415 72 | 73 | 74 | @create_sample.delayed_register(JAX_ARRAY) 75 | def _(_: type) -> None: 76 | from . import jax_sample as jax_sample # noqa: PLC0414, PLC0415 77 | -------------------------------------------------------------------------------- /tests/probly/tests_deprecations/test_deprecations.py: -------------------------------------------------------------------------------- 1 | """This module contains tests for deprecations.""" 2 | 3 | from __future__ import annotations 4 | 5 | from importlib.metadata import version 6 | import inspect 7 | import re 8 | import warnings 9 | 10 | from packaging.version import parse 11 | import pytest 12 | 13 | from .utils import DEPRECATED_FEATURES, DeprecatedFeature 14 | 15 | 16 | def feature_should_be_removed(feature: DeprecatedFeature) -> None: 17 | """Fails if the feature is still accessible after its scheduled removal date.""" 18 | source_file = inspect.getsourcefile(feature.call) 19 | source_line = inspect.getsourcelines(feature.call)[1] 20 | if parse(version("probly")) >= parse(feature.removed_in): 21 | pytest.fail( 22 | f"{feature.name} was scheduled for removal in {feature.removed_in} " 23 | f"but is still accessible. Remove the deprecated behavior and this test.\n" 24 | f"Feature registered at: {source_file}:{source_line}", 25 | ) 26 | 27 | 28 | def feature_raises_deprecation_warning( 29 | feature: DeprecatedFeature, 30 | request: pytest.FixtureRequest, 31 | ) -> None: 32 | """Fails if the feature does not raise a deprecation warning.""" 33 | expected_msg = f"deprecated in version {feature.deprecated_in} and will be removed in version {feature.removed_in}." 34 | 35 | with warnings.catch_warnings(record=True) as caught: 36 | warnings.simplefilter("always", DeprecationWarning) 37 | feature.call(request) 38 | 39 | deprecation_warnings = [w for w in caught if issubclass(w.category, DeprecationWarning)] 40 | 41 | # check if any deprecation warning matches the expected regex 42 | if not any(re.search(expected_msg, str(w.message)) for w in deprecation_warnings): 43 | formatted = "\n".join(f"{w.category.__name__}: {w.message}" for w in deprecation_warnings) 44 | source_file = inspect.getsourcefile(feature.call) 45 | source_line = inspect.getsourcelines(feature.call)[1] 46 | pytest.fail( 47 | f"No matching DeprecationWarning for feature '{feature.name}'.\n" 48 | f"Expected regex: {expected_msg}\n" 49 | f"Warnings captured: {formatted or 'None'}\n" 50 | f"Feature registered at: {source_file}:{source_line}\n", 51 | ) 52 | 53 | 54 | @pytest.mark.parametrize("feature", DEPRECATED_FEATURES, ids=lambda f: f.name) 55 | def test_deprecated_features(feature: DeprecatedFeature, request: pytest.FixtureRequest) -> None: 56 | """Tests the deprecated initialization with path_to_values.""" 57 | # check if the feature should already be removed 58 | feature_should_be_removed(feature) 59 | 60 | # check if the feature raises a correct deprecation warning 61 | feature_raises_deprecation_warning(feature, request) 62 | -------------------------------------------------------------------------------- /src/probly/transformation/bayesian/common.py: -------------------------------------------------------------------------------- 1 | """Shared Bayesian implementation.""" 2 | 3 | from __future__ import annotations 4 | 5 | from typing import TYPE_CHECKING 6 | 7 | from probly.traverse_nn import nn_compose 8 | from pytraverse import CLONE, GlobalVariable, lazydispatch_traverser, traverse 9 | 10 | if TYPE_CHECKING: 11 | from lazy_dispatch.isinstance import LazyType 12 | from probly.predictor import Predictor 13 | from pytraverse.composition import RegisteredLooseTraverser 14 | 15 | USE_BASE_WEIGHTS = GlobalVariable[bool]("USE_BASE_WEIGHTS", default=False) 16 | POSTERIOR_STD = GlobalVariable[float]("POSTERIOR_STD", default=0.05) 17 | PRIOR_MEAN = GlobalVariable[float]("PRIOR_MEAN", default=0.0) 18 | PRIOR_STD = GlobalVariable[float]("PRIOR_STD", default=1.0) 19 | 20 | bayesian_traverser = lazydispatch_traverser[object](name="bayesian_traverser") 21 | 22 | 23 | def register(cls: LazyType, traverser: RegisteredLooseTraverser) -> None: 24 | """Register a class to be replaced by Bayesian layers.""" 25 | bayesian_traverser.register( 26 | cls=cls, 27 | traverser=traverser, 28 | vars={ 29 | "use_base_weights": USE_BASE_WEIGHTS, 30 | "posterior_std": POSTERIOR_STD, 31 | "prior_mean": PRIOR_MEAN, 32 | "prior_std": PRIOR_STD, 33 | }, 34 | ) 35 | 36 | 37 | def bayesian[T: Predictor]( 38 | base: T, 39 | use_base_weights: bool = USE_BASE_WEIGHTS.default, 40 | posterior_std: float = POSTERIOR_STD.default, 41 | prior_mean: float = PRIOR_MEAN.default, 42 | prior_std: float = PRIOR_STD.default, 43 | ) -> T: 44 | """Create a Bayesian predictor from a base predictor. 45 | 46 | Args: 47 | base: The base model to be used for the Bayesian neural network. 48 | use_base_weights: bool, If True, the weights of the base model are used as the prior mean. 49 | posterior_std: float, The initial posterior standard deviation. 50 | prior_mean: float, The prior mean. 51 | prior_std: float, The prior standard deviation. 52 | 53 | Returns: 54 | The Bayesian predictor. 55 | """ 56 | if posterior_std <= 0: 57 | msg = ( 58 | "The initial posterior standard deviation posterior_std must be greater than 0, " 59 | f"but got {posterior_std} instead." 60 | ) 61 | raise ValueError(msg) 62 | if prior_std <= 0: 63 | msg = f"The prior standard deviation prior_std must be greater than 0, but got {prior_std} instead." 64 | raise ValueError(msg) 65 | return traverse( 66 | base, 67 | nn_compose(bayesian_traverser), 68 | init={ 69 | USE_BASE_WEIGHTS: use_base_weights, 70 | POSTERIOR_STD: posterior_std, 71 | PRIOR_MEAN: prior_mean, 72 | PRIOR_STD: prior_std, 73 | CLONE: True, 74 | }, 75 | ) 76 | -------------------------------------------------------------------------------- /src/probly/utils/probabilities.py: -------------------------------------------------------------------------------- 1 | """General utility functions for all other modules.""" 2 | 3 | from __future__ import annotations 4 | 5 | import numpy as np 6 | 7 | 8 | def differential_entropy_gaussian(sigma2: float | np.ndarray, base: float = 2) -> float | np.ndarray: 9 | """Compute the differential entropy of a Gaussian distribution given the variance. 10 | 11 | https://en.wikipedia.org/wiki/Differential_entropy 12 | Args: 13 | sigma2: float or numpy.ndarray shape (n_instances,), variance of the Gaussian distribution 14 | base: float, base of the logarithm 15 | Returns: 16 | diff_ent: float or numpy.ndarray shape (n_instances,), differential entropy of the Gaussian distribution 17 | """ 18 | return 0.5 * np.log(2 * np.pi * np.e * sigma2) / np.log(base) 19 | 20 | 21 | def kl_divergence_gaussian( 22 | mu1: float | np.ndarray, 23 | sigma21: float | np.ndarray, 24 | mu2: float | np.ndarray, 25 | sigma22: float | np.ndarray, 26 | base: float = 2, 27 | ) -> float | np.ndarray: 28 | """Compute the KL-divergence between two Gaussian distributions. 29 | 30 | https://en.wikipedia.org/wiki/Kullback-Leibler_divergence#Examples 31 | Args: 32 | mu1: float or numpy.ndarray shape (n_instances,), mean of the first Gaussian distribution 33 | sigma21: float or numpy.ndarray shape (n_instances,), variance of the first Gaussian distribution 34 | mu2: float or numpy.ndarray shape (n_instances,), mean of the second Gaussian distribution 35 | sigma22: float or numpy.ndarray shape (n_instances,), variance of the second Gaussian distribution 36 | base: float, base of the logarithm 37 | Returns: 38 | kl_div: float or numpy.ndarray shape (n_instances,), KL-divergence between the two Gaussian distributions 39 | """ 40 | kl_div = 0.5 * np.log(sigma22 / sigma21) / np.log(base) + (sigma21 + (mu1 - mu2) ** 2) / (2 * sigma22) - 0.5 41 | return kl_div 42 | 43 | 44 | def intersection_probability(probs: np.ndarray) -> np.ndarray: 45 | """Compute the intersection probability of a credal set based on intervals of lower and upper probabilities. 46 | 47 | Computes the intersection probability from :cite:`cuzzolinIntersectionProbability2022`. 48 | 49 | Args: 50 | probs: numpy.ndarray, shape (n_instances, n_samples, n_classes), credal sets 51 | Returns: 52 | int_probs: numpy.ndarray, shape (n_instances, n_classes), intersection probability of the credal sets 53 | """ 54 | lower = np.min(probs, axis=1) 55 | upper = np.max(probs, axis=1) 56 | diff = upper - lower 57 | diff_sum = np.sum(diff, axis=1) 58 | lower_sum = np.sum(lower, axis=1) 59 | # Compute alpha for instances for which probability intervals are not empty, otherwise set alpha to 0. 60 | alpha = np.zeros(probs.shape[0]) 61 | nonzero_idxs = diff_sum != 0 62 | alpha[nonzero_idxs] = (1 - lower_sum[nonzero_idxs]) / diff_sum[nonzero_idxs] 63 | int_probs = lower + alpha[:, None] * diff 64 | return int_probs 65 | -------------------------------------------------------------------------------- /tests/lazy_dispatch/test_isinstance.py: -------------------------------------------------------------------------------- 1 | """Tests for the probly.lazy_dispatch.isinstance module.""" 2 | 3 | from __future__ import annotations 4 | 5 | import pytest 6 | 7 | from lazy_dispatch.isinstance import LazyType, lazy_isinstance 8 | 9 | 10 | class TestIsInstance: 11 | """Tests for the lazy_isinstance function.""" 12 | 13 | @pytest.mark.parametrize( 14 | ("obj", "classinfo", "expected"), 15 | [ 16 | (5, int, True), 17 | (5, (int, str), True), 18 | (5, int | str, True), 19 | ("hello", str, True), 20 | ("hello", (int, str), True), 21 | ("hello", int | str, True), 22 | (5.0, int, False), 23 | (5.0, (int, str), False), 24 | (5.0, int | str, False), 25 | ([1, 2, 3], list, True), 26 | ([1, 2, 3], (list, dict), True), 27 | ([1, 2, 3], list | dict, True), 28 | ({"a": 1}, dict, True), 29 | ({"a": 1}, (list, dict), True), 30 | ({"a": 1}, list | (dict | list), True), 31 | (None, type(None), True), 32 | (None, (int, type(None)), True), 33 | (None, int | type(None), True), 34 | ], 35 | ) 36 | def test_eager_builtin_types(self, obj: object, classinfo: LazyType, expected: bool) -> None: 37 | """Test lazy_isinstance with real types.""" 38 | assert lazy_isinstance(obj, classinfo) == expected 39 | 40 | @pytest.mark.parametrize( 41 | ("obj", "classinfo", "expected"), 42 | [ 43 | (5, "int", True), 44 | (5, ("builtins.int", "builtins.str"), True), 45 | ("hello", "str", True), 46 | ("hello", ("int", "str"), True), 47 | (5.0, "builtins.int", False), 48 | (5.0, ("int", "str"), False), 49 | ([1, 2, 3], "builtins.list", True), 50 | ([1, 2, 3], ("list", "dict"), True), 51 | ({"a": 1}, "dict", True), 52 | ({"a": 1}, ("list", "builtins.dict"), True), 53 | ], 54 | ) 55 | def test_lazy_builtin_types(self, obj: object, classinfo: LazyType, expected: bool) -> None: 56 | """Test lazy_isinstance with stringified types.""" 57 | assert lazy_isinstance(obj, classinfo) == expected 58 | 59 | @pytest.mark.parametrize( 60 | ("obj", "classinfo", "expected"), 61 | [ 62 | (5, (str, "builtins.int"), True), 63 | (5, (int, "str"), True), 64 | ("hello", (int, "str"), True), 65 | ("hello", (str, "int"), True), 66 | (5.0, (str, "builtins.int"), False), 67 | (5.0, ("str", int), False), 68 | ([1, 2, 3], (dict, "builtins.list"), True), 69 | ([1, 2, 3], (list, "dict"), True), 70 | ({"a": 1}, (dict, "builtins.dict"), True), 71 | ({"a": 1}, ("dict", list), True), 72 | ], 73 | ) 74 | def test_mixed_builtin_types(self, obj: object, classinfo: LazyType, expected: bool) -> None: 75 | """Test lazy_isinstance with mixtures of real and stringified types.""" 76 | assert lazy_isinstance(obj, classinfo) == expected 77 | -------------------------------------------------------------------------------- /src/lazy_dispatch/isinstance.py: -------------------------------------------------------------------------------- 1 | """A lazy version of isinstance.""" 2 | 3 | from __future__ import annotations 4 | 5 | from types import UnionType 6 | from typing import Any, Union, get_args, get_origin 7 | 8 | type LazyType = type | str | UnionType | tuple[LazyType, ...] 9 | 10 | 11 | def _is_union_type(cls: Any) -> bool: # noqa: ANN401 12 | return get_origin(cls) in {Union, UnionType} 13 | 14 | 15 | def _split_lazy_type(lazy_type: LazyType) -> tuple[set[type], set[str]]: 16 | """Split classinfo into a set of types and a set of strings.""" 17 | if isinstance(lazy_type, str): 18 | return set(), {t.strip() for t in lazy_type.split("|")} 19 | if isinstance(lazy_type, type): 20 | return {lazy_type}, set() 21 | if isinstance(lazy_type, tuple): 22 | types: set[type] = set() 23 | strings: set[str] = set() 24 | for item in lazy_type: 25 | t, s = _split_lazy_type(item) 26 | types.update(t) 27 | strings.update(s) 28 | return types, strings 29 | if _is_union_type(lazy_type): 30 | types = set() 31 | strings = set() 32 | for arg in get_args(lazy_type): 33 | t, s = _split_lazy_type(arg) 34 | types.update(t) 35 | strings.update(s) 36 | return types, strings 37 | 38 | msg = f"Invalid classinfo: {lazy_type!r}" 39 | raise TypeError(msg) 40 | 41 | 42 | def _find_matching_string_type(cls: type, string_types: set[str] | dict[str, Any]) -> str | None: 43 | """Check if the type's name matches any of the strings.""" 44 | module = cls.__module__ 45 | qualname = cls.__qualname__ 46 | 47 | if module == "builtins": 48 | for s in string_types: 49 | if qualname == s: 50 | return s 51 | 52 | for s in string_types: 53 | if f"{module}.{qualname}" == s: 54 | return s 55 | 56 | return None 57 | 58 | 59 | def _find_closest_string_type(cls: type, string_types: set[str] | dict[str, Any]) -> tuple[type, str] | None: 60 | """Check if any type in the MRO matches any of the strings.""" 61 | mro = cls.__mro__ 62 | for super_cls in mro: 63 | matching_type = _find_matching_string_type(super_cls, string_types) 64 | if matching_type is not None: 65 | return super_cls, matching_type 66 | return None 67 | 68 | 69 | def lazy_isinstance(obj: object, class_or_tuple: LazyType, /) -> bool: 70 | """A lazy version of isinstance.""" 71 | types, strings = _split_lazy_type(class_or_tuple) 72 | if len(types) > 0 and isinstance(obj, tuple(types)): 73 | return True 74 | 75 | if len(strings) > 0: 76 | return _find_closest_string_type(type(obj), strings) is not None 77 | 78 | return False 79 | 80 | 81 | def lazy_issubclass(cls: type, class_or_tuple: LazyType, /) -> bool: 82 | """A lazy version of issubclass.""" 83 | types, strings = _split_lazy_type(class_or_tuple) 84 | if len(types) > 0 and issubclass(cls, tuple(types)): 85 | return True 86 | 87 | if len(strings) > 0: 88 | return _find_closest_string_type(cls, strings) is not None 89 | 90 | return False 91 | -------------------------------------------------------------------------------- /tests/probly/datasets/test_torch.py: -------------------------------------------------------------------------------- 1 | """Tests for the data module.""" 2 | 3 | from __future__ import annotations 4 | 5 | from typing import TYPE_CHECKING 6 | 7 | if TYPE_CHECKING: 8 | from collections.abc import Callable 9 | from typing import Any 10 | 11 | from torchvision.datasets import CIFAR10, ImageNet 12 | 13 | import json 14 | from pathlib import Path 15 | from unittest.mock import MagicMock, mock_open, patch 16 | 17 | import numpy as np 18 | import torch 19 | 20 | from probly.datasets.torch import CIFAR10H, Benthic, ImageNetReaL, Plankton, Treeversity1, Treeversity6 21 | 22 | 23 | def patch_cifar10_init(self: CIFAR10, root: str, train: bool, transform: Callable[..., Any], download: bool) -> None: # noqa: ARG001, the init requires these arguments 24 | self.root = root 25 | 26 | 27 | @patch("probly.datasets.torch.np.load") 28 | @patch("probly.datasets.torch.torchvision.datasets.CIFAR10.__init__", new=patch_cifar10_init) 29 | def test_cifar10h(mock_np_load: MagicMock, tmp_path: Path) -> None: 30 | counts = np.ones((5, 10)) 31 | mock_np_load.return_value = counts 32 | dataset = CIFAR10H(root=str(tmp_path)) 33 | dataset.data = [np.zeros((3, 32, 32))] * 2 34 | assert torch.allclose(torch.sum(dataset.targets, dim=1), torch.ones(5)) 35 | 36 | 37 | def patch_imagenet_init(self: ImageNet, root: str, split: str, transform: Callable[..., Any]) -> None: # noqa: ARG001, the init requires these arguments 38 | self.samples = [ 39 | ("some/path/ILSVRC2012_val_00000001.JPEG", 0), 40 | ("some/path/ILSVRC2012_val_00000002.JPEG", 1), 41 | ("some/path/ILSVRC2012_val_00000003.JPEG", 2), 42 | ] 43 | self.classes = [0, 1, 2] 44 | 45 | 46 | @patch("probly.datasets.torch.torchvision.datasets.ImageNet.__init__", new=patch_imagenet_init) 47 | @patch("pathlib.Path.open", new_callable=mock_open, read_data=json.dumps([[], [1], [1, 2]])) 48 | def test_imagenetreal(tmp_path: Path) -> None: 49 | dataset = ImageNetReaL(str(tmp_path)) 50 | for dist in dataset.dists: 51 | assert torch.isclose(torch.sum(dist), torch.tensor(1.0)) 52 | 53 | 54 | @patch("probly.datasets.torch.DCICDataset.__init__", return_value=None) 55 | def test_benthic(mock_dcic_init: MagicMock) -> None: 56 | root = "some/path" 57 | _ = Benthic(root, first_order=False) 58 | expected = Path(root) / "Benthic" 59 | mock_dcic_init.assert_called_once_with(expected, None, first_order=False) 60 | 61 | 62 | @patch("probly.datasets.torch.DCICDataset.__init__", return_value=None) 63 | def test_plankton(mock_dcic_init: MagicMock) -> None: 64 | root = "some/path" 65 | _ = Plankton(root, first_order=False) 66 | expected = Path(root) / "Plankton" 67 | mock_dcic_init.assert_called_once_with(expected, None, first_order=False) 68 | 69 | 70 | @patch("probly.datasets.torch.DCICDataset.__init__", return_value=None) 71 | def test_treeversity1(mock_dcic_init: MagicMock) -> None: 72 | root = "some/path" 73 | _ = Treeversity1(root, first_order=False) 74 | expected = Path(root) / "Treeversity#1" 75 | mock_dcic_init.assert_called_once_with(expected, None, first_order=False) 76 | 77 | 78 | @patch("probly.datasets.torch.DCICDataset.__init__", return_value=None) 79 | def test_treeversity6(mock_dcic_init: MagicMock) -> None: 80 | root = "some/path" 81 | _ = Treeversity6(root, first_order=False) 82 | expected = Path(root) / "Treeversity#6" 83 | mock_dcic_init.assert_called_once_with(expected, None, first_order=False) 84 | -------------------------------------------------------------------------------- /tests/probly/transformation/evidential/classification/test_torch.py: -------------------------------------------------------------------------------- 1 | """Test for torch classification models.""" 2 | 3 | from __future__ import annotations 4 | 5 | from torch import nn 6 | 7 | from probly.transformation.evidential.classification.common import ( 8 | evidential_classification, 9 | ) 10 | from tests.probly.torch_utils import count_layers 11 | 12 | 13 | def test_evidential_classification_appends_softplus_on_linear(torch_model_small_2d_2d: nn.Sequential) -> None: 14 | model = evidential_classification(torch_model_small_2d_2d) 15 | 16 | # count number of nn.Linear layers in original model 17 | count_linear_original = count_layers(torch_model_small_2d_2d, nn.Linear) 18 | # count number of softplus layers in original model 19 | count_softplus_original = count_layers(torch_model_small_2d_2d, nn.Softplus) 20 | # count number of nn.Sequential layers in original model 21 | count_sequential_original = count_layers(torch_model_small_2d_2d, nn.Sequential) 22 | 23 | # count number of nn.Linear layers in modified model 24 | count_linear_modified = count_layers(model, nn.Linear) 25 | # count number of softplus layers in modified model 26 | count_softplus_modified = count_layers(model, nn.Softplus) 27 | # count number of nn.Sequential layers in modified model 28 | count_sequential_modified = count_layers(model, nn.Sequential) 29 | 30 | # check that the model is not modified except for the softplus layer at the end of the new sequence layer 31 | assert model is not None 32 | assert isinstance(model, type(torch_model_small_2d_2d)) 33 | assert count_linear_original == count_linear_modified 34 | assert count_softplus_original == (count_softplus_modified - 1) 35 | assert count_sequential_original == (count_sequential_modified - 1) 36 | 37 | 38 | def test_evidential_classification_appends_softplus_on_conv(torch_conv_linear_model: nn.Sequential) -> None: 39 | model = evidential_classification(torch_conv_linear_model) 40 | 41 | # count number of nn.Linear layers in original model 42 | count_linear_original = count_layers(torch_conv_linear_model, nn.Linear) 43 | # count number of softplus layers in original model 44 | count_softplus_original = count_layers(torch_conv_linear_model, nn.Softplus) 45 | # count number of nn.Sequential layers in original model 46 | count_sequential_original = count_layers(torch_conv_linear_model, nn.Sequential) 47 | # count number of nn.Conv2d layers in original model 48 | count_conv_original = count_layers(torch_conv_linear_model, nn.Conv2d) 49 | 50 | # count number of nn.Linear layers in modified model 51 | count_linear_modified = count_layers(model, nn.Linear) 52 | # count number of softplus layers in modified model 53 | count_softplus_modified = count_layers(model, nn.Softplus) 54 | # count number of nn.Sequential layers in modified model 55 | count_sequential_modified = count_layers(model, nn.Sequential) 56 | # count number of nn.Conv2d layers in modified model 57 | count_conv_modified = count_layers(model, nn.Conv2d) 58 | 59 | # check that the model is not modified except for the softplus layer at the end of the new sequence layer 60 | assert model is not None 61 | assert isinstance(model, type(torch_conv_linear_model)) 62 | assert count_linear_original == count_linear_modified 63 | assert count_softplus_original == (count_softplus_modified - 1) 64 | assert count_sequential_original == (count_sequential_modified - 1) 65 | assert count_conv_original == count_conv_modified 66 | -------------------------------------------------------------------------------- /tests/probly/fixtures/flax_models.py: -------------------------------------------------------------------------------- 1 | """Fixtures for models used in tests.""" 2 | 3 | from __future__ import annotations 4 | 5 | import pytest 6 | 7 | torch = pytest.importorskip("flax") 8 | from flax import nnx # noqa: E402 9 | from flax.typing import Array # noqa: E402 10 | 11 | 12 | @pytest.fixture 13 | def flax_rngs() -> nnx.Rngs: 14 | """Return a random number generator for flax models.""" 15 | return nnx.Rngs(0) 16 | 17 | 18 | @pytest.fixture 19 | def flax_model_small_2d_2d(flax_rngs: nnx.Rngs) -> nnx.Module: 20 | """Return a small linear model with 2 input and 2 output neurons.""" 21 | model = nnx.Sequential( 22 | nnx.Linear(2, 2, rngs=flax_rngs), 23 | nnx.Linear(2, 2, rngs=flax_rngs), 24 | nnx.Linear(2, 2, rngs=flax_rngs), 25 | ) 26 | return model 27 | 28 | 29 | @pytest.fixture 30 | def flax_conv_linear_model(flax_rngs: nnx.Rngs) -> nnx.Module: 31 | """Return a small convolutional model with 3 input channels and 2 output neurons.""" 32 | model = nnx.Sequential( 33 | nnx.Conv(3, 5, (5, 5), rngs=flax_rngs), 34 | nnx.relu, 35 | nnx.flatten, 36 | nnx.Linear(5, 2, rngs=flax_rngs), 37 | ) 38 | return model 39 | 40 | 41 | @pytest.fixture 42 | def flax_regression_model_1d(flax_rngs: nnx.Rngs) -> nnx.Module: 43 | """Return a small regression model with 2 input and 1 output neurons.""" 44 | model = nnx.Sequential( 45 | nnx.Linear(2, 2, rngs=flax_rngs), 46 | nnx.relu, 47 | nnx.Linear(2, 1, rngs=flax_rngs), 48 | ) 49 | return model 50 | 51 | 52 | @pytest.fixture 53 | def flax_regression_model_2d(flax_rngs: nnx.Rngs) -> nnx.Module: 54 | """Return a small regression model with 4 input and 2 output neurons.""" 55 | model = nnx.Sequential( 56 | nnx.Linear(4, 4, rngs=flax_rngs), 57 | nnx.relu, 58 | nnx.Linear(4, 2, rngs=flax_rngs), 59 | ) 60 | return model 61 | 62 | 63 | @pytest.fixture 64 | def flax_custom_model(flax_rngs: nnx.Rngs) -> nnx.Module: 65 | """Return a small custom model.""" 66 | 67 | class TinyModel(nnx.Module): 68 | """A simple neural network model with two linear layers and activation functions. 69 | 70 | Attributes: 71 | linear1 : The first linear layer with input size 100 and output size 200. 72 | activation : The ReLU activation function applied after the first linear layer. 73 | linear2 : The second linear layer with input size 200 and output size 10. 74 | softmax : The softmax function for normalizing the output into probabilities. 75 | """ 76 | 77 | def __init__(self, rngs: nnx.Rngs) -> None: 78 | """Initialize the TinyModel class.""" 79 | super().__init__() 80 | 81 | self.linear1 = nnx.Linear(10, 20, rngs=rngs) 82 | self.activation = nnx.relu 83 | self.linear2 = nnx.Linear(20, 4, rngs=rngs) 84 | self.softmax = nnx.softmax 85 | 86 | def __call__(self, x: Array) -> Array: 87 | """Forward pass of the TinyModel model. 88 | 89 | Parameters: 90 | x: Input tensor to be processed by the forward method. 91 | 92 | Returns: 93 | Output tensor after being processed through the layers and activation functions. 94 | """ 95 | x = self.linear1(x) 96 | x = self.activation(x) 97 | x = self.linear2(x) 98 | x = self.softmax(x) 99 | return x 100 | 101 | return TinyModel(rngs=flax_rngs) 102 | -------------------------------------------------------------------------------- /tests/probly/transformation/evidential/regression/test_torch.py: -------------------------------------------------------------------------------- 1 | """Tests for torch evidential regression models.""" 2 | 3 | from __future__ import annotations 4 | 5 | import pytest 6 | import torch as th 7 | from torch import nn 8 | 9 | from probly.layers.torch import NormalInverseGammaLinear 10 | from probly.transformation.evidential.regression import evidential_regression 11 | from tests.probly.torch_utils import count_layers 12 | 13 | torch = pytest.importorskip("torch") 14 | 15 | 16 | class TestEvidentialRegression: 17 | """Test class for torch evidential regression models.""" 18 | 19 | def test_returns_a_clone(self, torch_model_small_2d_2d: nn.Sequential) -> None: 20 | """Tests if evidential_regression returns a clone of the input model.""" 21 | original_model = torch_model_small_2d_2d 22 | 23 | new_model = evidential_regression(original_model) 24 | 25 | assert new_model is not original_model 26 | 27 | def test_replaces_only_last_linear_layer(self, torch_model_small_2d_2d: nn.Sequential) -> None: 28 | """Tests if evidential_regression *only* replaces the last linear layer. 29 | 30 | This function verifies that the new model has exactly one LESS nn.Linear layer 31 | than the original, and one NormalInverseGammaLinear (NIG) layer. 32 | 33 | Parameters: 34 | torch_model_small_2d_2d: The torch model to be tested. 35 | """ 36 | original_model = torch_model_small_2d_2d 37 | new_model = evidential_regression(original_model) 38 | 39 | # Layer Count Checks 40 | count_linear_original = count_layers(original_model, nn.Linear) 41 | count_linear_modified = count_layers(new_model, nn.Linear) 42 | 43 | # Check the core logic: The new model should have one LESS nn.Linear layer 44 | assert count_linear_modified == (count_linear_original - 1) 45 | 46 | # The modified model should have exactly one NIG layer 47 | count_nig_modified = count_layers(new_model, NormalInverseGammaLinear) 48 | assert count_nig_modified == 1 49 | 50 | def test_last_layer_replacement_and_integrity(self, torch_model_small_2d_2d: nn.Sequential) -> None: 51 | """Tests replacement of the last layer and verifies model integrity. 52 | 53 | Parameters: 54 | torch_model_small_2d_2d: The torch model to be tested. 55 | """ 56 | original_model = torch_model_small_2d_2d 57 | 58 | # 1. Forward Pass & Shape Check 59 | input_data = th.randn(1, 2) 60 | 61 | # The output of the original model (Tensor) is needed to check the shape 62 | original_output = original_model(input_data) 63 | expected_output_shape = original_output.shape 64 | 65 | # Transformation 66 | new_model = evidential_regression(original_model) 67 | 68 | # 1a. Shape Check 69 | new_output = new_model(input_data) 70 | # KORREKTUR: Der NIG-Layer gibt ein Dictionary zurück, wir prüfen die Shape des 'gamma'-Tensors (mean). 71 | assert new_output["gamma"].shape == expected_output_shape 72 | 73 | # 1b. Last Layer Replacement Check 74 | # Check if the last layer in the new model is the NIG layer. 75 | last_layer_index = len(original_model) - 1 76 | last_layer_modified = new_model[last_layer_index] 77 | 78 | assert isinstance(last_layer_modified, NormalInverseGammaLinear) 79 | 80 | # 1c. Rest Model Integrity Check 81 | # The layers BEFORE the last layer must retain their original type. 82 | for i in range(last_layer_index): 83 | original_layer = original_model[i] 84 | modified_layer = new_model[i] 85 | assert type(original_layer) is type(modified_layer) 86 | -------------------------------------------------------------------------------- /tests/probly/train/evidential/test_torch.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import pytest 4 | 5 | from probly.predictor import Predictor 6 | from probly.train.evidential.torch import ( 7 | EvidentialCELoss, 8 | EvidentialKLDivergence, 9 | EvidentialLogLoss, 10 | EvidentialMSELoss, 11 | EvidentialNIGNLLLoss, 12 | EvidentialRegressionRegularization, 13 | ) 14 | from probly.transformation import evidential_regression 15 | from tests.probly.torch_utils import validate_loss 16 | 17 | torch = pytest.importorskip("torch") 18 | 19 | from torch import Tensor, nn # noqa: E402 20 | 21 | 22 | def test_evidential_log_loss( 23 | sample_classification_data: tuple[Tensor, Tensor], 24 | evidential_classification_model: nn.Module, 25 | ) -> None: 26 | inputs, targets = sample_classification_data 27 | outputs = evidential_classification_model(inputs) 28 | criterion = EvidentialLogLoss() 29 | loss = criterion(outputs, targets) 30 | validate_loss(loss) 31 | 32 | 33 | def test_evidential_ce_loss( 34 | sample_classification_data: tuple[Tensor, Tensor], 35 | evidential_classification_model: nn.Module, 36 | ) -> None: 37 | inputs, targets = sample_classification_data 38 | outputs = evidential_classification_model(inputs) 39 | criterion = EvidentialCELoss() 40 | loss = criterion(outputs, targets) 41 | validate_loss(loss) 42 | 43 | 44 | def test_evidential_mse_loss( 45 | sample_classification_data: tuple[Tensor, Tensor], 46 | evidential_classification_model: nn.Module, 47 | ) -> None: 48 | inputs, targets = sample_classification_data 49 | outputs = evidential_classification_model(inputs) 50 | criterion = EvidentialMSELoss() 51 | loss = criterion(outputs, targets) 52 | validate_loss(loss) 53 | 54 | 55 | def test_evidential_kl_divergence( 56 | sample_classification_data: tuple[Tensor, Tensor], 57 | evidential_classification_model: nn.Module, 58 | ) -> None: 59 | inputs, targets = sample_classification_data 60 | outputs = evidential_classification_model(inputs) 61 | criterion = EvidentialKLDivergence() 62 | loss = criterion(outputs, targets) 63 | validate_loss(loss) 64 | 65 | 66 | def test_evidential_nig_nll_loss( 67 | torch_regression_model_1d: nn.Module, 68 | torch_regression_model_2d: nn.Module, 69 | ) -> None: 70 | inputs = torch.randn(2, 2) 71 | targets = torch.randn(2, 1) 72 | model: Predictor = evidential_regression(torch_regression_model_1d) 73 | outputs = model(inputs) 74 | criterion = EvidentialNIGNLLLoss() 75 | loss = criterion(outputs, targets) 76 | validate_loss(loss) 77 | 78 | inputs = torch.randn(2, 4) 79 | targets = torch.randn(2, 2) 80 | model = evidential_regression(torch_regression_model_2d) 81 | outputs = model(inputs) 82 | criterion = EvidentialNIGNLLLoss() 83 | loss = criterion(outputs, targets) 84 | validate_loss(loss) 85 | 86 | 87 | def test_evidential_regression_regularization( 88 | torch_regression_model_1d: nn.Module, 89 | torch_regression_model_2d: nn.Module, 90 | ) -> None: 91 | inputs = torch.randn(2, 2) 92 | targets = torch.randn(2, 1) 93 | model: Predictor = evidential_regression(torch_regression_model_1d) 94 | outputs = model(inputs) 95 | criterion = EvidentialRegressionRegularization() 96 | loss = criterion(outputs, targets) 97 | validate_loss(loss) 98 | 99 | inputs = torch.randn(2, 4) 100 | targets = torch.randn(2, 2) 101 | model = evidential_regression(torch_regression_model_2d) 102 | outputs = model(inputs) 103 | criterion = EvidentialRegressionRegularization() 104 | loss = criterion(outputs, targets) 105 | validate_loss(loss) 106 | -------------------------------------------------------------------------------- /tests/probly/fixtures/torch_models.py: -------------------------------------------------------------------------------- 1 | """Fixtures for models used in tests.""" 2 | 3 | from __future__ import annotations 4 | 5 | import pytest 6 | 7 | from probly.predictor import Predictor 8 | from probly.transformation import evidential_classification 9 | 10 | torch = pytest.importorskip("torch") 11 | from torch import Tensor, nn # noqa: E402 12 | 13 | 14 | @pytest.fixture 15 | def torch_model_small_2d_2d() -> nn.Module: 16 | """Return a small linear model with 2 input and 2 output neurons.""" 17 | model = nn.Sequential( 18 | nn.Linear(2, 2), 19 | nn.Linear(2, 2), 20 | nn.Linear(2, 2), 21 | ) 22 | return model 23 | 24 | 25 | @pytest.fixture 26 | def torch_conv_linear_model() -> nn.Module: 27 | """Return a small convolutional model with 3 input channels and 2 output neurons.""" 28 | model = nn.Sequential( 29 | nn.Conv2d(3, 5, 5), 30 | nn.ReLU(), 31 | nn.Flatten(), 32 | nn.Linear(5, 2), 33 | ) 34 | return model 35 | 36 | 37 | @pytest.fixture 38 | def torch_regression_model_1d() -> nn.Module: 39 | """Return a small regression model with 2 input and 1 output neurons.""" 40 | model = nn.Sequential( 41 | nn.Linear(2, 2), 42 | nn.ReLU(), 43 | nn.Linear(2, 1), 44 | ) 45 | return model 46 | 47 | 48 | @pytest.fixture 49 | def torch_regression_model_2d() -> nn.Module: 50 | """Return a small regression model with 4 input and 2 output neurons.""" 51 | model = nn.Sequential( 52 | nn.Linear(4, 4), 53 | nn.ReLU(), 54 | nn.Linear(4, 2), 55 | ) 56 | return model 57 | 58 | 59 | @pytest.fixture 60 | def torch_dropout_model() -> nn.Module: 61 | """Return a small dropout model with 2 input and 2 output neurons.""" 62 | model = nn.Sequential( 63 | nn.Linear(2, 2), 64 | nn.ReLU(), 65 | nn.Dropout(0.5), 66 | nn.Linear(2, 2), 67 | ) 68 | return model 69 | 70 | 71 | @pytest.fixture 72 | def torch_custom_model() -> nn.Module: 73 | """Return a small custom model.""" 74 | 75 | class TinyModel(nn.Module): 76 | """A simple neural network model with two linear layers and activation functions. 77 | 78 | Attributes: 79 | linear1 : The first linear layer with input size 100 and output size 200. 80 | activation : The ReLU activation function applied after the first linear layer. 81 | linear2 : The second linear layer with input size 200 and output size 10. 82 | softmax : The softmax function for normalizing the output into probabilities. 83 | """ 84 | 85 | def __init__(self) -> None: 86 | """Initialize the TinyModel class.""" 87 | super().__init__() 88 | 89 | self.linear1 = nn.Linear(10, 20) 90 | self.activation = nn.ReLU() 91 | self.linear2 = nn.Linear(20, 4) 92 | self.softmax = nn.Softmax() 93 | 94 | def forward(self, x: Tensor) -> Tensor: 95 | """Forward pass of the TinyModel model. 96 | 97 | Parameters: 98 | x: Input tensor to be processed by the forward method. 99 | 100 | Returns: 101 | Output tensor after being processed through the layers and activation functions. 102 | """ 103 | x = self.linear1(x) 104 | x = self.activation(x) 105 | x = self.linear2(x) 106 | x = self.softmax(x) 107 | return x 108 | 109 | return TinyModel() 110 | 111 | 112 | @pytest.fixture 113 | def evidential_classification_model( 114 | torch_conv_linear_model: nn.Module, 115 | ) -> Predictor: 116 | model: Predictor = evidential_classification(torch_conv_linear_model) 117 | return model 118 | -------------------------------------------------------------------------------- /tests/probly/quantification/test_classification.py: -------------------------------------------------------------------------------- 1 | """Tests for the classification module.""" 2 | 3 | from __future__ import annotations 4 | 5 | from typing import TYPE_CHECKING 6 | 7 | import numpy as np 8 | import pytest 9 | 10 | from tests.probly.general_utils import validate_uncertainty 11 | 12 | if TYPE_CHECKING: 13 | from collections.abc import Callable 14 | 15 | from probly.evaluation.metrics import brier_score, log_loss, spherical_score, zero_one_loss 16 | from probly.quantification.classification import ( 17 | aleatoric_uncertainty_distance, 18 | conditional_entropy, 19 | epistemic_uncertainty_distance, 20 | evidential_uncertainty, 21 | expected_conditional_variance, 22 | expected_divergence, 23 | expected_entropy, 24 | expected_loss, 25 | generalized_hartley, 26 | lower_entropy, 27 | lower_entropy_convex_hull, 28 | mutual_information, 29 | total_entropy, 30 | total_uncertainty_distance, 31 | total_variance, 32 | upper_entropy, 33 | upper_entropy_convex_hull, 34 | variance_conditional_expectation, 35 | ) 36 | 37 | 38 | @pytest.fixture 39 | def sample_second_order_data() -> tuple[np.ndarray, np.ndarray]: 40 | rng = np.random.default_rng() 41 | probs2d = rng.dirichlet(np.ones(2), (10, 5)) 42 | probs3d = rng.dirichlet(np.ones(3), (10, 5)) 43 | return probs2d, probs3d 44 | 45 | 46 | @pytest.fixture 47 | def simplex_uniform() -> np.ndarray: 48 | return np.array([[[1 / 3, 1 / 3, 1 / 3], [1 / 3, 1 / 3, 1 / 3], [1 / 3, 1 / 3, 1 / 3]]]) 49 | 50 | 51 | @pytest.fixture 52 | def simplex_vertices() -> np.ndarray: 53 | return np.array([[[1, 0, 0], [0, 1, 0], [0, 0, 1]]]) 54 | 55 | 56 | @pytest.mark.parametrize( 57 | "uncertainty_fn", 58 | [ 59 | total_entropy, 60 | conditional_entropy, 61 | mutual_information, 62 | total_variance, 63 | expected_conditional_variance, 64 | variance_conditional_expectation, 65 | total_uncertainty_distance, 66 | aleatoric_uncertainty_distance, 67 | epistemic_uncertainty_distance, 68 | upper_entropy, 69 | lower_entropy, 70 | upper_entropy_convex_hull, 71 | lower_entropy_convex_hull, 72 | generalized_hartley, 73 | ], 74 | ) 75 | def test_uncertainty_function( 76 | uncertainty_fn: Callable[[np.ndarray], np.ndarray], 77 | sample_second_order_data: tuple[np.ndarray, np.ndarray], 78 | ) -> None: 79 | probs2d, probs3d = sample_second_order_data 80 | uncertainty = uncertainty_fn(probs2d) 81 | validate_uncertainty(uncertainty) 82 | 83 | uncertainty = uncertainty_fn(probs3d) 84 | validate_uncertainty(uncertainty) 85 | 86 | 87 | @pytest.mark.parametrize("uncertainty_fn", [expected_loss, expected_entropy, expected_divergence]) 88 | def test_loss_uncertainty_function( 89 | uncertainty_fn: Callable[[np.ndarray, Callable[[np.ndarray, np.ndarray | None], np.ndarray]], np.ndarray], 90 | sample_second_order_data: tuple[np.ndarray, np.ndarray], 91 | ) -> None: 92 | probs2d, probs3d = sample_second_order_data 93 | for loss_fn in [log_loss, brier_score, zero_one_loss, spherical_score]: 94 | uncertainty = uncertainty_fn(probs2d, loss_fn) 95 | validate_uncertainty(uncertainty) 96 | 97 | uncertainty = uncertainty_fn(probs3d, loss_fn) 98 | validate_uncertainty(uncertainty) 99 | 100 | 101 | def test_lower_entropy(simplex_vertices: np.ndarray, simplex_uniform: np.ndarray) -> None: 102 | le = lower_entropy(simplex_vertices) 103 | assert le == pytest.approx(0.0) 104 | 105 | le = lower_entropy(simplex_uniform) 106 | assert le == pytest.approx(1.5849625007) 107 | 108 | 109 | def test_upper_entropy(simplex_vertices: np.ndarray, simplex_uniform: np.ndarray) -> None: 110 | ue = upper_entropy(simplex_vertices) 111 | assert ue == pytest.approx(1.5849625007) 112 | 113 | ue = upper_entropy(simplex_uniform) 114 | assert ue == pytest.approx(1.5849625007) 115 | 116 | 117 | def test_evidential_uncertainty() -> None: 118 | rng = np.random.default_rng() 119 | evidence = rng.uniform(0, 100, (10, 3)) 120 | uncertainty = evidential_uncertainty(evidence) 121 | validate_uncertainty(uncertainty) 122 | -------------------------------------------------------------------------------- /src/probly/traverse_nn/torch.py: -------------------------------------------------------------------------------- 1 | """Traversal implementation for PyTorch modules.""" 2 | 3 | from __future__ import annotations 4 | 5 | from collections import OrderedDict 6 | import copy 7 | 8 | from torch.nn import Module, Sequential 9 | 10 | import pytraverse as t 11 | from pytraverse import generic 12 | from pytraverse.decorators import traverser 13 | 14 | from . import common as tnn 15 | 16 | # Torch traversal variables 17 | 18 | ROOT = t.StackVariable[Module | None]("ROOT", "A reference to the outermost module.") 19 | CLONE = t.StackVariable[bool]( 20 | "CLONE", 21 | "Whether to clone torch modules before making changes.", 22 | default=generic.CLONE, 23 | ) 24 | TRAVERSE_REVERSED = t.StackVariable[bool]( 25 | "TRAVERSE_REVERSED", 26 | "Whether to traverse elements in reverse order.", 27 | default=generic.TRAVERSE_REVERSED, 28 | ) 29 | FLATTEN_SEQUENTIAL = t.StackVariable[bool]( 30 | "FLATTEN_SEQUENTIAL", 31 | "Whether to flatten sequential torch modules after making changes.", 32 | default=tnn.FLATTEN_SEQUENTIAL, 33 | ) 34 | 35 | # Torch model cloning 36 | 37 | 38 | @traverser(type=Module) 39 | def _clone_traverser( 40 | obj: Module, 41 | state: t.State[Module], 42 | ) -> t.TraverserResult[Module]: 43 | if state[CLONE]: 44 | obj = copy.deepcopy(obj) 45 | # Do not clone the module twice: 46 | state[CLONE] = False 47 | # After deepcopy, generic datastructures will have been cloned as well: 48 | state[generic.CLONE] = False 49 | 50 | return obj, state 51 | 52 | 53 | # Torch model root tracking 54 | 55 | 56 | @traverser(type=Module) 57 | def _root_traverser( 58 | obj: Module, 59 | state: t.State[Module], 60 | ) -> t.TraverserResult[Module]: 61 | if state[ROOT] is None: 62 | state[ROOT] = obj 63 | state[tnn.LAYER_COUNT] = 0 64 | return obj, state 65 | 66 | 67 | # Torch model layer counting 68 | 69 | 70 | @tnn.layer_count_traverser.register(vars={"count": tnn.LAYER_COUNT}, update_vars=True) 71 | def _module_counter(obj: Module, count: int) -> tuple[Module, dict[str, int]]: 72 | return obj, { 73 | "count": count + 1, # Increment LAYER_COUNT for each traversed module. 74 | } 75 | 76 | 77 | @tnn.layer_count_traverser.register 78 | def _sequential_counter(obj: Sequential) -> Sequential: 79 | return obj # Don't count sequential modules as layers. 80 | 81 | 82 | # Torch model traverser 83 | 84 | _torch_traverser = t.singledispatch_traverser[Module](name="_torch_traverser") 85 | 86 | 87 | @_torch_traverser.register 88 | def _module_traverser( 89 | obj: Module, 90 | state: t.State[Module], 91 | traverse: t.TraverserCallback[Module], 92 | ) -> t.TraverserResult[Module]: 93 | children = obj.named_children() 94 | if state[TRAVERSE_REVERSED]: 95 | children = reversed(list(children)) 96 | for name, module in children: 97 | new_module, state = traverse(module, state, name) 98 | setattr(obj, name, new_module) 99 | 100 | return obj, state 101 | 102 | 103 | @_torch_traverser.register 104 | def _sequential_traverser( 105 | obj: Sequential, 106 | state: t.State[Module], 107 | traverse: t.TraverserCallback[Module], 108 | ) -> t.TraverserResult[Module]: 109 | if not state[FLATTEN_SEQUENTIAL]: 110 | return _module_traverser(obj, state, traverse) 111 | 112 | seq = [] 113 | children = obj.named_children() 114 | traverse_reversed = state[TRAVERSE_REVERSED] 115 | if traverse_reversed: 116 | children = reversed(list(children)) 117 | 118 | for name, module in children: 119 | new_module, state = traverse(module, state, name) 120 | if isinstance(new_module, Sequential): 121 | sub_children = new_module.named_children() 122 | if traverse_reversed: 123 | sub_children = reversed(list(sub_children)) 124 | for sub_name, sub_module in sub_children: 125 | seq.append((f"{name}_{sub_name}", sub_module)) 126 | else: 127 | seq.append((name, new_module)) 128 | 129 | new_obj = Sequential(OrderedDict(reversed(seq) if traverse_reversed else seq)) 130 | 131 | return new_obj, state 132 | 133 | 134 | # Public API combining cloning, root tracking, and module traversing 135 | 136 | torch_traverser: t.Traverser[Module] = t.sequential( 137 | _clone_traverser, 138 | _root_traverser, 139 | _torch_traverser, 140 | name="torch_traverser", 141 | ) 142 | torch_traverser.register = _torch_traverser.register # type: ignore[attr-defined] 143 | 144 | tnn.nn_traverser.register(Module, torch_traverser) 145 | -------------------------------------------------------------------------------- /src/probly/traverse_nn/flax.py: -------------------------------------------------------------------------------- 1 | """Traversal implementation for flax modules.""" 2 | 3 | from __future__ import annotations 4 | 5 | import copy 6 | from typing import TYPE_CHECKING 7 | 8 | from flax.nnx.helpers import Sequential 9 | from flax.nnx.module import Module 10 | 11 | import pytraverse as t 12 | from pytraverse import generic 13 | from pytraverse.decorators import traverser 14 | 15 | from . import common as tnn 16 | 17 | if TYPE_CHECKING: 18 | from collections.abc import Iterable, Iterator 19 | 20 | # Torch traversal variables 21 | 22 | ROOT = t.StackVariable[Module | None]("ROOT", "A reference to the outermost module.") 23 | CLONE = t.StackVariable[bool]( 24 | "CLONE", 25 | "Whether to clone torch modules before making changes.", 26 | default=generic.CLONE, 27 | ) 28 | TRAVERSE_REVERSED = t.StackVariable[bool]( 29 | "TRAVERSE_REVERSED", 30 | "Whether to traverse elements in reverse order.", 31 | default=generic.TRAVERSE_REVERSED, 32 | ) 33 | FLATTEN_SEQUENTIAL = t.StackVariable[bool]( 34 | "FLATTEN_SEQUENTIAL", 35 | "Whether to flatten sequential flax modules after making changes.", 36 | default=tnn.FLATTEN_SEQUENTIAL, 37 | ) 38 | 39 | # Torch model cloning 40 | 41 | 42 | @traverser(type=Module) 43 | def _clone_traverser( 44 | obj: Module, 45 | state: t.State[Module], 46 | ) -> t.TraverserResult[Module]: 47 | if state[CLONE]: 48 | obj = copy.deepcopy(obj) 49 | # Do not clone the module twice: 50 | state[CLONE] = False 51 | # After deepcopy, generic datastructures will have been cloned as well: 52 | state[generic.CLONE] = False 53 | 54 | return obj, state 55 | 56 | 57 | # Flax model root tracking 58 | 59 | 60 | @traverser(type=Module) 61 | def _root_traverser( 62 | obj: Module, 63 | state: t.State[Module], 64 | ) -> t.TraverserResult[Module]: 65 | if state[ROOT] is None: 66 | state[ROOT] = obj 67 | state[tnn.LAYER_COUNT] = 0 68 | return obj, state 69 | 70 | 71 | # Flax model layer counting 72 | 73 | 74 | @tnn.layer_count_traverser.register(vars={"count": tnn.LAYER_COUNT}, update_vars=True) 75 | def _module_counter(obj: Module, count: int) -> tuple[Module, dict[str, int]]: 76 | return obj, { 77 | "count": count + 1, # Increment LAYER_COUNT for each traversed module. 78 | } 79 | 80 | 81 | @tnn.layer_count_traverser.register 82 | def _sequential_counter(obj: Sequential) -> Sequential: 83 | return obj # Don't count sequential modules as layers. 84 | 85 | 86 | # Flax model traverser 87 | 88 | _torch_traverser = t.singledispatch_traverser[Module](name="_torch_traverser") 89 | 90 | 91 | @_torch_traverser.register 92 | def _module_traverser( 93 | obj: Module, 94 | state: t.State[Module], 95 | traverse: t.TraverserCallback[Module], 96 | ) -> t.TraverserResult[Module]: 97 | children: Iterator[tuple[str, Module]] = obj.iter_children() # type: ignore[assignment] 98 | if state[TRAVERSE_REVERSED]: 99 | children = reversed(list(children)) 100 | for name, module in children: 101 | new_module, state = traverse(module, state, name) 102 | setattr(obj, name, new_module) 103 | 104 | return obj, state 105 | 106 | 107 | @_torch_traverser.register 108 | def _sequential_traverser( 109 | obj: Sequential, 110 | state: t.State[Module], 111 | traverse: t.TraverserCallback[Module], 112 | ) -> t.TraverserResult[Module]: 113 | if not state[FLATTEN_SEQUENTIAL]: 114 | return _module_traverser(obj, state, traverse) 115 | 116 | seq: list[Module] = [] 117 | children: Iterable[Module] = obj.layers # type: ignore[assignment] 118 | traverse_reversed = state[TRAVERSE_REVERSED] 119 | if traverse_reversed: 120 | children = reversed(list(children)) 121 | 122 | for module in children: 123 | new_module, state = traverse(module, state) 124 | if isinstance(new_module, Sequential): 125 | sub_children: Iterable[Module] = new_module.layers # type: ignore[assignment] 126 | if traverse_reversed: 127 | sub_children = reversed(list(sub_children)) 128 | seq += sub_children 129 | else: 130 | seq.append(new_module) 131 | 132 | new_obj = Sequential(*(reversed(seq) if traverse_reversed else seq)) 133 | 134 | return new_obj, state 135 | 136 | 137 | # Public API combining cloning, root tracking, and module traversing 138 | 139 | torch_traverser: t.Traverser[Module] = t.sequential( 140 | _clone_traverser, 141 | _root_traverser, 142 | _torch_traverser, 143 | name="torch_traverser", 144 | ) 145 | torch_traverser.register = _torch_traverser.register # type: ignore[attr-defined] 146 | 147 | tnn.nn_traverser.register(Module, torch_traverser) 148 | -------------------------------------------------------------------------------- /tests/probly/transformation/bayesian/test_torch.py: -------------------------------------------------------------------------------- 1 | """Test for torch bayesian models.""" 2 | 3 | from __future__ import annotations 4 | 5 | import pytest 6 | 7 | from probly.layers.torch import BayesConv2d, BayesLinear 8 | from probly.transformation import bayesian 9 | from tests.probly.torch_utils import count_layers 10 | 11 | torch = pytest.importorskip("torch") 12 | 13 | from torch import nn # noqa: E402 14 | 15 | 16 | class TestNetworkArchitectures: 17 | """Test class for different network architectures.""" 18 | 19 | def test_linear_network_replacement( 20 | self, 21 | torch_model_small_2d_2d: nn.Sequential, 22 | ) -> None: 23 | """Tests if a model incorporates a bayesian layer correctly when a linear layer is present. 24 | 25 | This function verifies that: 26 | - A standard linear layer is replaced with a bayesian linear layer. 27 | """ 28 | model = bayesian(torch_model_small_2d_2d) 29 | 30 | # count number of nn.Linear layers in original model 31 | count_linear_original = count_layers(torch_model_small_2d_2d, nn.Linear) 32 | # count number of BayesLinear layers in original model 33 | count_bayesian_original = count_layers(torch_model_small_2d_2d, BayesLinear) 34 | # count number of nn.Sequential layers in original model 35 | count_sequential_original = count_layers(torch_model_small_2d_2d, nn.Sequential) 36 | 37 | # count number of nn.Linear layers in modified model 38 | count_linear_modified = count_layers(model, nn.Linear) 39 | # count number of BayesLinear layers in modified model 40 | count_bayesian_modified = count_layers(model, BayesLinear) 41 | # count number of nn.Sequential layers in modified model 42 | count_sequential_modified = count_layers(model, nn.Sequential) 43 | 44 | # check that the model is not modified except for the bayesian layer 45 | assert model is not None 46 | assert isinstance(model, type(torch_model_small_2d_2d)) 47 | assert count_bayesian_modified == count_bayesian_original + count_linear_original 48 | assert count_linear_modified == 0 49 | assert count_sequential_original == count_sequential_modified 50 | 51 | def test_convolutional_network(self, torch_conv_linear_model: nn.Sequential) -> None: 52 | """Tests the convolutional neural network modification with added bayesian layers. 53 | 54 | This function evaluates whether the given convolutional neural network model 55 | has been correctly modified to include bayesian layers without altering the 56 | number of other components such as linear, sequential, or convolutional layers. 57 | """ 58 | model = bayesian(torch_conv_linear_model) 59 | 60 | # count number of nn.Linear layers in original model 61 | count_linear_original = count_layers(torch_conv_linear_model, nn.Linear) 62 | # count number of BayesConv2d layers in original model 63 | count_bayesian_conv_original = count_layers(torch_conv_linear_model, BayesConv2d) 64 | # count number of nn.Sequential layers in original model 65 | count_sequential_original = count_layers(torch_conv_linear_model, nn.Sequential) 66 | # count number of nn.Conv2d layers in original model 67 | count_conv_original = count_layers(torch_conv_linear_model, nn.Conv2d) 68 | # count number of BayesLinear layers in original model 69 | count_bayesian_linear_original = count_layers(torch_conv_linear_model, BayesLinear) 70 | 71 | # count number of nn.Linear layers in modified model 72 | count_linear_modified = count_layers(model, nn.Linear) 73 | # count number of BayesConv2d layers in modified model 74 | count_bayesian_conv_modified = count_layers(model, BayesConv2d) 75 | # count number of nn.Sequential layers in modified model 76 | count_sequential_modified = count_layers(model, nn.Sequential) 77 | # count number of nn.Conv2d layers in modified model 78 | count_conv_modified = count_layers(model, nn.Conv2d) 79 | # count number of BayesLinear layers in modified model 80 | count_bayesian_linear_modified = count_layers(model, BayesLinear) 81 | 82 | # check that the model is not modified except for the bayesian layer 83 | assert model is not None 84 | assert isinstance(model, type(torch_conv_linear_model)) 85 | assert count_linear_modified == 0 86 | assert count_conv_modified == 0 87 | assert count_bayesian_conv_modified == count_bayesian_conv_original + count_conv_original 88 | assert count_bayesian_linear_modified == count_bayesian_linear_original + count_linear_original 89 | assert count_sequential_original == count_sequential_modified 90 | 91 | def test_custom_network(self, torch_custom_model: nn.Module) -> None: 92 | """Tests the custom model modification with added bayesian layers.""" 93 | model = bayesian(torch_custom_model) 94 | 95 | # check if model type is correct 96 | assert isinstance(model, type(torch_custom_model)) 97 | assert not isinstance(model, nn.Sequential) 98 | -------------------------------------------------------------------------------- /src/probly/representation/sampling/sampler.py: -------------------------------------------------------------------------------- 1 | """Model sampling and sample representer implementation.""" 2 | 3 | from __future__ import annotations 4 | 5 | from collections.abc import Callable, Iterable 6 | from typing import Any, Literal, Unpack 7 | 8 | from probly.lazy_types import FLAX_MODULE, TORCH_MODULE 9 | from probly.predictor import Predictor, predict 10 | from probly.representation.representer import Representer 11 | from probly.traverse_nn import nn_compose 12 | from pytraverse import CLONE, GlobalVariable, lazydispatch_traverser, traverse_with_state 13 | 14 | from .sample import Sample, create_sample 15 | 16 | type SamplingStrategy = Literal["sequential"] 17 | 18 | 19 | sampling_preparation_traverser = lazydispatch_traverser[object](name="sampling_preparation_traverser") 20 | 21 | CLEANUP_FUNCS = GlobalVariable[set[Callable[[], Any]]](name="CLEANUP_FUNCS") 22 | 23 | 24 | @sampling_preparation_traverser.delayed_register(TORCH_MODULE) 25 | def _(_: type) -> None: 26 | from . import torch_sampler as torch_sampler # noqa: PLC0414, PLC0415 27 | 28 | 29 | @sampling_preparation_traverser.delayed_register(FLAX_MODULE) 30 | def _(_: type) -> None: 31 | from . import flax_sampler as flax_sampler # noqa: PLC0414, PLC0415 32 | 33 | 34 | def get_sampling_predictor[In, KwIn, Out]( 35 | predictor: Predictor[In, KwIn, Out], 36 | ) -> tuple[Predictor[In, KwIn, Out], Callable[[], None]]: 37 | """Get the predictor to be used for sampling.""" 38 | predictor, state = traverse_with_state( 39 | predictor, 40 | nn_compose(sampling_preparation_traverser), 41 | init={CLONE: False, CLEANUP_FUNCS: set()}, 42 | ) 43 | cleanup_funcs = state[CLEANUP_FUNCS] 44 | 45 | def cleanup() -> None: 46 | for func in cleanup_funcs: 47 | func() 48 | 49 | return predictor, cleanup 50 | 51 | 52 | def sampler_factory[In, KwIn, Out]( 53 | predictor: Predictor[In, KwIn, Out], 54 | num_samples: int = 1, 55 | strategy: SamplingStrategy = "sequential", 56 | ) -> Predictor[In, KwIn, list[Out]]: 57 | """Sample multiple predictions from the predictor.""" 58 | 59 | def sampler(*args: In, **kwargs: Unpack[KwIn]) -> list[Out]: 60 | sampling_predictor, cleanup = get_sampling_predictor(predictor) 61 | try: 62 | if strategy == "sequential": 63 | return [predict(sampling_predictor, *args, **kwargs) for _ in range(num_samples)] 64 | finally: 65 | cleanup() 66 | 67 | msg = f"Unknown sampling strategy: {strategy}" 68 | raise ValueError(msg) 69 | 70 | return sampler 71 | 72 | 73 | class Sampler[In, KwIn, Out](Representer[In, KwIn, Out]): 74 | """A representation predictor that creates representations from finite samples.""" 75 | 76 | sampling_strategy: SamplingStrategy 77 | sample_factory: Callable[[Iterable[Out]], Sample[Out]] 78 | 79 | def __init__( 80 | self, 81 | predictor: Predictor[In, KwIn, Out], 82 | sampling_strategy: SamplingStrategy = "sequential", 83 | sample_factory: Callable[[Iterable[Out]], Sample[Out]] = create_sample, 84 | ) -> None: 85 | """Initialize the sampler. 86 | 87 | Args: 88 | predictor (Predictor[In, KwIn, Out]): The predictor to be used for sampling. 89 | sampling_strategy (SamplingStrategy, optional): How the samples should be computed. 90 | sample_factory (Callable[[Iterable[Out]], Sample[Out]], optional): Factory to create the sample. 91 | """ 92 | super().__init__(predictor) 93 | self.sampling_strategy = sampling_strategy 94 | self.sample_factory = sample_factory 95 | 96 | def predict(self, *args: In, num_samples: int, **kwargs: Unpack[KwIn]) -> Sample[Out]: 97 | """Sample from the predictor for a given input.""" 98 | return self.sample_factory( 99 | sampler_factory( 100 | self.predictor, 101 | num_samples=num_samples, 102 | strategy=self.sampling_strategy, 103 | )(*args, **kwargs), 104 | ) 105 | 106 | 107 | class EnsembleSampler[In, KwIn, Out](Representer[In, KwIn, Iterable[Out]]): 108 | """A sampler that creates representations from ensemble predictions.""" 109 | 110 | sample_factory: Callable[[Iterable[Out]], Sample[Out]] 111 | 112 | def __init__( 113 | self, 114 | predictor: Predictor[In, KwIn, Iterable[Out]], 115 | sample_factory: Callable[[Iterable[Out]], Sample[Out]] = create_sample, 116 | ) -> None: 117 | """Initialize the ensemble sampler. 118 | 119 | Args: 120 | predictor (Predictor[In, KwIn, Out]): The ensemble predictor. 121 | sample_factory (Callable[[Iterable[Out]], Sample[Out]], optional): Factory to create the sample. 122 | """ 123 | super().__init__(predictor) 124 | self.sample_factory = sample_factory 125 | 126 | def sample(self, *args: In, **kwargs: Unpack[KwIn]) -> Sample[Out]: 127 | """Sample from the ensemble predictor for a given input.""" 128 | return self.sample_factory( 129 | self.predictor(*args, **kwargs), 130 | ) 131 | -------------------------------------------------------------------------------- /src/probly/layers/flax.py: -------------------------------------------------------------------------------- 1 | """flax layer implementation.""" 2 | 3 | from __future__ import annotations 4 | 5 | from flax import nnx 6 | from flax.nnx import rnglib 7 | from flax.nnx.module import first_from 8 | import jax 9 | from jax import lax, random 10 | import jax.numpy as jnp 11 | 12 | 13 | class DropConnectLinear(nnx.Module): 14 | """Custom Linear layer with DropConnect applied to weights during training. 15 | 16 | Attributes: 17 | weight: nnx.Param, weight matrix of shape. 18 | bias: nnx.Param, bias vector of shape. 19 | rate: float, the dropconnect probability. 20 | deterministic: bool, if false the inputs are masked, whereas if true, no mask 21 | is applied and the inputs are returned as is. 22 | rng_collection: str, the rng collection name to use when requesting a rng key. 23 | rngs: nnx.Rngs or nnx.RngStream or None, rng key. 24 | 25 | """ 26 | 27 | def __init__( 28 | self, 29 | base_layer: nnx.Linear, 30 | rate: float = 0.25, 31 | *, 32 | rng_collection: str = "dropconnect", 33 | rngs: rnglib.Rngs | rnglib.RngStream | None = None, 34 | ) -> None: 35 | """Initialize a DropConnectLinear layer based on a given linear base layer. 36 | 37 | Args: 38 | base_layer: nnx.Linear, The original linear layer to be wrapped. 39 | rate: float, the dropconnect probability. 40 | rng_collection: str, rng collection name to use when requesting a rng key. 41 | rngs: nnx.Rngs or nn.RngStream or None, rng key. 42 | """ 43 | self.weight = base_layer.kernel 44 | self.bias = base_layer.bias if base_layer.bias is not None else None 45 | self.rate = rate 46 | self.rng_collection = rng_collection 47 | 48 | if isinstance(rngs, rnglib.Rngs): 49 | self.rngs = rngs[self.rng_collection].fork() 50 | elif isinstance(rngs, rnglib.RngStream): 51 | self.rngs = rngs.fork() 52 | elif rngs is None: 53 | self.rngs = nnx.data(None) 54 | else: 55 | msg = f"rngs must be a RNGS, RngStream or None, but got {type(rngs)}." 56 | raise TypeError(msg) 57 | 58 | def __call__( 59 | self, 60 | inputs: jax.Array, 61 | *, 62 | deterministic: bool = False, 63 | rngs: rnglib.Rngs | rnglib.RngStream | jax.Array | None = None, 64 | ) -> jax.Array: 65 | """Forward pass of the DropConnectLinear layer. 66 | 67 | Args: 68 | inputs: jax.Array, input data that should be randomly masked. 69 | deterministic: bool, if false the inputs are masked, whereas if true, no mask 70 | is applied and the inputs are returned as is. 71 | rngs: nnx.Rngs, nnx.RngStream or jax.Array, optional key used to generate the dropconnect mask. 72 | 73 | Returns: 74 | jax.Array, layer output. 75 | """ 76 | self.deterministic = deterministic 77 | 78 | deterministic = first_from( 79 | deterministic, 80 | self.deterministic, 81 | error_msg="""No `deterministic` argument was provided to DropConnect 82 | as either a __call__ argument or class attribute.""", 83 | ) 84 | 85 | if (self.rate == 0.0) or deterministic: 86 | return inputs 87 | 88 | # Prevent gradient NaNs in 1.0 edge-case. 89 | if self.rate == 1.0: 90 | out = inputs @ jnp.zeros_like(self.weight.value) 91 | return out if self.bias is None else out + self.bias 92 | 93 | rngs = first_from( 94 | rngs, 95 | self.rngs, 96 | error_msg="""`deterministic` is False, but no `rngs` argument was provided 97 | to DropConnect as either a __call__ argument or class atribute.""", 98 | ) 99 | 100 | if isinstance(rngs, rnglib.Rngs): 101 | key = rngs[self.rng_collection]() 102 | elif isinstance(rngs, rnglib.RngStream): 103 | key = rngs() 104 | elif isinstance(rngs, jax.Array): 105 | key = rngs 106 | else: 107 | msg = f"rngs must be Rngs, RngStream or jax.Array, but got {type(rngs)}." 108 | raise TypeError(msg) 109 | 110 | keep_prob = 1.0 - self.rate 111 | mask = random.bernoulli(key, p=keep_prob, shape=self.weight.value.shape) 112 | masked_weight = lax.select(mask, self.weight.value / keep_prob, jnp.zeros_like(self.weight.value)) 113 | 114 | out = inputs @ masked_weight 115 | if self.bias is not None: 116 | out = out + self.bias.value 117 | return out 118 | 119 | def __repr__(self) -> str: 120 | """Return a string representation of the layer including its class name and key attributes.""" 121 | return f"{self.__class__.__name__}({self.extra_repr()})" 122 | 123 | def extra_repr(self) -> str: 124 | """Expose description of in- and out-features of this layer.""" 125 | in_features = self.weight.value.shape[0] 126 | out_features = self.weight.value.shape[1] 127 | return f"in_features={in_features}, out_features={out_features}, bias={self.bias is not None}" 128 | -------------------------------------------------------------------------------- /src/probly/quantification/regression.py: -------------------------------------------------------------------------------- 1 | """Collection of uncertainty quantification measures for regression settings.""" 2 | 3 | from __future__ import annotations 4 | 5 | from typing import TYPE_CHECKING 6 | 7 | import numpy as np 8 | 9 | from probly.utils import differential_entropy_gaussian, kl_divergence_gaussian 10 | 11 | if TYPE_CHECKING: 12 | import numpy.typing as npt 13 | 14 | 15 | def total_variance(probs: npt.NDArray[np.floating]) -> npt.NDArray[np.floating]: 16 | """Compute total variance as the total uncertainty using variance-based measures. 17 | 18 | Assumes that the input is from a distribution over parameters of 19 | a normal distribution. The first element of the parameter vector is the mean 20 | and the second element is the variance. 21 | The total uncertainty is the variance of the mixture of normal distributions. 22 | 23 | Args: 24 | probs: numpy.ndarray, shape (n_instances, n_samples, (mu, sigma^2)) 25 | 26 | Returns: 27 | tv: numpy.ndarray, shape (n_instances,) 28 | 29 | """ 30 | tv = np.mean(probs[:, :, 1], axis=1) + np.var(probs[:, :, 0], axis=1) 31 | return tv 32 | 33 | 34 | def expected_conditional_variance(probs: np.ndarray) -> np.ndarray: 35 | """Compute expected conditional variance as the aleatoric uncertainty using variance-based measures. 36 | 37 | Assume that the input is from a distribution over parameters of 38 | a normal distribution. The first element of the parameter vector is the mean 39 | and the second element is the variance. 40 | The aleatoric uncertainty is the mean of the variance of the samples. 41 | 42 | Args: 43 | probs: numpy.ndarray, shape (n_instances, n_samples, (mu, sigma^2)) 44 | 45 | Returns: 46 | ecv: numpy.ndarray, shape (n_instances,) 47 | 48 | """ 49 | ecv = np.mean(probs[:, :, 1], axis=1) 50 | return ecv 51 | 52 | 53 | def variance_conditional_expectation(probs: np.ndarray) -> np.ndarray: 54 | """Compute variance of conditional expectation as the epistemic uncertainty using variance-based measures. 55 | 56 | Assume that the input is from a distribution over parameters of 57 | a normal distribution. The first element of the parameter vector is the mean 58 | and the second element is the variance. 59 | The epistemic uncertainty is the variance of the mean of the samples. 60 | 61 | Args: 62 | probs: numpy.ndarray, shape (n_instances, n_samples, (mu, sigma^2)) 63 | 64 | Returns: 65 | vce: numpy.ndarray, shape (n_instances,) 66 | 67 | """ 68 | vce = np.var(probs[:, :, 0], axis=1) 69 | return vce 70 | 71 | 72 | def total_differential_entropy(probs: np.ndarray) -> np.ndarray: 73 | """Compute total differential entropy as the epistemic uncertainty using entropy-based measures. 74 | 75 | Assume that the input is from a distribution over parameters of 76 | a normal distribution. The first element of the parameter vector is the mean 77 | and the second element is the variance. 78 | The total uncertainty is the differential entropy of the mixture of normal distributions. 79 | 80 | Args: 81 | probs: numpy.ndarray, shape (n_instances, n_samples, (mu, sigma^2)) 82 | 83 | Returns: 84 | tde: numpy.ndarray, shape (n_instances,) 85 | 86 | """ 87 | sigma2_mean = np.mean(probs[:, :, 1], axis=1) + np.var(probs[:, :, 0], axis=1) 88 | tde = differential_entropy_gaussian(sigma2_mean) 89 | return tde 90 | 91 | 92 | def conditional_differential_entropy(probs: np.ndarray) -> np.ndarray: 93 | """Compute conditional differential entropy as the aleatoric uncertainty using entropy-based measures. 94 | 95 | Assume that the input is from a distribution over parameters of 96 | a normal distribution. The first element of the parameter vector is the mean 97 | and the second element is the variance. 98 | The aleatoric uncertainty is the mean of the differential entropy of the samples. 99 | 100 | Args: 101 | probs: numpy.ndarray, shape (n_instances, n_samples, (mu, sigma^2)) 102 | 103 | Returns: 104 | cde: numpy.ndarray, shape (n_instances,) 105 | 106 | """ 107 | cde = np.mean(differential_entropy_gaussian(probs[:, :, 1]), axis=1) 108 | return cde 109 | 110 | 111 | def mutual_information(probs: np.ndarray) -> np.ndarray: 112 | """Compute mutual information as the epistemic uncertainty using entropy-based measures. 113 | 114 | Assume that the input is from a distribution over parameters of 115 | a normal distribution. The first element of the parameter vector is the mean 116 | and the second element is the variance. 117 | The epistemic uncertainty is the expected KL-divergence of the samples 118 | to the mean distribution. 119 | 120 | Args: 121 | probs: numpy.ndarray, shape (n_instances, n_samples, (mu, sigma^2)) 122 | 123 | Returns: 124 | mi: numpy.ndarray, shape (n_instances,) 125 | 126 | """ 127 | mu_mean = np.mean(probs[:, :, 0], axis=1) 128 | sigma2_mean = np.mean(probs[:, :, 1], axis=1) + np.var(probs[:, :, 0], axis=1) 129 | mu_mean = np.repeat(np.expand_dims(mu_mean, 1), repeats=probs.shape[1], axis=1) 130 | sigma2_mean = np.repeat(np.expand_dims(sigma2_mean, 1), repeats=probs.shape[1], axis=1) 131 | mi = np.mean(kl_divergence_gaussian(probs[:, :, 0], probs[:, :, 1], mu_mean, sigma2_mean), axis=1) 132 | return mi 133 | -------------------------------------------------------------------------------- /tests/probly/evaluation/test_metrics.py: -------------------------------------------------------------------------------- 1 | """Tests for the metrics module.""" 2 | 3 | from __future__ import annotations 4 | 5 | import math 6 | 7 | import numpy as np 8 | import pytest 9 | 10 | from probly.evaluation.metrics import ( 11 | ROUND_DECIMALS, 12 | brier_score, 13 | coverage, 14 | coverage_convex_hull, 15 | covered_efficiency, 16 | efficiency, 17 | expected_calibration_error, 18 | log_loss, 19 | spherical_score, 20 | zero_one_loss, 21 | ) 22 | from tests.probly.general_utils import validate_uncertainty 23 | 24 | 25 | @pytest.fixture 26 | def sample_zero_order_data() -> tuple[np.ndarray, np.ndarray]: 27 | rng = np.random.default_rng() 28 | probs = rng.dirichlet(np.ones(3), 10) 29 | targets = rng.integers(0, 3, 10) 30 | return probs, targets 31 | 32 | 33 | @pytest.fixture 34 | def sample_first_order_data() -> tuple[np.ndarray, np.ndarray]: 35 | rng = np.random.default_rng() 36 | probs = rng.dirichlet(np.ones(3), (10, 5)) 37 | targets = rng.dirichlet(np.ones(3), 10) 38 | return probs, targets 39 | 40 | 41 | @pytest.fixture 42 | def sample_conformal_data() -> tuple[np.ndarray, np.ndarray]: 43 | rng = np.random.default_rng() 44 | probs = rng.choice([True, False], (10, 3)) 45 | targets = rng.integers(0, 3, 10) 46 | return probs, targets 47 | 48 | 49 | def validate_metric(metric: float) -> None: 50 | assert isinstance(metric, float) 51 | assert not math.isnan(metric) 52 | assert not math.isinf(metric) 53 | assert metric >= 0 54 | 55 | 56 | def test_expected_calibration_error() -> None: 57 | probs = np.array([[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0], [1.0, 0.0, 0.0]]) 58 | targets = np.array([0, 1, 2, 0]) 59 | ece = expected_calibration_error(probs, targets) 60 | validate_metric(ece) 61 | assert ece == 0.0 62 | 63 | targets = np.array([1, 2, 0, 1]) 64 | ece = expected_calibration_error(probs, targets) 65 | validate_metric(ece) 66 | assert ece == 1.0 67 | 68 | 69 | def test_coverage( 70 | sample_conformal_data: tuple[np.array, np.array], 71 | sample_first_order_data: tuple[np.array, np.array], 72 | ) -> None: 73 | preds, targets = sample_conformal_data 74 | cov = coverage(preds, targets) 75 | validate_metric(cov) 76 | 77 | probs, targets = sample_first_order_data 78 | cov = coverage(probs, targets) 79 | validate_metric(cov) 80 | 81 | 82 | def test_efficiency( 83 | sample_conformal_data: tuple[np.array, np.array], 84 | sample_first_order_data: tuple[np.array, np.array], 85 | ) -> None: 86 | preds, _ = sample_conformal_data 87 | eff = efficiency(preds) 88 | validate_metric(eff) 89 | 90 | probs, _ = sample_first_order_data 91 | eff = efficiency(probs) 92 | validate_metric(eff) 93 | 94 | 95 | def test_coverage_convex_hull(sample_first_order_data: tuple[np.array, np.array]) -> None: 96 | probs, targets = sample_first_order_data 97 | cov = coverage_convex_hull(probs, targets) 98 | validate_metric(cov) 99 | 100 | 101 | def test_covered_efficiency( 102 | sample_conformal_data: tuple[np.array, np.array], 103 | sample_first_order_data: tuple[np.array, np.array], 104 | ) -> None: 105 | preds, targets = sample_conformal_data 106 | eff = covered_efficiency(preds, targets) 107 | covered = preds[np.arange(preds.shape[0]), targets] 108 | # if none of the instances cover the target, the efficiency should be np.nan 109 | if not np.any(covered): 110 | assert math.isnan(eff) 111 | else: 112 | validate_metric(eff) 113 | 114 | probs, targets = sample_first_order_data 115 | eff = covered_efficiency(probs, targets) 116 | probs_lower = np.round(np.nanmin(probs, axis=1), decimals=ROUND_DECIMALS) 117 | probs_upper = np.round(np.nanmax(probs, axis=1), decimals=ROUND_DECIMALS) 118 | covered = np.all((probs_lower <= targets) & (targets <= probs_upper), axis=1) 119 | # if none of the instances cover the target, the efficiency should be np.nan 120 | if not np.any(covered): 121 | assert math.isnan(eff) 122 | else: 123 | validate_metric(eff) 124 | 125 | 126 | def test_log_loss( 127 | sample_zero_order_data: tuple[np.array, np.array], 128 | sample_first_order_data: tuple[np.array, np.array], 129 | ) -> None: 130 | probs, targets = sample_zero_order_data 131 | loss = log_loss(probs, targets) 132 | validate_metric(loss) 133 | 134 | probs, _ = sample_first_order_data 135 | loss = log_loss(probs, None) 136 | validate_uncertainty(loss) 137 | 138 | 139 | def test_brier_score( 140 | sample_zero_order_data: tuple[np.array, np.array], 141 | sample_first_order_data: tuple[np.array, np.array], 142 | ) -> None: 143 | probs, targets = sample_zero_order_data 144 | loss = brier_score(probs, targets) 145 | validate_metric(loss) 146 | 147 | probs, _ = sample_first_order_data 148 | loss = brier_score(probs, None) 149 | validate_uncertainty(loss) 150 | 151 | 152 | def test_zero_one_loss( 153 | sample_zero_order_data: tuple[np.array, np.array], 154 | sample_first_order_data: tuple[np.array, np.array], 155 | ) -> None: 156 | probs, targets = sample_zero_order_data 157 | loss = zero_one_loss(probs, targets) 158 | validate_metric(loss) 159 | 160 | probs, _ = sample_first_order_data 161 | loss = zero_one_loss(probs, None) 162 | validate_uncertainty(loss) 163 | 164 | 165 | def test_spherical_score( 166 | sample_zero_order_data: tuple[np.array, np.array], 167 | sample_first_order_data: tuple[np.array, np.array], 168 | ) -> None: 169 | probs, targets = sample_zero_order_data 170 | loss = spherical_score(probs, targets) 171 | validate_metric(loss) 172 | 173 | probs, _ = sample_first_order_data 174 | loss = spherical_score(probs, None) 175 | validate_uncertainty(loss) 176 | -------------------------------------------------------------------------------- /src/probly/train/calibration/torch.py: -------------------------------------------------------------------------------- 1 | """Collection of torch calibration training functions.""" 2 | 3 | from __future__ import annotations 4 | 5 | import torch 6 | from torch import nn 7 | import torch.nn.functional as F 8 | 9 | 10 | class ExpectedCalibrationError(nn.Module): 11 | """Expected Calibration Error (ECE) :cite:`guoOnCalibration2017`. 12 | 13 | Attributes: 14 | num_bins: int, number of bins to use for calibration 15 | self.bins: torch.Tensor, the actual bins for calibration 16 | """ 17 | 18 | def __init__(self, num_bins: int = 10) -> None: 19 | """Initializes an instance of the ExpectedCalibrationError class. 20 | 21 | Args: 22 | num_bins: int, number of bins to use for calibration 23 | """ 24 | super().__init__() 25 | self.num_bins = num_bins 26 | self.bins = torch.linspace(0, 1, num_bins + 1) 27 | 28 | def forward(self, inputs: torch.Tensor, targets: torch.Tensor) -> torch.Tensor: 29 | """Forward pass of the expected calibration error. 30 | 31 | Assumes that inputs are probability distributions over classes. 32 | 33 | Args: 34 | inputs: torch.Tensor of size (n_instances, n_classes). 35 | targets: torch.Tensor of size (n_instances,) 36 | 37 | Returns: 38 | loss: torch.Tensor, mean loss value 39 | """ 40 | confs, preds = torch.max(inputs, dim=1) 41 | bin_indices = torch.bucketize(confs, self.bins.to(inputs.device), right=True) - 1 42 | num_instances = inputs.shape[0] 43 | loss: torch.Tensor = torch.tensor(0, dtype=torch.float32, device=inputs.device) 44 | for i in range(self.num_bins): 45 | _bin = torch.where(bin_indices == i)[0] 46 | # check if bin is empty 47 | if _bin.shape[0] == 0: 48 | continue 49 | acc_bin = torch.mean((preds[_bin] == targets[_bin]).float()) 50 | conf_bin = torch.mean(confs[_bin]) 51 | weight = _bin.shape[0] / num_instances 52 | loss += weight * torch.abs(acc_bin - conf_bin) 53 | return loss 54 | 55 | 56 | class LabelRelaxationLoss(nn.Module): 57 | """Label Relaxation Loss from :cite:`lienenFromLabel2021`. 58 | 59 | This loss is used to improve the calibration of a neural network. It works by minimizing 60 | the Kullback-Leibler divergence between the predicted probabilities and the target distribution in the credal set 61 | defined by the alpha parameter. The target distribution is the distribution in the credal set that minimizes the 62 | Kullback-Leibler divergence from the predicted probabilities. If the predicted probability distribution 63 | is in the credal set, the loss is zero. 64 | 65 | Attributes: 66 | alpha: float, the parameter that controls the amount of label relaxation. Increasing alpha, increases the size 67 | of the credal set and thus the amount of label relaxation. 68 | """ 69 | 70 | def __init__(self, alpha: float = 0.1) -> None: 71 | """Initializes an instance of the LabelRelaxationLoss class. 72 | 73 | Args: 74 | alpha: float, the parameter that controls the amount of label relaxation. 75 | Increasing alpha, increases the size of the credal set and thus the amount of label relaxation. 76 | """ 77 | super().__init__() 78 | self.alpha = alpha 79 | 80 | def forward(self, inputs: torch.Tensor, targets: torch.Tensor) -> torch.Tensor: 81 | """Forward pass of the label relaxation loss. 82 | 83 | Args: 84 | inputs: torch.Tensor of size (n_instances, n_classes) 85 | targets: torch.Tensor of size (n_instances,) 86 | 87 | Returns: 88 | loss: torch.Tensor, mean loss value 89 | """ 90 | inputs_probs = F.softmax(inputs, dim=1) 91 | 92 | with torch.no_grad(): 93 | inv_one_hot = 1 - F.one_hot(targets, inputs.shape[1]) 94 | targets_real = self.alpha * inputs_probs / torch.sum(inv_one_hot * inputs_probs, dim=1, keepdim=True) 95 | targets_real[torch.arange(targets.shape[0]), targets] = 1 - self.alpha 96 | 97 | kl_div = torch.sum(F.kl_div(inputs_probs.log(), targets_real, log_target=False, reduction="none"), dim=1) 98 | loss = torch.where(torch.sum(inv_one_hot * inputs_probs, dim=1) <= self.alpha, 0, kl_div) 99 | return loss.mean() 100 | 101 | 102 | class FocalLoss(nn.Module): 103 | """Focal Loss based on :cite:`linFocalLoss2017`. 104 | 105 | Attributes: 106 | alpha: float, control importance of minority class 107 | gamma: float, control loss for hard instances 108 | """ 109 | 110 | def __init__(self, alpha: float = 1, gamma: float = 2) -> None: 111 | """Initializes an instance of the FocalLoss class. 112 | 113 | Args: 114 | alpha: float, control importance of minority class 115 | gamma: float, control loss for hard instances 116 | """ 117 | super().__init__() 118 | self.alpha = alpha 119 | self.gamma = gamma 120 | 121 | def forward(self, inputs: torch.Tensor, targets: torch.Tensor) -> torch.Tensor: 122 | """Forward pass of the focal loss. 123 | 124 | Args: 125 | inputs: torch.Tensor of size (n_instances, n_classes) 126 | targets: torch.Tensor of size (n_instances,) 127 | 128 | Returns: 129 | loss: torch.Tensor, mean loss value 130 | 131 | """ 132 | targets_one_hot = F.one_hot(targets, num_classes=inputs.shape[-1]) 133 | prob = F.softmax(inputs, dim=-1) 134 | p_t = torch.sum(prob * targets_one_hot, dim=-1) 135 | 136 | log_prob = torch.log(prob) 137 | loss = -self.alpha * (1 - p_t) ** self.gamma * torch.sum(log_prob * targets_one_hot, dim=-1) 138 | 139 | return torch.mean(loss) 140 | -------------------------------------------------------------------------------- /notebooks/examples/lazy_dispatch_test.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "id": "af9ba3c1", 7 | "metadata": {}, 8 | "outputs": [], 9 | "source": [ 10 | "%load_ext autoreload\n", 11 | "%autoreload 2" 12 | ] 13 | }, 14 | { 15 | "cell_type": "code", 16 | "execution_count": 2, 17 | "id": "5a090940", 18 | "metadata": {}, 19 | "outputs": [], 20 | "source": [ 21 | "class A:\n", 22 | " class A:\n", 23 | " \"\"\"A test class.\"\"\"\n", 24 | "\n", 25 | "\n", 26 | "class B:\n", 27 | " class B:\n", 28 | " \"\"\"A test class.\"\"\"\n", 29 | "\n", 30 | "\n", 31 | "class C(A):\n", 32 | " class C(A.A):\n", 33 | " \"\"\"A test class.\"\"\"\n", 34 | "\n", 35 | "\n", 36 | "class D(A, B):\n", 37 | " pass\n", 38 | "\n", 39 | "\n", 40 | "class E(C, D):\n", 41 | " pass" 42 | ] 43 | }, 44 | { 45 | "cell_type": "code", 46 | "execution_count": null, 47 | "id": "458266bd", 48 | "metadata": {}, 49 | "outputs": [ 50 | { 51 | "data": { 52 | "text/plain": [ 53 | "torch.Tensor" 54 | ] 55 | }, 56 | "execution_count": 24, 57 | "metadata": {}, 58 | "output_type": "execute_result" 59 | } 60 | ], 61 | "source": [ 62 | "import torch\n", 63 | "\n", 64 | "torch.nn.Module" 65 | ] 66 | }, 67 | { 68 | "cell_type": "code", 69 | "execution_count": 29, 70 | "id": "99f8fb2c", 71 | "metadata": {}, 72 | "outputs": [ 73 | { 74 | "data": { 75 | "text/plain": [ 76 | "True" 77 | ] 78 | }, 79 | "execution_count": 29, 80 | "metadata": {}, 81 | "output_type": "execute_result" 82 | } 83 | ], 84 | "source": [ 85 | "from lazy_dispatch.isinstance import lazy_isinstance\n", 86 | "\n", 87 | "t = torch.tensor([1, 2, 3])\n", 88 | "lazy_isinstance(t, \"torch.Tensor\")" 89 | ] 90 | }, 91 | { 92 | "cell_type": "code", 93 | "execution_count": null, 94 | "id": "149b8309", 95 | "metadata": {}, 96 | "outputs": [], 97 | "source": [ 98 | "from typing import TYPE_CHECKING\n", 99 | "\n", 100 | "import lazy_dispatch as ld\n", 101 | "\n", 102 | "if TYPE_CHECKING:\n", 103 | " import keras # type: ignore # noqa: PGH003\n", 104 | " import torch\n", 105 | "\n", 106 | "\n", 107 | "@ld.lazy_singledispatch\n", 108 | "def f(_: object) -> str:\n", 109 | " return \"unknown\"\n", 110 | "\n", 111 | "\n", 112 | "@f.register\n", 113 | "def _(_: int) -> str:\n", 114 | " return \"int\"\n", 115 | "\n", 116 | "\n", 117 | "@f.register\n", 118 | "def _(_: \"torch.nn.modules.module.Module\") -> str:\n", 119 | " return \"torch module\"\n", 120 | "\n", 121 | "\n", 122 | "@f.register\n", 123 | "def _(_: \"torch.nn.modules.linear.Linear\") -> str:\n", 124 | " return \"torch linear\"\n", 125 | "\n", 126 | "\n", 127 | "@f.register\n", 128 | "def _(_: \"keras.engine.training.Model\") -> str:\n", 129 | " return \"keras model\"\n", 130 | "\n", 131 | "\n", 132 | "@f.register\n", 133 | "def _(_: A) -> str:\n", 134 | " return \"A\"\n", 135 | "\n", 136 | "\n", 137 | "@f.register\n", 138 | "def _(_: \"__main__.B\") -> str: # type: ignore # noqa: F821, PGH003\n", 139 | " return \"B\"\n", 140 | "\n", 141 | "\n", 142 | "@f.register\n", 143 | "def _(_: E) -> str:\n", 144 | " return \"E\"" 145 | ] 146 | }, 147 | { 148 | "cell_type": "code", 149 | "execution_count": 4, 150 | "id": "d6dfb9cd", 151 | "metadata": {}, 152 | "outputs": [ 153 | { 154 | "data": { 155 | "text/plain": [ 156 | "{'_': object, 'return': str}" 157 | ] 158 | }, 159 | "execution_count": 4, 160 | "metadata": {}, 161 | "output_type": "execute_result" 162 | } 163 | ], 164 | "source": [ 165 | "f.__annotations__" 166 | ] 167 | }, 168 | { 169 | "cell_type": "code", 170 | "execution_count": null, 171 | "id": "7af81c92", 172 | "metadata": {}, 173 | "outputs": [ 174 | { 175 | "name": "stdout", 176 | "output_type": "stream", 177 | "text": [ 178 | "42 int\n", 179 | "1.5 unknown\n", 180 | "linear: torch linear\n", 181 | "seq: torch module\n", 182 | "A: A\n", 183 | "B: B\n", 184 | "C: A\n", 185 | "D: A\n", 186 | "E: E\n" 187 | ] 188 | } 189 | ], 190 | "source": [ 191 | "import torch.nn\n", 192 | "\n", 193 | "layer = torch.nn.Linear(10, 10)\n", 194 | "seq = torch.nn.Sequential(layer, layer)\n", 195 | "\n", 196 | "print(42, f(42))\n", 197 | "print(1.5, f(1.5))\n", 198 | "print(\"linear:\", f(layer))\n", 199 | "print(\"seq:\", f(seq))\n", 200 | "print(\"A:\", f(A()))\n", 201 | "print(\"B:\", f(B()))\n", 202 | "print(\"C:\", f(C()))\n", 203 | "print(\"D:\", f(D()))\n", 204 | "print(\"E:\", f(E()))" 205 | ] 206 | }, 207 | { 208 | "cell_type": "code", 209 | "execution_count": null, 210 | "id": "f35cfff2", 211 | "metadata": {}, 212 | "outputs": [ 213 | { 214 | "data": { 215 | "text/plain": [ 216 | "mappingproxy({'keras.engine.training.Model': str>})" 217 | ] 218 | }, 219 | "execution_count": 20, 220 | "metadata": {}, 221 | "output_type": "execute_result" 222 | } 223 | ], 224 | "source": [ 225 | "f.string_registry" 226 | ] 227 | } 228 | ], 229 | "metadata": { 230 | "kernelspec": { 231 | "display_name": "uncertainty-toolkit", 232 | "language": "python", 233 | "name": "python3" 234 | }, 235 | "language_info": { 236 | "codemirror_mode": { 237 | "name": "ipython", 238 | "version": 3 239 | }, 240 | "file_extension": ".py", 241 | "mimetype": "text/x-python", 242 | "name": "python", 243 | "nbconvert_exporter": "python", 244 | "pygments_lexer": "ipython3", 245 | "version": "3.13.1" 246 | } 247 | }, 248 | "nbformat": 4, 249 | "nbformat_minor": 5 250 | } 251 | -------------------------------------------------------------------------------- /docs/source/conf.py: -------------------------------------------------------------------------------- 1 | """Configuration file for the Sphinx documentation builder.""" 2 | 3 | # Configuration file for the Sphinx documentation builder. 4 | # 5 | # For the full list of built-in configuration values, see the documentation: 6 | # https://www.sphinx-doc.org/en/master/usage/configuration.html 7 | from __future__ import annotations 8 | 9 | import importlib 10 | import inspect 11 | import os 12 | import sys 13 | 14 | import probly 15 | 16 | # -- Path setup -------------------------------------------------------------- 17 | sys.path.insert(0, os.path.abspath("../../src")) 18 | sys.path.insert(0, os.path.abspath("../../examples")) 19 | 20 | # -- Project information ----------------------------------------------------- 21 | # https://www.sphinx-doc.org/en/master/usage/configuration.html#project-information 22 | 23 | project = "probly" 24 | copyright = "2025, probly team" # noqa: A001 25 | author = "probly team" 26 | release = probly.__version__ 27 | version = probly.__version__ 28 | 29 | # -- General configuration --------------------------------------------------- 30 | # https://www.sphinx-doc.org/en/master/usage/configuration.html#general-configuration 31 | 32 | extensions = [ 33 | "sphinx.ext.autodoc", # generates API documentation from docstrings 34 | "sphinx.ext.autosummary", # generates .rst files for each module 35 | "sphinx_autodoc_typehints", # optional, nice for type hints in docs 36 | # "sphinx.ext.linkcode", # adds [source] links to code that link to GitHub. Use when repo is public. # noqa: E501, ERA001 37 | "sphinx.ext.viewcode", # adds [source] links to code that link to the source code in the docs. 38 | "sphinx.ext.napoleon", # for Google-style docstrings 39 | "sphinx.ext.duration", # optional, show the duration of the build 40 | "myst_nb", # for jupyter notebook support, also includes myst_parser 41 | "sphinx.ext.intersphinx", # for linking to other projects' docs 42 | "sphinx.ext.mathjax", # for math support 43 | "sphinx.ext.doctest", # for testing code snippets in the docs 44 | "sphinx_copybutton", # adds a copy button to code blocks 45 | "sphinx.ext.autosectionlabel", # for auto-generating section labels, 46 | "sphinxcontrib.bibtex", # for bibliography support 47 | ] 48 | 49 | templates_path = ["_templates"] 50 | exclude_patterns = ["_build", "Thumbs.db", ".DS_Store"] 51 | bibtex_bibfiles = ["references.bib"] 52 | bibtex_default_style = "alpha" 53 | nb_execution_mode = "off" # don't run notebooks when building the docs 54 | 55 | intersphinx_mapping = { 56 | "python3": ("https://docs.python.org/3", None), 57 | "numpy": ("https://numpy.org/doc/stable/", None), 58 | "scipy": ("https://docs.scipy.org/doc/scipy", None), 59 | "matplotlib": ("https://matplotlib.org/stable", None), 60 | "PIL": ("https://pillow.readthedocs.io/en/stable/", None), 61 | "torch": ("https://pytorch.org/docs/stable/", None), 62 | } 63 | 64 | 65 | def linkcode_resolve(domain: str, info: dict[str, str]) -> str | None: 66 | """Resolve the link to the source code in GitHub. 67 | 68 | This function is required by sphinx.ext.linkcode and is used to generate links to the source code on GitHub. 69 | 70 | Args: 71 | domain (str): The domain of the object. 72 | info (dict[str, str]): The information about the object. 73 | 74 | Returns: 75 | str | None: The URL to the source code or None if not found. 76 | """ 77 | if domain != "py" or not info["module"]: 78 | return None 79 | 80 | try: 81 | module = importlib.import_module(info["module"]) 82 | obj = module 83 | for part in info["fullname"].split("."): 84 | obj = getattr(obj, part) 85 | fn = inspect.getsourcefile(obj) 86 | source, lineno = inspect.getsourcelines(obj) 87 | root = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..")) 88 | relpath = os.path.relpath(fn, start=root) 89 | except (ModuleNotFoundError, AttributeError, TypeError, OSError): 90 | return None 91 | 92 | base = "https://github.com/pwhofman/probly" 93 | tag = "v0.2.0-pre-alpha" if version == "0.2.0" else f"v{version}" 94 | 95 | return f"{base}/blob/{tag}/{relpath}#L{lineno}" 96 | 97 | 98 | # -- Options for HTML output ------------------------------------------------- 99 | # https://www.sphinx-doc.org/en/master/usage/configuration.html#options-for-html-output 100 | html_theme = "furo" 101 | html_static_path = ["_static"] 102 | html_css_files = [ 103 | "css/custom.css", 104 | ] 105 | # TODO(pwhofman): add favicon Issue: https://github.com/pwhofman/probly/issues/95 106 | # html_favicon = "_static/logo/" # noqa: ERA001 107 | pygments_dark_style = "monokai" 108 | html_theme_options = { 109 | "sidebar_hide_name": True, 110 | "light_logo": "logo/logo_light.png", 111 | "dark_logo": "logo/logo_dark.png", 112 | } 113 | 114 | html_sidebars = { 115 | "**": [ 116 | "sidebar/scroll-start.html", 117 | "sidebar/brand.html", 118 | "sidebar/search.html", 119 | "sidebar/navigation.html", 120 | "sidebar/ethical-ads.html", 121 | "sidebar/scroll-end.html", 122 | "sidebar/footer.html", # to get the github link in the footer of the sidebar 123 | ], 124 | } 125 | 126 | html_show_sourcelink = False # to remove button next to dark mode showing source in txt format 127 | 128 | # -- Autodoc --------------------------------------------------------------------------------------- 129 | autosummary_generate = True 130 | autodoc_default_options = { 131 | "show-inheritance": True, 132 | "members": True, 133 | "member-order": "groupwise", 134 | "special-members": "__call__", 135 | "undoc-members": True, 136 | "exclude-members": "__weakref__", 137 | } 138 | autoclass_content = "class" 139 | # TODO(pwhofman): maybe set this to True, Issue https://github.com/pwhofman/probly/issues/94 140 | autodoc_inherit_docstrings = False 141 | 142 | autodoc_typehints = "both" # to show type hints in the docstring 143 | 144 | # -- Copy Paste Button ----------------------------------------------------------------------------- 145 | # Ignore >>> when copying code 146 | copybutton_prompt_text = r">>> |\.\.\. " 147 | copybutton_prompt_is_regexp = True 148 | -------------------------------------------------------------------------------- /.github/workflows/ci.yml: -------------------------------------------------------------------------------- 1 | # A GitHub Actions workflow file to run on pull request events. 2 | name: Pull Request Checks 3 | on: 4 | pull_request: 5 | push: 6 | branches: [main] 7 | workflow_dispatch: 8 | 9 | jobs: 10 | # ---------------------------------------------------------------------------------------------- 11 | # Code Quality Checks and Linting 12 | # ---------------------------------------------------------------------------------------------- 13 | check_code_quality: 14 | name: Code Quality Checks 15 | runs-on: ubuntu-latest 16 | steps: 17 | - uses: actions/checkout@v5 18 | - name: Set up Python and uv 19 | uses: astral-sh/setup-uv@v7 20 | with: 21 | python-version: "3.12" 22 | enable-cache: true 23 | - name: Install pre-commit 24 | run: | 25 | uv sync --only-group lint 26 | uv run --no-sync pre-commit install 27 | - name: Run code-quality checks 28 | run: uv run --no-sync pre-commit run --all-files --show-diff-on-failure 29 | # ---------------------------------------------------------------------------------------------- 30 | # Install and Import Check 31 | # ---------------------------------------------------------------------------------------------- 32 | install_and_import: 33 | name: Install and Import Check 34 | runs-on: ubuntu-latest 35 | steps: 36 | - uses: actions/checkout@v5 37 | - name: Set up Python and uv 38 | uses: astral-sh/setup-uv@v7 39 | with: 40 | python-version: "3.12" 41 | enable-cache: true 42 | - name: Create uv virtual environment 43 | run: uv venv 44 | - name: Install probly package 45 | run: uv run --no-sync uv pip install . 46 | - name: Test import 47 | run: uv run --no-sync python -c "import probly; print('✅ probly imported successfully')" 48 | # ---------------------------------------------------------------------------------------------- 49 | # Unit Tests with Matrix 50 | # ---------------------------------------------------------------------------------------------- 51 | run_unit_tests: 52 | name: Run Unit Tests 53 | strategy: 54 | fail-fast: false 55 | matrix: 56 | include: 57 | - os: ubuntu-latest 58 | python-version: "3.12" 59 | dependency-group: "all_ml" 60 | - os: ubuntu-latest 61 | python-version: "3.12" 62 | dependency-group: "torch" 63 | - os: windows-latest 64 | python-version: "3.12" 65 | dependency-group: "all_ml" 66 | - os: macos-latest 67 | python-version: "3.12" 68 | dependency-group: "all_ml" 69 | runs-on: ${{ matrix.os }} 70 | steps: 71 | - uses: actions/checkout@v5 72 | - name: Set up Python and uv 73 | uses: astral-sh/setup-uv@v7 74 | with: 75 | python-version: ${{ matrix.python-version }} 76 | enable-cache: true 77 | - name: Install test dependencies 78 | run: uv sync --no-dev --group test --group ${{ matrix.dependency-group }} 79 | - name: Run unit tests 80 | run: uv run --no-sync pytest "tests/probly" -m "not integration" 81 | # ---------------------------------------------------------------------------------------------- 82 | # Test Documentation Build 83 | # ---------------------------------------------------------------------------------------------- 84 | doc_build: 85 | name: Test Documentation Build 86 | runs-on: ubuntu-latest 87 | steps: 88 | - uses: actions/checkout@v5 89 | - name: Set up Python and uv 90 | uses: astral-sh/setup-uv@v7 91 | with: 92 | python-version: "3.12" 93 | enable-cache: true 94 | - name: Install pandoc 95 | run: | 96 | sudo apt-get update 97 | sudo apt-get install -y pandoc 98 | - name: Install dependencies 99 | run: uv sync 100 | # TODO(mmshlk) Turn this step to fail on warnings in the future. https://github.com/pwhofman/probly/issues/131 101 | - name: Build documentation 102 | run: | 103 | uv run sphinx-build -b html docs/source docs/build/html 104 | 105 | # ---------------------------------------------------------------------------------------------- 106 | # Code Coverage 107 | # ---------------------------------------------------------------------------------------------- 108 | run_coverage: 109 | name: Run Test Coverage 110 | runs-on: ubuntu-latest 111 | needs: 112 | - run_unit_tests 113 | - check_code_quality 114 | - install_and_import 115 | steps: 116 | - uses: actions/checkout@v5 117 | - name: Set up Python and uv 118 | uses: astral-sh/setup-uv@v7 119 | with: 120 | python-version: "3.12" 121 | enable-cache: true 122 | - name: Install test dependencies 123 | run: uv sync --no-dev --group test --group all_ml 124 | - name: Measure coverage 125 | run: uv run --no-sync pytest "tests/probly" --cov=probly --cov-report=xml 126 | - name: Upload coverage to codecov 127 | if: ${{ github.repository == 'pwhofman/probly' }} # Only upload reports for main repo and not forks 128 | uses: codecov/codecov-action@v5 129 | with: 130 | token: ${{ secrets.CODECOV_TOKEN }} 131 | fail_ci_if_error: true 132 | -------------------------------------------------------------------------------- /src/pytraverse/generic.py: -------------------------------------------------------------------------------- 1 | """Traverser for standard Python datatypes (tuples, lists, dicts, sets). 2 | 3 | This module provides a generic traverser that can handle common Python data structures 4 | using single dispatch. It includes configurable behavior for cloning and traversing 5 | dictionary keys. 6 | """ 7 | 8 | from __future__ import annotations 9 | 10 | from typing import TYPE_CHECKING, Any 11 | 12 | from pytraverse.composition import SingledispatchTraverser 13 | from pytraverse.core import ( 14 | StackVariable, 15 | State, 16 | TraverserCallback, 17 | TraverserResult, 18 | ) 19 | 20 | if TYPE_CHECKING: 21 | from collections.abc import Iterable 22 | 23 | CLONE = StackVariable[bool]( 24 | "CLONE", 25 | "Whether to clone datastructures before making changes.", 26 | default=True, 27 | ) 28 | TRAVERSE_KEYS = StackVariable[bool]( 29 | "TRAVERSE_KEYS", 30 | "Whether to traverse the keys of dictionaries.", 31 | default=False, 32 | ) 33 | TRAVERSE_REVERSED = StackVariable[bool]( 34 | "TRAVERSE_REVERSED", 35 | "Whether to traverse elements in reverse order.", 36 | default=False, 37 | ) 38 | 39 | 40 | generic_traverser = SingledispatchTraverser[object](name="generic_traverser") 41 | 42 | 43 | @generic_traverser.register 44 | def _tuple_traverser( 45 | obj: tuple, 46 | state: State[tuple[Any]], 47 | traverse: TraverserCallback[Any], 48 | ) -> TraverserResult[tuple[Any]]: 49 | """Traverse tuple elements and reconstruct the tuple. 50 | 51 | Always creates a new tuple since tuples are immutable. 52 | 53 | Args: 54 | obj: The tuple to traverse. 55 | state: Current traversal state. 56 | traverse: Callback for traversing child elements. 57 | 58 | Returns: 59 | A new tuple with traversed elements and updated state. 60 | """ 61 | new_obj = [] 62 | items: Iterable[tuple[int, Any]] = enumerate(obj) 63 | traverse_reversed = state[TRAVERSE_REVERSED] 64 | if traverse_reversed: 65 | items = reversed(list(items)) 66 | for i, element in items: 67 | new_element, state = traverse(element, state, i) 68 | new_obj.append(new_element) 69 | if traverse_reversed: 70 | return tuple(reversed(new_obj)), state 71 | return tuple(new_obj), state 72 | 73 | 74 | @generic_traverser.register 75 | def _list_traverser( 76 | obj: list, 77 | state: State[list[Any]], 78 | traverse: TraverserCallback[Any], 79 | ) -> TraverserResult[list[Any]]: 80 | """Traverse list elements, optionally cloning the list. 81 | 82 | Behavior depends on the CLONE variable: 83 | - If True: Creates a new list with traversed elements 84 | - If False: Modifies the original list in-place 85 | 86 | Args: 87 | obj: The list to traverse. 88 | state: Current traversal state. 89 | traverse: Callback for traversing child elements. 90 | 91 | Returns: 92 | The modified or new list and updated state. 93 | """ 94 | items: Iterable[tuple[int, Any]] = enumerate(obj) 95 | traverse_reversed = state[TRAVERSE_REVERSED] 96 | if traverse_reversed: 97 | items = reversed(list(items)) 98 | 99 | if state[CLONE]: 100 | new_obj = obj.__class__() 101 | for i, element in items: 102 | new_element, state = traverse(element, state, i) 103 | new_obj.append(new_element) 104 | if traverse_reversed: 105 | new_obj.reverse() 106 | return new_obj, state 107 | 108 | for i, element in items: 109 | new_element, state = traverse(element, state, i) 110 | obj[i] = new_element 111 | return obj, state 112 | 113 | 114 | @generic_traverser.register 115 | def _dict_traverser( 116 | obj: dict, 117 | state: State[dict[Any, Any]], 118 | traverse: TraverserCallback[Any], 119 | ) -> TraverserResult[dict[Any, Any]]: 120 | """Traverse dictionary values and optionally keys. 121 | 122 | Behavior depends on CLONE and TRAVERSE_KEYS variables: 123 | - CLONE=True or TRAVERSE_KEYS=True: Creates a new dictionary 124 | - CLONE=False and TRAVERSE_KEYS=False: Modifies original dictionary 125 | - TRAVERSE_KEYS=True: Also traverses dictionary keys 126 | 127 | Args: 128 | obj: The dictionary to traverse. 129 | state: Current traversal state. 130 | traverse: Callback for traversing child elements. 131 | 132 | Returns: 133 | The modified or new dictionary and updated state. 134 | """ 135 | traverse_keys = state[TRAVERSE_KEYS] 136 | items: Iterable[tuple[Any, Any]] = obj.items() 137 | traverse_reversed = state[TRAVERSE_REVERSED] 138 | if traverse_reversed: 139 | items = reversed(list(items)) 140 | 141 | if state[CLONE] or traverse_keys: 142 | new_obj = obj.__class__() 143 | if traverse_reversed: 144 | additions = [] 145 | for key, value in items: 146 | if traverse_keys: 147 | new_key, state = traverse(key, state) 148 | else: 149 | new_key = key 150 | new_value, state = traverse(value, state, new_key) 151 | if traverse_reversed: 152 | additions.append((new_key, new_value)) 153 | else: 154 | new_obj[new_key] = new_value 155 | if traverse_reversed: 156 | new_obj.update(reversed(additions)) 157 | return new_obj, state 158 | 159 | for key, value in items: 160 | new_value, state = traverse(value, state, key) 161 | obj[key] = new_value 162 | 163 | return obj, state 164 | 165 | 166 | @generic_traverser.register 167 | def _set_traverser( 168 | obj: set, 169 | state: State[set[Any]], 170 | traverse: TraverserCallback[Any], 171 | ) -> TraverserResult[set[Any]]: 172 | """Traverse set elements and reconstruct the set. 173 | 174 | Always creates a new set since sets are unordered and elements 175 | may change during traversal. 176 | 177 | Args: 178 | obj: The set to traverse. 179 | state: Current traversal state. 180 | traverse: Callback for traversing child elements. 181 | 182 | Returns: 183 | A new set with traversed elements and updated state. 184 | """ 185 | new_obj = obj.__class__() 186 | for element in obj: 187 | new_element, state = traverse(element, state) 188 | new_obj.add(new_element) 189 | return new_obj, state 190 | --------------------------------------------------------------------------------